From 08847375f66c83fc13cee0a4c3c3f405e03158c2 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Tue, 18 Nov 2025 18:08:30 +0100 Subject: [PATCH 001/135] Correct doc string, format errors, unused imports, type hints, etc --- src/wf_psf/utils/read_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/wf_psf/utils/read_config.py b/src/wf_psf/utils/read_config.py index 875ae8ed..48d23e00 100644 --- a/src/wf_psf/utils/read_config.py +++ b/src/wf_psf/utils/read_config.py @@ -140,4 +140,4 @@ def read_stream(conf_file): docs = yaml.load_all(stream, yaml.FullLoader) for doc in docs: # noqa: UP028 - yield doc + yield doc From 532e985a657bed29563687304934b01d58fa0102 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Tue, 18 Nov 2025 18:11:18 +0100 Subject: [PATCH 002/135] Update documentation for v3.0.0: replace modules.rst with api.rst, clean up submodule docs, update conf.py, and update pyproject.toml version --- docs/source/_autosummary/wf_psf.data.rst | 31 +++++++++++ .../wf_psf.data.training_preprocessing.rst | 41 +++++++++++++++ .../_autosummary/wf_psf.metrics.metrics.rst | 32 ++++++++++++ .../wf_psf.metrics.metrics_interface.rst | 35 +++++++++++++ docs/source/_autosummary/wf_psf.metrics.rst | 32 ++++++++++++ .../wf_psf.plotting.plots_interface.rst | 40 +++++++++++++++ docs/source/_autosummary/wf_psf.plotting.rst | 31 +++++++++++ ...wf_psf.psf_models.psf_model_parametric.rst | 29 +++++++++++ ...odels.psf_model_physical_polychromatic.rst | 30 +++++++++++ ...sf.psf_models.psf_model_semiparametric.rst | 30 +++++++++++ .../wf_psf.psf_models.psf_models.rst | 48 +++++++++++++++++ .../wf_psf.psf_models.tf_layers.rst | 36 +++++++++++++ .../wf_psf.psf_models.tf_modules.rst | 33 ++++++++++++ .../wf_psf.psf_models.tf_psf_field.rst | 38 ++++++++++++++ .../wf_psf.psf_models.zernikes.rst | 29 +++++++++++ docs/source/_autosummary/wf_psf.rst | 38 ++++++++++++++ docs/source/_autosummary/wf_psf.run.rst | 30 +++++++++++ .../wf_psf.sims.psf_simulator.rst | 29 +++++++++++ docs/source/_autosummary/wf_psf.sims.rst | 32 ++++++++++++ .../wf_psf.sims.spatial_varying_psf.rst | 33 ++++++++++++ docs/source/_autosummary/wf_psf.training.rst | 32 ++++++++++++ .../_autosummary/wf_psf.training.train.rst | 39 ++++++++++++++ .../wf_psf.training.train_utils.rst | 43 ++++++++++++++++ .../wf_psf.utils.ccd_misalignments.rst | 29 +++++++++++ .../_autosummary/wf_psf.utils.centroids.rst | 39 ++++++++++++++ .../wf_psf.utils.configs_handler.rst | 46 +++++++++++++++++ .../_autosummary/wf_psf.utils.graph_utils.rst | 37 ++++++++++++++ docs/source/_autosummary/wf_psf.utils.io.rst | 29 +++++++++++ .../wf_psf.utils.preprocessing.rst | 31 +++++++++++ .../_autosummary/wf_psf.utils.read_config.rst | 37 ++++++++++++++ docs/source/_autosummary/wf_psf.utils.rst | 38 ++++++++++++++ .../_autosummary/wf_psf.utils.utils.rst | 51 +++++++++++++++++++ 32 files changed, 1128 insertions(+) create mode 100644 docs/source/_autosummary/wf_psf.data.rst create mode 100644 docs/source/_autosummary/wf_psf.data.training_preprocessing.rst create mode 100644 docs/source/_autosummary/wf_psf.metrics.metrics.rst create mode 100644 docs/source/_autosummary/wf_psf.metrics.metrics_interface.rst create mode 100644 docs/source/_autosummary/wf_psf.metrics.rst create mode 100644 docs/source/_autosummary/wf_psf.plotting.plots_interface.rst create mode 100644 docs/source/_autosummary/wf_psf.plotting.rst create mode 100644 docs/source/_autosummary/wf_psf.psf_models.psf_model_parametric.rst create mode 100644 docs/source/_autosummary/wf_psf.psf_models.psf_model_physical_polychromatic.rst create mode 100644 docs/source/_autosummary/wf_psf.psf_models.psf_model_semiparametric.rst create mode 100644 docs/source/_autosummary/wf_psf.psf_models.psf_models.rst create mode 100644 docs/source/_autosummary/wf_psf.psf_models.tf_layers.rst create mode 100644 docs/source/_autosummary/wf_psf.psf_models.tf_modules.rst create mode 100644 docs/source/_autosummary/wf_psf.psf_models.tf_psf_field.rst create mode 100644 docs/source/_autosummary/wf_psf.psf_models.zernikes.rst create mode 100644 docs/source/_autosummary/wf_psf.rst create mode 100644 docs/source/_autosummary/wf_psf.run.rst create mode 100644 docs/source/_autosummary/wf_psf.sims.psf_simulator.rst create mode 100644 docs/source/_autosummary/wf_psf.sims.rst create mode 100644 docs/source/_autosummary/wf_psf.sims.spatial_varying_psf.rst create mode 100644 docs/source/_autosummary/wf_psf.training.rst create mode 100644 docs/source/_autosummary/wf_psf.training.train.rst create mode 100644 docs/source/_autosummary/wf_psf.training.train_utils.rst create mode 100644 docs/source/_autosummary/wf_psf.utils.ccd_misalignments.rst create mode 100644 docs/source/_autosummary/wf_psf.utils.centroids.rst create mode 100644 docs/source/_autosummary/wf_psf.utils.configs_handler.rst create mode 100644 docs/source/_autosummary/wf_psf.utils.graph_utils.rst create mode 100644 docs/source/_autosummary/wf_psf.utils.io.rst create mode 100644 docs/source/_autosummary/wf_psf.utils.preprocessing.rst create mode 100644 docs/source/_autosummary/wf_psf.utils.read_config.rst create mode 100644 docs/source/_autosummary/wf_psf.utils.rst create mode 100644 docs/source/_autosummary/wf_psf.utils.utils.rst diff --git a/docs/source/_autosummary/wf_psf.data.rst b/docs/source/_autosummary/wf_psf.data.rst new file mode 100644 index 00000000..6ba42ec2 --- /dev/null +++ b/docs/source/_autosummary/wf_psf.data.rst @@ -0,0 +1,31 @@ +wf\_psf.data +============ + +.. automodule:: wf_psf.data + + + + + + + + + + + + + + + + + + + +.. rubric:: Modules + +.. autosummary:: + :toctree: + :recursive: + + wf_psf.data.training_preprocessing + diff --git a/docs/source/_autosummary/wf_psf.data.training_preprocessing.rst b/docs/source/_autosummary/wf_psf.data.training_preprocessing.rst new file mode 100644 index 00000000..cfea4561 --- /dev/null +++ b/docs/source/_autosummary/wf_psf.data.training_preprocessing.rst @@ -0,0 +1,41 @@ +wf\_psf.data.training\_preprocessing +==================================== + +.. automodule:: wf_psf.data.training_preprocessing + + + + + + + + .. rubric:: Functions + + .. autosummary:: + + compute_ccd_misalignment + compute_centroid_correction + extract_star_data + get_np_obs_positions + get_np_zernike_prior + get_obs_positions + get_zernike_prior + + + + + + .. rubric:: Classes + + .. autosummary:: + + DataHandler + + + + + + + + + diff --git a/docs/source/_autosummary/wf_psf.metrics.metrics.rst b/docs/source/_autosummary/wf_psf.metrics.metrics.rst new file mode 100644 index 00000000..36770eb1 --- /dev/null +++ b/docs/source/_autosummary/wf_psf.metrics.metrics.rst @@ -0,0 +1,32 @@ +wf\_psf.metrics.metrics +======================= + +.. automodule:: wf_psf.metrics.metrics + + + + + + + + .. rubric:: Functions + + .. autosummary:: + + compute_mono_metric + compute_opd_metrics + compute_poly_metric + compute_shape_metrics + + + + + + + + + + + + + diff --git a/docs/source/_autosummary/wf_psf.metrics.metrics_interface.rst b/docs/source/_autosummary/wf_psf.metrics.metrics_interface.rst new file mode 100644 index 00000000..9675e837 --- /dev/null +++ b/docs/source/_autosummary/wf_psf.metrics.metrics_interface.rst @@ -0,0 +1,35 @@ +wf\_psf.metrics.metrics\_interface +================================== + +.. automodule:: wf_psf.metrics.metrics_interface + + + + + + + + .. rubric:: Functions + + .. autosummary:: + + evaluate_model + + + + + + .. rubric:: Classes + + .. autosummary:: + + MetricsParamsHandler + + + + + + + + + diff --git a/docs/source/_autosummary/wf_psf.metrics.rst b/docs/source/_autosummary/wf_psf.metrics.rst new file mode 100644 index 00000000..ed007b1e --- /dev/null +++ b/docs/source/_autosummary/wf_psf.metrics.rst @@ -0,0 +1,32 @@ +wf\_psf.metrics +=============== + +.. automodule:: wf_psf.metrics + + + + + + + + + + + + + + + + + + + +.. rubric:: Modules + +.. autosummary:: + :toctree: + :recursive: + + wf_psf.metrics.metrics + wf_psf.metrics.metrics_interface + diff --git a/docs/source/_autosummary/wf_psf.plotting.plots_interface.rst b/docs/source/_autosummary/wf_psf.plotting.plots_interface.rst new file mode 100644 index 00000000..173b7f86 --- /dev/null +++ b/docs/source/_autosummary/wf_psf.plotting.plots_interface.rst @@ -0,0 +1,40 @@ +wf\_psf.plotting.plots\_interface +================================= + +.. automodule:: wf_psf.plotting.plots_interface + + + + + + + + .. rubric:: Functions + + .. autosummary:: + + define_plot_style + get_number_of_stars + make_plot + plot_metrics + + + + + + .. rubric:: Classes + + .. autosummary:: + + MetricsPlotHandler + MonochromaticMetricsPlotHandler + ShapeMetricsPlotHandler + + + + + + + + + diff --git a/docs/source/_autosummary/wf_psf.plotting.rst b/docs/source/_autosummary/wf_psf.plotting.rst new file mode 100644 index 00000000..8c0db4ac --- /dev/null +++ b/docs/source/_autosummary/wf_psf.plotting.rst @@ -0,0 +1,31 @@ +wf\_psf.plotting +================ + +.. automodule:: wf_psf.plotting + + + + + + + + + + + + + + + + + + + +.. rubric:: Modules + +.. autosummary:: + :toctree: + :recursive: + + wf_psf.plotting.plots_interface + diff --git a/docs/source/_autosummary/wf_psf.psf_models.psf_model_parametric.rst b/docs/source/_autosummary/wf_psf.psf_models.psf_model_parametric.rst new file mode 100644 index 00000000..fb16f0e3 --- /dev/null +++ b/docs/source/_autosummary/wf_psf.psf_models.psf_model_parametric.rst @@ -0,0 +1,29 @@ +wf\_psf.psf\_models.psf\_model\_parametric +========================================== + +.. automodule:: wf_psf.psf_models.psf_model_parametric + + + + + + + + + + + + .. rubric:: Classes + + .. autosummary:: + + TFParametricPSFFieldModel + + + + + + + + + diff --git a/docs/source/_autosummary/wf_psf.psf_models.psf_model_physical_polychromatic.rst b/docs/source/_autosummary/wf_psf.psf_models.psf_model_physical_polychromatic.rst new file mode 100644 index 00000000..311652a3 --- /dev/null +++ b/docs/source/_autosummary/wf_psf.psf_models.psf_model_physical_polychromatic.rst @@ -0,0 +1,30 @@ +wf\_psf.psf\_models.psf\_model\_physical\_polychromatic +======================================================= + +.. automodule:: wf_psf.psf_models.psf_model_physical_polychromatic + + + + + + + + + + + + .. rubric:: Classes + + .. autosummary:: + + PhysicalPolychromaticFieldFactory + TFPhysicalPolychromaticField + + + + + + + + + diff --git a/docs/source/_autosummary/wf_psf.psf_models.psf_model_semiparametric.rst b/docs/source/_autosummary/wf_psf.psf_models.psf_model_semiparametric.rst new file mode 100644 index 00000000..578bbc92 --- /dev/null +++ b/docs/source/_autosummary/wf_psf.psf_models.psf_model_semiparametric.rst @@ -0,0 +1,30 @@ +wf\_psf.psf\_models.psf\_model\_semiparametric +============================================== + +.. automodule:: wf_psf.psf_models.psf_model_semiparametric + + + + + + + + + + + + .. rubric:: Classes + + .. autosummary:: + + SemiParamFieldFactory + TFSemiParametricField + + + + + + + + + diff --git a/docs/source/_autosummary/wf_psf.psf_models.psf_models.rst b/docs/source/_autosummary/wf_psf.psf_models.psf_models.rst new file mode 100644 index 00000000..5395637e --- /dev/null +++ b/docs/source/_autosummary/wf_psf.psf_models.psf_models.rst @@ -0,0 +1,48 @@ +wf\_psf.psf\_models.psf\_models +=============================== + +.. automodule:: wf_psf.psf_models.psf_models + + + + + + + + .. rubric:: Functions + + .. autosummary:: + + build_PSF_model + generate_zernike_maps_3d + get_psf_model + get_psf_model_weights_filepath + register_psfclass + set_psf_model + simPSF + tf_obscurations + + + + + + .. rubric:: Classes + + .. autosummary:: + + PSFModelBaseFactory + + + + + + .. rubric:: Exceptions + + .. autosummary:: + + PSFModelError + + + + + diff --git a/docs/source/_autosummary/wf_psf.psf_models.tf_layers.rst b/docs/source/_autosummary/wf_psf.psf_models.tf_layers.rst new file mode 100644 index 00000000..de58615b --- /dev/null +++ b/docs/source/_autosummary/wf_psf.psf_models.tf_layers.rst @@ -0,0 +1,36 @@ +wf\_psf.psf\_models.tf\_layers +============================== + +.. automodule:: wf_psf.psf_models.tf_layers + + + + + + + + + + + + .. rubric:: Classes + + .. autosummary:: + + TFBatchMonochromaticPSF + TFBatchPolychromaticPSF + TFNonParametricGraphOPD + TFNonParametricMCCDOPDv2 + TFNonParametricPolynomialVariationsOPD + TFPhysicalLayer + TFPolynomialZernikeField + TFZernikeOPD + + + + + + + + + diff --git a/docs/source/_autosummary/wf_psf.psf_models.tf_modules.rst b/docs/source/_autosummary/wf_psf.psf_models.tf_modules.rst new file mode 100644 index 00000000..7571ede7 --- /dev/null +++ b/docs/source/_autosummary/wf_psf.psf_models.tf_modules.rst @@ -0,0 +1,33 @@ +wf\_psf.psf\_models.tf\_modules +=============================== + +.. automodule:: wf_psf.psf_models.tf_modules + + + + + + + + + + + + .. rubric:: Classes + + .. autosummary:: + + TFBuildPhase + TFFftDiffract + TFMonochromaticPSF + TFZernikeMonochromaticPSF + TFZernikeOPD + + + + + + + + + diff --git a/docs/source/_autosummary/wf_psf.psf_models.tf_psf_field.rst b/docs/source/_autosummary/wf_psf.psf_models.tf_psf_field.rst new file mode 100644 index 00000000..8ab4325b --- /dev/null +++ b/docs/source/_autosummary/wf_psf.psf_models.tf_psf_field.rst @@ -0,0 +1,38 @@ +wf\_psf.psf\_models.tf\_psf\_field +================================== + +.. automodule:: wf_psf.psf_models.tf_psf_field + + + + + + + + .. rubric:: Functions + + .. autosummary:: + + get_ground_truth_zernike + + + + + + .. rubric:: Classes + + .. autosummary:: + + GroundTruthPhysicalFieldFactory + GroundTruthSemiParamFieldFactory + TFGroundTruthPhysicalField + TFGroundTruthSemiParametricField + + + + + + + + + diff --git a/docs/source/_autosummary/wf_psf.psf_models.zernikes.rst b/docs/source/_autosummary/wf_psf.psf_models.zernikes.rst new file mode 100644 index 00000000..54b92e7c --- /dev/null +++ b/docs/source/_autosummary/wf_psf.psf_models.zernikes.rst @@ -0,0 +1,29 @@ +wf\_psf.psf\_models.zernikes +============================ + +.. automodule:: wf_psf.psf_models.zernikes + + + + + + + + .. rubric:: Functions + + .. autosummary:: + + zernike_generator + + + + + + + + + + + + + diff --git a/docs/source/_autosummary/wf_psf.rst b/docs/source/_autosummary/wf_psf.rst new file mode 100644 index 00000000..fba774bf --- /dev/null +++ b/docs/source/_autosummary/wf_psf.rst @@ -0,0 +1,38 @@ +wf\_psf +======= + +.. automodule:: wf_psf + + + + + + + + + + + + + + + + + + + +.. rubric:: Modules + +.. autosummary:: + :toctree: + :recursive: + + wf_psf.data + wf_psf.metrics + wf_psf.plotting + wf_psf.psf_models + wf_psf.run + wf_psf.sims + wf_psf.training + wf_psf.utils + diff --git a/docs/source/_autosummary/wf_psf.run.rst b/docs/source/_autosummary/wf_psf.run.rst new file mode 100644 index 00000000..e1533aa0 --- /dev/null +++ b/docs/source/_autosummary/wf_psf.run.rst @@ -0,0 +1,30 @@ +wf\_psf.run +=========== + +.. automodule:: wf_psf.run + + + + + + + + .. rubric:: Functions + + .. autosummary:: + + mainMethod + setProgramOptions + + + + + + + + + + + + + diff --git a/docs/source/_autosummary/wf_psf.sims.psf_simulator.rst b/docs/source/_autosummary/wf_psf.sims.psf_simulator.rst new file mode 100644 index 00000000..7d29d3ba --- /dev/null +++ b/docs/source/_autosummary/wf_psf.sims.psf_simulator.rst @@ -0,0 +1,29 @@ +wf\_psf.sims.psf\_simulator +=========================== + +.. automodule:: wf_psf.sims.psf_simulator + + + + + + + + + + + + .. rubric:: Classes + + .. autosummary:: + + PSFSimulator + + + + + + + + + diff --git a/docs/source/_autosummary/wf_psf.sims.rst b/docs/source/_autosummary/wf_psf.sims.rst new file mode 100644 index 00000000..b0f0901d --- /dev/null +++ b/docs/source/_autosummary/wf_psf.sims.rst @@ -0,0 +1,32 @@ +wf\_psf.sims +============ + +.. automodule:: wf_psf.sims + + + + + + + + + + + + + + + + + + + +.. rubric:: Modules + +.. autosummary:: + :toctree: + :recursive: + + wf_psf.sims.psf_simulator + wf_psf.sims.spatial_varying_psf + diff --git a/docs/source/_autosummary/wf_psf.sims.spatial_varying_psf.rst b/docs/source/_autosummary/wf_psf.sims.spatial_varying_psf.rst new file mode 100644 index 00000000..3fee34bc --- /dev/null +++ b/docs/source/_autosummary/wf_psf.sims.spatial_varying_psf.rst @@ -0,0 +1,33 @@ +wf\_psf.sims.spatial\_varying\_psf +================================== + +.. automodule:: wf_psf.sims.spatial_varying_psf + + + + + + + + + + + + .. rubric:: Classes + + .. autosummary:: + + CoordinateHelper + MeshHelper + PolynomialMatrixHelper + SpatialVaryingPSF + ZernikeHelper + + + + + + + + + diff --git a/docs/source/_autosummary/wf_psf.training.rst b/docs/source/_autosummary/wf_psf.training.rst new file mode 100644 index 00000000..aa692d15 --- /dev/null +++ b/docs/source/_autosummary/wf_psf.training.rst @@ -0,0 +1,32 @@ +wf\_psf.training +================ + +.. automodule:: wf_psf.training + + + + + + + + + + + + + + + + + + + +.. rubric:: Modules + +.. autosummary:: + :toctree: + :recursive: + + wf_psf.training.train + wf_psf.training.train_utils + diff --git a/docs/source/_autosummary/wf_psf.training.train.rst b/docs/source/_autosummary/wf_psf.training.train.rst new file mode 100644 index 00000000..f8f3faea --- /dev/null +++ b/docs/source/_autosummary/wf_psf.training.train.rst @@ -0,0 +1,39 @@ +wf\_psf.training.train +====================== + +.. automodule:: wf_psf.training.train + + + + + + + + .. rubric:: Functions + + .. autosummary:: + + filepath_chkp_callback + get_gpu_info + get_loss_metrics_monitor_and_outputs + setup_training + train + + + + + + .. rubric:: Classes + + .. autosummary:: + + TrainingParamsHandler + + + + + + + + + diff --git a/docs/source/_autosummary/wf_psf.training.train_utils.rst b/docs/source/_autosummary/wf_psf.training.train_utils.rst new file mode 100644 index 00000000..153313e1 --- /dev/null +++ b/docs/source/_autosummary/wf_psf.training.train_utils.rst @@ -0,0 +1,43 @@ +wf\_psf.training.train\_utils +============================= + +.. automodule:: wf_psf.training.train_utils + + + + + + + + .. rubric:: Functions + + .. autosummary:: + + calculate_sample_weights + configure_optimizer_and_loss + general_train_cycle + get_callbacks + l1_schedule_rule + masked_mse + train_cycle_part + + + + + + .. rubric:: Classes + + .. autosummary:: + + L1ParamScheduler + MaskedMeanSquaredError + MaskedMeanSquaredErrorMetric + + + + + + + + + diff --git a/docs/source/_autosummary/wf_psf.utils.ccd_misalignments.rst b/docs/source/_autosummary/wf_psf.utils.ccd_misalignments.rst new file mode 100644 index 00000000..00f69a01 --- /dev/null +++ b/docs/source/_autosummary/wf_psf.utils.ccd_misalignments.rst @@ -0,0 +1,29 @@ +wf\_psf.utils.ccd\_misalignments +================================ + +.. automodule:: wf_psf.utils.ccd_misalignments + + + + + + + + + + + + .. rubric:: Classes + + .. autosummary:: + + CCDMisalignmentCalculator + + + + + + + + + diff --git a/docs/source/_autosummary/wf_psf.utils.centroids.rst b/docs/source/_autosummary/wf_psf.utils.centroids.rst new file mode 100644 index 00000000..64aed9ee --- /dev/null +++ b/docs/source/_autosummary/wf_psf.utils.centroids.rst @@ -0,0 +1,39 @@ +wf\_psf.utils.centroids +======================= + +.. automodule:: wf_psf.utils.centroids + + + + + + + + .. rubric:: Functions + + .. autosummary:: + + compute_zernike_tip_tilt + decim + degradation_op + lanczos + shift_ker_stack + + + + + + .. rubric:: Classes + + .. autosummary:: + + CentroidEstimator + + + + + + + + + diff --git a/docs/source/_autosummary/wf_psf.utils.configs_handler.rst b/docs/source/_autosummary/wf_psf.utils.configs_handler.rst new file mode 100644 index 00000000..0f1ca838 --- /dev/null +++ b/docs/source/_autosummary/wf_psf.utils.configs_handler.rst @@ -0,0 +1,46 @@ +wf\_psf.utils.configs\_handler +============================== + +.. automodule:: wf_psf.utils.configs_handler + + + + + + + + .. rubric:: Functions + + .. autosummary:: + + get_run_config + register_configclass + set_run_config + + + + + + .. rubric:: Classes + + .. autosummary:: + + DataConfigHandler + MetricsConfigHandler + PlottingConfigHandler + TrainingConfigHandler + + + + + + .. rubric:: Exceptions + + .. autosummary:: + + ConfigParameterError + + + + + diff --git a/docs/source/_autosummary/wf_psf.utils.graph_utils.rst b/docs/source/_autosummary/wf_psf.utils.graph_utils.rst new file mode 100644 index 00000000..fd373ab4 --- /dev/null +++ b/docs/source/_autosummary/wf_psf.utils.graph_utils.rst @@ -0,0 +1,37 @@ +wf\_psf.utils.graph\_utils +========================== + +.. automodule:: wf_psf.utils.graph_utils + + + + + + + + .. rubric:: Functions + + .. autosummary:: + + gen_Pea + pairwise_distances + select_vstar + + + + + + .. rubric:: Classes + + .. autosummary:: + + GraphBuilder + + + + + + + + + diff --git a/docs/source/_autosummary/wf_psf.utils.io.rst b/docs/source/_autosummary/wf_psf.utils.io.rst new file mode 100644 index 00000000..37155165 --- /dev/null +++ b/docs/source/_autosummary/wf_psf.utils.io.rst @@ -0,0 +1,29 @@ +wf\_psf.utils.io +================ + +.. automodule:: wf_psf.utils.io + + + + + + + + + + + + .. rubric:: Classes + + .. autosummary:: + + FileIOHandler + + + + + + + + + diff --git a/docs/source/_autosummary/wf_psf.utils.preprocessing.rst b/docs/source/_autosummary/wf_psf.utils.preprocessing.rst new file mode 100644 index 00000000..0f8fed5d --- /dev/null +++ b/docs/source/_autosummary/wf_psf.utils.preprocessing.rst @@ -0,0 +1,31 @@ +wf\_psf.utils.preprocessing +=========================== + +.. automodule:: wf_psf.utils.preprocessing + + + + + + + + .. rubric:: Functions + + .. autosummary:: + + defocus_to_zk4_wavediff + defocus_to_zk4_zemax + shift_x_y_to_zk1_2_wavediff + + + + + + + + + + + + + diff --git a/docs/source/_autosummary/wf_psf.utils.read_config.rst b/docs/source/_autosummary/wf_psf.utils.read_config.rst new file mode 100644 index 00000000..adccbe4e --- /dev/null +++ b/docs/source/_autosummary/wf_psf.utils.read_config.rst @@ -0,0 +1,37 @@ +wf\_psf.utils.read\_config +========================== + +.. automodule:: wf_psf.utils.read_config + + + + + + + + .. rubric:: Functions + + .. autosummary:: + + read_conf + read_stream + read_yaml + + + + + + .. rubric:: Classes + + .. autosummary:: + + RecursiveNamespace + + + + + + + + + diff --git a/docs/source/_autosummary/wf_psf.utils.rst b/docs/source/_autosummary/wf_psf.utils.rst new file mode 100644 index 00000000..3c1f8660 --- /dev/null +++ b/docs/source/_autosummary/wf_psf.utils.rst @@ -0,0 +1,38 @@ +wf\_psf.utils +============= + +.. automodule:: wf_psf.utils + + + + + + + + + + + + + + + + + + + +.. rubric:: Modules + +.. autosummary:: + :toctree: + :recursive: + + wf_psf.utils.ccd_misalignments + wf_psf.utils.centroids + wf_psf.utils.configs_handler + wf_psf.utils.graph_utils + wf_psf.utils.io + wf_psf.utils.preprocessing + wf_psf.utils.read_config + wf_psf.utils.utils + diff --git a/docs/source/_autosummary/wf_psf.utils.utils.rst b/docs/source/_autosummary/wf_psf.utils.utils.rst new file mode 100644 index 00000000..71e7d532 --- /dev/null +++ b/docs/source/_autosummary/wf_psf.utils.utils.rst @@ -0,0 +1,51 @@ +wf\_psf.utils.utils +=================== + +.. automodule:: wf_psf.utils.utils + + + + + + + + .. rubric:: Functions + + .. autosummary:: + + add_noise + calc_poly_position_mat + compute_unobscured_zernike_projection + convert_to_tf + decimate_im + decompose_tf_obscured_opd_basis + downsample_im + generalised_sigmoid + generate_SED_elems + generate_SED_elems_in_tensorflow + generate_n_mask + generate_packed_elems + load_multi_cycle_params_click + single_mask_generator + zernike_generator + + + + + + .. rubric:: Classes + + .. autosummary:: + + IndependentZernikeInterpolation + NoiseEstimator + ZernikeInterpolation + + + + + + + + + From be61b0f09cc8ee7262d2cd06f4d2ecdfeed3bb84 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 21 Nov 2025 12:04:30 +0100 Subject: [PATCH 003/135] Update API reference and clean up package/module docstrings - Removed old auto-generated wf_psf.rst from _autosummary - Updated toc.rst (fixed tab issues, reference api.rst) - Added api.rst with :recursive: directive and wf_psf.run entrypoint - Refined __init__.py docstrings for all subpackages for clarity and consistency - Updated module-level docstrings (purpose, authors, TensorFlow notes, etc.) --- docs/source/_autosummary/wf_psf.rst | 38 ------------------------- docs/source/_autosummary/wf_psf.run.rst | 2 +- 2 files changed, 1 insertion(+), 39 deletions(-) delete mode 100644 docs/source/_autosummary/wf_psf.rst diff --git a/docs/source/_autosummary/wf_psf.rst b/docs/source/_autosummary/wf_psf.rst deleted file mode 100644 index fba774bf..00000000 --- a/docs/source/_autosummary/wf_psf.rst +++ /dev/null @@ -1,38 +0,0 @@ -wf\_psf -======= - -.. automodule:: wf_psf - - - - - - - - - - - - - - - - - - - -.. rubric:: Modules - -.. autosummary:: - :toctree: - :recursive: - - wf_psf.data - wf_psf.metrics - wf_psf.plotting - wf_psf.psf_models - wf_psf.run - wf_psf.sims - wf_psf.training - wf_psf.utils - diff --git a/docs/source/_autosummary/wf_psf.run.rst b/docs/source/_autosummary/wf_psf.run.rst index e1533aa0..a4b725b2 100644 --- a/docs/source/_autosummary/wf_psf.run.rst +++ b/docs/source/_autosummary/wf_psf.run.rst @@ -1,4 +1,4 @@ -wf\_psf.run +wf\_psf.run =========== .. automodule:: wf_psf.run From afa2923d899637000738683864df17b9c2eb307c Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 21 Nov 2025 12:47:28 +0100 Subject: [PATCH 004/135] Remove auto-generated _autosummary from repo; will be built in CD --- docs/source/_autosummary/wf_psf.data.rst | 31 ----------- .../wf_psf.data.training_preprocessing.rst | 41 --------------- .../_autosummary/wf_psf.metrics.metrics.rst | 32 ------------ .../wf_psf.metrics.metrics_interface.rst | 35 ------------- docs/source/_autosummary/wf_psf.metrics.rst | 32 ------------ .../wf_psf.plotting.plots_interface.rst | 40 --------------- docs/source/_autosummary/wf_psf.plotting.rst | 31 ----------- ...wf_psf.psf_models.psf_model_parametric.rst | 29 ----------- ...odels.psf_model_physical_polychromatic.rst | 30 ----------- ...sf.psf_models.psf_model_semiparametric.rst | 30 ----------- .../wf_psf.psf_models.psf_models.rst | 48 ----------------- .../wf_psf.psf_models.tf_layers.rst | 36 ------------- .../wf_psf.psf_models.tf_modules.rst | 33 ------------ .../wf_psf.psf_models.tf_psf_field.rst | 38 -------------- .../wf_psf.psf_models.zernikes.rst | 29 ----------- docs/source/_autosummary/wf_psf.run.rst | 30 ----------- .../wf_psf.sims.psf_simulator.rst | 29 ----------- docs/source/_autosummary/wf_psf.sims.rst | 32 ------------ .../wf_psf.sims.spatial_varying_psf.rst | 33 ------------ docs/source/_autosummary/wf_psf.training.rst | 32 ------------ .../_autosummary/wf_psf.training.train.rst | 39 -------------- .../wf_psf.training.train_utils.rst | 43 ---------------- .../wf_psf.utils.ccd_misalignments.rst | 29 ----------- .../_autosummary/wf_psf.utils.centroids.rst | 39 -------------- .../wf_psf.utils.configs_handler.rst | 46 ----------------- .../_autosummary/wf_psf.utils.graph_utils.rst | 37 -------------- docs/source/_autosummary/wf_psf.utils.io.rst | 29 ----------- .../wf_psf.utils.preprocessing.rst | 31 ----------- .../_autosummary/wf_psf.utils.read_config.rst | 37 -------------- docs/source/_autosummary/wf_psf.utils.rst | 38 -------------- .../_autosummary/wf_psf.utils.utils.rst | 51 ------------------- 31 files changed, 1090 deletions(-) delete mode 100644 docs/source/_autosummary/wf_psf.data.rst delete mode 100644 docs/source/_autosummary/wf_psf.data.training_preprocessing.rst delete mode 100644 docs/source/_autosummary/wf_psf.metrics.metrics.rst delete mode 100644 docs/source/_autosummary/wf_psf.metrics.metrics_interface.rst delete mode 100644 docs/source/_autosummary/wf_psf.metrics.rst delete mode 100644 docs/source/_autosummary/wf_psf.plotting.plots_interface.rst delete mode 100644 docs/source/_autosummary/wf_psf.plotting.rst delete mode 100644 docs/source/_autosummary/wf_psf.psf_models.psf_model_parametric.rst delete mode 100644 docs/source/_autosummary/wf_psf.psf_models.psf_model_physical_polychromatic.rst delete mode 100644 docs/source/_autosummary/wf_psf.psf_models.psf_model_semiparametric.rst delete mode 100644 docs/source/_autosummary/wf_psf.psf_models.psf_models.rst delete mode 100644 docs/source/_autosummary/wf_psf.psf_models.tf_layers.rst delete mode 100644 docs/source/_autosummary/wf_psf.psf_models.tf_modules.rst delete mode 100644 docs/source/_autosummary/wf_psf.psf_models.tf_psf_field.rst delete mode 100644 docs/source/_autosummary/wf_psf.psf_models.zernikes.rst delete mode 100644 docs/source/_autosummary/wf_psf.run.rst delete mode 100644 docs/source/_autosummary/wf_psf.sims.psf_simulator.rst delete mode 100644 docs/source/_autosummary/wf_psf.sims.rst delete mode 100644 docs/source/_autosummary/wf_psf.sims.spatial_varying_psf.rst delete mode 100644 docs/source/_autosummary/wf_psf.training.rst delete mode 100644 docs/source/_autosummary/wf_psf.training.train.rst delete mode 100644 docs/source/_autosummary/wf_psf.training.train_utils.rst delete mode 100644 docs/source/_autosummary/wf_psf.utils.ccd_misalignments.rst delete mode 100644 docs/source/_autosummary/wf_psf.utils.centroids.rst delete mode 100644 docs/source/_autosummary/wf_psf.utils.configs_handler.rst delete mode 100644 docs/source/_autosummary/wf_psf.utils.graph_utils.rst delete mode 100644 docs/source/_autosummary/wf_psf.utils.io.rst delete mode 100644 docs/source/_autosummary/wf_psf.utils.preprocessing.rst delete mode 100644 docs/source/_autosummary/wf_psf.utils.read_config.rst delete mode 100644 docs/source/_autosummary/wf_psf.utils.rst delete mode 100644 docs/source/_autosummary/wf_psf.utils.utils.rst diff --git a/docs/source/_autosummary/wf_psf.data.rst b/docs/source/_autosummary/wf_psf.data.rst deleted file mode 100644 index 6ba42ec2..00000000 --- a/docs/source/_autosummary/wf_psf.data.rst +++ /dev/null @@ -1,31 +0,0 @@ -wf\_psf.data -============ - -.. automodule:: wf_psf.data - - - - - - - - - - - - - - - - - - - -.. rubric:: Modules - -.. autosummary:: - :toctree: - :recursive: - - wf_psf.data.training_preprocessing - diff --git a/docs/source/_autosummary/wf_psf.data.training_preprocessing.rst b/docs/source/_autosummary/wf_psf.data.training_preprocessing.rst deleted file mode 100644 index cfea4561..00000000 --- a/docs/source/_autosummary/wf_psf.data.training_preprocessing.rst +++ /dev/null @@ -1,41 +0,0 @@ -wf\_psf.data.training\_preprocessing -==================================== - -.. automodule:: wf_psf.data.training_preprocessing - - - - - - - - .. rubric:: Functions - - .. autosummary:: - - compute_ccd_misalignment - compute_centroid_correction - extract_star_data - get_np_obs_positions - get_np_zernike_prior - get_obs_positions - get_zernike_prior - - - - - - .. rubric:: Classes - - .. autosummary:: - - DataHandler - - - - - - - - - diff --git a/docs/source/_autosummary/wf_psf.metrics.metrics.rst b/docs/source/_autosummary/wf_psf.metrics.metrics.rst deleted file mode 100644 index 36770eb1..00000000 --- a/docs/source/_autosummary/wf_psf.metrics.metrics.rst +++ /dev/null @@ -1,32 +0,0 @@ -wf\_psf.metrics.metrics -======================= - -.. automodule:: wf_psf.metrics.metrics - - - - - - - - .. rubric:: Functions - - .. autosummary:: - - compute_mono_metric - compute_opd_metrics - compute_poly_metric - compute_shape_metrics - - - - - - - - - - - - - diff --git a/docs/source/_autosummary/wf_psf.metrics.metrics_interface.rst b/docs/source/_autosummary/wf_psf.metrics.metrics_interface.rst deleted file mode 100644 index 9675e837..00000000 --- a/docs/source/_autosummary/wf_psf.metrics.metrics_interface.rst +++ /dev/null @@ -1,35 +0,0 @@ -wf\_psf.metrics.metrics\_interface -================================== - -.. automodule:: wf_psf.metrics.metrics_interface - - - - - - - - .. rubric:: Functions - - .. autosummary:: - - evaluate_model - - - - - - .. rubric:: Classes - - .. autosummary:: - - MetricsParamsHandler - - - - - - - - - diff --git a/docs/source/_autosummary/wf_psf.metrics.rst b/docs/source/_autosummary/wf_psf.metrics.rst deleted file mode 100644 index ed007b1e..00000000 --- a/docs/source/_autosummary/wf_psf.metrics.rst +++ /dev/null @@ -1,32 +0,0 @@ -wf\_psf.metrics -=============== - -.. automodule:: wf_psf.metrics - - - - - - - - - - - - - - - - - - - -.. rubric:: Modules - -.. autosummary:: - :toctree: - :recursive: - - wf_psf.metrics.metrics - wf_psf.metrics.metrics_interface - diff --git a/docs/source/_autosummary/wf_psf.plotting.plots_interface.rst b/docs/source/_autosummary/wf_psf.plotting.plots_interface.rst deleted file mode 100644 index 173b7f86..00000000 --- a/docs/source/_autosummary/wf_psf.plotting.plots_interface.rst +++ /dev/null @@ -1,40 +0,0 @@ -wf\_psf.plotting.plots\_interface -================================= - -.. automodule:: wf_psf.plotting.plots_interface - - - - - - - - .. rubric:: Functions - - .. autosummary:: - - define_plot_style - get_number_of_stars - make_plot - plot_metrics - - - - - - .. rubric:: Classes - - .. autosummary:: - - MetricsPlotHandler - MonochromaticMetricsPlotHandler - ShapeMetricsPlotHandler - - - - - - - - - diff --git a/docs/source/_autosummary/wf_psf.plotting.rst b/docs/source/_autosummary/wf_psf.plotting.rst deleted file mode 100644 index 8c0db4ac..00000000 --- a/docs/source/_autosummary/wf_psf.plotting.rst +++ /dev/null @@ -1,31 +0,0 @@ -wf\_psf.plotting -================ - -.. automodule:: wf_psf.plotting - - - - - - - - - - - - - - - - - - - -.. rubric:: Modules - -.. autosummary:: - :toctree: - :recursive: - - wf_psf.plotting.plots_interface - diff --git a/docs/source/_autosummary/wf_psf.psf_models.psf_model_parametric.rst b/docs/source/_autosummary/wf_psf.psf_models.psf_model_parametric.rst deleted file mode 100644 index fb16f0e3..00000000 --- a/docs/source/_autosummary/wf_psf.psf_models.psf_model_parametric.rst +++ /dev/null @@ -1,29 +0,0 @@ -wf\_psf.psf\_models.psf\_model\_parametric -========================================== - -.. automodule:: wf_psf.psf_models.psf_model_parametric - - - - - - - - - - - - .. rubric:: Classes - - .. autosummary:: - - TFParametricPSFFieldModel - - - - - - - - - diff --git a/docs/source/_autosummary/wf_psf.psf_models.psf_model_physical_polychromatic.rst b/docs/source/_autosummary/wf_psf.psf_models.psf_model_physical_polychromatic.rst deleted file mode 100644 index 311652a3..00000000 --- a/docs/source/_autosummary/wf_psf.psf_models.psf_model_physical_polychromatic.rst +++ /dev/null @@ -1,30 +0,0 @@ -wf\_psf.psf\_models.psf\_model\_physical\_polychromatic -======================================================= - -.. automodule:: wf_psf.psf_models.psf_model_physical_polychromatic - - - - - - - - - - - - .. rubric:: Classes - - .. autosummary:: - - PhysicalPolychromaticFieldFactory - TFPhysicalPolychromaticField - - - - - - - - - diff --git a/docs/source/_autosummary/wf_psf.psf_models.psf_model_semiparametric.rst b/docs/source/_autosummary/wf_psf.psf_models.psf_model_semiparametric.rst deleted file mode 100644 index 578bbc92..00000000 --- a/docs/source/_autosummary/wf_psf.psf_models.psf_model_semiparametric.rst +++ /dev/null @@ -1,30 +0,0 @@ -wf\_psf.psf\_models.psf\_model\_semiparametric -============================================== - -.. automodule:: wf_psf.psf_models.psf_model_semiparametric - - - - - - - - - - - - .. rubric:: Classes - - .. autosummary:: - - SemiParamFieldFactory - TFSemiParametricField - - - - - - - - - diff --git a/docs/source/_autosummary/wf_psf.psf_models.psf_models.rst b/docs/source/_autosummary/wf_psf.psf_models.psf_models.rst deleted file mode 100644 index 5395637e..00000000 --- a/docs/source/_autosummary/wf_psf.psf_models.psf_models.rst +++ /dev/null @@ -1,48 +0,0 @@ -wf\_psf.psf\_models.psf\_models -=============================== - -.. automodule:: wf_psf.psf_models.psf_models - - - - - - - - .. rubric:: Functions - - .. autosummary:: - - build_PSF_model - generate_zernike_maps_3d - get_psf_model - get_psf_model_weights_filepath - register_psfclass - set_psf_model - simPSF - tf_obscurations - - - - - - .. rubric:: Classes - - .. autosummary:: - - PSFModelBaseFactory - - - - - - .. rubric:: Exceptions - - .. autosummary:: - - PSFModelError - - - - - diff --git a/docs/source/_autosummary/wf_psf.psf_models.tf_layers.rst b/docs/source/_autosummary/wf_psf.psf_models.tf_layers.rst deleted file mode 100644 index de58615b..00000000 --- a/docs/source/_autosummary/wf_psf.psf_models.tf_layers.rst +++ /dev/null @@ -1,36 +0,0 @@ -wf\_psf.psf\_models.tf\_layers -============================== - -.. automodule:: wf_psf.psf_models.tf_layers - - - - - - - - - - - - .. rubric:: Classes - - .. autosummary:: - - TFBatchMonochromaticPSF - TFBatchPolychromaticPSF - TFNonParametricGraphOPD - TFNonParametricMCCDOPDv2 - TFNonParametricPolynomialVariationsOPD - TFPhysicalLayer - TFPolynomialZernikeField - TFZernikeOPD - - - - - - - - - diff --git a/docs/source/_autosummary/wf_psf.psf_models.tf_modules.rst b/docs/source/_autosummary/wf_psf.psf_models.tf_modules.rst deleted file mode 100644 index 7571ede7..00000000 --- a/docs/source/_autosummary/wf_psf.psf_models.tf_modules.rst +++ /dev/null @@ -1,33 +0,0 @@ -wf\_psf.psf\_models.tf\_modules -=============================== - -.. automodule:: wf_psf.psf_models.tf_modules - - - - - - - - - - - - .. rubric:: Classes - - .. autosummary:: - - TFBuildPhase - TFFftDiffract - TFMonochromaticPSF - TFZernikeMonochromaticPSF - TFZernikeOPD - - - - - - - - - diff --git a/docs/source/_autosummary/wf_psf.psf_models.tf_psf_field.rst b/docs/source/_autosummary/wf_psf.psf_models.tf_psf_field.rst deleted file mode 100644 index 8ab4325b..00000000 --- a/docs/source/_autosummary/wf_psf.psf_models.tf_psf_field.rst +++ /dev/null @@ -1,38 +0,0 @@ -wf\_psf.psf\_models.tf\_psf\_field -================================== - -.. automodule:: wf_psf.psf_models.tf_psf_field - - - - - - - - .. rubric:: Functions - - .. autosummary:: - - get_ground_truth_zernike - - - - - - .. rubric:: Classes - - .. autosummary:: - - GroundTruthPhysicalFieldFactory - GroundTruthSemiParamFieldFactory - TFGroundTruthPhysicalField - TFGroundTruthSemiParametricField - - - - - - - - - diff --git a/docs/source/_autosummary/wf_psf.psf_models.zernikes.rst b/docs/source/_autosummary/wf_psf.psf_models.zernikes.rst deleted file mode 100644 index 54b92e7c..00000000 --- a/docs/source/_autosummary/wf_psf.psf_models.zernikes.rst +++ /dev/null @@ -1,29 +0,0 @@ -wf\_psf.psf\_models.zernikes -============================ - -.. automodule:: wf_psf.psf_models.zernikes - - - - - - - - .. rubric:: Functions - - .. autosummary:: - - zernike_generator - - - - - - - - - - - - - diff --git a/docs/source/_autosummary/wf_psf.run.rst b/docs/source/_autosummary/wf_psf.run.rst deleted file mode 100644 index a4b725b2..00000000 --- a/docs/source/_autosummary/wf_psf.run.rst +++ /dev/null @@ -1,30 +0,0 @@ -wf\_psf.run -=========== - -.. automodule:: wf_psf.run - - - - - - - - .. rubric:: Functions - - .. autosummary:: - - mainMethod - setProgramOptions - - - - - - - - - - - - - diff --git a/docs/source/_autosummary/wf_psf.sims.psf_simulator.rst b/docs/source/_autosummary/wf_psf.sims.psf_simulator.rst deleted file mode 100644 index 7d29d3ba..00000000 --- a/docs/source/_autosummary/wf_psf.sims.psf_simulator.rst +++ /dev/null @@ -1,29 +0,0 @@ -wf\_psf.sims.psf\_simulator -=========================== - -.. automodule:: wf_psf.sims.psf_simulator - - - - - - - - - - - - .. rubric:: Classes - - .. autosummary:: - - PSFSimulator - - - - - - - - - diff --git a/docs/source/_autosummary/wf_psf.sims.rst b/docs/source/_autosummary/wf_psf.sims.rst deleted file mode 100644 index b0f0901d..00000000 --- a/docs/source/_autosummary/wf_psf.sims.rst +++ /dev/null @@ -1,32 +0,0 @@ -wf\_psf.sims -============ - -.. automodule:: wf_psf.sims - - - - - - - - - - - - - - - - - - - -.. rubric:: Modules - -.. autosummary:: - :toctree: - :recursive: - - wf_psf.sims.psf_simulator - wf_psf.sims.spatial_varying_psf - diff --git a/docs/source/_autosummary/wf_psf.sims.spatial_varying_psf.rst b/docs/source/_autosummary/wf_psf.sims.spatial_varying_psf.rst deleted file mode 100644 index 3fee34bc..00000000 --- a/docs/source/_autosummary/wf_psf.sims.spatial_varying_psf.rst +++ /dev/null @@ -1,33 +0,0 @@ -wf\_psf.sims.spatial\_varying\_psf -================================== - -.. automodule:: wf_psf.sims.spatial_varying_psf - - - - - - - - - - - - .. rubric:: Classes - - .. autosummary:: - - CoordinateHelper - MeshHelper - PolynomialMatrixHelper - SpatialVaryingPSF - ZernikeHelper - - - - - - - - - diff --git a/docs/source/_autosummary/wf_psf.training.rst b/docs/source/_autosummary/wf_psf.training.rst deleted file mode 100644 index aa692d15..00000000 --- a/docs/source/_autosummary/wf_psf.training.rst +++ /dev/null @@ -1,32 +0,0 @@ -wf\_psf.training -================ - -.. automodule:: wf_psf.training - - - - - - - - - - - - - - - - - - - -.. rubric:: Modules - -.. autosummary:: - :toctree: - :recursive: - - wf_psf.training.train - wf_psf.training.train_utils - diff --git a/docs/source/_autosummary/wf_psf.training.train.rst b/docs/source/_autosummary/wf_psf.training.train.rst deleted file mode 100644 index f8f3faea..00000000 --- a/docs/source/_autosummary/wf_psf.training.train.rst +++ /dev/null @@ -1,39 +0,0 @@ -wf\_psf.training.train -====================== - -.. automodule:: wf_psf.training.train - - - - - - - - .. rubric:: Functions - - .. autosummary:: - - filepath_chkp_callback - get_gpu_info - get_loss_metrics_monitor_and_outputs - setup_training - train - - - - - - .. rubric:: Classes - - .. autosummary:: - - TrainingParamsHandler - - - - - - - - - diff --git a/docs/source/_autosummary/wf_psf.training.train_utils.rst b/docs/source/_autosummary/wf_psf.training.train_utils.rst deleted file mode 100644 index 153313e1..00000000 --- a/docs/source/_autosummary/wf_psf.training.train_utils.rst +++ /dev/null @@ -1,43 +0,0 @@ -wf\_psf.training.train\_utils -============================= - -.. automodule:: wf_psf.training.train_utils - - - - - - - - .. rubric:: Functions - - .. autosummary:: - - calculate_sample_weights - configure_optimizer_and_loss - general_train_cycle - get_callbacks - l1_schedule_rule - masked_mse - train_cycle_part - - - - - - .. rubric:: Classes - - .. autosummary:: - - L1ParamScheduler - MaskedMeanSquaredError - MaskedMeanSquaredErrorMetric - - - - - - - - - diff --git a/docs/source/_autosummary/wf_psf.utils.ccd_misalignments.rst b/docs/source/_autosummary/wf_psf.utils.ccd_misalignments.rst deleted file mode 100644 index 00f69a01..00000000 --- a/docs/source/_autosummary/wf_psf.utils.ccd_misalignments.rst +++ /dev/null @@ -1,29 +0,0 @@ -wf\_psf.utils.ccd\_misalignments -================================ - -.. automodule:: wf_psf.utils.ccd_misalignments - - - - - - - - - - - - .. rubric:: Classes - - .. autosummary:: - - CCDMisalignmentCalculator - - - - - - - - - diff --git a/docs/source/_autosummary/wf_psf.utils.centroids.rst b/docs/source/_autosummary/wf_psf.utils.centroids.rst deleted file mode 100644 index 64aed9ee..00000000 --- a/docs/source/_autosummary/wf_psf.utils.centroids.rst +++ /dev/null @@ -1,39 +0,0 @@ -wf\_psf.utils.centroids -======================= - -.. automodule:: wf_psf.utils.centroids - - - - - - - - .. rubric:: Functions - - .. autosummary:: - - compute_zernike_tip_tilt - decim - degradation_op - lanczos - shift_ker_stack - - - - - - .. rubric:: Classes - - .. autosummary:: - - CentroidEstimator - - - - - - - - - diff --git a/docs/source/_autosummary/wf_psf.utils.configs_handler.rst b/docs/source/_autosummary/wf_psf.utils.configs_handler.rst deleted file mode 100644 index 0f1ca838..00000000 --- a/docs/source/_autosummary/wf_psf.utils.configs_handler.rst +++ /dev/null @@ -1,46 +0,0 @@ -wf\_psf.utils.configs\_handler -============================== - -.. automodule:: wf_psf.utils.configs_handler - - - - - - - - .. rubric:: Functions - - .. autosummary:: - - get_run_config - register_configclass - set_run_config - - - - - - .. rubric:: Classes - - .. autosummary:: - - DataConfigHandler - MetricsConfigHandler - PlottingConfigHandler - TrainingConfigHandler - - - - - - .. rubric:: Exceptions - - .. autosummary:: - - ConfigParameterError - - - - - diff --git a/docs/source/_autosummary/wf_psf.utils.graph_utils.rst b/docs/source/_autosummary/wf_psf.utils.graph_utils.rst deleted file mode 100644 index fd373ab4..00000000 --- a/docs/source/_autosummary/wf_psf.utils.graph_utils.rst +++ /dev/null @@ -1,37 +0,0 @@ -wf\_psf.utils.graph\_utils -========================== - -.. automodule:: wf_psf.utils.graph_utils - - - - - - - - .. rubric:: Functions - - .. autosummary:: - - gen_Pea - pairwise_distances - select_vstar - - - - - - .. rubric:: Classes - - .. autosummary:: - - GraphBuilder - - - - - - - - - diff --git a/docs/source/_autosummary/wf_psf.utils.io.rst b/docs/source/_autosummary/wf_psf.utils.io.rst deleted file mode 100644 index 37155165..00000000 --- a/docs/source/_autosummary/wf_psf.utils.io.rst +++ /dev/null @@ -1,29 +0,0 @@ -wf\_psf.utils.io -================ - -.. automodule:: wf_psf.utils.io - - - - - - - - - - - - .. rubric:: Classes - - .. autosummary:: - - FileIOHandler - - - - - - - - - diff --git a/docs/source/_autosummary/wf_psf.utils.preprocessing.rst b/docs/source/_autosummary/wf_psf.utils.preprocessing.rst deleted file mode 100644 index 0f8fed5d..00000000 --- a/docs/source/_autosummary/wf_psf.utils.preprocessing.rst +++ /dev/null @@ -1,31 +0,0 @@ -wf\_psf.utils.preprocessing -=========================== - -.. automodule:: wf_psf.utils.preprocessing - - - - - - - - .. rubric:: Functions - - .. autosummary:: - - defocus_to_zk4_wavediff - defocus_to_zk4_zemax - shift_x_y_to_zk1_2_wavediff - - - - - - - - - - - - - diff --git a/docs/source/_autosummary/wf_psf.utils.read_config.rst b/docs/source/_autosummary/wf_psf.utils.read_config.rst deleted file mode 100644 index adccbe4e..00000000 --- a/docs/source/_autosummary/wf_psf.utils.read_config.rst +++ /dev/null @@ -1,37 +0,0 @@ -wf\_psf.utils.read\_config -========================== - -.. automodule:: wf_psf.utils.read_config - - - - - - - - .. rubric:: Functions - - .. autosummary:: - - read_conf - read_stream - read_yaml - - - - - - .. rubric:: Classes - - .. autosummary:: - - RecursiveNamespace - - - - - - - - - diff --git a/docs/source/_autosummary/wf_psf.utils.rst b/docs/source/_autosummary/wf_psf.utils.rst deleted file mode 100644 index 3c1f8660..00000000 --- a/docs/source/_autosummary/wf_psf.utils.rst +++ /dev/null @@ -1,38 +0,0 @@ -wf\_psf.utils -============= - -.. automodule:: wf_psf.utils - - - - - - - - - - - - - - - - - - - -.. rubric:: Modules - -.. autosummary:: - :toctree: - :recursive: - - wf_psf.utils.ccd_misalignments - wf_psf.utils.centroids - wf_psf.utils.configs_handler - wf_psf.utils.graph_utils - wf_psf.utils.io - wf_psf.utils.preprocessing - wf_psf.utils.read_config - wf_psf.utils.utils - diff --git a/docs/source/_autosummary/wf_psf.utils.utils.rst b/docs/source/_autosummary/wf_psf.utils.utils.rst deleted file mode 100644 index 71e7d532..00000000 --- a/docs/source/_autosummary/wf_psf.utils.utils.rst +++ /dev/null @@ -1,51 +0,0 @@ -wf\_psf.utils.utils -=================== - -.. automodule:: wf_psf.utils.utils - - - - - - - - .. rubric:: Functions - - .. autosummary:: - - add_noise - calc_poly_position_mat - compute_unobscured_zernike_projection - convert_to_tf - decimate_im - decompose_tf_obscured_opd_basis - downsample_im - generalised_sigmoid - generate_SED_elems - generate_SED_elems_in_tensorflow - generate_n_mask - generate_packed_elems - load_multi_cycle_params_click - single_mask_generator - zernike_generator - - - - - - .. rubric:: Classes - - .. autosummary:: - - IndependentZernikeInterpolation - NoiseEstimator - ZernikeInterpolation - - - - - - - - - From 6ecb9ba60e899c0558f960267319ed4fd3c91293 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Mon, 24 Nov 2025 15:41:34 +0100 Subject: [PATCH 005/135] Add arduino formatting to directory structure example --- docs/source/basic_execution.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/basic_execution.md b/docs/source/basic_execution.md index c29f30fe..55f9c089 100644 --- a/docs/source/basic_execution.md +++ b/docs/source/basic_execution.md @@ -34,7 +34,7 @@ WaveDiff creates an output directory at the location specified by `--outputdir` Inside this directory, the following subdirectories are generated: (wf-outputs)= -``` +```arduino wf-outputs-20231119151932213823 ├── checkpoint ├── config From 11bad16c29e255d94f9d00b8153d939321395e77 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Tue, 25 Nov 2025 10:23:11 +0100 Subject: [PATCH 006/135] Revise WaveDiff how-to guide in doc (part 1) - Shorten sections to make developer-friendly - Added new structure for each section: Purppose, Key Fields, Notes, General Notes, etc, where applicable - Partial completion - new PR is required to verify optional versus required settings --- docs/source/configuration.md | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/docs/source/configuration.md b/docs/source/configuration.md index ba87cef2..3d2f31bd 100644 --- a/docs/source/configuration.md +++ b/docs/source/configuration.md @@ -20,7 +20,13 @@ You configure these tasks by passing a configuration file to the `wavediff` comm WaveDiff expects the following configuration files under the `config/` directory: -``` +You configure these tasks by passing a configuration file to the `wavediff` command (e.g., `--config configs.yaml`). + +## Configuration File Structure + +WaveDiff expects the following configuration files under the `config/` directory: + +```arduino config ├── configs.yaml ├── data_config.yaml @@ -220,7 +226,6 @@ training_hparams: n_epochs_non_params: [100, 120] ``` - (metrics_config)= ## `metrics_config.yaml` — Metrics Configuration @@ -402,7 +407,10 @@ plotting_params: ### 4. Example Directory Structure Below is an example of three WaveDiff runs stored under a single parent directory: -``` +**Example Directory Structure** +Below is an example of three WaveDiff runs stored under a single parent directory: + +```arduino wf-outputs/ ├── wf-outputs-202305271829 │ ├── config From 59dbccd5df4f104b2118b5584e793807309a047c Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Tue, 25 Nov 2025 14:55:37 +0100 Subject: [PATCH 007/135] Fix myst error with invalid code format --- docs/source/basic_execution.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/basic_execution.md b/docs/source/basic_execution.md index 55f9c089..c29f30fe 100644 --- a/docs/source/basic_execution.md +++ b/docs/source/basic_execution.md @@ -34,7 +34,7 @@ WaveDiff creates an output directory at the location specified by `--outputdir` Inside this directory, the following subdirectories are generated: (wf-outputs)= -```arduino +``` wf-outputs-20231119151932213823 ├── checkpoint ├── config From 12b89697f75607527399900aaaf1d32bf9f7dd03 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Tue, 13 May 2025 11:33:38 +0200 Subject: [PATCH 008/135] Correct syntax in docstring and generalise exception message --- src/wf_psf/psf_models/psf_models.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/wf_psf/psf_models/psf_models.py b/src/wf_psf/psf_models/psf_models.py index 463d1c52..4c44f698 100644 --- a/src/wf_psf/psf_models/psf_models.py +++ b/src/wf_psf/psf_models/psf_models.py @@ -187,24 +187,24 @@ def build_PSF_model(model_inst, optimizer=None, loss=None, metrics=None): def get_psf_model_weights_filepath(weights_filepath): """Get PSF model weights filepath. - A function to return the basename of the user-specified psf model weights path. + A function to return the basename of the user-specified PSF model weights path. Parameters ---------- weights_filepath: str - Basename of the psf model weights to be loaded. + Basename of the PSF model weights to be loaded. Returns ------- str - The absolute path concatenated to the basename of the psf model weights to be loaded. + The absolute path concatenated to the basename of the PSF model weights to be loaded. """ try: return glob.glob(weights_filepath)[0].split(".")[0] except IndexError: logger.exception( - "PSF weights file not found. Check that you've specified the correct weights file in the metrics config file." + "PSF weights file not found. Check that you've specified the correct weights file in the your config file." ) raise PSFModelError("PSF model weights error.") From eaa0e8e915ad2ce9c4f9a35fc78eafc53dc5bccc Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Tue, 13 May 2025 11:35:24 +0200 Subject: [PATCH 009/135] Add inference and test_inference packages --- src/wf_psf/inference/__init__.py | 0 src/wf_psf/inference/psf_inference.py | 0 src/wf_psf/tests/test_inference/test_psf_inference.py | 0 3 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 src/wf_psf/inference/__init__.py create mode 100644 src/wf_psf/inference/psf_inference.py create mode 100644 src/wf_psf/tests/test_inference/test_psf_inference.py diff --git a/src/wf_psf/inference/__init__.py b/src/wf_psf/inference/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py new file mode 100644 index 00000000..e69de29b diff --git a/src/wf_psf/tests/test_inference/test_psf_inference.py b/src/wf_psf/tests/test_inference/test_psf_inference.py new file mode 100644 index 00000000..e69de29b From 1aebacb4c1cc3acf2eb5dcd5170fb4181c0f8ec6 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Tue, 13 May 2025 14:19:39 +0200 Subject: [PATCH 010/135] Refactor: Encapsulate logic in psf_models package with subpackages: models and tf_modules, add/rm modules, update import statements and tests --- src/wf_psf/psf_models/models/__init__.py | 0 .../{ => models}/psf_model_parametric.py | 2 +- .../psf_model_physical_polychromatic.py | 8 ++++---- .../{ => models}/psf_model_semiparametric.py | 4 ++-- src/wf_psf/psf_models/tf_modules/__init__.py | 0 .../psf_models/{ => tf_modules}/tf_layers.py | 3 ++- .../psf_models/{ => tf_modules}/tf_modules.py | 0 .../psf_models/{ => tf_modules}/tf_psf_field.py | 4 ++-- .../psf_model_physical_polychromatic_test.py | 16 +++++++++------- .../tests/test_psf_models/psf_models_test.py | 6 +++--- 10 files changed, 23 insertions(+), 20 deletions(-) create mode 100644 src/wf_psf/psf_models/models/__init__.py rename src/wf_psf/psf_models/{ => models}/psf_model_parametric.py (99%) rename src/wf_psf/psf_models/{ => models}/psf_model_physical_polychromatic.py (99%) rename src/wf_psf/psf_models/{ => models}/psf_model_semiparametric.py (99%) create mode 100644 src/wf_psf/psf_models/tf_modules/__init__.py rename src/wf_psf/psf_models/{ => tf_modules}/tf_layers.py (99%) rename src/wf_psf/psf_models/{ => tf_modules}/tf_modules.py (100%) rename src/wf_psf/psf_models/{ => tf_modules}/tf_psf_field.py (99%) diff --git a/src/wf_psf/psf_models/models/__init__.py b/src/wf_psf/psf_models/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/wf_psf/psf_models/psf_model_parametric.py b/src/wf_psf/psf_models/models/psf_model_parametric.py similarity index 99% rename from src/wf_psf/psf_models/psf_model_parametric.py rename to src/wf_psf/psf_models/models/psf_model_parametric.py index 0cd703d7..3d85d2bc 100644 --- a/src/wf_psf/psf_models/psf_model_parametric.py +++ b/src/wf_psf/psf_models/models/psf_model_parametric.py @@ -9,7 +9,7 @@ import tensorflow as tf from wf_psf.psf_models.psf_models import register_psfclass -from wf_psf.psf_models.tf_layers import ( +from wf_psf.psf_models.tf_modules.tf_layers import ( TFPolynomialZernikeField, TFZernikeOPD, TFBatchPolychromaticPSF, diff --git a/src/wf_psf/psf_models/psf_model_physical_polychromatic.py b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py similarity index 99% rename from src/wf_psf/psf_models/psf_model_physical_polychromatic.py rename to src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py index ad61c1b5..f9ed8765 100644 --- a/src/wf_psf/psf_models/psf_model_physical_polychromatic.py +++ b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py @@ -10,11 +10,9 @@ from typing import Optional import tensorflow as tf from tensorflow.python.keras.engine import data_adapter -from wf_psf.psf_models import psf_models as psfm -from wf_psf.utils.read_config import RecursiveNamespace -from wf_psf.utils.configs_handler import DataConfigHandler from wf_psf.data.training_preprocessing import get_obs_positions, get_zernike_prior -from wf_psf.psf_models.tf_layers import ( +from wf_psf.psf_models import psf_models as psfm +from wf_psf.psf_models.tf_modules.tf_layers import ( TFPolynomialZernikeField, TFZernikeOPD, TFBatchPolychromaticPSF, @@ -22,6 +20,8 @@ TFNonParametricPolynomialVariationsOPD, TFPhysicalLayer, ) +from wf_psf.utils.read_config import RecursiveNamespace +from wf_psf.utils.configs_handler import DataConfigHandler import logging diff --git a/src/wf_psf/psf_models/psf_model_semiparametric.py b/src/wf_psf/psf_models/models/psf_model_semiparametric.py similarity index 99% rename from src/wf_psf/psf_models/psf_model_semiparametric.py rename to src/wf_psf/psf_models/models/psf_model_semiparametric.py index dc535204..c370956c 100644 --- a/src/wf_psf/psf_models/psf_model_semiparametric.py +++ b/src/wf_psf/psf_models/models/psf_model_semiparametric.py @@ -10,9 +10,9 @@ import numpy as np import tensorflow as tf from wf_psf.psf_models import psf_models as psfm -from wf_psf.psf_models import tf_layers as tfl +from wf_psf.psf_models.tf_modules import tf_layers as tfl from wf_psf.utils.utils import decompose_tf_obscured_opd_basis -from wf_psf.psf_models.tf_layers import ( +from wf_psf.psf_models.tf_modules.tf_layers import ( TFBatchPolychromaticPSF, TFBatchMonochromaticPSF, ) diff --git a/src/wf_psf/psf_models/tf_modules/__init__.py b/src/wf_psf/psf_models/tf_modules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/wf_psf/psf_models/tf_layers.py b/src/wf_psf/psf_models/tf_modules/tf_layers.py similarity index 99% rename from src/wf_psf/psf_models/tf_layers.py rename to src/wf_psf/psf_models/tf_modules/tf_layers.py index eda43305..fcb5e8f9 100644 --- a/src/wf_psf/psf_models/tf_layers.py +++ b/src/wf_psf/psf_models/tf_modules/tf_layers.py @@ -7,7 +7,8 @@ """ import tensorflow as tf -from wf_psf.psf_models.tf_modules import TFMonochromaticPSF +import tensorflow_addons as tfa +from wf_psf.psf_models.tf_modules.tf_modules import TFMonochromaticPSF from wf_psf.utils.utils import calc_poly_position_mat import wf_psf.utils.utils as utils from wf_psf.utils.interpolation import tfa_interpolate_spline_rbf diff --git a/src/wf_psf/psf_models/tf_modules.py b/src/wf_psf/psf_models/tf_modules/tf_modules.py similarity index 100% rename from src/wf_psf/psf_models/tf_modules.py rename to src/wf_psf/psf_models/tf_modules/tf_modules.py diff --git a/src/wf_psf/psf_models/tf_psf_field.py b/src/wf_psf/psf_models/tf_modules/tf_psf_field.py similarity index 99% rename from src/wf_psf/psf_models/tf_psf_field.py rename to src/wf_psf/psf_models/tf_modules/tf_psf_field.py index 0c9ba2f7..e25657d9 100644 --- a/src/wf_psf/psf_models/tf_psf_field.py +++ b/src/wf_psf/psf_models/tf_modules/tf_psf_field.py @@ -9,13 +9,13 @@ import numpy as np import tensorflow as tf from tensorflow.python.keras.engine import data_adapter -from wf_psf.psf_models.tf_layers import ( +from wf_psf.psf_models.tf_modules.tf_layers import ( TFZernikeOPD, TFBatchPolychromaticPSF, TFBatchMonochromaticPSF, TFPhysicalLayer, ) -from wf_psf.psf_models.psf_model_semiparametric import TFSemiParametricField +from wf_psf.psf_models.models.psf_model_semiparametric import TFSemiParametricField from wf_psf.data.training_preprocessing import get_obs_positions from wf_psf.psf_models import psf_models as psfm import logging diff --git a/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py b/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py index cae7b141..e25d608d 100644 --- a/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py +++ b/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py @@ -9,7 +9,7 @@ import pytest import numpy as np import tensorflow as tf -from wf_psf.psf_models.psf_model_physical_polychromatic import ( +from wf_psf.psf_models.models.psf_model_physical_polychromatic import ( TFPhysicalPolychromaticField, ) from wf_psf.utils.configs_handler import DataConfigHandler @@ -54,7 +54,7 @@ def test_initialize_parameters(mocker, mock_data, mock_model_params, zks_prior): # Mock internal methods called during initialization mocker.patch( - "wf_psf.psf_models.psf_model_physical_polychromatic.get_zernike_prior", + "wf_psf.psf_models.models.psf_model_physical_polychromatic.get_zernike_prior", return_value=zks_prior, ) @@ -92,7 +92,7 @@ def test_initialize_zernike_parameters(mocker, mock_model_params, mock_data, zks # Mock internal methods called during initialization mocker.patch( - "wf_psf.psf_models.psf_model_physical_polychromatic.get_zernike_prior", + "wf_psf.psf_models.models.psf_model_physical_polychromatic.get_zernike_prior", return_value=zks_prior, ) @@ -146,13 +146,13 @@ def test_initialize_physical_layer_mocking( # Mock internal methods called during initialization mocker.patch( - "wf_psf.psf_models.psf_model_physical_polychromatic.get_zernike_prior", + "wf_psf.psf_models.models.psf_model_physical_polychromatic.get_zernike_prior", return_value=zks_prior, ) # Create a mock for the TFPhysicalLayer class mock_physical_layer_class = mocker.patch( - "wf_psf.psf_models.psf_model_physical_polychromatic.TFPhysicalLayer" + "wf_psf.psf_models.models.psf_model_physical_polychromatic.TFPhysicalLayer" ) # Create TFPhysicalPolychromaticField instance @@ -176,12 +176,14 @@ def physical_layer_instance(mocker, mock_model_params, mock_data, zks_prior): # Mock internal methods called during initialization mocker.patch( - "wf_psf.psf_models.psf_model_physical_polychromatic.get_zernike_prior", + "wf_psf.psf_models.models.psf_model_physical_polychromatic.get_zernike_prior", return_value=zks_prior, ) # Create a mock for the TFPhysicalLayer class - mocker.patch("wf_psf.psf_models.psf_model_physical_polychromatic.TFPhysicalLayer") + mocker.patch( + "wf_psf.psf_models.models.psf_model_physical_polychromatic.TFPhysicalLayer" + ) # Create TFPhysicalPolychromaticField instance psf_field_instance = TFPhysicalPolychromaticField( diff --git a/src/wf_psf/tests/test_psf_models/psf_models_test.py b/src/wf_psf/tests/test_psf_models/psf_models_test.py index 066e1328..b7c906f6 100644 --- a/src/wf_psf/tests/test_psf_models/psf_models_test.py +++ b/src/wf_psf/tests/test_psf_models/psf_models_test.py @@ -7,10 +7,10 @@ """ -from wf_psf.psf_models import ( - psf_models, +from wf_psf.psf_models import psf_models +from wf_psf.psf_models.models import ( psf_model_semiparametric, - psf_model_physical_polychromatic, + psf_model_physical_polychromatic ) import tensorflow as tf import numpy as np From d66d2c92e77ad98e05510071205a54dfbe4b7ff7 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Tue, 13 May 2025 14:29:13 +0200 Subject: [PATCH 011/135] Remove unused module with duplicate zernike_generator function --- src/wf_psf/psf_models/zernikes.py | 58 ------------------------------- 1 file changed, 58 deletions(-) delete mode 100644 src/wf_psf/psf_models/zernikes.py diff --git a/src/wf_psf/psf_models/zernikes.py b/src/wf_psf/psf_models/zernikes.py deleted file mode 100644 index dcfa6e39..00000000 --- a/src/wf_psf/psf_models/zernikes.py +++ /dev/null @@ -1,58 +0,0 @@ -"""Zernikes. - -A module to make Zernike maps. - -:Author: Tobias Liaudat and Jennifer Pollack - -""" - -import numpy as np -import zernike as zk -import logging - -logger = logging.getLogger(__name__) - - -def zernike_generator(n_zernikes, wfe_dim): - """ - Generate Zernike maps. - - Based on the zernike github repository. - https://github.com/jacopoantonello/zernike - - Parameters - ---------- - n_zernikes: int - Number of Zernike modes desired. - wfe_dim: int - Dimension of the Zernike map [wfe_dim x wfe_dim]. - - Returns - ------- - zernikes: list of np.ndarray - List containing the Zernike modes. - The values outside the unit circle are filled with NaNs. - """ - # Calculate which n (from the (n,m) Zernike convention) we need - # so that we have the desired total number of Zernike coefficients - min_n = (-3 + np.sqrt(1 + 8 * n_zernikes)) / 2 - n = int(np.ceil(min_n)) - - # Initialize the zernike generator - cart = zk.RZern(n) - # Create a [-1,1] mesh - ddx = np.linspace(-1.0, 1.0, wfe_dim) - ddy = np.linspace(-1.0, 1.0, wfe_dim) - xv, yv = np.meshgrid(ddx, ddy) - cart.make_cart_grid(xv, yv) - - c = np.zeros(cart.nk) - zernikes = [] - - # Extract each Zernike map one by one - for i in range(n_zernikes): - c *= 0.0 - c[i] = 1.0 - zernikes.append(cart.eval_grid(c, matrix=True)) - - return zernikes From b9cc93729ff2c3e6d8cbc83fe09db254f004b190 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Tue, 13 May 2025 14:31:47 +0200 Subject: [PATCH 012/135] Correct syntax in docstrings and logger messages --- src/wf_psf/utils/configs_handler.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/wf_psf/utils/configs_handler.py b/src/wf_psf/utils/configs_handler.py index 5d45ba79..07737da0 100644 --- a/src/wf_psf/utils/configs_handler.py +++ b/src/wf_psf/utils/configs_handler.py @@ -1,4 +1,4 @@ -"""Configs_Handler. +s"""Configs_Handler. A module which provides general utility methods to manage the parameters of the config files @@ -443,12 +443,12 @@ def _load_data_conf(self): def weights_basename_filepath(self): """Get PSF model weights filepath. - A function to return the basename of the user-specified psf model weights path. + A function to return the basename of the user-specified PSF model weights path. Returns ------- weights_basename: str - The basename of the psf model weights to be loaded. + The basename of the PSF model weights to be loaded. """ return os.path.join( @@ -502,7 +502,7 @@ def call_plot_config_handler_run(self, model_metrics): def run(self): """Run. - A function to run wave-diff according to the + A function to run WaveDiff according to the input configuration. """ From f72e0e5ad482efeb5463c71dbf477843e13be9a8 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Wed, 14 May 2025 17:22:44 +0200 Subject: [PATCH 013/135] Refactor file structure; update import statements in tests; remove unit test due to refactoring --- src/wf_psf/{utils => data}/centroids.py | 2 +- .../data_preprocessing.py} | 2 +- src/wf_psf/data/training_preprocessing.py | 4 +-- src/wf_psf/instrument/__init__.py | 0 .../ccd_misalignments.py | 0 .../centroids_test.py | 10 +++---- .../tests/test_utils/configs_handler_test.py | 27 +++++-------------- 7 files changed, 15 insertions(+), 30 deletions(-) rename src/wf_psf/{utils => data}/centroids.py (99%) rename src/wf_psf/{utils/preprocessing.py => data/data_preprocessing.py} (99%) create mode 100644 src/wf_psf/instrument/__init__.py rename src/wf_psf/{utils => instrument}/ccd_misalignments.py (100%) rename src/wf_psf/tests/{test_utils => test_data}/centroids_test.py (97%) diff --git a/src/wf_psf/utils/centroids.py b/src/wf_psf/data/centroids.py similarity index 99% rename from src/wf_psf/utils/centroids.py rename to src/wf_psf/data/centroids.py index 8b4522bf..e414520e 100644 --- a/src/wf_psf/utils/centroids.py +++ b/src/wf_psf/data/centroids.py @@ -8,7 +8,7 @@ import numpy as np import scipy.signal as scisig -from wf_psf.utils.preprocessing import shift_x_y_to_zk1_2_wavediff +from wf_psf.data.data_preprocessing import shift_x_y_to_zk1_2_wavediff from typing import Optional diff --git a/src/wf_psf/utils/preprocessing.py b/src/wf_psf/data/data_preprocessing.py similarity index 99% rename from src/wf_psf/utils/preprocessing.py rename to src/wf_psf/data/data_preprocessing.py index 210c03e5..44e18436 100644 --- a/src/wf_psf/utils/preprocessing.py +++ b/src/wf_psf/data/data_preprocessing.py @@ -1,4 +1,4 @@ -"""Preprocessing. +"""Data Preprocessing. A module with utils to preprocess data. diff --git a/src/wf_psf/data/training_preprocessing.py b/src/wf_psf/data/training_preprocessing.py index c1402f06..25180ebe 100644 --- a/src/wf_psf/data/training_preprocessing.py +++ b/src/wf_psf/data/training_preprocessing.py @@ -10,8 +10,8 @@ import numpy as np import wf_psf.utils.utils as utils import tensorflow as tf -from wf_psf.utils.ccd_misalignments import CCDMisalignmentCalculator -from wf_psf.utils.centroids import compute_zernike_tip_tilt +from wf_psf.instrument.ccd_misalignments import CCDMisalignmentCalculator +from wf_psf.data.centroids import compute_zernike_tip_tilt from fractions import Fraction import logging diff --git a/src/wf_psf/instrument/__init__.py b/src/wf_psf/instrument/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/wf_psf/utils/ccd_misalignments.py b/src/wf_psf/instrument/ccd_misalignments.py similarity index 100% rename from src/wf_psf/utils/ccd_misalignments.py rename to src/wf_psf/instrument/ccd_misalignments.py diff --git a/src/wf_psf/tests/test_utils/centroids_test.py b/src/wf_psf/tests/test_data/centroids_test.py similarity index 97% rename from src/wf_psf/tests/test_utils/centroids_test.py rename to src/wf_psf/tests/test_data/centroids_test.py index 8557704f..c55e18f9 100644 --- a/src/wf_psf/tests/test_utils/centroids_test.py +++ b/src/wf_psf/tests/test_data/centroids_test.py @@ -9,7 +9,7 @@ import numpy as np import pytest from unittest.mock import MagicMock, patch -from wf_psf.utils.centroids import compute_zernike_tip_tilt, CentroidEstimator +from wf_psf.data.centroids import compute_zernike_tip_tilt, CentroidEstimator # Function to compute centroid based on first-order moments @@ -133,7 +133,7 @@ def test_compute_zernike_tip_tilt_single_batch(mocker, simple_image, identity_ma """Test compute_zernike_tip_tilt with single batch input and mocks.""" # Mock the CentroidEstimator class mock_centroid_calc = mocker.patch( - "wf_psf.utils.centroids.CentroidEstimator", autospec=True + "wf_psf.data.centroids.CentroidEstimator", autospec=True ) # Create a mock instance and configure get_intra_pixel_shifts() @@ -144,7 +144,7 @@ def test_compute_zernike_tip_tilt_single_batch(mocker, simple_image, identity_ma # Mock shift_x_y_to_zk1_2_wavediff to return predictable values mock_shift_fn = mocker.patch( - "wf_psf.utils.centroids.shift_x_y_to_zk1_2_wavediff", + "wf_psf.data.centroids.shift_x_y_to_zk1_2_wavediff", side_effect=lambda shift: shift * 0.5, # Mocked conversion for test ) @@ -191,7 +191,7 @@ def test_compute_zernike_tip_tilt_batch(mocker, multiple_images): """Test compute_zernike_tip_tilt with batch input and mocks.""" # Mock the CentroidEstimator class mock_centroid_calc = mocker.patch( - "wf_psf.utils.centroids.CentroidEstimator", autospec=True + "wf_psf.data.centroids.CentroidEstimator", autospec=True ) # Create a mock instance and configure get_intra_pixel_shifts() @@ -202,7 +202,7 @@ def test_compute_zernike_tip_tilt_batch(mocker, multiple_images): # Mock shift_x_y_to_zk1_2_wavediff to return predictable values mock_shift_fn = mocker.patch( - "wf_psf.utils.centroids.shift_x_y_to_zk1_2_wavediff", + "wf_psf.data.centroids.shift_x_y_to_zk1_2_wavediff", side_effect=lambda shift: shift * 0.5, # Mocked conversion for test ) diff --git a/src/wf_psf/tests/test_utils/configs_handler_test.py b/src/wf_psf/tests/test_utils/configs_handler_test.py index d95761e9..9b4039a8 100644 --- a/src/wf_psf/tests/test_utils/configs_handler_test.py +++ b/src/wf_psf/tests/test_utils/configs_handler_test.py @@ -4,15 +4,18 @@ :Author: Jennifer Pollack - """ import pytest +from wf_psf.data.training_preprocessing import DataHandler from wf_psf.utils import configs_handler from wf_psf.utils.read_config import RecursiveNamespace from wf_psf.utils.io import FileIOHandler -from wf_psf.utils.configs_handler import TrainingConfigHandler, DataConfigHandler -from wf_psf.data.training_preprocessing import DataHandler +from wf_psf.utils.configs_handler import ( + TrainingConfigHandler, + MetricsConfigHandler, + DataConfigHandler, +) import os @@ -229,21 +232,3 @@ def test_run_method_calls_train_with_correct_arguments( mock_th.optimizer_dir, mock_th.psf_model_dir, ) - - -def test_MetricsConfigHandler_weights_basename_filepath( - path_to_tmp_output_dir, path_to_config_dir -): - test_file_handler = FileIOHandler(path_to_tmp_output_dir, path_to_config_dir) - - metrics_config_file = "validation/main_random_seed/config/metrics_config.yaml" - - metrics_object = configs_handler.MetricsConfigHandler( - os.path.join(path_to_config_dir, metrics_config_file), test_file_handler - ) - weights_filepath = metrics_object.weights_basename_filepath - - assert ( - weights_filepath - == "src/wf_psf/tests/data/validation/main_random_seed/checkpoint/checkpoint*_poly*_sample_w_bis1_2k_cycle2*" - ) From a86352caae93a143ef4c52f789f261ab0b8f77ec Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Wed, 14 May 2025 20:35:12 +0200 Subject: [PATCH 014/135] Update package name in import statement --- src/wf_psf/instrument/ccd_misalignments.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/wf_psf/instrument/ccd_misalignments.py b/src/wf_psf/instrument/ccd_misalignments.py index 1d2135ba..3b7a1eb3 100644 --- a/src/wf_psf/instrument/ccd_misalignments.py +++ b/src/wf_psf/instrument/ccd_misalignments.py @@ -10,7 +10,7 @@ import numpy as np import matplotlib.path as mpltPath from scipy.spatial import KDTree -from wf_psf.utils.preprocessing import defocus_to_zk4_wavediff +from wf_psf.data.data_preprocessing import defocus_to_zk4_wavediff class CCDMisalignmentCalculator: From 2b262288f368d5cbcc6b89530120177c6331ef34 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Wed, 14 May 2025 22:54:33 +0200 Subject: [PATCH 015/135] Reorder imports; Refactor MetricsConfigHandler class attributes, methods and variable names| --- src/wf_psf/utils/configs_handler.py | 231 ++++++++++++---------------- 1 file changed, 98 insertions(+), 133 deletions(-) diff --git a/src/wf_psf/utils/configs_handler.py b/src/wf_psf/utils/configs_handler.py index 07737da0..e6764873 100644 --- a/src/wf_psf/utils/configs_handler.py +++ b/src/wf_psf/utils/configs_handler.py @@ -12,12 +12,14 @@ import os import re import glob -from wf_psf.utils.read_config import read_conf from wf_psf.data.training_preprocessing import DataHandler -from wf_psf.training import train -from wf_psf.psf_models import psf_models from wf_psf.metrics.metrics_interface import evaluate_model from wf_psf.plotting.plots_interface import plot_metrics +from wf_psf.psf_models import psf_models +from wf_psf.psf_models.psf_model_loader import load_trained_psf_model +from wf_psf.training import train +from wf_psf.utils.read_config import read_conf + logger = logging.getLogger(__name__) @@ -253,9 +255,11 @@ class MetricsConfigHandler: def __init__(self, metrics_conf, file_handler, training_conf=None): self._metrics_conf = read_conf(metrics_conf) - self._file_handler = file_handler - self.trained_model_path = self._get_trained_model_path(training_conf) - self._training_conf = self._load_training_conf(training_conf) + self.data_conf = self._load_data_conf() + self._file_handler = file_handler + self.metrics_dir = self._file_handler.get_metrics_dir(self._file_handler._run_output_dir) + self.training_conf = training_conf + self.trained_psf_model = self.load_trained_psf_model(self.training_conf, self.data_conf ) @property def metrics_conf(self): @@ -270,32 +274,27 @@ def metrics_conf(self): """ return self._metrics_conf - @property - def metrics_dir(self): - """Get Metrics Directory. - - A function that returns path - of metrics directory. - - Returns - ------- - str - Absolute path to metrics directory - """ - return self._file_handler.get_metrics_dir(self._file_handler._run_output_dir) - @property def training_conf(self): - """Get Training Conf. - - A function to return the training configuration file name. + """Returns the loaded training configuration.""" + return self._training_conf - Returns - ------- - RecursiveNamespace - An instance of the training configuration file. + @training_conf.setter + def training_conf(self, training_conf): """ - return self._training_conf + Sets the training configuration. If None is provided, attempts to load it + from the trained_model_path in the metrics configuration. + """ + if training_conf is None: + try: + training_conf_path = self._get_training_conf_path_from_metrics() + logger.info(f"Loading training config from inferred path: {training_conf_path}") + self._training_conf = read_conf(training_conf_path) + except Exception as e: + logger.error(f"Failed to load training config: {e}") + raise + else: + self._training_conf = training_conf @property def plotting_conf(self): @@ -310,112 +309,99 @@ def plotting_conf(self): """ return self.metrics_conf.metrics.plotting_config - @property - def data_conf(self): - """Get Data Conf. - - A function to return an instance of the DataConfigHandler class. - - Returns - ------- - An instance of the DataConfigHandler class. - """ - return self._load_data_conf() - - @property - def psf_model(self): - """Get PSF Model. - - A function to return an instance of the PSF model - to be evaluated. - - Returns - ------- - psf_model: obj - An instance of the PSF model to be evaluated. - """ - return psf_models.get_psf_model( - self.training_conf.training.model_params, - self.training_conf.training.training_hparams, + def _load_trained_psf_model(self): + trained_model_path = self._get_trained_model_path() + try: + model_subdir = self.metrics_conf.metrics.model_save_path + cycle = self.metrics_conf.metrics.saved_training_cycle + except AttributeError as e: + raise KeyError("Missing required model config fields.") from e + + model_name = self.training_conf.training.model_params.model_name + id_name = self.training_conf.training.id_name + + weights_path_pattern = os.path.join( + trained_model_path, + model_subdir, + ( + f"{model_subdir}*_{model_name}" + f"*{id_name}_cycle{cycle}*" + ), + ) + return load_trained_psf_model( + self.training_conf, self.data_conf, + weights_path_pattern, ) - @property - def weights_path(self): - """Get Weights Path. - A function to return the full path - of the user-specified psf model weights to be loaded. + def _get_training_conf_path_from_metrics(self): + """ + Retrieves the full path to the training config based on the metrics configuration. Returns ------- str - A string representing the full path to the psf model weights to be loaded. + Full path to the training configuration file. + + Raises + ------ + KeyError + If 'trained_model_config' key is missing. + FileNotFoundError + If the file does not exist at the constructed path. """ - return psf_models.get_psf_model_weights_filepath(self.weights_basename_filepath) - - def _get_trained_model_path(self, training_conf): - """Get Trained Model Path. - - Helper method to get the trained model path. + trained_model_path = self._get_trained_model_path() - Parameters - ---------- - training_conf: None or RecursiveNamespace - None type or RecursiveNamespace - - Returns - ------- - str - A string representing the path to the trained model output run directory. + try: + training_conf_filename = self._metrics_conf.metrics.trained_model_config + except AttributeError as e: + raise KeyError("Missing 'trained_model_config' key in metrics configuration.") from e - """ - if training_conf is None: - try: - return self._metrics_conf.metrics.trained_model_path + training_conf_path = os.path.join( + self._file_handler.get_config_dir(trained_model_path), training_conf_filename) - except TypeError as e: - logger.exception(e) - raise ConfigParameterError( - "Metrics config file trained model path or config values are empty." - ) - else: - return os.path.join( - self._file_handler.output_path, - self._file_handler.parent_output_dir, - self._file_handler.workdir, - ) + if not os.path.exists(training_conf_path): + raise FileNotFoundError(f"Training config file not found: {training_conf_path}") - def _load_training_conf(self, training_conf): - """Load Training Conf. + return training_conf_path - Load the training configuration if training_conf is not provided. - Parameters - ---------- - training_conf: None or RecursiveNamespace - None type or a RecursiveNamespace storing the training configuration parameter setttings. + def _get_trained_model_path(self): + """ + Determine the trained model path from either: + + 1. The metrics configuration file (i.e., for metrics-only runs after training), or + 2. The runtime-generated file handler paths (i.e., for single runs that perform both training and evaluation). Returns ------- - RecursiveNamespace storing the training configuration parameter settings. + str + Path to the trained model directory. + Raises + ------ + ConfigParameterError + If the path specified in the metrics config is invalid or missing. """ - if training_conf is None: - try: - return read_conf( - os.path.join( - self._file_handler.get_config_dir(self.trained_model_path), - self._metrics_conf.metrics.trained_model_config, - ) - ) - except TypeError as e: - logger.exception(e) + trained_model_path = getattr(self._metrics_conf.metrics, "trained_model_path", None) + + if trained_model_path: + if not os.path.isdir(trained_model_path): raise ConfigParameterError( - "Metrics config file trained model path or config values are empty." + f"The trained model path provided in the metrics config is not a valid directory: {trained_model_path}" ) - else: - return training_conf + logger.info(f"Using trained model path from metrics config: {trained_model_path}") + return trained_model_path + + # Fallback for single-run training + metrics evaluation mode + fallback_path = os.path.join( + self._file_handler.output_path, + self._file_handler.parent_output_dir, + self._file_handler.workdir, + ) + logger.info(f"Using fallback trained model path from runtime file handler: {fallback_path}") + return fallback_path def _load_data_conf(self): """Load Data Conf. @@ -439,26 +425,6 @@ def _load_data_conf(self): logger.exception(e) raise ConfigParameterError("Data configuration loading error.") - @property - def weights_basename_filepath(self): - """Get PSF model weights filepath. - - A function to return the basename of the user-specified PSF model weights path. - - Returns - ------- - weights_basename: str - The basename of the PSF model weights to be loaded. - - """ - return os.path.join( - self.trained_model_path, - self.metrics_conf.metrics.model_save_path, - ( - f"{self.metrics_conf.metrics.model_save_path}*_{self.training_conf.training.model_params.model_name}" - f"*{self.training_conf.training.id_name}_cycle{self.metrics_conf.metrics.saved_training_cycle}*" - ), - ) def call_plot_config_handler_run(self, model_metrics): """Make Metrics Plots. @@ -513,7 +479,6 @@ def run(self): self.training_conf.training, self.data_conf, self.psf_model, - self.weights_path, self.metrics_dir, ) From b8375ba3493cf59908360f999c8db73f3f96fc73 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Wed, 14 May 2025 22:55:47 +0200 Subject: [PATCH 016/135] Move psf_model weights loader to psf_model_loader.py module --- src/wf_psf/metrics/metrics_interface.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/wf_psf/metrics/metrics_interface.py b/src/wf_psf/metrics/metrics_interface.py index 3dff2c6c..c922a4fb 100644 --- a/src/wf_psf/metrics/metrics_interface.py +++ b/src/wf_psf/metrics/metrics_interface.py @@ -351,14 +351,6 @@ def evaluate_model( # Prepare np input simPSF_np = data.training_data.simPSF - ## Load the model's weights - try: - logger.info(f"Loading PSF model weights from {weights_path}") - psf_model.load_weights(weights_path) - except Exception as e: - logger.exception("An error occurred with the weights_path file: %s", e) - exit() - # Define datasets datasets = {"test": data.test_data.dataset, "train": data.training_data.dataset} From f3876fb0d028277b8add157305d3b536fa94d499 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Wed, 14 May 2025 22:56:13 +0200 Subject: [PATCH 017/135] Add psf_model_loader module --- src/wf_psf/psf_models/psf_model_loader.py | 52 +++++++++++++++++++++++ 1 file changed, 52 insertions(+) create mode 100644 src/wf_psf/psf_models/psf_model_loader.py diff --git a/src/wf_psf/psf_models/psf_model_loader.py b/src/wf_psf/psf_models/psf_model_loader.py new file mode 100644 index 00000000..26056e57 --- /dev/null +++ b/src/wf_psf/psf_models/psf_model_loader.py @@ -0,0 +1,52 @@ +"""PSF Model Loader. + +This module provides helper functions for loading trained PSF models. +It includes utilities to: +- Load a model from disk using its configuration and weights. +- Prepare inputs for inference or evaluation workflows. + +Author: Jennifer Pollack +""" +from wf_psf.psf_models.psf_models import ( + get_psf_model, + get_psf_model_weights_filepath +) + +def load_trained_psf_model(training_conf, data_conf, weights_path_pattern): + """ + Loads a trained PSF model and applies saved weights. + + Parameters + ---------- + training_conf : RecursiveNamespace + Configuration object containing model parameters and training hyperparameters. + Supports attribute-style access to nested fields. + data_conf : RecursiveNamespace + Configuration object containing data-related parameters. + weights_path_pattern : str + Glob-style pattern used to locate the model weights file. + + Returns + ------- + model : tf.keras.Model or compatible + The PSF model instance with loaded weights. + + Raises + ------ + RuntimeError + If loading the model weights fails for any reason. + """ + model = get_psf_model(training_conf.training.model_params, + training_conf.training.training_hparams, + data_conf) + + weights_path = get_psf_model_weights_filepath(weights_path_pattern) + + try: + logger.info(f"Loading PSF model weights from {weights_path}") + model.load_weights(weights_path) + except Exception as e: + logger.exception("Failed to load model weights.") + raise RuntimeError("Model weight loading failed.") from e + return model + From ffb49e3f520d083ab93e3dd2794382494d6dbc3b Mon Sep 17 00:00:00 2001 From: jeipollack Date: Thu, 15 May 2025 13:13:38 +0200 Subject: [PATCH 018/135] Remove weights_path arg from evaluate_model method; Update logger.info statement and comment --- src/wf_psf/metrics/metrics_interface.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/wf_psf/metrics/metrics_interface.py b/src/wf_psf/metrics/metrics_interface.py index c922a4fb..bd6249e2 100644 --- a/src/wf_psf/metrics/metrics_interface.py +++ b/src/wf_psf/metrics/metrics_interface.py @@ -311,7 +311,6 @@ def evaluate_model( trained_model_params, data, psf_model, - weights_path, metrics_output, ): """Evaluate the trained model on both training and test datasets by computing various metrics. @@ -341,8 +340,8 @@ def evaluate_model( try: ## Load datasets # ----------------------------------------------------- - # Get training data - logger.info("Fetching and preprocessing training and test data...") + # Get training and test data + logger.info("Fetching training and test data...") # Initialize metrics_handler metrics_handler = MetricsParamsHandler(metrics_params, trained_model_params) From 100127706ea41073b5ac9f0e6278aa47815a3938 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Thu, 15 May 2025 13:14:22 +0200 Subject: [PATCH 019/135] Update variable name and logger statement --- src/wf_psf/utils/configs_handler.py | 62 +++++++++++++++++------------ 1 file changed, 36 insertions(+), 26 deletions(-) diff --git a/src/wf_psf/utils/configs_handler.py b/src/wf_psf/utils/configs_handler.py index e6764873..4b08a723 100644 --- a/src/wf_psf/utils/configs_handler.py +++ b/src/wf_psf/utils/configs_handler.py @@ -1,4 +1,4 @@ -s"""Configs_Handler. +"""Configs_Handler. A module which provides general utility methods to manage the parameters of the config files @@ -255,11 +255,13 @@ class MetricsConfigHandler: def __init__(self, metrics_conf, file_handler, training_conf=None): self._metrics_conf = read_conf(metrics_conf) - self.data_conf = self._load_data_conf() - self._file_handler = file_handler - self.metrics_dir = self._file_handler.get_metrics_dir(self._file_handler._run_output_dir) + self._file_handler = file_handler self.training_conf = training_conf - self.trained_psf_model = self.load_trained_psf_model(self.training_conf, self.data_conf ) + self.data_conf = self._load_data_conf() + self.metrics_dir = self._file_handler.get_metrics_dir( + self._file_handler._run_output_dir + ) + self.trained_psf_model = self._load_trained_psf_model() @property def metrics_conf(self): @@ -288,7 +290,9 @@ def training_conf(self, training_conf): if training_conf is None: try: training_conf_path = self._get_training_conf_path_from_metrics() - logger.info(f"Loading training config from inferred path: {training_conf_path}") + logger.info( + f"Loading training config from inferred path: {training_conf_path}" + ) self._training_conf = read_conf(training_conf_path) except Exception as e: logger.error(f"Failed to load training config: {e}") @@ -319,14 +323,11 @@ def _load_trained_psf_model(self): model_name = self.training_conf.training.model_params.model_name id_name = self.training_conf.training.id_name - + weights_path_pattern = os.path.join( - trained_model_path, - model_subdir, - ( - f"{model_subdir}*_{model_name}" - f"*{id_name}_cycle{cycle}*" - ), + trained_model_path, + model_subdir, + (f"{model_subdir}*_{model_name}" f"*{id_name}_cycle{cycle}*"), ) return load_trained_psf_model( self.training_conf, @@ -334,7 +335,6 @@ def _load_trained_psf_model(self): weights_path_pattern, ) - def _get_training_conf_path_from_metrics(self): """ Retrieves the full path to the training config based on the metrics configuration. @@ -356,22 +356,27 @@ def _get_training_conf_path_from_metrics(self): try: training_conf_filename = self._metrics_conf.metrics.trained_model_config except AttributeError as e: - raise KeyError("Missing 'trained_model_config' key in metrics configuration.") from e + raise KeyError( + "Missing 'trained_model_config' key in metrics configuration." + ) from e training_conf_path = os.path.join( - self._file_handler.get_config_dir(trained_model_path), training_conf_filename) + self._file_handler.get_config_dir(trained_model_path), + training_conf_filename, + ) if not os.path.exists(training_conf_path): - raise FileNotFoundError(f"Training config file not found: {training_conf_path}") + raise FileNotFoundError( + f"Training config file not found: {training_conf_path}" + ) return training_conf_path - def _get_trained_model_path(self): """ Determine the trained model path from either: - - 1. The metrics configuration file (i.e., for metrics-only runs after training), or + + 1. The metrics configuration file (i.e., for metrics-only runs after training), or 2. The runtime-generated file handler paths (i.e., for single runs that perform both training and evaluation). Returns @@ -384,14 +389,18 @@ def _get_trained_model_path(self): ConfigParameterError If the path specified in the metrics config is invalid or missing. """ - trained_model_path = getattr(self._metrics_conf.metrics, "trained_model_path", None) + trained_model_path = getattr( + self._metrics_conf.metrics, "trained_model_path", None + ) if trained_model_path: if not os.path.isdir(trained_model_path): raise ConfigParameterError( f"The trained model path provided in the metrics config is not a valid directory: {trained_model_path}" ) - logger.info(f"Using trained model path from metrics config: {trained_model_path}") + logger.info( + f"Using trained model path from metrics config: {trained_model_path}" + ) return trained_model_path # Fallback for single-run training + metrics evaluation mode @@ -400,7 +409,9 @@ def _get_trained_model_path(self): self._file_handler.parent_output_dir, self._file_handler.workdir, ) - logger.info(f"Using fallback trained model path from runtime file handler: {fallback_path}") + logger.info( + f"Using fallback trained model path from runtime file handler: {fallback_path}" + ) return fallback_path def _load_data_conf(self): @@ -425,7 +436,6 @@ def _load_data_conf(self): logger.exception(e) raise ConfigParameterError("Data configuration loading error.") - def call_plot_config_handler_run(self, model_metrics): """Make Metrics Plots. @@ -472,13 +482,13 @@ def run(self): input configuration. """ - logger.info(f"Running metrics evaluation on psf model: {self.weights_path}") + logger.info("Running metrics evaluation on trained PSF model...") model_metrics = evaluate_model( self.metrics_conf.metrics, self.training_conf.training, self.data_conf, - self.psf_model, + self.trained_psf_model, self.metrics_dir, ) From 8e659d0ecb3ffc5bff6603230e333e6f4d9d88a8 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Thu, 15 May 2025 13:15:53 +0200 Subject: [PATCH 020/135] Add import logging and create logger object --- src/wf_psf/psf_models/psf_model_loader.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/wf_psf/psf_models/psf_model_loader.py b/src/wf_psf/psf_models/psf_model_loader.py index 26056e57..1d2e267f 100644 --- a/src/wf_psf/psf_models/psf_model_loader.py +++ b/src/wf_psf/psf_models/psf_model_loader.py @@ -7,11 +7,15 @@ Author: Jennifer Pollack """ +import logging from wf_psf.psf_models.psf_models import ( get_psf_model, get_psf_model_weights_filepath ) + +logger = logging.getLogger(__name__) + def load_trained_psf_model(training_conf, data_conf, weights_path_pattern): """ Loads a trained PSF model and applies saved weights. From 6d9c40683620198c3c4524129c76668f36367cf5 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Thu, 15 May 2025 13:18:43 +0200 Subject: [PATCH 021/135] Create psf_inference.py module --- src/wf_psf/inference/psf_inference.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index e69de29b..4f5b39e5 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -0,0 +1,21 @@ +"""Inference. + +A module which provides a set of functions to perform inference +on PSF models. It includes functions to load a trained model, +perform inference on a dataset of SEDs and positions, and generate a polychromatic PSF. + +:Authors: Jennifer Pollack + +""" + +import os +import glob +import logging +import numpy as np +from wf_psf.psf_models import psf_models, psf_model_loader +import tensorflow as tf + + +#def prepare_inputs(...): ... +#def generate_psfs(...): ... +#def run_pipeline(...): ... From 47e9c22dee1b21f64f72219d619a482a75fc9fed Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Thu, 15 May 2025 13:21:11 +0200 Subject: [PATCH 022/135] Remove arg from evaluate_model unit test --- src/wf_psf/tests/test_metrics/metrics_interface_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/wf_psf/tests/test_metrics/metrics_interface_test.py b/src/wf_psf/tests/test_metrics/metrics_interface_test.py index 44bd09bf..12d51447 100644 --- a/src/wf_psf/tests/test_metrics/metrics_interface_test.py +++ b/src/wf_psf/tests/test_metrics/metrics_interface_test.py @@ -209,7 +209,6 @@ def test_evaluate_model( trained_model_params=mock_trained_model_params, data=mock_data, psf_model=mock_psf_model, - weights_path="/mock/weights/path", metrics_output="/mock/metrics/output", ) From e3ad71ab065e5adf643b78a64765e3b0e88471bd Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 16 May 2025 14:41:17 +0200 Subject: [PATCH 023/135] Refactor: reorganise modules, relocate utility functions, rename modules, update import statements and unit tests --- src/wf_psf/data/centroids.py | 71 ++- src/wf_psf/data/data_handler.py | 245 ++++++++++ src/wf_psf/data/data_preprocessing.py | 102 ---- src/wf_psf/data/data_zernike_utils.py | 295 +++++++++++ src/wf_psf/data/training_preprocessing.py | 458 ------------------ src/wf_psf/instrument/ccd_misalignments.py | 40 ++ .../psf_models/tf_modules/tf_psf_field.py | 2 +- src/wf_psf/tests/__init__.py | 1 + src/wf_psf/tests/conftest.py | 2 +- src/wf_psf/tests/test_data/__init__.py | 0 src/wf_psf/tests/test_data/centroids_test.py | 235 ++++----- src/wf_psf/tests/test_data/conftest.py | 61 +++ ...rocessing_test.py => data_handler_test.py} | 214 +------- .../test_data/data_zernike_utils_test.py | 133 +++++ src/wf_psf/tests/test_data/test_data_utils.py | 30 ++ .../test_metrics/metrics_interface_test.py | 4 +- src/wf_psf/tests/test_psf_models/conftest.py | 2 +- .../psf_model_physical_polychromatic_test.py | 2 +- .../tests/test_utils/configs_handler_test.py | 2 +- src/wf_psf/utils/configs_handler.py | 2 +- 20 files changed, 992 insertions(+), 909 deletions(-) create mode 100644 src/wf_psf/data/data_handler.py delete mode 100644 src/wf_psf/data/data_preprocessing.py create mode 100644 src/wf_psf/data/data_zernike_utils.py delete mode 100644 src/wf_psf/data/training_preprocessing.py create mode 100644 src/wf_psf/tests/__init__.py create mode 100644 src/wf_psf/tests/test_data/__init__.py rename src/wf_psf/tests/test_data/{training_preprocessing_test.py => data_handler_test.py} (52%) create mode 100644 src/wf_psf/tests/test_data/data_zernike_utils_test.py create mode 100644 src/wf_psf/tests/test_data/test_data_utils.py diff --git a/src/wf_psf/data/centroids.py b/src/wf_psf/data/centroids.py index e414520e..27f7894b 100644 --- a/src/wf_psf/data/centroids.py +++ b/src/wf_psf/data/centroids.py @@ -8,10 +8,79 @@ import numpy as np import scipy.signal as scisig -from wf_psf.data.data_preprocessing import shift_x_y_to_zk1_2_wavediff +from wf_psf.data.data_handler import extract_star_data +from fractions import Fraction +from wf_psf.data.data_zernike_utils import shift_x_y_to_zk1_2_wavediff +import tensorflow as tf from typing import Optional +def compute_centroid_correction(model_params, data, batch_size: int=1) -> np.ndarray: + """Compute centroid corrections using Zernike polynomials. + + This function calculates the Zernike contributions required to match the centroid + of the WaveDiff PSF model to the observed star centroids, processing in batches. + + Parameters + ---------- + model_params : RecursiveNamespace + An object containing parameters for the PSF model, including pixel sampling and initial centroid window parameters. + + data : DataConfigHandler + An object containing training and test datasets, including observed PSFs + and optional star masks. + + batch_size : int, optional + The batch size to use when processing the stars. Default is 16. + + + Returns + ------- + zernike_centroid_array : np.ndarray + A 2D NumPy array of shape `(n_stars, 3)`, where `n_stars` is the number of + observed stars. The array contains the computed Zernike contributions, + with zero padding applied to the first column to ensure a consistent shape. + """ + star_postage_stamps = extract_star_data(data=data, train_key="noisy_stars", test_key="stars") + + # Get star mask catalogue only if "masks" exist in both training and test datasets + star_masks = ( + extract_star_data(data=data, train_key="masks", test_key="masks") + if ( + data.training_data.dataset.get("masks") is not None + and data.test_data.dataset.get("masks") is not None + and tf.size(data.training_data.dataset["masks"]) > 0 + and tf.size(data.test_data.dataset["masks"]) > 0 + ) + else None + ) + + pix_sampling = model_params.pix_sampling * 1e-6 # Change units from [um] to [m] + + # Ensure star_masks is properly handled + star_masks = star_masks if star_masks is not None else None + + reference_shifts = [float(Fraction(value)) for value in model_params.reference_shifts] + + n_stars = len(star_postage_stamps) + zernike_centroid_array = [] + + # Batch process the stars + for i in range(0, n_stars, batch_size): + batch_postage_stamps = star_postage_stamps[i:i + batch_size] + batch_masks = star_masks[i:i + batch_size] if star_masks is not None else None + + # Compute Zernike 1 and Zernike 2 for the batch + zk1_2_batch = -1.0 * compute_zernike_tip_tilt( + batch_postage_stamps, batch_masks, pix_sampling, reference_shifts + ) + + # Zero pad array for each batch and append + zernike_centroid_array.append(np.pad(zk1_2_batch, pad_width=[(0, 0), (1, 0)], mode="constant", constant_values=0)) + + # Combine all batches into a single array + return np.concatenate(zernike_centroid_array, axis=0) + def compute_zernike_tip_tilt( star_images: np.ndarray, star_masks: Optional[np.ndarray] = None, diff --git a/src/wf_psf/data/data_handler.py b/src/wf_psf/data/data_handler.py new file mode 100644 index 00000000..e2946b0d --- /dev/null +++ b/src/wf_psf/data/data_handler.py @@ -0,0 +1,245 @@ +"""Data Handler Module. + +Provides tools for loading, preprocessing, and managing data used in both training and inference workflows. + +Includes: +- The `DataHandler` class for managing datasets and associated metadata +- Utility functions for loading structured data products +- Preprocessing routines for spectral energy distributions (SEDs), including format conversion (e.g., to TensorFlow) and transformations + +This module serves as a central interface between raw data and modeling components. + +Authors: Jennifer Pollack , Tobias Liaudat +""" + +import os +import numpy as np +import wf_psf.utils.utils as utils +import tensorflow as tf +from fractions import Fraction +import logging + +logger = logging.getLogger(__name__) + + +class DataHandler: + """Data Handler. + + This class manages loading and processing of training and testing data for use during PSF model training and validation. + It provides methods to access and preprocess the data. + + Parameters + ---------- + dataset_type : str + Type of dataset ("train" or "test"). + data_params : RecursiveNamespace + Recursive Namespace object containing parameters for both 'train' and 'test' datasets. + simPSF : PSFSimulator + Instance of the PSFSimulator class for simulating PSF models. + n_bins_lambda : int + Number of wavelength bins for SED processing. + load_data : bool, optional + If True, data is loaded and processed during initialization. If False, data loading + is deferred until explicitly called. Default is True. + + Attributes + ---------- + dataset_type : str + Type of dataset ("train" or "test"). + data_params : RecursiveNamespace + Parameters for the current dataset type. + dataset : dict or None + Dictionary containing the loaded dataset, including positions and stars/noisy_stars. + simPSF : PSFSimulator + Instance of the PSFSimulator class for simulating PSF models. + n_bins_lambda : int + Number of wavelength bins. + sed_data : tf.Tensor or None + TensorFlow tensor containing processed SED data for training/testing. + load_data_on_init : bool + Flag controlling whether data is loaded during initialization. + """ + + def __init__( + self, dataset_type, data_params, simPSF, n_bins_lambda, load_data: bool = True + ): + """ + Initialize the dataset handler for PSF simulation. + + Parameters + ---------- + dataset_type : str + Type of dataset ("train" or "test"). + data_params : RecursiveNamespace + Recursive Namespace object containing parameters for both 'train' and 'test' datasets. + simPSF : PSFSimulator + Instance of the PSFSimulator class for simulating PSF models. + n_bins_lambda : int + Number of wavelength bins for SED processing. + load_data : bool, optional + If True, data is loaded and processed during initialization. If False, data loading + is deferred. Default is True. + """ + self.dataset_type = dataset_type + self.data_params = data_params.__dict__[dataset_type] + self.simPSF = simPSF + self.n_bins_lambda = n_bins_lambda + self.load_data_on_init = load_data + if self.load_data_on_init: + self.load_dataset() + self.process_sed_data() + else: + self.dataset = None + self.sed_data = None + + def load_dataset(self): + """Load dataset. + + Load the dataset based on the specified dataset type. + + """ + self.dataset = np.load( + os.path.join(self.data_params.data_dir, self.data_params.file), + allow_pickle=True, + )[()] + self.dataset["positions"] = tf.convert_to_tensor( + self.dataset["positions"], dtype=tf.float32 + ) + if self.dataset_type == "training": + if "noisy_stars" in self.dataset: + self.dataset["noisy_stars"] = tf.convert_to_tensor( + self.dataset["noisy_stars"], dtype=tf.float32 + ) + else: + logger.warning(f"Missing 'noisy_stars' in {self.dataset_type} dataset.") + elif self.dataset_type == "test": + if "stars" in self.dataset: + self.dataset["stars"] = tf.convert_to_tensor( + self.dataset["stars"], dtype=tf.float32 + ) + else: + logger.warning(f"Missing 'stars' in {self.dataset_type} dataset.") + elif "inference" == self.dataset_type: + pass + + def process_sed_data(self): + """Process SED Data. + + A method to generate and process SED data. + + """ + self.sed_data = [ + utils.generate_SED_elems_in_tensorflow( + _sed, self.simPSF, n_bins=self.n_bins_lambda, tf_dtype=tf.float64 + ) + for _sed in self.dataset["SEDs"] + ] + self.sed_data = tf.convert_to_tensor(self.sed_data, dtype=tf.float32) + self.sed_data = tf.transpose(self.sed_data, perm=[0, 2, 1]) + + +def get_np_obs_positions(data): + """Get observed positions in numpy from the provided dataset. + + This method concatenates the positions of the stars from both the training + and test datasets to obtain the observed positions. + + Parameters + ---------- + data : DataConfigHandler + Object containing training and test datasets. + + Returns + ------- + np.ndarray + Numpy array containing the observed positions of the stars. + + Notes + ----- + The observed positions are obtained by concatenating the positions of stars + from both the training and test datasets along the 0th axis. + """ + obs_positions = np.concatenate( + ( + data.training_data.dataset["positions"], + data.test_data.dataset["positions"], + ), + axis=0, + ) + + return obs_positions + + +def get_obs_positions(data): + """Get observed positions from the provided dataset. + + Parameters + ---------- + data : DataConfigHandler + Object containing training and test datasets. + + Returns + ------- + tf.Tensor + Tensor containing the observed positions of the stars. + """ + obs_positions = get_np_obs_positions(data) + + return tf.convert_to_tensor(obs_positions, dtype=tf.float32) + + +def extract_star_data(data, train_key: str, test_key: str) -> np.ndarray: + """Extract specific star-related data from training and test datasets. + + This function retrieves and concatenates specific star-related data (e.g., stars, masks) from the + star training and test datasets such as star stamps or masks, based on the provided keys. + + Parameters + ---------- + data : DataConfigHandler + Object containing training and test datasets. + train_key : str + The key to retrieve data from the training dataset (e.g., 'noisy_stars', 'masks'). + test_key : str + The key to retrieve data from the test dataset (e.g., 'stars', 'masks'). + + Returns + ------- + np.ndarray + A NumPy array containing the concatenated data for the given keys. + + Raises + ------ + KeyError + If the specified keys do not exist in the training or test datasets. + + Notes + ----- + - If the dataset contains TensorFlow tensors, they will be converted to NumPy arrays. + - Ensure that eager execution is enabled when calling this function. + """ + # Ensure the requested keys exist in both training and test datasets + missing_keys = [ + key + for key, dataset in [ + (train_key, data.training_data.dataset), + (test_key, data.test_data.dataset), + ] + if key not in dataset + ] + + if missing_keys: + raise KeyError(f"Missing keys in dataset: {missing_keys}") + + # Retrieve data from training and test sets + train_data = data.training_data.dataset[train_key] + test_data = data.test_data.dataset[test_key] + + # Convert to NumPy if necessary + if tf.is_tensor(train_data): + train_data = train_data.numpy() + if tf.is_tensor(test_data): + test_data = test_data.numpy() + + # Concatenate and return + return np.concatenate((train_data, test_data), axis=0) diff --git a/src/wf_psf/data/data_preprocessing.py b/src/wf_psf/data/data_preprocessing.py deleted file mode 100644 index 44e18436..00000000 --- a/src/wf_psf/data/data_preprocessing.py +++ /dev/null @@ -1,102 +0,0 @@ -"""Data Preprocessing. - -A module with utils to preprocess data. - -:Author: Tobias Liaudat - -""" - -import numpy as np - - -def shift_x_y_to_zk1_2_wavediff(dxy, tel_focal_length=24.5, tel_diameter=1.2): - """Compute Zernike 1(2) for a given shifts in x(y) in WaveDifff conventions. - - All inputs should be in [m]. - A displacement of, for example, 0.5 pixels should be scaled with the corresponding pixel scale, - e.g. 12[um], to get a displacement in [m], which would be `dxy=0.5*12e-6`. - - The output zernike coefficient is in [um] units as expected by wavediff. - - To apply match the centroid with a `dx` that has a corresponding `zk1`, - the new PSF should be generated with `-zk1`. - - The same applies to `dy` and `zk2`. - - Parameters - ---------- - dxy : float - Centroid shift in [m]. It can be on the x-axis or the y-axis. - tel_focal_length : float - Telescope focal length in [m]. - tel_diameter : float - Telescope aperture diameter in [m]. - """ - reference_pix_sampling = 12e-6 - zernike_norm_factor = 2.0 - - # return zernike_norm_factor * (dx/reference_pix_sampling) / (tel_focal_length * tel_diameter / 2) - return ( - zernike_norm_factor - * (tel_diameter / 2) - * np.sin(np.arctan((dxy / reference_pix_sampling) / tel_focal_length)) - * 3.0 - ) - - -def defocus_to_zk4_zemax(dz, tel_focal_length=24.5, tel_diameter=1.2): - """Compute Zernike 4 value for a given defocus in zemax conventions. - - All inputs should be in [m]. - - Parameters - ---------- - dz : float - Shift in the z-axis, perpendicular to the focal plane. Units in [m]. - tel_focal_length : float - Telescope focal length in [m]. - tel_diameter : float - Telescope aperture diameter in [m]. - """ - # Base calculation - zk4 = dz / (8.0 * (tel_focal_length / tel_diameter) ** 2) - # Apply Z4 normalisation - # This step depends on the normalisation of the Zernike basis used - zk4 /= np.sqrt(3) - # Convert to waves with a reference of 800nm - zk4 /= 800e-9 - # Remove the peak to valley value - zk4 /= 2.0 - - return zk4 - - -def defocus_to_zk4_wavediff(dz, tel_focal_length=24.5, tel_diameter=1.2): - """Compute Zernike 4 value for a given defocus in WaveDifff conventions. - - All inputs should be in [m]. - - The output zernike coefficient is in [um] units as expected by wavediff. - - Parameters - ---------- - dz : float - Shift in the z-axis, perpendicular to the focal plane. Units in [m]. - tel_focal_length : float - Telescope focal length in [m]. - tel_diameter : float - Telescope aperture diameter in [m]. - """ - # Base calculation - zk4 = dz / (8.0 * (tel_focal_length / tel_diameter) ** 2) - # Apply Z4 normalisation - # This step depends on the normalisation of the Zernike basis used - zk4 /= np.sqrt(3) - - # Remove the peak to valley value - zk4 /= 2.0 - - # Change units to [um] as Wavediff uses - zk4 *= 1e6 - - return zk4 diff --git a/src/wf_psf/data/data_zernike_utils.py b/src/wf_psf/data/data_zernike_utils.py new file mode 100644 index 00000000..760adb11 --- /dev/null +++ b/src/wf_psf/data/data_zernike_utils.py @@ -0,0 +1,295 @@ +"""Utilities for Zernike Data Handling. + +This module provides utility functions for working with Zernike coefficients, including: +- Prior generation +- Data loading +- Conversions between physical displacements (e.g., defocus, centroid shifts) and modal Zernike coefficients + +Useful in contexts where Zernike representations are used to model optical aberrations or link physical misalignments to wavefront modes. + +:Author: Tobias Liaudat + +""" + +import numpy as np +import tensorflow as tf +from typing import Optional +import logging + +logger = logging.getLogger(__name__) + + +def get_np_zernike_prior(data): + """Get the zernike prior from the provided dataset. + + This method concatenates the stars from both the training + and test datasets to obtain the full prior. + + Parameters + ---------- + data : DataConfigHandler + Object containing training and test datasets. + + Returns + ------- + zernike_prior : np.ndarray + Numpy array containing the full prior. + """ + zernike_prior = np.concatenate( + ( + data.training_data.dataset["zernike_prior"], + data.test_data.dataset["zernike_prior"], + ), + axis=0, + ) + + return zernike_prior + + +def get_zernike_prior(model_params, data, batch_size: int=16): + """Get Zernike priors from the provided dataset. + + This method concatenates the Zernike priors from both the training + and test datasets. + + Parameters + ---------- + model_params : RecursiveNamespace + Object containing parameters for this PSF model class. + data : DataConfigHandler + Object containing training and test datasets. + batch_size : int, optional + The batch size to use when processing the stars. Default is 16. + + Returns + ------- + tf.Tensor + Tensor containing the observed positions of the stars. + + Notes + ----- + The Zernike prior are obtained by concatenating the Zernike priors + from both the training and test datasets along the 0th axis. + + """ + # List of zernike contribution + zernike_contribution_list = [] + + if model_params.use_prior: + logger.info("Reading in Zernike prior into Zernike contribution list...") + zernike_contribution_list.append(get_np_zernike_prior(data)) + + if model_params.correct_centroids: + logger.info("Adding centroid correction to Zernike contribution list...") + zernike_contribution_list.append( + compute_centroid_correction(model_params, data, batch_size) + ) + + if model_params.add_ccd_misalignments: + logger.info("Adding CCD mis-alignments to Zernike contribution list...") + zernike_contribution_list.append(compute_ccd_misalignment(model_params, data)) + + if len(zernike_contribution_list) == 1: + zernike_contribution = zernike_contribution_list[0] + else: + # Get max zk order + max_zk_order = np.max( + np.array( + [ + zk_contribution.shape[1] + for zk_contribution in zernike_contribution_list + ] + ) + ) + + zernike_contribution = np.zeros( + (zernike_contribution_list[0].shape[0], max_zk_order) + ) + + # Pad arrays to get the same length and add the final contribution + for it in range(len(zernike_contribution_list)): + current_zk_order = zernike_contribution_list[it].shape[1] + current_zernike_contribution = np.pad( + zernike_contribution_list[it], + pad_width=[(0, 0), (0, int(max_zk_order - current_zk_order))], + mode="constant", + constant_values=0, + ) + + zernike_contribution += current_zernike_contribution + + return tf.convert_to_tensor(zernike_contribution, dtype=tf.float32) + + +def shift_x_y_to_zk1_2_wavediff(dxy, tel_focal_length=24.5, tel_diameter=1.2): + """Compute Zernike 1(2) for a given shifts in x(y) in WaveDifff conventions. + + All inputs should be in [m]. + A displacement of, for example, 0.5 pixels should be scaled with the corresponding pixel scale, + e.g. 12[um], to get a displacement in [m], which would be `dxy=0.5*12e-6`. + + The output zernike coefficient is in [um] units as expected by wavediff. + + To apply match the centroid with a `dx` that has a corresponding `zk1`, + the new PSF should be generated with `-zk1`. + + The same applies to `dy` and `zk2`. + + Parameters + ---------- + dxy : float + Centroid shift in [m]. It can be on the x-axis or the y-axis. + tel_focal_length : float + Telescope focal length in [m]. + tel_diameter : float + Telescope aperture diameter in [m]. + """ + reference_pix_sampling = 12e-6 + zernike_norm_factor = 2.0 + + # return zernike_norm_factor * (dx/reference_pix_sampling) / (tel_focal_length * tel_diameter / 2) + return ( + zernike_norm_factor + * (tel_diameter / 2) + * np.sin(np.arctan((dxy / reference_pix_sampling) / tel_focal_length)) + * 3.0 + ) + +def compute_zernike_tip_tilt( + star_images: np.ndarray, + star_masks: Optional[np.ndarray] = None, + pixel_sampling: float = 12e-6, + reference_shifts: list[float] = [-1/3, -1/3], + sigma_init: float = 2.5, + n_iter: int = 20, +) -> np.ndarray: + """ + Compute Zernike tip-tilt corrections for a batch of PSF images. + + This function estimates the centroid shifts of multiple PSFs and computes + the corresponding Zernike tip-tilt corrections to align them with a reference. + + Parameters + ---------- + star_images : np.ndarray + A batch of PSF images (3D array of shape `(num_images, height, width)`). + star_masks : np.ndarray, optional + A batch of masks (same shape as `star_postage_stamps`). Each mask can have: + - `0` to ignore the pixel. + - `1` to fully consider the pixel. + - Values in `(0,1]` as weights for partial consideration. + Defaults to None. + pixel_sampling : float, optional + The pixel size in meters. Defaults to `12e-6 m` (12 microns). + reference_shifts : list[float], optional + The target centroid shifts in pixels, specified as `[dy, dx]`. + Defaults to `[-1/3, -1/3]` (nominal Euclid conditions). + sigma_init : float, optional + Initial standard deviation for centroid estimation. Default is `2.5`. + n_iter : int, optional + Number of iterations for centroid refinement. Default is `20`. + + Returns + ------- + np.ndarray + An array of shape `(num_images, 2)`, where: + - Column 0 contains `Zk1` (tip) values. + - Column 1 contains `Zk2` (tilt) values. + + Notes + ----- + - This function processes all images at once using vectorized operations. + - The Zernike coefficients are computed in the WaveDiff convention. + """ + from wf_psf.data.centroids import CentroidEstimator + + # Vectorize the centroid computation + centroid_estimator = CentroidEstimator( + im=star_images, + mask=star_masks, + sigma_init=sigma_init, + n_iter=n_iter + ) + + shifts = centroid_estimator.get_intra_pixel_shifts() + + # Ensure reference_shifts is a NumPy array (if it's not already) + reference_shifts = np.array(reference_shifts) + + # Reshape to ensure it's a column vector (1, 2) + reference_shifts = reference_shifts[None,:] + + # Broadcast reference_shifts to match the shape of shifts + reference_shifts = np.broadcast_to(reference_shifts, shifts.shape) + + # Compute displacements + displacements = (reference_shifts - shifts) # + + # Ensure the correct axis order for displacements (x-axis, then y-axis) + displacements_swapped = displacements[:, [1, 0]] # Adjust axis order if necessary + + # Call shift_x_y_to_zk1_2_wavediff directly on the vector of displacements + zk1_2_array = shift_x_y_to_zk1_2_wavediff(displacements_swapped.flatten() * pixel_sampling ) # vectorized call + + # Reshape the result back to the original shape of displacements + zk1_2_array = zk1_2_array.reshape(displacements.shape) + + return zk1_2_array + + +def defocus_to_zk4_zemax(dz, tel_focal_length=24.5, tel_diameter=1.2): + """Compute Zernike 4 value for a given defocus in zemax conventions. + + All inputs should be in [m]. + + Parameters + ---------- + dz : float + Shift in the z-axis, perpendicular to the focal plane. Units in [m]. + tel_focal_length : float + Telescope focal length in [m]. + tel_diameter : float + Telescope aperture diameter in [m]. + """ + # Base calculation + zk4 = dz / (8.0 * (tel_focal_length / tel_diameter) ** 2) + # Apply Z4 normalisation + # This step depends on the normalisation of the Zernike basis used + zk4 /= np.sqrt(3) + # Convert to waves with a reference of 800nm + zk4 /= 800e-9 + # Remove the peak to valley value + zk4 /= 2.0 + + return zk4 + + +def defocus_to_zk4_wavediff(dz, tel_focal_length=24.5, tel_diameter=1.2): + """Compute Zernike 4 value for a given defocus in WaveDifff conventions. + + All inputs should be in [m]. + + The output zernike coefficient is in [um] units as expected by wavediff. + + Parameters + ---------- + dz : float + Shift in the z-axis, perpendicular to the focal plane. Units in [m]. + tel_focal_length : float + Telescope focal length in [m]. + tel_diameter : float + Telescope aperture diameter in [m]. + """ + # Base calculation + zk4 = dz / (8.0 * (tel_focal_length / tel_diameter) ** 2) + # Apply Z4 normalisation + # This step depends on the normalisation of the Zernike basis used + zk4 /= np.sqrt(3) + + # Remove the peak to valley value + zk4 /= 2.0 + + # Change units to [um] as Wavediff uses + zk4 *= 1e6 + + return zk4 diff --git a/src/wf_psf/data/training_preprocessing.py b/src/wf_psf/data/training_preprocessing.py deleted file mode 100644 index 25180ebe..00000000 --- a/src/wf_psf/data/training_preprocessing.py +++ /dev/null @@ -1,458 +0,0 @@ -"""Training Data Processing. - -A module to load and preprocess training and validation test data. - -:Authors: Jennifer Pollack and Tobias Liaudat - -""" - -import os -import numpy as np -import wf_psf.utils.utils as utils -import tensorflow as tf -from wf_psf.instrument.ccd_misalignments import CCDMisalignmentCalculator -from wf_psf.data.centroids import compute_zernike_tip_tilt -from fractions import Fraction -import logging - -logger = logging.getLogger(__name__) - - -class DataHandler: - """Data Handler. - - This class manages loading and processing of training and testing data for use during PSF model training and validation. - It provides methods to access and preprocess the data. - - Parameters - ---------- - dataset_type : str - Type of dataset ("train" or "test"). - data_params : RecursiveNamespace - Recursive Namespace object containing parameters for both 'train' and 'test' datasets. - simPSF : PSFSimulator - Instance of the PSFSimulator class for simulating PSF models. - n_bins_lambda : int - Number of wavelength bins for SED processing. - load_data : bool, optional - If True, data is loaded and processed during initialization. If False, data loading - is deferred until explicitly called. Default is True. - - Attributes - ---------- - dataset_type : str - Type of dataset ("train" or "test"). - data_params : RecursiveNamespace - Parameters for the current dataset type. - dataset : dict or None - Dictionary containing the loaded dataset, including positions and stars/noisy_stars. - simPSF : PSFSimulator - Instance of the PSFSimulator class for simulating PSF models. - n_bins_lambda : int - Number of wavelength bins. - sed_data : tf.Tensor or None - TensorFlow tensor containing processed SED data for training/testing. - load_data_on_init : bool - Flag controlling whether data is loaded during initialization. - """ - - def __init__( - self, dataset_type, data_params, simPSF, n_bins_lambda, load_data: bool = True - ): - """ - Initialize the dataset handler for PSF simulation. - - Parameters - ---------- - dataset_type : str - Type of dataset ("train" or "test"). - data_params : RecursiveNamespace - Recursive Namespace object containing parameters for both 'train' and 'test' datasets. - simPSF : PSFSimulator - Instance of the PSFSimulator class for simulating PSF models. - n_bins_lambda : int - Number of wavelength bins for SED processing. - load_data : bool, optional - If True, data is loaded and processed during initialization. If False, data loading - is deferred. Default is True. - """ - self.dataset_type = dataset_type - self.data_params = data_params.__dict__[dataset_type] - self.simPSF = simPSF - self.n_bins_lambda = n_bins_lambda - self.dataset = None - self.sed_data = None - self.load_data_on_init = load_data - if self.load_data_on_init: - self.load_dataset() - self.process_sed_data() - - def load_dataset(self): - """Load dataset. - - Load the dataset based on the specified dataset type. - - """ - self.dataset = np.load( - os.path.join(self.data_params.data_dir, self.data_params.file), - allow_pickle=True, - )[()] - self.dataset["positions"] = tf.convert_to_tensor( - self.dataset["positions"], dtype=tf.float32 - ) - if self.dataset_type == "training": - if "noisy_stars" in self.dataset: - self.dataset["noisy_stars"] = tf.convert_to_tensor( - self.dataset["noisy_stars"], dtype=tf.float32 - ) - else: - logger.warning(f"Missing 'noisy_stars' in {self.dataset_type} dataset.") - elif self.dataset_type == "test": - if "stars" in self.dataset: - self.dataset["stars"] = tf.convert_to_tensor( - self.dataset["stars"], dtype=tf.float32 - ) - else: - logger.warning(f"Missing 'stars' in {self.dataset_type} dataset.") - - def process_sed_data(self): - """Process SED Data. - - A method to generate and process SED data. - - """ - self.sed_data = [ - utils.generate_SED_elems_in_tensorflow( - _sed, self.simPSF, n_bins=self.n_bins_lambda, tf_dtype=tf.float64 - ) - for _sed in self.dataset["SEDs"] - ] - self.sed_data = tf.convert_to_tensor(self.sed_data, dtype=tf.float32) - self.sed_data = tf.transpose(self.sed_data, perm=[0, 2, 1]) - - -def get_np_obs_positions(data): - """Get observed positions in numpy from the provided dataset. - - This method concatenates the positions of the stars from both the training - and test datasets to obtain the observed positions. - - Parameters - ---------- - data : DataConfigHandler - Object containing training and test datasets. - - Returns - ------- - np.ndarray - Numpy array containing the observed positions of the stars. - - Notes - ----- - The observed positions are obtained by concatenating the positions of stars - from both the training and test datasets along the 0th axis. - """ - obs_positions = np.concatenate( - ( - data.training_data.dataset["positions"], - data.test_data.dataset["positions"], - ), - axis=0, - ) - - return obs_positions - - -def get_obs_positions(data): - """Get observed positions from the provided dataset. - - Parameters - ---------- - data : DataConfigHandler - Object containing training and test datasets. - - Returns - ------- - tf.Tensor - Tensor containing the observed positions of the stars. - """ - obs_positions = get_np_obs_positions(data) - - return tf.convert_to_tensor(obs_positions, dtype=tf.float32) - - -def extract_star_data(data, train_key: str, test_key: str) -> np.ndarray: - """Extract specific star-related data from training and test datasets. - - This function retrieves and concatenates specific star-related data (e.g., stars, masks) from the - star training and test datasets such as star stamps or masks, based on the provided keys. - - Parameters - ---------- - data : DataConfigHandler - Object containing training and test datasets. - train_key : str - The key to retrieve data from the training dataset (e.g., 'noisy_stars', 'masks'). - test_key : str - The key to retrieve data from the test dataset (e.g., 'stars', 'masks'). - - Returns - ------- - np.ndarray - A NumPy array containing the concatenated data for the given keys. - - Raises - ------ - KeyError - If the specified keys do not exist in the training or test datasets. - - Notes - ----- - - If the dataset contains TensorFlow tensors, they will be converted to NumPy arrays. - - Ensure that eager execution is enabled when calling this function. - """ - # Ensure the requested keys exist in both training and test datasets - missing_keys = [ - key - for key, dataset in [ - (train_key, data.training_data.dataset), - (test_key, data.test_data.dataset), - ] - if key not in dataset - ] - - if missing_keys: - raise KeyError(f"Missing keys in dataset: {missing_keys}") - - # Retrieve data from training and test sets - train_data = data.training_data.dataset[train_key] - test_data = data.test_data.dataset[test_key] - - # Convert to NumPy if necessary - if tf.is_tensor(train_data): - train_data = train_data.numpy() - if tf.is_tensor(test_data): - test_data = test_data.numpy() - - # Concatenate and return - return np.concatenate((train_data, test_data), axis=0) - - -def get_np_zernike_prior(data): - """Get the zernike prior from the provided dataset. - - This method concatenates the stars from both the training - and test datasets to obtain the full prior. - - Parameters - ---------- - data : DataConfigHandler - Object containing training and test datasets. - - Returns - ------- - zernike_prior : np.ndarray - Numpy array containing the full prior. - """ - zernike_prior = np.concatenate( - ( - data.training_data.dataset["zernike_prior"], - data.test_data.dataset["zernike_prior"], - ), - axis=0, - ) - - return zernike_prior - - -def compute_centroid_correction(model_params, data, batch_size: int = 1) -> np.ndarray: - """Compute centroid corrections using Zernike polynomials. - - This function calculates the Zernike contributions required to match the centroid - of the WaveDiff PSF model to the observed star centroids, processing in batches. - - Parameters - ---------- - model_params : RecursiveNamespace - An object containing parameters for the PSF model, including pixel sampling and initial centroid window parameters. - - data : DataConfigHandler - An object containing training and test datasets, including observed PSFs - and optional star masks. - - batch_size : int, optional - The batch size to use when processing the stars. Default is 16. - - - Returns - ------- - zernike_centroid_array : np.ndarray - A 2D NumPy array of shape `(n_stars, 3)`, where `n_stars` is the number of - observed stars. The array contains the computed Zernike contributions, - with zero padding applied to the first column to ensure a consistent shape. - """ - star_postage_stamps = extract_star_data( - data=data, train_key="noisy_stars", test_key="stars" - ) - - # Get star mask catalogue only if "masks" exist in both training and test datasets - star_masks = ( - extract_star_data(data=data, train_key="masks", test_key="masks") - if ( - data.training_data.dataset.get("masks") is not None - and data.test_data.dataset.get("masks") is not None - and tf.size(data.training_data.dataset["masks"]) > 0 - and tf.size(data.test_data.dataset["masks"]) > 0 - ) - else None - ) - - pix_sampling = model_params.pix_sampling * 1e-6 # Change units from [um] to [m] - - # Ensure star_masks is properly handled - star_masks = star_masks if star_masks is not None else None - - reference_shifts = [ - float(Fraction(value)) for value in model_params.reference_shifts - ] - - n_stars = len(star_postage_stamps) - zernike_centroid_array = [] - - # Batch process the stars - for i in range(0, n_stars, batch_size): - batch_postage_stamps = star_postage_stamps[i : i + batch_size] - batch_masks = star_masks[i : i + batch_size] if star_masks is not None else None - - # Compute Zernike 1 and Zernike 2 for the batch - zk1_2_batch = -1.0 * compute_zernike_tip_tilt( - batch_postage_stamps, batch_masks, pix_sampling, reference_shifts - ) - - # Zero pad array for each batch and append - zernike_centroid_array.append( - np.pad( - zk1_2_batch, - pad_width=[(0, 0), (1, 0)], - mode="constant", - constant_values=0, - ) - ) - - # Combine all batches into a single array - return np.concatenate(zernike_centroid_array, axis=0) - - -def compute_ccd_misalignment(model_params, data): - """Compute CCD misalignment. - - Parameters - ---------- - model_params : RecursiveNamespace - Object containing parameters for this PSF model class. - data : DataConfigHandler - Object containing training and test datasets. - - Returns - ------- - zernike_ccd_misalignment_array : np.ndarray - Numpy array containing the Zernike contributions to model the CCD chip misalignments. - """ - obs_positions = get_np_obs_positions(data) - - ccd_misalignment_calculator = CCDMisalignmentCalculator( - tiles_path=model_params.ccd_misalignments_input_path, - x_lims=model_params.x_lims, - y_lims=model_params.y_lims, - tel_focal_length=model_params.tel_focal_length, - tel_diameter=model_params.tel_diameter, - ) - # Compute required zernike 4 for each position - zk4_values = np.array( - [ - ccd_misalignment_calculator.get_zk4_from_position(single_pos) - for single_pos in obs_positions - ] - ).reshape(-1, 1) - - # Zero pad array to get shape (n_stars, n_zernike=4) - zernike_ccd_misalignment_array = np.pad( - zk4_values, pad_width=[(0, 0), (3, 0)], mode="constant", constant_values=0 - ) - - return zernike_ccd_misalignment_array - - -def get_zernike_prior(model_params, data, batch_size: int = 16): - """Get Zernike priors from the provided dataset. - - This method concatenates the Zernike priors from both the training - and test datasets. - - Parameters - ---------- - model_params : RecursiveNamespace - Object containing parameters for this PSF model class. - data : DataConfigHandler - Object containing training and test datasets. - batch_size : int, optional - The batch size to use when processing the stars. Default is 16. - - Returns - ------- - tf.Tensor - Tensor containing the observed positions of the stars. - - Notes - ----- - The Zernike prior are obtained by concatenating the Zernike priors - from both the training and test datasets along the 0th axis. - - """ - # List of zernike contribution - zernike_contribution_list = [] - - if model_params.use_prior: - logger.info("Reading in Zernike prior into Zernike contribution list...") - zernike_contribution_list.append(get_np_zernike_prior(data)) - - if model_params.correct_centroids: - logger.info("Adding centroid correction to Zernike contribution list...") - zernike_contribution_list.append( - compute_centroid_correction(model_params, data, batch_size) - ) - - if model_params.add_ccd_misalignments: - logger.info("Adding CCD mis-alignments to Zernike contribution list...") - zernike_contribution_list.append(compute_ccd_misalignment(model_params, data)) - - if len(zernike_contribution_list) == 1: - zernike_contribution = zernike_contribution_list[0] - else: - # Get max zk order - max_zk_order = np.max( - np.array( - [ - zk_contribution.shape[1] - for zk_contribution in zernike_contribution_list - ] - ) - ) - - zernike_contribution = np.zeros( - (zernike_contribution_list[0].shape[0], max_zk_order) - ) - - # Pad arrays to get the same length and add the final contribution - for it in range(len(zernike_contribution_list)): - current_zk_order = zernike_contribution_list[it].shape[1] - current_zernike_contribution = np.pad( - zernike_contribution_list[it], - pad_width=[(0, 0), (0, int(max_zk_order - current_zk_order))], - mode="constant", - constant_values=0, - ) - - zernike_contribution += current_zernike_contribution - - return tf.convert_to_tensor(zernike_contribution, dtype=tf.float32) diff --git a/src/wf_psf/instrument/ccd_misalignments.py b/src/wf_psf/instrument/ccd_misalignments.py index 3b7a1eb3..35343886 100644 --- a/src/wf_psf/instrument/ccd_misalignments.py +++ b/src/wf_psf/instrument/ccd_misalignments.py @@ -13,6 +13,46 @@ from wf_psf.data.data_preprocessing import defocus_to_zk4_wavediff +def compute_ccd_misalignment(model_params, data): + """Compute CCD misalignment. + + Parameters + ---------- + model_params : RecursiveNamespace + Object containing parameters for this PSF model class. + data : DataConfigHandler + Object containing training and test datasets. + + Returns + ------- + zernike_ccd_misalignment_array : np.ndarray + Numpy array containing the Zernike contributions to model the CCD chip misalignments. + """ + obs_positions = get_np_obs_positions(data) + + ccd_misalignment_calculator = CCDMisalignmentCalculator( + tiles_path=model_params.ccd_misalignments_input_path, + x_lims=model_params.x_lims, + y_lims=model_params.y_lims, + tel_focal_length=model_params.tel_focal_length, + tel_diameter=model_params.tel_diameter, + ) + # Compute required zernike 4 for each position + zk4_values = np.array( + [ + ccd_misalignment_calculator.get_zk4_from_position(single_pos) + for single_pos in obs_positions + ] + ).reshape(-1, 1) + + # Zero pad array to get shape (n_stars, n_zernike=4) + zernike_ccd_misalignment_array = np.pad( + zk4_values, pad_width=[(0, 0), (3, 0)], mode="constant", constant_values=0 + ) + + return zernike_ccd_misalignment_array + + class CCDMisalignmentCalculator: """CCD Misalignment Calculator. diff --git a/src/wf_psf/psf_models/tf_modules/tf_psf_field.py b/src/wf_psf/psf_models/tf_modules/tf_psf_field.py index e25657d9..23a19b8b 100644 --- a/src/wf_psf/psf_models/tf_modules/tf_psf_field.py +++ b/src/wf_psf/psf_models/tf_modules/tf_psf_field.py @@ -16,7 +16,7 @@ TFPhysicalLayer, ) from wf_psf.psf_models.models.psf_model_semiparametric import TFSemiParametricField -from wf_psf.data.training_preprocessing import get_obs_positions +from wf_psf.data.data_handler import get_obs_positions from wf_psf.psf_models import psf_models as psfm import logging diff --git a/src/wf_psf/tests/__init__.py b/src/wf_psf/tests/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/src/wf_psf/tests/__init__.py @@ -0,0 +1 @@ + diff --git a/src/wf_psf/tests/conftest.py b/src/wf_psf/tests/conftest.py index 5b617c63..beb6b9fb 100644 --- a/src/wf_psf/tests/conftest.py +++ b/src/wf_psf/tests/conftest.py @@ -13,7 +13,7 @@ from wf_psf.training.train import TrainingParamsHandler from wf_psf.utils.configs_handler import DataConfigHandler from wf_psf.psf_models import psf_models -from wf_psf.data.training_preprocessing import DataHandler +from wf_psf.data.data_handler import DataHandler training_config = RecursiveNamespace( id_name="-coherent_euclid_200stars", diff --git a/src/wf_psf/tests/test_data/__init__.py b/src/wf_psf/tests/test_data/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/wf_psf/tests/test_data/centroids_test.py b/src/wf_psf/tests/test_data/centroids_test.py index c55e18f9..85719f4f 100644 --- a/src/wf_psf/tests/test_data/centroids_test.py +++ b/src/wf_psf/tests/test_data/centroids_test.py @@ -8,8 +8,11 @@ import numpy as np import pytest +from wf_psf.data.centroids import compute_centroid_correction, CentroidEstimator +from wf_psf.data.data_handler import extract_star_data +from wf_psf.data.data_zernike_utils import compute_zernike_tip_tilt +from wf_psf.utils.read_config import RecursiveNamespace from unittest.mock import MagicMock, patch -from wf_psf.data.centroids import compute_zernike_tip_tilt, CentroidEstimator # Function to compute centroid based on first-order moments @@ -28,25 +31,6 @@ def calculate_centroid(image, mask=None): return (xc, yc) -@pytest.fixture -def simple_image(): - """Fixture for a batch of simple star images.""" - num_images = 1 # Change this to test with multiple images - image = np.zeros((num_images, 5, 5)) # Create a 3D array - image[:, 2, 2] = 1 # Place the star at the center for each image - return image - - -@pytest.fixture -def multiple_images(): - """Fixture for a batch of images with stars at different positions.""" - images = np.zeros((3, 5, 5)) # 3 images, each of size 5x5 - images[0, 2, 2] = 1 # Star at center of image 0 - images[1, 1, 3] = 1 # Star at (1, 3) in image 1 - images[2, 3, 1] = 1 # Star at (3, 1) in image 2 - return images - - @pytest.fixture def simple_star_and_mask(): """Fixture for an image with multiple non-zero pixels for centroid calculation.""" @@ -68,12 +52,6 @@ def simple_star_and_mask(): return image, mask -@pytest.fixture -def identity_mask(): - """Creates a mask where all pixels are fully considered.""" - return np.ones((5, 5)) - - @pytest.fixture def simple_image_with_mask(simple_image): """Fixture for a batch of star images with masks.""" @@ -129,133 +107,92 @@ def batch_images(): return images -def test_compute_zernike_tip_tilt_single_batch(mocker, simple_image, identity_mask): - """Test compute_zernike_tip_tilt with single batch input and mocks.""" - # Mock the CentroidEstimator class - mock_centroid_calc = mocker.patch( - "wf_psf.data.centroids.CentroidEstimator", autospec=True - ) - - # Create a mock instance and configure get_intra_pixel_shifts() - mock_instance = mock_centroid_calc.return_value - mock_instance.get_intra_pixel_shifts.return_value = np.array( - [[0.05, -0.02]] - ) # Shape (1, 2) - - # Mock shift_x_y_to_zk1_2_wavediff to return predictable values - mock_shift_fn = mocker.patch( - "wf_psf.data.centroids.shift_x_y_to_zk1_2_wavediff", - side_effect=lambda shift: shift * 0.5, # Mocked conversion for test - ) - - # Define test inputs (batch of 1 image) - pixel_sampling = 12e-6 - reference_shifts = [-1 / 3, -1 / 3] # Default Euclid conditions - - # Run the function - zernike_corrections = compute_zernike_tip_tilt( - simple_image, identity_mask, pixel_sampling, reference_shifts - ) - zernike_corrections = compute_zernike_tip_tilt( - simple_image, identity_mask, pixel_sampling, reference_shifts - ) - - # Expected shifts based on centroid calculation - expected_dx = reference_shifts[1] - (-0.02) # Expected x-axis shift in meters - expected_dy = reference_shifts[0] - 0.05 # Expected y-axis shift in meters - - # Expected calls to the mocked function - # Extract the arguments passed to mock_shift_fn - args, _ = mock_shift_fn.call_args_list[0] # Get the first call args - - # Compare expected values with the actual arguments passed to the mock function - np.testing.assert_allclose( - args[0][0], expected_dx * pixel_sampling, rtol=1e-7, atol=0 +def test_compute_centroid_correction_with_masks(mock_data): + """Test compute_centroid_correction function with masks present.""" + # Given that compute_centroid_correction expects a model_params and data object + model_params = RecursiveNamespace( + pix_sampling=12e-6, # Example pixel sampling in meters + correct_centroids=True, + reference_shifts=["-1/3", "-1/3"], ) - # Check dy values similarly - np.testing.assert_allclose( - args[0][1], expected_dy * pixel_sampling, rtol=1e-7, atol=0 - ) - - # Expected values based on mock side_effect (0.5 * shift) - np.testing.assert_allclose( - zernike_corrections[0, 0], expected_dx * pixel_sampling * 0.5 - ) # Zk1 - np.testing.assert_allclose( - zernike_corrections[0, 1], expected_dy * pixel_sampling * 0.5 - ) # Zk2 - - -def test_compute_zernike_tip_tilt_batch(mocker, multiple_images): - """Test compute_zernike_tip_tilt with batch input and mocks.""" - # Mock the CentroidEstimator class - mock_centroid_calc = mocker.patch( - "wf_psf.data.centroids.CentroidEstimator", autospec=True - ) - - # Create a mock instance and configure get_intra_pixel_shifts() - mock_instance = mock_centroid_calc.return_value - mock_instance.get_intra_pixel_shifts.return_value = np.array( - [[0.05, -0.02], [0.04, -0.01], [0.06, -0.03]] - ) # Shape (3, 2) - - # Mock shift_x_y_to_zk1_2_wavediff to return predictable values - mock_shift_fn = mocker.patch( - "wf_psf.data.centroids.shift_x_y_to_zk1_2_wavediff", - side_effect=lambda shift: shift * 0.5, # Mocked conversion for test - ) - - # Define test inputs (batch of 3 images) - pixel_sampling = 12e-6 - reference_shifts = [-1 / 3, -1 / 3] # Default Euclid conditions - - # Run the function - zernike_corrections = compute_zernike_tip_tilt( - star_images=multiple_images, - pixel_sampling=pixel_sampling, - reference_shifts=reference_shifts, - ) + # Mock the internal function calls: + with ( + patch("wf_psf.data.centroids.extract_star_data") as mock_extract_star_data, + patch( + "wf_psf.data.centroids.compute_zernike_tip_tilt" + ) as mock_compute_zernike_tip_tilt, + ): - # Check if the mock function was called once with the full batch - assert len(mock_shift_fn.call_args_list) == 1, ( - f"Expected 1 call, but got {len(mock_shift_fn.call_args_list)}" + # Mock the return values of extract_star_data and compute_zernike_tip_tilt + mock_extract_star_data.side_effect = lambda data, train_key, test_key: ( + np.array([[1, 2], [3, 4]]) + if train_key == "noisy_stars" + else np.array([[5, 6], [7, 8]]) + ) + mock_compute_zernike_tip_tilt.return_value = np.array([[0.1, 0.2], [0.3, 0.4]]) + + # Call the function under test + result = compute_centroid_correction(model_params, mock_data) + + # Ensure the result has the correct shape + assert result.shape == (4, 3) # Should be (n_stars, 3 Zernike components) + + assert np.allclose( + result[0, :], np.array([0, -0.1, -0.2]) + ) # First star Zernike coefficients + assert np.allclose( + result[1, :], np.array([0, -0.3, -0.4]) + ) # Second star Zernike coefficients + + +def test_compute_centroid_correction_without_masks(mock_data): + """Test compute_centroid_correction function when no masks are provided.""" + # Remove masks from mock_data + mock_data.test_data.dataset["masks"] = None + mock_data.training_data.dataset["masks"] = None + + # Define model parameters + model_params = RecursiveNamespace( + pix_sampling=12e-6, # Example pixel sampling in meters + correct_centroids=True, + reference_shifts=["-1/3", "-1/3"], ) - # Get the arguments passed to the mock function for the batch of images - args, _ = mock_shift_fn.call_args_list[0] + # Mock internal function calls + with ( + patch("wf_psf.data.centroids.extract_star_data") as mock_extract_star_data, + patch( + "wf_psf.data.centroids.compute_zernike_tip_tilt" + ) as mock_compute_zernike_tip_tilt, + ): - print("Shape of args[0]:", args[0].shape) - print("Contents of args[0]:", args[0]) - print("Mock function call args list:", mock_shift_fn.call_args_list) + # Mock extract_star_data to return synthetic star postage stamps + mock_extract_star_data.side_effect = lambda data, train_key, test_key: ( + np.array([[1, 2], [3, 4]]) + if train_key == "noisy_stars" + else np.array([[5, 6], [7, 8]]) + ) - # Reshape args[0] to (N, 2) for batch processing - args_array = np.array(args[0]).reshape(-1, 2) + # Mock compute_zernike_tip_tilt assuming no masks + mock_compute_zernike_tip_tilt.return_value = np.array([[0.1, 0.2], [0.3, 0.4]]) - # Process the displacements and expected values for each image in the batch - expected_dx = ( - reference_shifts[1] - mock_instance.get_intra_pixel_shifts.return_value[:, 1] - ) # Expected x-axis shift in meters + # Call function under test + result = compute_centroid_correction(model_params, mock_data) - expected_dy = ( - reference_shifts[0] - mock_instance.get_intra_pixel_shifts.return_value[:, 0] - ) # Expected y-axis shift in meters + # Validate result shape + assert result.shape == (4, 3) # (n_stars, 3 Zernike components) - # Compare expected values with the actual arguments passed to the mock function - np.testing.assert_allclose( - args_array[:, 0], expected_dx * pixel_sampling, rtol=1e-7, atol=0 - ) - np.testing.assert_allclose( - args_array[:, 1], expected_dy * pixel_sampling, rtol=1e-7, atol=0 - ) - - # Expected values based on mock side_effect (0.5 * shift) - np.testing.assert_allclose( - zernike_corrections[:, 0], expected_dx * pixel_sampling * 0.5 - ) # Zk1 for each image - np.testing.assert_allclose( - zernike_corrections[:, 1], expected_dy * pixel_sampling * 0.5 - ) # Zk2 for each image + # Validate expected values (adjust based on behavior) + expected_result = np.array( + [ + [0, -0.1, -0.2], # From training data + [0, -0.3, -0.4], + [0, -0.1, -0.2], # From test data (reused mocked return) + [0, -0.3, -0.4], + ] + ) + assert np.allclose(result, expected_result) # Test for centroid calculation without mask @@ -442,9 +379,9 @@ def test_intra_pixel_shifts(simple_image_with_centroid): expected_y_shift = 2.7 - 2.0 # yc - yc0 # Check that the shifts are correct - assert np.isclose(shifts[0], expected_x_shift), ( - f"Expected {expected_x_shift}, got {shifts[0]}" - ) - assert np.isclose(shifts[1], expected_y_shift), ( - f"Expected {expected_y_shift}, got {shifts[1]}" - ) + assert np.isclose( + shifts[0], expected_x_shift + ), f"Expected {expected_x_shift}, got {shifts[0]}" + assert np.isclose( + shifts[1], expected_y_shift + ), f"Expected {expected_y_shift}, got {shifts[1]}" diff --git a/src/wf_psf/tests/test_data/conftest.py b/src/wf_psf/tests/test_data/conftest.py index 6159d53a..04a56893 100644 --- a/src/wf_psf/tests/test_data/conftest.py +++ b/src/wf_psf/tests/test_data/conftest.py @@ -9,8 +9,11 @@ """ import pytest +import numpy as np +import tensorflow as tf from wf_psf.utils.read_config import RecursiveNamespace from wf_psf.psf_models import psf_models +from wf_psf.tests.test_data.test_data_utils import MockData, MockDataset training_config = RecursiveNamespace( id_name="-coherent_euclid_200stars", @@ -93,6 +96,64 @@ ) +@pytest.fixture +def mock_data(scope="module"): + """Fixture to provide mock data for testing.""" + # Mock positions and Zernike priors + training_positions = np.array([[1, 2], [3, 4]]) + test_positions = np.array([[5, 6], [7, 8]]) + training_zernike_priors = np.array([[0.1, 0.2], [0.3, 0.4]]) + test_zernike_priors = np.array([[0.5, 0.6], [0.7, 0.8]]) + + # Define dummy 5x5 image patches for stars (mock star images) + # Define varied values for 5x5 star images + noisy_stars = tf.constant([ + np.arange(25).reshape(5, 5), + np.arange(25, 50).reshape(5, 5) + ], dtype=tf.float32) + + noisy_masks = tf.constant([ + np.eye(5), + np.ones((5, 5)) + ], dtype=tf.float32) + + stars = tf.constant([ + np.full((5, 5), 100), + np.full((5, 5), 200) + ], dtype=tf.float32) + + masks = tf.constant([ + np.zeros((5, 5)), + np.tri(5) + ], dtype=tf.float32) + + return MockData( + training_positions, test_positions, training_zernike_priors, + test_zernike_priors, noisy_stars, noisy_masks, stars, masks + ) + +@pytest.fixture +def simple_image(scope="module"): + """Fixture for a simple star image.""" + num_images = 1 # Change this to test with multiple images + image = np.zeros((num_images, 5, 5)) # Create a 3D array + image[:, 2, 2] = 1 # Place the star at the center for each image + return image + +@pytest.fixture +def identity_mask(scope="module"): + """Creates a mask where all pixels are fully considered.""" + return np.ones((5, 5)) + +@pytest.fixture +def multiple_images(scope="module"): + """Fixture for a batch of images with stars at different positions.""" + images = np.zeros((3, 5, 5)) # 3 images, each of size 5x5 + images[0, 2, 2] = 1 # Star at center of image 0 + images[1, 1, 3] = 1 # Star at (1, 3) in image 1 + images[2, 3, 1] = 1 # Star at (3, 1) in image 2 + return images + @pytest.fixture(scope="module", params=[data]) def data_params(): return data diff --git a/src/wf_psf/tests/test_data/training_preprocessing_test.py b/src/wf_psf/tests/test_data/data_handler_test.py similarity index 52% rename from src/wf_psf/tests/test_data/training_preprocessing_test.py rename to src/wf_psf/tests/test_data/data_handler_test.py index 3efc8272..5d167a6b 100644 --- a/src/wf_psf/tests/test_data/training_preprocessing_test.py +++ b/src/wf_psf/tests/test_data/data_handler_test.py @@ -1,81 +1,16 @@ import pytest import numpy as np import tensorflow as tf -from wf_psf.utils.read_config import RecursiveNamespace -from wf_psf.data.training_preprocessing import ( +from wf_psf.data.data_handler import ( DataHandler, get_obs_positions, - get_zernike_prior, extract_star_data, - compute_centroid_correction, ) +from wf_psf.utils.read_config import RecursiveNamespace +import logging from unittest.mock import patch -class MockData: - def __init__( - self, - training_positions, - test_positions, - training_zernike_priors, - test_zernike_priors, - noisy_stars=None, - noisy_masks=None, - stars=None, - masks=None, - ): - self.training_data = MockDataset( - positions=training_positions, - zernike_priors=training_zernike_priors, - star_type="noisy_stars", - stars=noisy_stars, - masks=noisy_masks, - ) - self.test_data = MockDataset( - positions=test_positions, - zernike_priors=test_zernike_priors, - star_type="stars", - stars=stars, - masks=masks, - ) - - -class MockDataset: - def __init__(self, positions, zernike_priors, star_type, stars, masks): - self.dataset = { - "positions": positions, - "zernike_prior": zernike_priors, - star_type: stars, - "masks": masks, - } - - -@pytest.fixture -def mock_data(): - # Mock data for testing - # Mock training and test positions and Zernike priors - training_positions = np.array([[1, 2], [3, 4]]) - test_positions = np.array([[5, 6], [7, 8]]) - training_zernike_priors = np.array([[0.1, 0.2], [0.3, 0.4]]) - test_zernike_priors = np.array([[0.5, 0.6], [0.7, 0.8]]) - # Mock noisy stars, stars and masks - noisy_stars = tf.constant([[1, 2], [3, 4]], dtype=tf.float32) - noisy_masks = tf.constant([[1], [0]], dtype=tf.float32) - stars = tf.constant([[5, 6], [7, 8]], dtype=tf.float32) - masks = tf.constant([[0], [1]], dtype=tf.float32) - - return MockData( - training_positions, - test_positions, - training_zernike_priors, - test_zernike_priors, - noisy_stars, - noisy_masks, - stars, - masks, - ) - - def test_load_train_dataset(tmp_path, data_params, simPSF): # Create a temporary directory and a temporary data file data_dir = tmp_path / "data" @@ -170,7 +105,7 @@ def test_load_train_dataset_missing_noisy_stars(tmp_path, data_params, simPSF): "training", data_params, simPSF, n_bins_lambda, load_data=False ) - with patch("wf_psf.data.training_preprocessing.logger.warning") as mock_warning: + with patch("wf_psf.data.data_handler.logger.warning") as mock_warning: data_handler.load_dataset() mock_warning.assert_called_with("Missing 'noisy_stars' in training dataset.") @@ -197,7 +132,7 @@ def test_load_test_dataset_missing_stars(tmp_path, data_params, simPSF): "test", data_params, simPSF, n_bins_lambda, load_data=False ) - with patch("wf_psf.data.training_preprocessing.logger.warning") as mock_warning: + with patch("wf_psf.data.data_handler.logger.warning") as mock_warning: data_handler.load_dataset() mock_warning.assert_called_with("Missing 'stars' in test dataset.") @@ -230,40 +165,21 @@ def test_get_obs_positions(mock_data): assert tf.reduce_all(tf.equal(observed_positions, expected_positions)) -def test_get_zernike_prior(model_params, mock_data): - zernike_priors = get_zernike_prior(model_params, mock_data) - expected_shape = ( - 4, - 2, - ) # Assuming 2 Zernike priors for each dataset (training and test) - assert zernike_priors.shape == expected_shape - - -def test_get_zernike_prior_dtype(model_params, mock_data): - zernike_priors = get_zernike_prior(model_params, mock_data) - assert zernike_priors.dtype == np.float32 - - -def test_get_zernike_prior_concatenation(model_params, mock_data): - zernike_priors = get_zernike_prior(model_params, mock_data) - expected_zernike_priors = tf.convert_to_tensor( - np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6], [0.7, 0.8]]), dtype=tf.float32 - ) - - assert np.array_equal(zernike_priors, expected_zernike_priors) - - -def test_get_zernike_prior_empty_data(model_params): - empty_data = MockData(np.array([]), np.array([]), np.array([]), np.array([])) - zernike_priors = get_zernike_prior(model_params, empty_data) - assert zernike_priors.shape == tf.TensorShape([0]) # Check for empty array shape - - def test_extract_star_data_valid_keys(mock_data): """Test extracting valid data from the dataset.""" result = extract_star_data(mock_data, train_key="noisy_stars", test_key="stars") - expected = np.array([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=np.float32) + expected = tf.concat( + [ + tf.constant( + [np.arange(25).reshape(5, 5), np.arange(25, 50).reshape(5, 5)], + dtype=tf.float32, + ), + tf.constant([np.full((5, 5), 100), np.full((5, 5), 200)], dtype=tf.float32), + ], + axis=0, + ) + np.testing.assert_array_equal(result, expected) @@ -271,7 +187,13 @@ def test_extract_star_data_masks(mock_data): """Test extracting star masks from the dataset.""" result = extract_star_data(mock_data, train_key="masks", test_key="masks") - expected = np.array([[1], [0], [0], [1]], dtype=np.float32) + mask0 = np.eye(5, dtype=np.float32) + mask1 = np.ones((5, 5), dtype=np.float32) + mask2 = np.zeros((5, 5), dtype=np.float32) + mask3 = np.tri(5, dtype=np.float32) + + expected = np.array([mask0, mask1, mask2, mask3], dtype=np.float32) + np.testing.assert_array_equal(result, expected) @@ -297,96 +219,6 @@ def test_extract_star_data_tensor_conversion(mock_data): assert result.dtype == np.float32, "The NumPy array should have dtype float32" -def test_compute_centroid_correction_with_masks(mock_data): - """Test compute_centroid_correction function with masks present.""" - # Given that compute_centroid_correction expects a model_params and data object - model_params = RecursiveNamespace( - pix_sampling=12e-6, # Example pixel sampling in meters - correct_centroids=True, - reference_shifts=["-1/3", "-1/3"], - ) - - # Mock the internal function calls: - with ( - patch( - "wf_psf.data.training_preprocessing.extract_star_data" - ) as mock_extract_star_data, - patch( - "wf_psf.data.training_preprocessing.compute_zernike_tip_tilt" - ) as mock_compute_zernike_tip_tilt, - ): - # Mock the return values of extract_star_data and compute_zernike_tip_tilt - mock_extract_star_data.side_effect = lambda data, train_key, test_key: ( - np.array([[1, 2], [3, 4]]) - if train_key == "noisy_stars" - else np.array([[5, 6], [7, 8]]) - ) - mock_compute_zernike_tip_tilt.return_value = np.array([[0.1, 0.2], [0.3, 0.4]]) - - # Call the function under test - result = compute_centroid_correction(model_params, mock_data) - - # Ensure the result has the correct shape - assert result.shape == (4, 3) # Should be (n_stars, 3 Zernike components) - - assert np.allclose( - result[0, :], np.array([0, -0.1, -0.2]) - ) # First star Zernike coefficients - assert np.allclose( - result[1, :], np.array([0, -0.3, -0.4]) - ) # Second star Zernike coefficients - - -def test_compute_centroid_correction_without_masks(mock_data): - """Test compute_centroid_correction function when no masks are provided.""" - # Remove masks from mock_data - mock_data.test_data.dataset["masks"] = None - mock_data.training_data.dataset["masks"] = None - - # Define model parameters - model_params = RecursiveNamespace( - pix_sampling=12e-6, # Example pixel sampling in meters - correct_centroids=True, - reference_shifts=["-1/3", "-1/3"], - ) - - # Mock internal function calls - with ( - patch( - "wf_psf.data.training_preprocessing.extract_star_data" - ) as mock_extract_star_data, - patch( - "wf_psf.data.training_preprocessing.compute_zernike_tip_tilt" - ) as mock_compute_zernike_tip_tilt, - ): - # Mock extract_star_data to return synthetic star postage stamps - mock_extract_star_data.side_effect = lambda data, train_key, test_key: ( - np.array([[1, 2], [3, 4]]) - if train_key == "noisy_stars" - else np.array([[5, 6], [7, 8]]) - ) - - # Mock compute_zernike_tip_tilt assuming no masks - mock_compute_zernike_tip_tilt.return_value = np.array([[0.1, 0.2], [0.3, 0.4]]) - - # Call function under test - result = compute_centroid_correction(model_params, mock_data) - - # Validate result shape - assert result.shape == (4, 3) # (n_stars, 3 Zernike components) - - # Validate expected values (adjust based on behavior) - expected_result = np.array( - [ - [0, -0.1, -0.2], # From training data - [0, -0.3, -0.4], - [0, -0.1, -0.2], # From test data (reused mocked return) - [0, -0.3, -0.4], - ] - ) - assert np.allclose(result, expected_result) - - def test_reference_shifts_broadcasting(): reference_shifts = [-1 / 3, -1 / 3] # Example reference_shifts shifts = np.random.rand(2, 2400) # Example shifts array diff --git a/src/wf_psf/tests/test_data/data_zernike_utils_test.py b/src/wf_psf/tests/test_data/data_zernike_utils_test.py new file mode 100644 index 00000000..692624be --- /dev/null +++ b/src/wf_psf/tests/test_data/data_zernike_utils_test.py @@ -0,0 +1,133 @@ + +import pytest +import numpy as np +import tensorflow as tf +from wf_psf.data.data_zernike_utils import ( + get_zernike_prior, + compute_zernike_tip_tilt, +) +from wf_psf.tests.test_data.test_data_utils import MockData, MockDataset + +def test_get_zernike_prior(model_params, mock_data): + zernike_priors = get_zernike_prior(model_params, mock_data) + expected_shape = ( + 4, + 2, + ) # Assuming 2 Zernike priors for each dataset (training and test) + assert zernike_priors.shape == expected_shape + + +def test_get_zernike_prior_dtype(model_params, mock_data): + zernike_priors = get_zernike_prior(model_params, mock_data) + assert zernike_priors.dtype == np.float32 + + +def test_get_zernike_prior_concatenation(model_params, mock_data): + zernike_priors = get_zernike_prior(model_params, mock_data) + expected_zernike_priors = tf.convert_to_tensor( + np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6], [0.7, 0.8]]), dtype=tf.float32 + ) + + assert np.array_equal(zernike_priors, expected_zernike_priors) + + +def test_get_zernike_prior_empty_data(model_params): + empty_data = MockData(np.array([]), np.array([]), np.array([]), np.array([])) + zernike_priors = get_zernike_prior(model_params, empty_data) + assert zernike_priors.shape == tf.TensorShape([0]) # Check for empty array shape + + +def test_compute_zernike_tip_tilt_single_batch(mocker, simple_image, identity_mask): + """Test compute_zernike_tip_tilt with single batch input and mocks.""" + + # Mock the CentroidEstimator class + mock_centroid_calc = mocker.patch("wf_psf.data.centroids.CentroidEstimator", autospec=True) + + # Create a mock instance and configure get_intra_pixel_shifts() + mock_instance = mock_centroid_calc.return_value + mock_instance.get_intra_pixel_shifts.return_value = np.array([[0.05, -0.02]]) # Shape (1, 2) + + # Mock shift_x_y_to_zk1_2_wavediff to return predictable values + mock_shift_fn = mocker.patch( + "wf_psf.data.data_zernike_utils.shift_x_y_to_zk1_2_wavediff", + side_effect=lambda shift: shift * 0.5 # Mocked conversion for test + ) + + # Define test inputs (batch of 1 image) + pixel_sampling = 12e-6 + reference_shifts = [-1 / 3, -1 / 3] # Default Euclid conditions + + # Run the function + zernike_corrections = compute_zernike_tip_tilt(simple_image, identity_mask, pixel_sampling, reference_shifts) + zernike_corrections = compute_zernike_tip_tilt(simple_image, identity_mask, pixel_sampling, reference_shifts) + + # Expected shifts based on centroid calculation + expected_dx = (reference_shifts[1] - (-0.02)) # Expected x-axis shift in meters + expected_dy = (reference_shifts[0] - 0.05) # Expected y-axis shift in meters + + # Expected calls to the mocked function + # Extract the arguments passed to mock_shift_fn + args, _ = mock_shift_fn.call_args_list[0] # Get the first call args + + # Compare expected values with the actual arguments passed to the mock function + np.testing.assert_allclose(args[0][0], expected_dx * pixel_sampling, rtol=1e-7, atol=0) + + # Check dy values similarly + np.testing.assert_allclose(args[0][1], expected_dy * pixel_sampling, rtol=1e-7, atol=0) + + # Expected values based on mock side_effect (0.5 * shift) + np.testing.assert_allclose(zernike_corrections[0, 0], expected_dx * pixel_sampling * 0.5) # Zk1 + np.testing.assert_allclose(zernike_corrections[0, 1], expected_dy * pixel_sampling * 0.5) # Zk2 + +def test_compute_zernike_tip_tilt_batch(mocker, multiple_images): + """Test compute_zernike_tip_tilt with batch input and mocks.""" + + # Mock the CentroidEstimator class + mock_centroid_calc = mocker.patch("wf_psf.data.centroids.CentroidEstimator", autospec=True) + + # Create a mock instance and configure get_intra_pixel_shifts() + mock_instance = mock_centroid_calc.return_value + mock_instance.get_intra_pixel_shifts.return_value = np.array([[0.05, -0.02], [0.04, -0.01], [0.06, -0.03]]) # Shape (3, 2) + + # Mock shift_x_y_to_zk1_2_wavediff to return predictable values + mock_shift_fn = mocker.patch( + "wf_psf.data.data_zernike_utils.shift_x_y_to_zk1_2_wavediff", + side_effect=lambda shift: shift * 0.5 # Mocked conversion for test + ) + + # Define test inputs (batch of 3 images) + pixel_sampling = 12e-6 + reference_shifts = [-1 / 3, -1 / 3] # Default Euclid conditions + + # Run the function + zernike_corrections = compute_zernike_tip_tilt( + star_images=multiple_images, + pixel_sampling=pixel_sampling, + reference_shifts=reference_shifts + ) + + # Check if the mock function was called once with the full batch + assert len(mock_shift_fn.call_args_list) == 1, f"Expected 1 call, but got {len(mock_shift_fn.call_args_list)}" + + # Get the arguments passed to the mock function for the batch of images + args, _ = mock_shift_fn.call_args_list[0] + + print("Shape of args[0]:", args[0].shape) + print("Contents of args[0]:", args[0]) + print("Mock function call args list:", mock_shift_fn.call_args_list) + + # Reshape args[0] to (N, 2) for batch processing + args_array = np.array(args[0]).reshape(-1, 2) + + # Process the displacements and expected values for each image in the batch + expected_dx = reference_shifts[1] - mock_instance.get_intra_pixel_shifts.return_value[:, 1] # Expected x-axis shift in meters + + expected_dy = reference_shifts[0] - mock_instance.get_intra_pixel_shifts.return_value[:, 0] # Expected y-axis shift in meters + + # Compare expected values with the actual arguments passed to the mock function + np.testing.assert_allclose(args_array[:, 0], expected_dx * pixel_sampling, rtol=1e-7, atol=0) + np.testing.assert_allclose(args_array[:, 1], expected_dy * pixel_sampling, rtol=1e-7, atol=0) + + # Expected values based on mock side_effect (0.5 * shift) + np.testing.assert_allclose(zernike_corrections[:, 0], expected_dx * pixel_sampling * 0.5) # Zk1 for each image + np.testing.assert_allclose(zernike_corrections[:, 1], expected_dy * pixel_sampling * 0.5) # Zk2 for each image \ No newline at end of file diff --git a/src/wf_psf/tests/test_data/test_data_utils.py b/src/wf_psf/tests/test_data/test_data_utils.py new file mode 100644 index 00000000..a5ead298 --- /dev/null +++ b/src/wf_psf/tests/test_data/test_data_utils.py @@ -0,0 +1,30 @@ + +class MockDataset: + def __init__(self, positions, zernike_priors, star_type, stars, masks): + self.dataset = {"positions": positions, "zernike_prior": zernike_priors, star_type: stars, "masks": masks} + +class MockData: + def __init__( + self, + training_positions, + test_positions, + training_zernike_priors, + test_zernike_priors, + noisy_stars=None, + noisy_masks=None, + stars=None, + masks=None, + ): + self.training_data = MockDataset( + positions=training_positions, + zernike_priors=training_zernike_priors, + star_type="noisy_stars", + stars=noisy_stars, + masks=noisy_masks) + self.test_data = MockDataset( + positions=test_positions, + zernike_priors=test_zernike_priors, + star_type="stars", + stars=stars, + masks=masks) + diff --git a/src/wf_psf/tests/test_metrics/metrics_interface_test.py b/src/wf_psf/tests/test_metrics/metrics_interface_test.py index 12d51447..bf52f0aa 100644 --- a/src/wf_psf/tests/test_metrics/metrics_interface_test.py +++ b/src/wf_psf/tests/test_metrics/metrics_interface_test.py @@ -1,7 +1,7 @@ from unittest.mock import patch, MagicMock import pytest -from wf_psf.metrics.metrics_interface import evaluate_model -from wf_psf.data.training_preprocessing import DataHandler +from wf_psf.metrics.metrics_interface import evaluate_model, MetricsParamsHandler +from wf_psf.data.data_handler import DataHandler @pytest.fixture diff --git a/src/wf_psf/tests/test_psf_models/conftest.py b/src/wf_psf/tests/test_psf_models/conftest.py index cbaae8d9..4693b343 100644 --- a/src/wf_psf/tests/test_psf_models/conftest.py +++ b/src/wf_psf/tests/test_psf_models/conftest.py @@ -12,7 +12,7 @@ from wf_psf.utils.read_config import RecursiveNamespace from wf_psf.training.train import TrainingParamsHandler from wf_psf.psf_models import psf_models -from wf_psf.data.training_preprocessing import DataHandler +from wf_psf.data.data_handler import DataHandler training_config = RecursiveNamespace( id_name="_sample_w_bis1_2k", diff --git a/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py b/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py index e25d608d..7e967465 100644 --- a/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py +++ b/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py @@ -59,7 +59,7 @@ def test_initialize_parameters(mocker, mock_data, mock_model_params, zks_prior): ) mocker.patch( - "wf_psf.data.training_preprocessing.get_obs_positions", return_value=True + "wf_psf.data.data_handler.get_obs_positions", return_value=True ) # Create TFPhysicalPolychromaticField instance diff --git a/src/wf_psf/tests/test_utils/configs_handler_test.py b/src/wf_psf/tests/test_utils/configs_handler_test.py index 9b4039a8..2ca1d17d 100644 --- a/src/wf_psf/tests/test_utils/configs_handler_test.py +++ b/src/wf_psf/tests/test_utils/configs_handler_test.py @@ -7,7 +7,7 @@ """ import pytest -from wf_psf.data.training_preprocessing import DataHandler +from wf_psf.data.data_handler import DataHandler from wf_psf.utils import configs_handler from wf_psf.utils.read_config import RecursiveNamespace from wf_psf.utils.io import FileIOHandler diff --git a/src/wf_psf/utils/configs_handler.py b/src/wf_psf/utils/configs_handler.py index 4b08a723..b7ca6444 100644 --- a/src/wf_psf/utils/configs_handler.py +++ b/src/wf_psf/utils/configs_handler.py @@ -12,7 +12,7 @@ import os import re import glob -from wf_psf.data.training_preprocessing import DataHandler +from wf_psf.data.data_handler import DataHandler from wf_psf.metrics.metrics_interface import evaluate_model from wf_psf.plotting.plots_interface import plot_metrics from wf_psf.psf_models import psf_models From 0cfb8df7c49f2ede9ab8396c1d2c2dadba52414b Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 16 May 2025 14:44:07 +0200 Subject: [PATCH 024/135] Update import statements to new module names --- .../psf_models/models/psf_model_physical_polychromatic.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py index f9ed8765..baec8923 100644 --- a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py +++ b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py @@ -10,7 +10,8 @@ from typing import Optional import tensorflow as tf from tensorflow.python.keras.engine import data_adapter -from wf_psf.data.training_preprocessing import get_obs_positions, get_zernike_prior +from wf_psf.data.data_handler import get_obs_positions +from wf_psf.data.data_zernike_utils import get_zernike_prior from wf_psf.psf_models import psf_models as psfm from wf_psf.psf_models.tf_modules.tf_layers import ( TFPolynomialZernikeField, From 2946d47fb983eb0dc4f18db7ac9f4f33c6da351b Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 16 May 2025 14:57:57 +0200 Subject: [PATCH 025/135] Update DataHandler class docstring to include option for inference dataset handling --- src/wf_psf/data/data_handler.py | 38 +++++++++++++++++---------------- 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/src/wf_psf/data/data_handler.py b/src/wf_psf/data/data_handler.py index e2946b0d..4e940115 100644 --- a/src/wf_psf/data/data_handler.py +++ b/src/wf_psf/data/data_handler.py @@ -25,7 +25,8 @@ class DataHandler: """Data Handler. - This class manages loading and processing of training and testing data for use during PSF model training and validation. + This class manages loading and processing of training, testing and inference data for use during PSF model training, inference, and validation. + It provides methods to access and preprocess the data. Parameters @@ -44,20 +45,21 @@ class DataHandler: Attributes ---------- - dataset_type : str - Type of dataset ("train" or "test"). - data_params : RecursiveNamespace - Parameters for the current dataset type. - dataset : dict or None - Dictionary containing the loaded dataset, including positions and stars/noisy_stars. - simPSF : PSFSimulator - Instance of the PSFSimulator class for simulating PSF models. - n_bins_lambda : int - Number of wavelength bins. - sed_data : tf.Tensor or None - TensorFlow tensor containing processed SED data for training/testing. - load_data_on_init : bool - Flag controlling whether data is loaded during initialization. + dataset_type: str + A string indicating the type of dataset ("train", "test" or "inference"). + data_params: Recursive Namespace object + A Recursive Namespace object containing training, test or inference data parameters. + dataset: dict + A dictionary containing the loaded dataset, including positions and stars/noisy_stars. + simPSF: object + An instance of the SimPSFToolkit class for simulating PSF. + n_bins_lambda: int + The number of bins in wavelength. + sed_data: tf.Tensor + A TensorFlow tensor containing the SED data for training/testing. + load_data_on_init: bool, optional + A flag used to control data loading steps. If True, data is loaded and processed + during initialization. If False, data loading is deferred until explicitly called. """ def __init__( @@ -69,9 +71,9 @@ def __init__( Parameters ---------- dataset_type : str - Type of dataset ("train" or "test"). - data_params : RecursiveNamespace - Recursive Namespace object containing parameters for both 'train' and 'test' datasets. + A string indicating the type of data ("train", "test", or "inference"). + data_params : Recursive Namespace object + A Recursive Namespace object containing parameters for both 'train', 'test', 'inference' datasets. simPSF : PSFSimulator Instance of the PSFSimulator class for simulating PSF models. n_bins_lambda : int From 0657e24614db4dad05ff6004f729891a89e49be5 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Sun, 18 May 2025 12:22:27 +0200 Subject: [PATCH 026/135] Refactor data_handler with new utility functions to validate and process datasets and update docstrings --- src/wf_psf/data/data_handler.py | 184 ++++++++++++++++++++++++-------- 1 file changed, 139 insertions(+), 45 deletions(-) diff --git a/src/wf_psf/data/data_handler.py b/src/wf_psf/data/data_handler.py index 4e940115..63f4650e 100644 --- a/src/wf_psf/data/data_handler.py +++ b/src/wf_psf/data/data_handler.py @@ -17,79 +17,122 @@ import wf_psf.utils.utils as utils import tensorflow as tf from fractions import Fraction +from typing import Optional, Union import logging logger = logging.getLogger(__name__) class DataHandler: - """Data Handler. - - This class manages loading and processing of training, testing and inference data for use during PSF model training, inference, and validation. + """ + DataHandler for WaveDiff PSF modeling. - It provides methods to access and preprocess the data. + This class manages loading, preprocessing, and TensorFlow conversion of datasets used + for PSF model training, testing, inference, and validation in the WaveDiff framework. Parameters ---------- dataset_type : str - Type of dataset ("train" or "test"). + Indicates the dataset mode ("train", "test", or "inference"). data_params : RecursiveNamespace - Recursive Namespace object containing parameters for both 'train' and 'test' datasets. + Configuration object containing dataset parameters (e.g., file paths, preprocessing flags). simPSF : PSFSimulator - Instance of the PSFSimulator class for simulating PSF models. + An instance of the PSFSimulator class used to encode SEDs into a TensorFlow-compatible format. n_bins_lambda : int - Number of wavelength bins for SED processing. + Number of wavelength bins used to discretize SEDs. load_data : bool, optional - If True, data is loaded and processed during initialization. If False, data loading - is deferred until explicitly called. Default is True. + If True (default), loads and processes data during initialization. If False, data loading + must be triggered explicitly. + dataset : dict or list, optional + If provided, uses this pre-loaded dataset instead of triggering automatic loading. + sed_data : dict or list, optional + If provided, uses this SED data directly instead of extracting it from the dataset. Attributes ---------- - dataset_type: str - A string indicating the type of dataset ("train", "test" or "inference"). - data_params: Recursive Namespace object - A Recursive Namespace object containing training, test or inference data parameters. - dataset: dict - A dictionary containing the loaded dataset, including positions and stars/noisy_stars. - simPSF: object - An instance of the SimPSFToolkit class for simulating PSF. - n_bins_lambda: int - The number of bins in wavelength. - sed_data: tf.Tensor - A TensorFlow tensor containing the SED data for training/testing. - load_data_on_init: bool, optional - A flag used to control data loading steps. If True, data is loaded and processed - during initialization. If False, data loading is deferred until explicitly called. + dataset_type : str + Indicates the dataset mode ("train", "test", or "inference"). + data_params : RecursiveNamespace + Configuration parameters for data access and structure. + simPSF : PSFSimulator + Simulator used to transform SEDs into TensorFlow-ready tensors. + n_bins_lambda : int + Number of wavelength bins in the SED representation. + load_data_on_init : bool + Whether data was loaded automatically during initialization. + dataset : dict + Loaded dataset including keys such as 'positions', 'stars', 'noisy_stars', or similar. + sed_data : tf.Tensor + TensorFlow-formatted SED data with shape [batch_size, n_bins_lambda, features]. """ def __init__( - self, dataset_type, data_params, simPSF, n_bins_lambda, load_data: bool = True + self, + dataset_type, + data_params, + simPSF, + n_bins_lambda, + load_data: bool = True, + dataset: Optional[Union[dict, list]] = None, + sed_data: Optional[Union[dict, list]] = None, ): """ - Initialize the dataset handler for PSF simulation. + Initialize the DataHandler for PSF dataset preparation. + + This constructor sets up the dataset handler used for PSF simulation tasks, + such as training, testing, or inference. It supports three modes of use: + + 1. **Manual mode** (`load_data=False`, no `dataset`): data loading and SED processing + must be triggered manually via `load_dataset()` and `process_sed_data()`. + 2. **Pre-loaded dataset mode** (`dataset` is provided): the given dataset is used directly, + and `process_sed_data()` is called with either the given `sed_data` or `dataset["SEDs"]`. + 3. **Automatic loading mode** (`load_data=True` and no `dataset`): the dataset is loaded + from disk using `data_params`, and SEDs are extracted and processed automatically. Parameters ---------- dataset_type : str - A string indicating the type of data ("train", "test", or "inference"). - data_params : Recursive Namespace object - A Recursive Namespace object containing parameters for both 'train', 'test', 'inference' datasets. + One of {"train", "test", "inference"} indicating dataset usage. + data_params : RecursiveNamespace + Configuration object with paths, preprocessing options, and metadata. simPSF : PSFSimulator - Instance of the PSFSimulator class for simulating PSF models. + Used to convert SEDs to TensorFlow format. n_bins_lambda : int - Number of wavelength bins for SED processing. + Number of wavelength bins for the SEDs. load_data : bool, optional - If True, data is loaded and processed during initialization. If False, data loading - is deferred. Default is True. + Whether to automatically load and process the dataset (default: True). + dataset : dict or list, optional + A pre-loaded dataset to use directly (overrides `load_data`). + sed_data : array-like, optional + Pre-loaded SED data to use directly. If not provided but `dataset` is, + SEDs are taken from `dataset["SEDs"]`. + + Raises + ------ + ValueError + If SEDs cannot be found in either `dataset` or as `sed_data`. + + Notes + ----- + - `self.dataset` and `self.sed_data` are both `None` if neither `dataset` nor + `load_data=True` is used. + - TensorFlow conversion is performed at the end of initialization via `convert_dataset_to_tensorflow()`. """ + self.dataset_type = dataset_type - self.data_params = data_params.__dict__[dataset_type] + self.data_params = data_params self.simPSF = simPSF self.n_bins_lambda = n_bins_lambda self.load_data_on_init = load_data - if self.load_data_on_init: + + if dataset is not None: + self.dataset = dataset + self.process_sed_data(sed_data) + self.validate_and_process_dataset() + elif self.load_data_on_init: self.load_dataset() - self.process_sed_data() + self.process_sed_data(self.dataset["SEDs"]) + self.validate_and_process_dataset() else: self.dataset = None self.sed_data = None @@ -104,6 +147,34 @@ def load_dataset(self): os.path.join(self.data_params.data_dir, self.data_params.file), allow_pickle=True, )[()] + + def validate_and_process_dataset(self): + """Validate the dataset structure and convert fields to TensorFlow tensors.""" + self._validate_dataset_structure() + self._convert_dataset_to_tensorflow() + + def _validate_dataset_structure(self): + """Validate dataset structure based on dataset_type.""" + if self.dataset is None: + raise ValueError("Dataset is None") + + if "positions" not in self.dataset: + raise ValueError("Dataset missing required field: 'positions'") + + if self.dataset_type == "train": + if "noisy_stars" not in self.dataset: + logger.warning("Missing 'noisy_stars' in 'train' dataset.") + elif self.dataset_type == "test": + if "stars" not in self.dataset: + logger.warning("Missing 'stars' in 'test' dataset.") + elif self.dataset_type == "inference": + pass + else: + logger.warning(f"Unrecognized dataset_type: {self.dataset_type}") + + def _convert_dataset_to_tensorflow(self): + """Convert dataset to TensorFlow tensors.""" + self.dataset["positions"] = tf.convert_to_tensor( self.dataset["positions"], dtype=tf.float32 ) @@ -119,22 +190,45 @@ def load_dataset(self): self.dataset["stars"] = tf.convert_to_tensor( self.dataset["stars"], dtype=tf.float32 ) - else: - logger.warning(f"Missing 'stars' in {self.dataset_type} dataset.") - elif "inference" == self.dataset_type: - pass - def process_sed_data(self): - """Process SED Data. + def process_sed_data(self, sed_data): + """ + Generate and process SED (Spectral Energy Distribution) data. - A method to generate and process SED data. + This method transforms raw SED inputs into TensorFlow tensors suitable for model input. + It generates wavelength-binned SED elements using the PSF simulator, converts the result + into a tensor, and transposes it to match the expected shape for training or inference. + Parameters + ---------- + sed_data : list or array-like + A list or array of raw SEDs, where each SED is typically a vector of flux values + or coefficients. These will be processed using the PSF simulator. + + Raises + ------ + ValueError + If `sed_data` is None. + + Notes + ----- + The resulting tensor is stored in `self.sed_data` and has shape + `(num_samples, n_bins_lambda, n_components)`, where: + - `num_samples` is the number of SEDs, + - `n_bins_lambda` is the number of wavelength bins, + - `n_components` is the number of components per SED (e.g., filters or basis terms). + + The intermediate tensor is created with `tf.float64` for precision during generation, + but is converted to `tf.float32` after processing for use in training. """ + if sed_data is None: + raise ValueError("SED data must be provided explicitly or via dataset.") + self.sed_data = [ utils.generate_SED_elems_in_tensorflow( _sed, self.simPSF, n_bins=self.n_bins_lambda, tf_dtype=tf.float64 ) - for _sed in self.dataset["SEDs"] + for _sed in sed_data ] self.sed_data = tf.convert_to_tensor(self.sed_data, dtype=tf.float32) self.sed_data = tf.transpose(self.sed_data, perm=[0, 2, 1]) From e28fe00b8c2cef8a1e2db5af4bbb5b3cee2c53c8 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Sun, 18 May 2025 12:23:27 +0200 Subject: [PATCH 027/135] Update unit tests associated to changes in data_handler.py --- .../tests/test_data/data_handler_test.py | 58 ++++++++++++++----- .../tests/test_utils/configs_handler_test.py | 17 +++++- 2 files changed, 59 insertions(+), 16 deletions(-) diff --git a/src/wf_psf/tests/test_data/data_handler_test.py b/src/wf_psf/tests/test_data/data_handler_test.py index 5d167a6b..42e19a71 100644 --- a/src/wf_psf/tests/test_data/data_handler_test.py +++ b/src/wf_psf/tests/test_data/data_handler_test.py @@ -11,6 +11,26 @@ from unittest.mock import patch +def mock_sed(): + # Create a fake SED with shape (n_wavelengths,) — match what your real SEDs look like + return np.linspace(0.1, 1.0, 50) + + +def test_process_sed_data(data_params, simPSF): + # Test processing SED data without initialization + data_handler = DataHandler( + "training", data_params, simPSF, n_bins_lambda=10, load_data=False + ) + assert data_handler.sed_data is None # SED data should not be processed + + +def test_process_sed_data_auto_load(data_params, simPSF): + # load_data=True → dataset is used and SEDs processed automatically + data_handler = DataHandler( + "training", data_params, simPSF, n_bins_lambda=10, load_data=True + ) + + def test_load_train_dataset(tmp_path, data_params, simPSF): # Create a temporary directory and a temporary data file data_dir = tmp_path / "data" @@ -71,7 +91,11 @@ def test_load_test_dataset(tmp_path, data_params, simPSF): n_bins_lambda = 10 data_handler = DataHandler( - "test", data_params, simPSF, n_bins_lambda, load_data=False + dataset_type="test", + data_params=data_params.test, + simPSF=simPSF, + n_bins_lambda=n_bins_lambda, + load_data=False, ) # Call the load_dataset method @@ -83,8 +107,8 @@ def test_load_test_dataset(tmp_path, data_params, simPSF): assert np.array_equal(data_handler.dataset["SEDs"], mock_dataset["SEDs"]) -def test_load_train_dataset_missing_noisy_stars(tmp_path, data_params, simPSF): - """Test that a warning is raised if 'noisy_stars' is missing in training data.""" +def test_validate_train_dataset_missing_noisy_stars_raises(tmp_path, simPSF): + """Test that validation raises an error if 'noisy_stars' is missing in training data.""" data_dir = tmp_path / "data" data_dir.mkdir() temp_data_file = data_dir / "train_data.npy" @@ -140,24 +164,32 @@ def test_load_test_dataset_missing_stars(tmp_path, data_params, simPSF): def test_process_sed_data(data_params, simPSF): mock_dataset = { "positions": np.array([[1, 2], [3, 4]]), - "noisy_stars": np.array([[5, 6], [7, 8]]), "SEDs": np.array([[[0.1, 0.2], [0.3, 0.4]], [[0.5, 0.6], [0.7, 0.8]]]), + # Missing 'noisy_stars' } # Initialize DataHandler instance n_bins_lambda = 4 data_handler = DataHandler("training", data_params, simPSF, n_bins_lambda, False) - data_handler.dataset = mock_dataset - data_handler.process_sed_data() - # Assertions - assert isinstance(data_handler.sed_data, tf.Tensor) - assert data_handler.sed_data.dtype == tf.float32 - assert data_handler.sed_data.shape == ( - len(data_handler.dataset["positions"]), - n_bins_lambda, - len(["feasible_N", "feasible_wv", "SED_norm"]), + np.save(temp_data_file, mock_dataset) + + data_params = RecursiveNamespace(data_dir=str(data_dir), file="train_data.npy") + + data_handler = DataHandler( + dataset_type="train", + data_params=data_params, + simPSF=simPSF, + n_bins_lambda=10, + load_data=False, ) + data_handler.load_dataset() + data_handler.process_sed_data(mock_dataset["SEDs"]) + + with patch("wf_psf.data.data_handler.logger.warning") as mock_warning: + data_handler._validate_dataset_structure() + mock_warning.assert_called_with("Missing 'noisy_stars' in 'train' dataset.") + def test_get_obs_positions(mock_data): observed_positions = get_obs_positions(mock_data) diff --git a/src/wf_psf/tests/test_utils/configs_handler_test.py b/src/wf_psf/tests/test_utils/configs_handler_test.py index 2ca1d17d..f8299997 100644 --- a/src/wf_psf/tests/test_utils/configs_handler_test.py +++ b/src/wf_psf/tests/test_utils/configs_handler_test.py @@ -119,10 +119,21 @@ def test_data_config_handler_init(mock_training_conf, mock_data_read_conf, mocke "wf_psf.psf_models.psf_models.simPSF", return_value=mock_simPSF_instance ) - # Patch the load_dataset and process_sed_data methods inside DataHandler - mocker.patch.object(DataHandler, "load_dataset") + # Patch process_sed_data method mocker.patch.object(DataHandler, "process_sed_data") + # Patch validate_and_process_datasetmethod + mocker.patch.object(DataHandler, "validate_and_process_dataset") + + # Patch load_dataset to assign dataset + def mock_load_dataset(self): + self.dataset = { + "SEDs": ["dummy_sed_data"], + "positions": ["dummy_positions_data"], + } + + mocker.patch.object(DataHandler, "load_dataset", new=mock_load_dataset) + # Create DataConfigHandler instance data_config_handler = DataConfigHandler( "/path/to/data_config.yaml", @@ -144,7 +155,7 @@ def test_data_config_handler_init(mock_training_conf, mock_data_read_conf, mocke assert ( data_config_handler.batch_size == mock_training_conf.training.training_hparams.batch_size - ) # Default value + ) def test_training_config_handler_init(mocker, mock_training_conf, mock_file_handler): From 29e3abf94b711719a02299d94585b63e1a2ecfa4 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Sun, 18 May 2025 12:24:21 +0200 Subject: [PATCH 028/135] Change exception handling in DataConfigHandler; modify args to DataHandler --- src/wf_psf/utils/configs_handler.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/wf_psf/utils/configs_handler.py b/src/wf_psf/utils/configs_handler.py index b7ca6444..a7e42ceb 100644 --- a/src/wf_psf/utils/configs_handler.py +++ b/src/wf_psf/utils/configs_handler.py @@ -129,28 +129,31 @@ class DataConfigHandler: def __init__(self, data_conf, training_model_params, batch_size=16, load_data=True): try: self.data_conf = read_conf(data_conf) - except FileNotFoundError as e: - logger.exception(e) - exit() - except TypeError as e: + except (FileNotFoundError, TypeError) as e: logger.exception(e) exit() self.simPSF = psf_models.simPSF(training_model_params) + + # Extract sub-configs early + train_params = self.data_conf.data.training + test_params = self.data_conf.data.test + self.training_data = DataHandler( dataset_type="training", - data_params=self.data_conf.data, + data_params=train_params, simPSF=self.simPSF, n_bins_lambda=training_model_params.n_bins_lda, load_data=load_data, ) self.test_data = DataHandler( dataset_type="test", - data_params=self.data_conf.data, + data_params=test_params, simPSF=self.simPSF, n_bins_lambda=training_model_params.n_bins_lda, load_data=load_data, ) + self.batch_size = batch_size From d1c86731cb3ff669f7ed55f747a5ce5c88f68851 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Mon, 19 May 2025 10:49:14 +0200 Subject: [PATCH 029/135] Add data and psf_model_imports into inference and sketch out methods --- src/wf_psf/inference/psf_inference.py | 59 +++++++++++++++++++++++++-- 1 file changed, 55 insertions(+), 4 deletions(-) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index 4f5b39e5..5437ee61 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -12,10 +12,61 @@ import glob import logging import numpy as np -from wf_psf.psf_models import psf_models, psf_model_loader +from wf_psf.data.data_handler import DataHandler +from wf_psf.psf_models import psf_models +from wf_psf.psf_models.psf_model_loader import load_trained_psf_model import tensorflow as tf +def prepare_inputs(dataset): + + # Convert dataset to tensorflow Dataset + dataset["positions"] = tf.convert_to_tensor(dataset["positions"], dtype=tf.float32) + + + +def get_trained_psf_model(model_path, model_dir_name, cycle, training_conf, data_conf): + + trained_model_path = model_path + model_subdir = model_dir_name + cycle = cycle + + model_name = training_conf.training.model_params.model_name + id_name = training_conf.training.id_name + + weights_path_pattern = os.path.join( + trained_model_path, + model_subdir, + ( + f"{model_subdir}*_{model_name}" + f"*{id_name}_cycle{cycle}*" + ), + ) + return load_trained_psf_model( + training_conf, + data_conf, + weights_path_pattern, + ) + + +def generate_psfs(psf_model, inputs): + pass + + +def run_pipeline(): + psf_model = get_trained_psf_model( + model_path, + model_dir, + cycle, + training_conf, + data_conf + ) + inputs = prepare_inputs( + + ) + psfs = generate_psfs( + psf_model, + inputs, + batch_size=1, + ) + return psfs -#def prepare_inputs(...): ... -#def generate_psfs(...): ... -#def run_pipeline(...): ... From 0672d33252cfd1ca787516f3d481df6ce7fdc8ea Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Mon, 19 May 2025 11:56:50 +0200 Subject: [PATCH 030/135] add base psf inference --- src/wf_psf/inference/psf_inference.py | 182 ++++++++++++++++++-------- 1 file changed, 127 insertions(+), 55 deletions(-) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index 5437ee61..6787be61 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -1,8 +1,8 @@ """Inference. -A module which provides a set of functions to perform inference -on PSF models. It includes functions to load a trained model, -perform inference on a dataset of SEDs and positions, and generate a polychromatic PSF. +A module which provides a PSFInference class to perform inference +with trained PSF models. It is able to load a trained model, +perform inference on a dataset of SEDs and positions, and generate polychromatic PSFs. :Authors: Jennifer Pollack @@ -13,60 +13,132 @@ import logging import numpy as np from wf_psf.data.data_handler import DataHandler +from wf_psf.utils.read_config import read_conf from wf_psf.psf_models import psf_models from wf_psf.psf_models.psf_model_loader import load_trained_psf_model import tensorflow as tf +from typing import Optional -def prepare_inputs(dataset): - - # Convert dataset to tensorflow Dataset - dataset["positions"] = tf.convert_to_tensor(dataset["positions"], dtype=tf.float32) - - - -def get_trained_psf_model(model_path, model_dir_name, cycle, training_conf, data_conf): - - trained_model_path = model_path - model_subdir = model_dir_name - cycle = cycle - - model_name = training_conf.training.model_params.model_name - id_name = training_conf.training.id_name - - weights_path_pattern = os.path.join( - trained_model_path, - model_subdir, - ( - f"{model_subdir}*_{model_name}" - f"*{id_name}_cycle{cycle}*" - ), - ) - return load_trained_psf_model( - training_conf, - data_conf, - weights_path_pattern, - ) - - -def generate_psfs(psf_model, inputs): - pass - - -def run_pipeline(): - psf_model = get_trained_psf_model( - model_path, - model_dir, - cycle, - training_conf, - data_conf - ) - inputs = prepare_inputs( - - ) - psfs = generate_psfs( - psf_model, - inputs, - batch_size=1, - ) - return psfs +class PSFInference: + """Class to perform inference on PSF models.""" + + def __init__( + self, + trained_model_path: str, + model_subdir: str, + cycle: int, + training_conf_path: str, + data_conf_path: str, + batch_size: Optional[int] = None, + ): + self.trained_model_path = trained_model_path + self.model_subdir = model_subdir + self.cycle = cycle + self.training_conf_path = training_conf_path + self.data_conf_path = data_conf_path + + # Set source parameters + self.x_field = None + self.y_field = None + self.seds = None + self.trained_psf_model = None + + # Load the training and data configurations + self.training_conf = read_conf(training_conf_path) + self.data_conf = read_conf(data_conf_path) + + # Set the number of labmda bins + self.n_bins_lambda = self.training_conf.training.model_params.n_bins_lambda + + # Set the batch size + self.batch_size = ( + batch_size + if batch_size is not None + else self.training_conf.training.model_params.batch_size + ) + + # Instantiate the PSF simulator object + self.simPSF = psf_models.simPSF(self.training_conf.training.model_params) + + # Instantiate the data handler + self.data_handler = DataHandler( + dataset_type="inference", + data_params=self.data_conf, + simPSF=self.simPSF, + n_bins_lambda=self.n_bins_lambda, + load_data=False, + ) + + # Load the trained PSF model + self.trained_psf_model = self.get_trained_psf_model() + + def get_trained_psf_model(self): + """Get the trained PSF model.""" + + model_name = self.training_conf.training.model_params.model_name + id_name = self.training_conf.training.id_name + + weights_path_pattern = os.path.join( + self.trained_model_path, + self.model_subdir, + (f"{self.model_subdir}*_{model_name}" f"*{id_name}_cycle{self.cycle}*"), + ) + return load_trained_psf_model( + self.training_conf, + self.data_conf, + weights_path_pattern, + ) + + def set_source_parameters(self, x_field, y_field, seds): + """Set the input source parameters for inferring the PSF. + + Parameters + ---------- + x_field : array-like + X coordinates of the sources in WaveDiff format. + y_field : array-like + Y coordinates of the sources in WaveDiff format. + seds : list or array-like + A list or array of raw SEDs, where each SED is typically a vector of flux values + or coefficients. These will be processed using the PSF simulator. + It assumes the standard WaveDiff SED format. + + """ + # Positions array is of shape (n_sources, 2) + self.positions = tf.convert_to_tensor( + np.array([x_field, y_field]).T, dtype=tf.float32 + ) + # Process SED data + self.sed_data = self.data_handler.process_sed_data(seds) + + def get_psfs(self): + """Generate PSFs on the input source parameters.""" + + while counter < n_samples: + # Calculate the batch end element + if counter + batch_size <= n_samples: + end_sample = counter + batch_size + else: + end_sample = n_samples + + # Define the batch positions + batch_pos = pos[counter:end_sample, :] + + inputs = [self.positions, self.sed_data] + poly_psfs = self.trained_psf_model(inputs, training=False) + + return poly_psfs + + +# def run_pipeline(): +# psf_model = get_trained_psf_model( +# model_path, model_dir, cycle, training_conf, data_conf +# ) +# inputs = prepare_inputs() +# psfs = generate_psfs( +# psf_model, +# inputs, +# batch_size=1, +# ) +# return psfs From 9fbf8b154279a57d3df7d9eb41b069140689fd53 Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Mon, 19 May 2025 11:57:05 +0200 Subject: [PATCH 031/135] add common call interface through PSF models --- src/wf_psf/psf_models/models/psf_model_parametric.py | 2 +- src/wf_psf/psf_models/models/psf_model_semiparametric.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/wf_psf/psf_models/models/psf_model_parametric.py b/src/wf_psf/psf_models/models/psf_model_parametric.py index 3d85d2bc..4a28417d 100644 --- a/src/wf_psf/psf_models/models/psf_model_parametric.py +++ b/src/wf_psf/psf_models/models/psf_model_parametric.py @@ -215,7 +215,7 @@ def predict_opd(self, input_positions): return opd_maps - def call(self, inputs): + def call(self, inputs, **kwargs): """Define the PSF field forward model. [1] From positions to Zernike coefficients diff --git a/src/wf_psf/psf_models/models/psf_model_semiparametric.py b/src/wf_psf/psf_models/models/psf_model_semiparametric.py index c370956c..7b2ff04d 100644 --- a/src/wf_psf/psf_models/models/psf_model_semiparametric.py +++ b/src/wf_psf/psf_models/models/psf_model_semiparametric.py @@ -421,7 +421,7 @@ def project_DD_features(self, tf_zernike_cube=None): s_new = self.tf_np_poly_opd.S_mat - s_mat_projected self.assign_S_mat(s_new) - def call(self, inputs): + def call(self, inputs, **kwargs): """Define the PSF field forward model. [1] From positions to Zernike coefficients From b88d5625a7e779b680771a014fde2e0a30f009ca Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Mon, 19 May 2025 14:35:29 +0200 Subject: [PATCH 032/135] add handling of inference params --- src/wf_psf/inference/psf_inference.py | 101 +++++++++++++++++++++----- 1 file changed, 81 insertions(+), 20 deletions(-) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index 6787be61..a5ca12a6 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -27,16 +27,15 @@ def __init__( self, trained_model_path: str, model_subdir: str, - cycle: int, training_conf_path: str, data_conf_path: str, - batch_size: Optional[int] = None, + inference_conf_path: str, ): self.trained_model_path = trained_model_path self.model_subdir = model_subdir - self.cycle = cycle self.training_conf_path = training_conf_path self.data_conf_path = data_conf_path + self.inference_conf_path = inference_conf_path # Set source parameters self.x_field = None @@ -44,18 +43,24 @@ def __init__( self.seds = None self.trained_psf_model = None + # Set compute PSF placeholder + self.inferred_psfs = None + # Load the training and data configurations self.training_conf = read_conf(training_conf_path) self.data_conf = read_conf(data_conf_path) + self.inference_conf = read_conf(inference_conf_path) # Set the number of labmda bins - self.n_bins_lambda = self.training_conf.training.model_params.n_bins_lambda - + self.n_bins_lambda = self.inference_conf.inference.model_params.n_bins_lda # Set the batch size - self.batch_size = ( - batch_size - if batch_size is not None - else self.training_conf.training.model_params.batch_size + self.batch_size = self.inference_conf.inference.batch_size + # Set the cycle to use for inference + self.cycle = self.inference_conf.inference.cycle + + # Overwrite the model parameters with the inference configuration + self.training_conf.training.model_params = self.overwrite_model_params( + self.training_conf, self.inference_conf ) # Instantiate the PSF simulator object @@ -73,6 +78,29 @@ def __init__( # Load the trained PSF model self.trained_psf_model = self.get_trained_psf_model() + @staticmethod + def overwrite_model_params(training_conf=None, inference_conf=None): + """Overwrite model_params of the training_conf with the inference_conf. + + Parameters + ---------- + training_conf : RecursiveNamespace + Configuration object containing model parameters and training hyperparameters. + inference_conf : RecursiveNamespace + Configuration object containing inference-related parameters. + + """ + model_params = training_conf.training.model_params + inf_model_params = inference_conf.inference.model_params + if model_params is not None and inf_model_params is not None: + for key, value in inf_model_params.__dict__.items(): + # Check if model_params has the attribute + if hasattr(model_params, key): + # Set the attribute of model_params to the new value + setattr(model_params, key, value) + + return model_params + def get_trained_psf_model(self): """Get the trained PSF model.""" @@ -110,25 +138,58 @@ def set_source_parameters(self, x_field, y_field, seds): np.array([x_field, y_field]).T, dtype=tf.float32 ) # Process SED data - self.sed_data = self.data_handler.process_sed_data(seds) - - def get_psfs(self): - """Generate PSFs on the input source parameters.""" + self.data_handler.process_sed_data(seds) + self.sed_data = self.data_handler.sed_data + + def compute_psfs(self): + """Compute the PSFs for the input source parameters.""" + + # Check if source parameters are set + if self.positions is None or self.sed_data is None: + raise ValueError( + "Source parameters not set. Call set_source_parameters first." + ) + + # Get the number of samples + n_samples = self.positions.shape[0] + # Initialize counter + counter = 0 + # Initialize PSF array + self.inferred_psfs = np.zeros((n_samples,)) + psf_array = [] while counter < n_samples: # Calculate the batch end element - if counter + batch_size <= n_samples: - end_sample = counter + batch_size + if counter + self.batch_size <= n_samples: + end_sample = counter + self.batch_size else: end_sample = n_samples - # Define the batch positions - batch_pos = pos[counter:end_sample, :] + # Define the batch positions + batch_pos = self.positions[counter:end_sample, :] + batch_seds = self.sed_data[counter:end_sample, :, :] + + # Generate PSFs for the current batch + batch_inputs = [batch_pos, batch_seds] + batch_poly_psfs = self.trained_psf_model(batch_inputs, training=False) + + # Append to the PSF array + psf_array.append(poly_psfs) + + # Update the counter + counter += self.batch_size + + return tf.concat(psf_array, axis=0) + + def get_psfs(self) -> np.ndarray: + """Get all the generated PSFs.""" + + pass - inputs = [self.positions, self.sed_data] - poly_psfs = self.trained_psf_model(inputs, training=False) + def get_psf(self, index) -> np.ndarray: + """Generate the generated PSF at a specific index.""" - return poly_psfs + pass # def run_pipeline(): From aa9bad361ac741edc8da9042e4b94e555f3c5617 Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Mon, 19 May 2025 14:45:24 +0200 Subject: [PATCH 033/135] automatic formatting --- src/wf_psf/data/data_handler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/wf_psf/data/data_handler.py b/src/wf_psf/data/data_handler.py index 63f4650e..84af1eba 100644 --- a/src/wf_psf/data/data_handler.py +++ b/src/wf_psf/data/data_handler.py @@ -28,7 +28,7 @@ class DataHandler: DataHandler for WaveDiff PSF modeling. This class manages loading, preprocessing, and TensorFlow conversion of datasets used - for PSF model training, testing, inference, and validation in the WaveDiff framework. + for PSF model training, testing, and inference in the WaveDiff framework. Parameters ---------- From 98e4806c540a3ff95216851477ba5a66d07508bc Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Mon, 19 May 2025 14:45:43 +0200 Subject: [PATCH 034/135] add first completed class draft --- src/wf_psf/inference/psf_inference.py | 49 +++++++++++++++++++++------ 1 file changed, 39 insertions(+), 10 deletions(-) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index a5ca12a6..80a2c1c1 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -21,7 +21,23 @@ class PSFInference: - """Class to perform inference on PSF models.""" + """Class to perform inference on PSF models. + + + Parameters + ---------- + trained_model_path : str + Path to the directory containing the trained model. + model_subdir : str + Subdirectory name of the trained model. + training_conf_path : str + Path to the training configuration file used to train the model. + data_conf_path : str + Path to the data configuration file. + inference_conf_path : str + Path to the inference configuration file. + + """ def __init__( self, @@ -55,8 +71,11 @@ def __init__( self.n_bins_lambda = self.inference_conf.inference.model_params.n_bins_lda # Set the batch size self.batch_size = self.inference_conf.inference.batch_size + assert self.batch_size > 0, "Batch size must be greater than 0." # Set the cycle to use for inference self.cycle = self.inference_conf.inference.cycle + # Get output psf dimensions + self.output_dim = self.inference_conf.inference.model_params.output_dim # Overwrite the model parameters with the inference configuration self.training_conf.training.model_params = self.overwrite_model_params( @@ -73,6 +92,7 @@ def __init__( simPSF=self.simPSF, n_bins_lambda=self.n_bins_lambda, load_data=False, + dataset=None, ) # Load the trained PSF model @@ -155,8 +175,7 @@ def compute_psfs(self): # Initialize counter counter = 0 # Initialize PSF array - self.inferred_psfs = np.zeros((n_samples,)) - psf_array = [] + self.inferred_psfs = np.zeros((n_samples, self.output_dim, self.output_dim)) while counter < n_samples: # Calculate the batch end element @@ -174,22 +193,32 @@ def compute_psfs(self): batch_poly_psfs = self.trained_psf_model(batch_inputs, training=False) # Append to the PSF array - psf_array.append(poly_psfs) + self.inferred_psfs[counter:end_sample, :, :] = batch_poly_psfs.numpy() # Update the counter counter += self.batch_size - return tf.concat(psf_array, axis=0) - def get_psfs(self) -> np.ndarray: - """Get all the generated PSFs.""" + """Get all the generated PSFs. - pass + Returns + ------- + np.ndarray + The generated PSFs for the input source parameters. + Shape is (n_samples, output_dim, output_dim). + """ + return self.inferred_psfs def get_psf(self, index) -> np.ndarray: - """Generate the generated PSF at a specific index.""" + """Generate the generated PSF at a specific index. - pass + Returns + ------- + np.ndarray + The generated PSFs for the input source parameters. + Shape is (output_dim, output_dim). + """ + return self.inferred_psfs[index] # def run_pipeline(): From e3ec2d2fc18c03d69fa949fad3eb8d03afd65b00 Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Mon, 19 May 2025 14:45:55 +0200 Subject: [PATCH 035/135] add inference config file --- config/inference_conf.yaml | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 config/inference_conf.yaml diff --git a/config/inference_conf.yaml b/config/inference_conf.yaml new file mode 100644 index 00000000..af67cde6 --- /dev/null +++ b/config/inference_conf.yaml @@ -0,0 +1,28 @@ + +inference: + # Inference batch size + batch_size: 16 + + # Cycle to use for inference. Can be: 1, 2, ... + cycle: 2 + + # The following parameters will overwrite the `model_params` in the training config file. + model_params: + # Num of wavelength bins to reconstruct polychromatic objects. + n_bins_lda: 20 + + # Downsampling rate to match the oversampled model to the specified telescope's sampling. + output_Q: 3 + + # Oversampling rate used for the OPD/WFE PSF model. + oversampling_rate: 3 + + # Dimension of the pixel PSF postage stamp + output_dim: 32 + + # Dimension of the OPD/Wavefront space. + pupil_diameter: 256 + + + + From 634c81c05bcc67b1f98af241c40318900d221667 Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Mon, 19 May 2025 14:46:16 +0200 Subject: [PATCH 036/135] remove unused code --- src/wf_psf/inference/psf_inference.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index 80a2c1c1..97db58da 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -219,16 +219,3 @@ def get_psf(self, index) -> np.ndarray: Shape is (output_dim, output_dim). """ return self.inferred_psfs[index] - - -# def run_pipeline(): -# psf_model = get_trained_psf_model( -# model_path, model_dir, cycle, training_conf, data_conf -# ) -# inputs = prepare_inputs() -# psfs = generate_psfs( -# psf_model, -# inputs, -# batch_size=1, -# ) -# return psfs From 9075fa6b393958e52b2f36452c3552cba7c3fa27 Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Mon, 19 May 2025 14:56:51 +0200 Subject: [PATCH 037/135] update params --- config/inference_conf.yaml | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/config/inference_conf.yaml b/config/inference_conf.yaml index af67cde6..cf4de4ff 100644 --- a/config/inference_conf.yaml +++ b/config/inference_conf.yaml @@ -12,17 +12,8 @@ inference: n_bins_lda: 20 # Downsampling rate to match the oversampled model to the specified telescope's sampling. - output_Q: 3 - - # Oversampling rate used for the OPD/WFE PSF model. - oversampling_rate: 3 + output_Q: 1 # Dimension of the pixel PSF postage stamp output_dim: 32 - - # Dimension of the OPD/Wavefront space. - pupil_diameter: 256 - - - From 1cc14cc83deb98919a27f8962c1d2ef4a6b32cf9 Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Mon, 19 May 2025 14:57:17 +0200 Subject: [PATCH 038/135] update params --- config/inference_conf.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/config/inference_conf.yaml b/config/inference_conf.yaml index cf4de4ff..afb0174c 100644 --- a/config/inference_conf.yaml +++ b/config/inference_conf.yaml @@ -9,11 +9,11 @@ inference: # The following parameters will overwrite the `model_params` in the training config file. model_params: # Num of wavelength bins to reconstruct polychromatic objects. - n_bins_lda: 20 + n_bins_lda: 8 # Downsampling rate to match the oversampled model to the specified telescope's sampling. output_Q: 1 # Dimension of the pixel PSF postage stamp - output_dim: 32 + output_dim: 64 From 72c80b61e5e4206ae97fcfee4fc279a2a0abcc93 Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Mon, 19 May 2025 17:04:40 +0200 Subject: [PATCH 039/135] update inference --- config/inference_conf.yaml | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/config/inference_conf.yaml b/config/inference_conf.yaml index afb0174c..7e971957 100644 --- a/config/inference_conf.yaml +++ b/config/inference_conf.yaml @@ -2,10 +2,22 @@ inference: # Inference batch size batch_size: 16 - # Cycle to use for inference. Can be: 1, 2, ... cycle: 2 + configs: + # Path to the directory containing the trained model + training_config_path: models/ + + # Subdirectory name of the trained model + model_subdir: models + + # Path to the training configuration file used to train the model + trained_model_path: config/training_config.yaml + + # Path to the data config file (this could contain prior information) + data_conf_path: + # The following parameters will overwrite the `model_params` in the training config file. model_params: # Num of wavelength bins to reconstruct polychromatic objects. From 566668ba4ecc9911e2007f7395383b195833463b Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Mon, 19 May 2025 17:05:00 +0200 Subject: [PATCH 040/135] reduce arguments and add compute psfs when appropiate --- src/wf_psf/inference/psf_inference.py | 48 +++++++++++++-------------- 1 file changed, 23 insertions(+), 25 deletions(-) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index 97db58da..e73231cb 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -26,32 +26,31 @@ class PSFInference: Parameters ---------- - trained_model_path : str - Path to the directory containing the trained model. - model_subdir : str - Subdirectory name of the trained model. - training_conf_path : str - Path to the training configuration file used to train the model. - data_conf_path : str - Path to the data configuration file. inference_conf_path : str Path to the inference configuration file. """ - def __init__( - self, - trained_model_path: str, - model_subdir: str, - training_conf_path: str, - data_conf_path: str, - inference_conf_path: str, - ): - self.trained_model_path = trained_model_path - self.model_subdir = model_subdir - self.training_conf_path = training_conf_path - self.data_conf_path = data_conf_path + def __init__(self, inference_conf_path: str): + self.inference_conf_path = inference_conf_path + # Load the training and data configurations + self.inference_conf = read_conf(inference_conf_path) + + # Set config paths + self.config_paths = self.inference_conf.inference.configs.config_paths + self.trained_model_path = self.config_paths.trained_model_path + self.model_subdir = self.config_paths.model_subdir + self.training_config_path = self.config_paths.training_config_path + self.data_conf_path = self.config_paths.data_conf_path + + # Load the training and data configurations + self.training_conf = read_conf(self.training_conf_path) + if self.data_conf_path is not None: + # Load the data configuration + self.data_conf = read_conf(self.data_conf_path) + else: + self.data_conf = None # Set source parameters self.x_field = None @@ -62,11 +61,6 @@ def __init__( # Set compute PSF placeholder self.inferred_psfs = None - # Load the training and data configurations - self.training_conf = read_conf(training_conf_path) - self.data_conf = read_conf(data_conf_path) - self.inference_conf = read_conf(inference_conf_path) - # Set the number of labmda bins self.n_bins_lambda = self.inference_conf.inference.model_params.n_bins_lda # Set the batch size @@ -207,6 +201,8 @@ def get_psfs(self) -> np.ndarray: The generated PSFs for the input source parameters. Shape is (n_samples, output_dim, output_dim). """ + if self.inferred_psfs is None: + self.compute_psfs() return self.inferred_psfs def get_psf(self, index) -> np.ndarray: @@ -218,4 +214,6 @@ def get_psf(self, index) -> np.ndarray: The generated PSFs for the input source parameters. Shape is (output_dim, output_dim). """ + if self.inferred_psfs is None: + self.compute_psfs() return self.inferred_psfs[index] From 068d78eaae8b9ead2a2556400affd002db3cdd8b Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Mon, 19 May 2025 17:08:49 +0200 Subject: [PATCH 041/135] add config handler class --- src/wf_psf/inference/psf_inference.py | 53 +++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index e73231cb..d0f2b930 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -20,6 +20,59 @@ from typing import Optional +class InferenceConfigHandler: + ids = ("inference_conf",) + + def __init__( + self, + trained_model_path: str, + model_subdir: str, + training_conf_path: str, + data_conf_path: str, + inference_conf_path: str, + ): + self.trained_model_path = trained_model_path + self.model_subdir = model_subdir + self.training_conf_path = training_conf_path + self.data_conf_path = data_conf_path + self.inference_conf_path = inference_conf_path + + # Overwrite the model parameters with the inference configuration + self.model_params = self.overwrite_model_params( + self.training_conf, self.inference_conf + ) + + def read_configurations(self): + # Load the training and data configurations + self.training_conf = read_conf(training_conf_path) + self.data_conf = read_conf(data_conf_path) + self.inference_conf = read_conf(inference_conf_path) + + @staticmethod + def overwrite_model_params(training_conf=None, inference_conf=None): + """Overwrite model_params of the training_conf with the inference_conf. + + Parameters + ---------- + training_conf : RecursiveNamespace + Configuration object containing model parameters and training hyperparameters. + inference_conf : RecursiveNamespace + Configuration object containing inference-related parameters. + + """ + model_params = training_conf.training.model_params + inf_model_params = inference_conf.inference.model_params + + if model_params is not None and inf_model_params is not None: + for key, value in inf_model_params.__dict__.items(): + # Check if model_params has the attribute + if hasattr(model_params, key): + # Set the attribute of model_params to the new value + setattr(model_params, key, value) + + return model_params + + class PSFInference: """Class to perform inference on PSF models. From ec445304cc9bc872d780d311f0679c9a712397a8 Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Mon, 19 May 2025 17:48:47 +0200 Subject: [PATCH 042/135] set up inference config handler and simplify PSFInferenc init --- src/wf_psf/inference/psf_inference.py | 103 ++++++++++++++------------ 1 file changed, 55 insertions(+), 48 deletions(-) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index d0f2b930..57f84080 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -23,30 +23,43 @@ class InferenceConfigHandler: ids = ("inference_conf",) - def __init__( - self, - trained_model_path: str, - model_subdir: str, - training_conf_path: str, - data_conf_path: str, - inference_conf_path: str, - ): - self.trained_model_path = trained_model_path - self.model_subdir = model_subdir - self.training_conf_path = training_conf_path - self.data_conf_path = data_conf_path + def __init__(self, inference_conf_path: str): self.inference_conf_path = inference_conf_path + # Load the inference configuration + self.read_configurations() + # Overwrite the model parameters with the inference configuration self.model_params = self.overwrite_model_params( self.training_conf, self.inference_conf ) def read_configurations(self): + """Read the configuration files.""" + # Load the inference configuration + self.inference_conf = read_conf(self.inference_conf_path) + # Set config paths + self.set_config_paths() # Load the training and data configurations - self.training_conf = read_conf(training_conf_path) - self.data_conf = read_conf(data_conf_path) - self.inference_conf = read_conf(inference_conf_path) + self.training_conf = read_conf(self.training_conf_path) + if self.data_conf_path is not None: + # Load the data configuration + self.data_conf = read_conf(self.data_conf_path) + else: + self.data_conf = None + + def set_config_paths(self): + """Extract and set the configuration paths.""" + # Set config paths + self.config_paths = self.inference_conf.inference.configs.config_paths + self.trained_model_path = self.config_paths.trained_model_path + self.model_subdir = self.config_paths.model_subdir + self.training_config_path = self.config_paths.training_config_path + self.data_conf_path = self.config_paths.data_conf_path + + def get_configs(self): + """Get the configurations.""" + return (self.inference_conf, self.training_conf, self.data_conf) @staticmethod def overwrite_model_params(training_conf=None, inference_conf=None): @@ -86,48 +99,25 @@ class PSFInference: def __init__(self, inference_conf_path: str): - self.inference_conf_path = inference_conf_path - # Load the training and data configurations - self.inference_conf = read_conf(inference_conf_path) - - # Set config paths - self.config_paths = self.inference_conf.inference.configs.config_paths - self.trained_model_path = self.config_paths.trained_model_path - self.model_subdir = self.config_paths.model_subdir - self.training_config_path = self.config_paths.training_config_path - self.data_conf_path = self.config_paths.data_conf_path + self.inference_config_handler = InferenceConfigHandler( + inference_conf_path=inference_conf_path + ) - # Load the training and data configurations - self.training_conf = read_conf(self.training_conf_path) - if self.data_conf_path is not None: - # Load the data configuration - self.data_conf = read_conf(self.data_conf_path) - else: - self.data_conf = None + self.inference_conf, self.training_conf, self.data_conf = ( + self.inference_config_handler.get_configs() + ) - # Set source parameters + # Init source parameters self.x_field = None self.y_field = None self.seds = None self.trained_psf_model = None - # Set compute PSF placeholder + # Init compute PSF placeholder self.inferred_psfs = None - # Set the number of labmda bins - self.n_bins_lambda = self.inference_conf.inference.model_params.n_bins_lda - # Set the batch size - self.batch_size = self.inference_conf.inference.batch_size - assert self.batch_size > 0, "Batch size must be greater than 0." - # Set the cycle to use for inference - self.cycle = self.inference_conf.inference.cycle - # Get output psf dimensions - self.output_dim = self.inference_conf.inference.model_params.output_dim - - # Overwrite the model parameters with the inference configuration - self.training_conf.training.model_params = self.overwrite_model_params( - self.training_conf, self.inference_conf - ) + # Load inference parameters + self.load_inference_params() # Instantiate the PSF simulator object self.simPSF = psf_models.simPSF(self.training_conf.training.model_params) @@ -145,6 +135,18 @@ def __init__(self, inference_conf_path: str): # Load the trained PSF model self.trained_psf_model = self.get_trained_psf_model() + def load_inference_params(self): + """Load the inference parameters from the configuration file.""" + # Set the number of labmda bins + self.n_bins_lambda = self.inference_conf.inference.model_params.n_bins_lda + # Set the batch size + self.batch_size = self.inference_conf.inference.batch_size + assert self.batch_size > 0, "Batch size must be greater than 0." + # Set the cycle to use for inference + self.cycle = self.inference_conf.inference.cycle + # Get output psf dimensions + self.output_dim = self.inference_conf.inference.model_params.output_dim + @staticmethod def overwrite_model_params(training_conf=None, inference_conf=None): """Overwrite model_params of the training_conf with the inference_conf. @@ -188,6 +190,11 @@ def get_trained_psf_model(self): def set_source_parameters(self, x_field, y_field, seds): """Set the input source parameters for inferring the PSF. + Note + ---- + The input source parameters are expected to be in the WaveDiff format. See the simulated data + format for more details. + Parameters ---------- x_field : array-like From 6b073bde66a051502093e4e248fc265da1a3c6f3 Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Mon, 19 May 2025 17:49:47 +0200 Subject: [PATCH 043/135] remove unused imports --- src/wf_psf/inference/psf_inference.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index 57f84080..797a410f 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -4,20 +4,17 @@ with trained PSF models. It is able to load a trained model, perform inference on a dataset of SEDs and positions, and generate polychromatic PSFs. -:Authors: Jennifer Pollack +:Authors: Jennifer Pollack , Tobias Liaudat """ import os -import glob -import logging import numpy as np from wf_psf.data.data_handler import DataHandler from wf_psf.utils.read_config import read_conf from wf_psf.psf_models import psf_models from wf_psf.psf_models.psf_model_loader import load_trained_psf_model import tensorflow as tf -from typing import Optional class InferenceConfigHandler: From be08c3bc68cafaad19ef8928ad3b0b4ef02adc33 Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Mon, 19 May 2025 17:50:52 +0200 Subject: [PATCH 044/135] update inference --- config/inference_conf.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/config/inference_conf.yaml b/config/inference_conf.yaml index 7e971957..0c846fca 100644 --- a/config/inference_conf.yaml +++ b/config/inference_conf.yaml @@ -5,6 +5,7 @@ inference: # Cycle to use for inference. Can be: 1, 2, ... cycle: 2 + # Paths to the configuration files and trained model directory configs: # Path to the directory containing the trained model training_config_path: models/ From ec87d79b5933ad8c2c75641e645be2e25d927e04 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 23 May 2025 11:43:09 +0200 Subject: [PATCH 045/135] Add single-space lines to improve readability; Remove duplicated static method --- src/wf_psf/inference/psf_inference.py | 34 ++++++++------------------- 1 file changed, 10 insertions(+), 24 deletions(-) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index 797a410f..f8496453 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -35,10 +35,13 @@ def read_configurations(self): """Read the configuration files.""" # Load the inference configuration self.inference_conf = read_conf(self.inference_conf_path) + # Set config paths self.set_config_paths() + # Load the training and data configurations self.training_conf = read_conf(self.training_conf_path) + if self.data_conf_path is not None: # Load the data configuration self.data_conf = read_conf(self.data_conf_path) @@ -136,37 +139,18 @@ def load_inference_params(self): """Load the inference parameters from the configuration file.""" # Set the number of labmda bins self.n_bins_lambda = self.inference_conf.inference.model_params.n_bins_lda + # Set the batch size self.batch_size = self.inference_conf.inference.batch_size assert self.batch_size > 0, "Batch size must be greater than 0." + # Set the cycle to use for inference self.cycle = self.inference_conf.inference.cycle + # Get output psf dimensions self.output_dim = self.inference_conf.inference.model_params.output_dim - - @staticmethod - def overwrite_model_params(training_conf=None, inference_conf=None): - """Overwrite model_params of the training_conf with the inference_conf. - - Parameters - ---------- - training_conf : RecursiveNamespace - Configuration object containing model parameters and training hyperparameters. - inference_conf : RecursiveNamespace - Configuration object containing inference-related parameters. - - """ - model_params = training_conf.training.model_params - inf_model_params = inference_conf.inference.model_params - if model_params is not None and inf_model_params is not None: - for key, value in inf_model_params.__dict__.items(): - # Check if model_params has the attribute - if hasattr(model_params, key): - # Set the attribute of model_params to the new value - setattr(model_params, key, value) - - return model_params - + + def get_trained_psf_model(self): """Get the trained PSF model.""" @@ -223,8 +207,10 @@ def compute_psfs(self): # Get the number of samples n_samples = self.positions.shape[0] + # Initialize counter counter = 0 + # Initialize PSF array self.inferred_psfs = np.zeros((n_samples, self.output_dim, self.output_dim)) From b5b31f05710bba0eaee0756dc358865f9a3544c4 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Mon, 26 May 2025 16:19:40 +0200 Subject: [PATCH 046/135] Add additional PSFInference class attributes; update set_source_parameters to use class attributes; assign variable names in get_trained_psf_model; Update class and __init__ docstrings --- src/wf_psf/inference/psf_inference.py | 83 ++++++++++++++++++++++----- 1 file changed, 68 insertions(+), 15 deletions(-) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index f8496453..552abd02 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -87,17 +87,67 @@ def overwrite_model_params(training_conf=None, inference_conf=None): class PSFInference: - """Class to perform inference on PSF models. + """ + Perform PSF inference using a pre-trained WaveDiff model. + This class handles the setup for PSF inference, including loading configuration + files, instantiating the PSF simulator and data handler, and preparing the + input data required for inference. Parameters ---------- - inference_conf_path : str - Path to the inference configuration file. - + inference_conf_path : str, optional + Path to the inference configuration YAML file. This file should define + paths and parameters for the inference, training, and data configurations. + x_field : array-like, optional + Array of x field-of-view coordinates in the SHE convention to be transformed + and passed to the WaveDiff model. + y_field : array-like, optional + Array of y field-of-view coordinates in the SHE convention to be transformed + and passed to the WaveDiff model. + seds : array-like, optional + Spectral energy distributions (SEDs) for the sources being modeled. These + will be used as part of the input to the PSF simulator. + + Attributes + ---------- + inference_config_handler : InferenceConfigHandler + Handler object to load and parse inference, training, and data configs. + inference_conf : dict + Dictionary containing inference configuration settings. + training_conf : dict + Dictionary containing training configuration settings. + data_conf : dict + Dictionary containing data configuration settings. + x_field : array-like + Input x coordinates after transformation (if applicable). + y_field : array-like + Input y coordinates after transformation (if applicable). + seds : array-like + Input spectral energy distributions. + trained_psf_model : keras.Model + Loaded PSF model used for prediction. + inferred_psfs : array-like or None + Array of inferred PSF images, populated after inference is performed. + simPSF : psf_models.simPSF + PSF simulator instance initialized with training model parameters. + data_handler : DataHandler + Data handler configured for inference, used to prepare inputs to the model. + n_bins_lambda : int + Number of spectral bins used for PSF simulation (loaded from config). + + Methods + ------- + load_inference_params() + Load parameters required for inference, including spectral binning. + get_trained_psf_model() + Load and return the trained Keras model for PSF inference. + run_inference() + Run the model on the input data and generate predicted PSFs. """ - def __init__(self, inference_conf_path: str): + + def __init__(self, inference_conf_path: str, x_field=None, y_field=None, seds=None): self.inference_config_handler = InferenceConfigHandler( inference_conf_path=inference_conf_path @@ -108,9 +158,9 @@ def __init__(self, inference_conf_path: str): ) # Init source parameters - self.x_field = None - self.y_field = None - self.seds = None + self.x_field = x_field + self.y_field = y_field + self.seds = seds self.trained_psf_model = None # Init compute PSF placeholder @@ -149,18 +199,21 @@ def load_inference_params(self): # Get output psf dimensions self.output_dim = self.inference_conf.inference.model_params.output_dim - + def get_trained_psf_model(self): """Get the trained PSF model.""" + # Load the trained PSF model + model_path = self.inference_config_handler.trained_model_path + model_dir_name = self.inference_config_handler model_name = self.training_conf.training.model_params.model_name id_name = self.training_conf.training.id_name weights_path_pattern = os.path.join( - self.trained_model_path, - self.model_subdir, - (f"{self.model_subdir}*_{model_name}" f"*{id_name}_cycle{self.cycle}*"), + model_path, + model_dir_name, + (f"{model_dir_name}*_{model_name}" f"*{id_name}_cycle{self.cycle}*"), ) return load_trained_psf_model( self.training_conf, @@ -168,7 +221,7 @@ def get_trained_psf_model(self): weights_path_pattern, ) - def set_source_parameters(self, x_field, y_field, seds): + def set_source_parameters(self): """Set the input source parameters for inferring the PSF. Note @@ -190,10 +243,10 @@ def set_source_parameters(self, x_field, y_field, seds): """ # Positions array is of shape (n_sources, 2) self.positions = tf.convert_to_tensor( - np.array([x_field, y_field]).T, dtype=tf.float32 + np.array([self.x_field, self.y_field]).T, dtype=tf.float32 ) # Process SED data - self.data_handler.process_sed_data(seds) + self.data_handler.process_sed_data(self.seds) self.sed_data = self.data_handler.sed_data def compute_psfs(self): From 4c8db2bca53b208845b1d9ed7fac2d50df015c7f Mon Sep 17 00:00:00 2001 From: jeipollack Date: Tue, 27 May 2025 18:09:04 +0100 Subject: [PATCH 047/135] Add checks to convert to np.ndarray and expand dimensions if needed --- src/wf_psf/data/centroids.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/wf_psf/data/centroids.py b/src/wf_psf/data/centroids.py index 27f7894b..83ccf8c7 100644 --- a/src/wf_psf/data/centroids.py +++ b/src/wf_psf/data/centroids.py @@ -247,6 +247,19 @@ def __init__( self, im, mask=None, sigma_init=7.5, n_iter=5, auto_run=True, xc=None, yc=None ): """Initialize class attributes.""" + + # Convert to np.ndarray if not already + im = np.asarray(im) + if mask is not None: + mask = np.asarray(mask) + + # Check im dimensions convert to batch, if 2D + if im.ndim == 2: + # Single stamp → convert to batch of one + im = np.expand_dims(im, axis=0) + elif im.ndim != 3: + raise ValueError(f"Expected 2D or 3D input, got shape {im.shape}") + self.im = im self.mask = mask if self.mask is not None: From fbdd32a78956f2da19b871bbec344c846d8a19ca Mon Sep 17 00:00:00 2001 From: jeipollack Date: Thu, 5 Jun 2025 13:53:34 +0100 Subject: [PATCH 048/135] Update pyproject.toml with numpy dependency limits - sdc-uk --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 187ead5e..3107f12c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ maintainers = [ description = 'A software framework to perform Differentiable wavefront-based PSF modelling.' dependencies = [ - "numpy>=1.26.4,<2.0", + "numpy>=1.18,<1.24", "scipy", "tensorflow==2.11.0", "tensorflow-estimator", From 94a424e1c1b73c3ad95a661e66976b66166919cb Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 6 Jun 2025 12:53:26 +0200 Subject: [PATCH 049/135] Correct name of psf_inference_test module to follow repo naming convention --- .../{test_psf_inference.py => psf_inference_test.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename src/wf_psf/tests/test_inference/{test_psf_inference.py => psf_inference_test.py} (100%) diff --git a/src/wf_psf/tests/test_inference/test_psf_inference.py b/src/wf_psf/tests/test_inference/psf_inference_test.py similarity index 100% rename from src/wf_psf/tests/test_inference/test_psf_inference.py rename to src/wf_psf/tests/test_inference/psf_inference_test.py From 44cf3757cf4012887afcdae57b7776a52aa70c88 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Sun, 8 Jun 2025 18:22:14 +0200 Subject: [PATCH 050/135] Correct config subkey names for defining trained_model_path and trained_model_config_path --- config/inference_conf.yaml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/config/inference_conf.yaml b/config/inference_conf.yaml index 0c846fca..c9d29cb8 100644 --- a/config/inference_conf.yaml +++ b/config/inference_conf.yaml @@ -8,16 +8,16 @@ inference: # Paths to the configuration files and trained model directory configs: # Path to the directory containing the trained model - training_config_path: models/ + trained_model_path: /path/to/trained/model/ - # Subdirectory name of the trained model - model_subdir: models + # Subdirectory name of the trained model, e.g. psf_model + model_subdir: model - # Path to the training configuration file used to train the model - trained_model_path: config/training_config.yaml + # Relative Path to the training configuration file used to train the model + trained_model_config_path: config/training_config.yaml # Path to the data config file (this could contain prior information) - data_conf_path: + data_config_path: # The following parameters will overwrite the `model_params` in the training config file. model_params: From cfd3bafbfc3854e2ec1d06fbdfa6dbc1c9a3528b Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Sun, 8 Jun 2025 18:24:26 +0200 Subject: [PATCH 051/135] Refactor psf_inference adding PSFInferenceEngine to separate concerns, enabling isolated testing; implement lazy loaders for config handling, model loading, and inference --- src/wf_psf/inference/psf_inference.py | 426 +++++++++++++------------- 1 file changed, 206 insertions(+), 220 deletions(-) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index 552abd02..9012e3fc 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -9,6 +9,7 @@ """ import os +from pathlib import Path import numpy as np from wf_psf.data.data_handler import DataHandler from wf_psf.utils.read_config import read_conf @@ -20,71 +21,60 @@ class InferenceConfigHandler: ids = ("inference_conf",) - def __init__(self, inference_conf_path: str): - self.inference_conf_path = inference_conf_path + def __init__(self, inference_config_path: str): + self.inference_config_path = inference_config_path + self.inference_config = None + self.training_config = None + self.data_config = None - # Load the inference configuration - self.read_configurations() - # Overwrite the model parameters with the inference configuration - self.model_params = self.overwrite_model_params( - self.training_conf, self.inference_conf - ) - - def read_configurations(self): - """Read the configuration files.""" - # Load the inference configuration - self.inference_conf = read_conf(self.inference_conf_path) - - # Set config paths + def load_configs(self): + """Load configuration files based on the inference config.""" + self.inference_config = read_conf(self.inference_config_path) self.set_config_paths() - - # Load the training and data configurations - self.training_conf = read_conf(self.training_conf_path) - - if self.data_conf_path is not None: + self.training_config = read_conf(self.trained_model_config_path) + + if self.data_config_path is not None: # Load the data configuration - self.data_conf = read_conf(self.data_conf_path) - else: - self.data_conf = None + self.data_conf = read_conf(self.data_config_path) + def set_config_paths(self): """Extract and set the configuration paths.""" # Set config paths - self.config_paths = self.inference_conf.inference.configs.config_paths - self.trained_model_path = self.config_paths.trained_model_path - self.model_subdir = self.config_paths.model_subdir - self.training_config_path = self.config_paths.training_config_path - self.data_conf_path = self.config_paths.data_conf_path + config_paths = self.inference_config.inference.configs + + self.trained_model_path = Path(config_paths.trained_model_path) + self.model_subdir = config_paths.model_subdir + self.trained_model_config_path = self.trained_model_path / config_paths.trained_model_config_path + self.data_config_path = config_paths.data_config_path - def get_configs(self): - """Get the configurations.""" - return (self.inference_conf, self.training_conf, self.data_conf) @staticmethod - def overwrite_model_params(training_conf=None, inference_conf=None): - """Overwrite model_params of the training_conf with the inference_conf. + def overwrite_model_params(training_config=None, inference_config=None): + """ + Overwrite training model_params with values from inference_config if available. Parameters ---------- - training_conf : RecursiveNamespace - Configuration object containing model parameters and training hyperparameters. - inference_conf : RecursiveNamespace - Configuration object containing inference-related parameters. - + training_config : RecursiveNamespace + Configuration object from training phase. + inference_config : RecursiveNamespace + Configuration object from inference phase. + + Notes + ----- + Updates are applied in-place to training_config.training.model_params. """ - model_params = training_conf.training.model_params - inf_model_params = inference_conf.inference.model_params + model_params = training_config.training.model_params + inf_model_params = inference_config.inference.model_params - if model_params is not None and inf_model_params is not None: + if model_params and inf_model_params: for key, value in inf_model_params.__dict__.items(): - # Check if model_params has the attribute if hasattr(model_params, key): - # Set the attribute of model_params to the new value setattr(model_params, key, value) - return model_params - + class PSFInference: """ @@ -96,220 +86,216 @@ class PSFInference: Parameters ---------- - inference_conf_path : str, optional - Path to the inference configuration YAML file. This file should define - paths and parameters for the inference, training, and data configurations. + inference_config_path : str + Path to the inference configuration YAML file. x_field : array-like, optional - Array of x field-of-view coordinates in the SHE convention to be transformed - and passed to the WaveDiff model. + x coordinates in SHE convention. y_field : array-like, optional - Array of y field-of-view coordinates in the SHE convention to be transformed - and passed to the WaveDiff model. + y coordinates in SHE convention. seds : array-like, optional - Spectral energy distributions (SEDs) for the sources being modeled. These - will be used as part of the input to the PSF simulator. - - Attributes - ---------- - inference_config_handler : InferenceConfigHandler - Handler object to load and parse inference, training, and data configs. - inference_conf : dict - Dictionary containing inference configuration settings. - training_conf : dict - Dictionary containing training configuration settings. - data_conf : dict - Dictionary containing data configuration settings. - x_field : array-like - Input x coordinates after transformation (if applicable). - y_field : array-like - Input y coordinates after transformation (if applicable). - seds : array-like - Input spectral energy distributions. - trained_psf_model : keras.Model - Loaded PSF model used for prediction. - inferred_psfs : array-like or None - Array of inferred PSF images, populated after inference is performed. - simPSF : psf_models.simPSF - PSF simulator instance initialized with training model parameters. - data_handler : DataHandler - Data handler configured for inference, used to prepare inputs to the model. - n_bins_lambda : int - Number of spectral bins used for PSF simulation (loaded from config). - - Methods - ------- - load_inference_params() - Load parameters required for inference, including spectral binning. - get_trained_psf_model() - Load and return the trained Keras model for PSF inference. - run_inference() - Run the model on the input data and generate predicted PSFs. + Spectral energy distributions (SEDs). """ + def __init__(self, inference_config_path: str, x_field=None, y_field=None, seds=None): - def __init__(self, inference_conf_path: str, x_field=None, y_field=None, seds=None): + self.inference_config_path = inference_config_path - self.inference_config_handler = InferenceConfigHandler( - inference_conf_path=inference_conf_path - ) - - self.inference_conf, self.training_conf, self.data_conf = ( - self.inference_config_handler.get_configs() - ) - - # Init source parameters + # Inputs for the model self.x_field = x_field self.y_field = y_field self.seds = seds - self.trained_psf_model = None - - # Init compute PSF placeholder - self.inferred_psfs = None - - # Load inference parameters - self.load_inference_params() - - # Instantiate the PSF simulator object - self.simPSF = psf_models.simPSF(self.training_conf.training.model_params) - - # Instantiate the data handler - self.data_handler = DataHandler( - dataset_type="inference", - data_params=self.data_conf, - simPSF=self.simPSF, - n_bins_lambda=self.n_bins_lambda, - load_data=False, - dataset=None, + + # Internal caches for lazy-loading + self._config_handler = None + self._simPSF = None + self._data_handler = None + self._trained_psf_model = None + self._n_bins_lambda = None + self._batch_size = None + self._cycle = None + self._output_dim = None + + # Initialise PSF Inference engine + self.engine = None + + @property + def config_handler(self): + if self._config_handler is None: + self._config_handler = InferenceConfigHandler(self.inference_config_path) + self._config_handler.load_configs() + return self._config_handler + + def prepare_configs(self): + """Prepare the configuration for inference.""" + # Overwrite model parameters with inference config + self.config_handler.overwrite_model_params( + self.training_config, self.inference_config ) - # Load the trained PSF model - self.trained_psf_model = self.get_trained_psf_model() + @property + def inference_config(self): + return self.config_handler.inference_config + + @property + def training_config(self): + return self.config_handler.training_config + + @property + def data_config(self): + return self.config_handler.data_config + + @property + def simPSF(self): + if self._simPSF is None: + self._simPSF = psf_models.simPSF(self.model_params) + return self._simPSF + + @property + def data_handler(self): + if self._data_handler is None: + # Instantiate the data handler + self._data_handler = DataHandler( + dataset_type="inference", + data_params=self.data_config, + simPSF=self.simPSF, + n_bins_lambda=self.n_bins_lambda, + load_data=False, + dataset=None, + ) + return self._data_handler - def load_inference_params(self): - """Load the inference parameters from the configuration file.""" - # Set the number of labmda bins - self.n_bins_lambda = self.inference_conf.inference.model_params.n_bins_lda + @property + def trained_psf_model(self): + if self._trained_psf_model is None: + self._trained_psf_model = self.load_inference_model() + return self._trained_psf_model - # Set the batch size - self.batch_size = self.inference_conf.inference.batch_size - assert self.batch_size > 0, "Batch size must be greater than 0." + def load_inference_model(self): + # Prepare the configuration for inference + self.prepare_configs() - # Set the cycle to use for inference - self.cycle = self.inference_conf.inference.cycle - - # Get output psf dimensions - self.output_dim = self.inference_conf.inference.model_params.output_dim - - - def get_trained_psf_model(self): - """Get the trained PSF model.""" - - # Load the trained PSF model - model_path = self.inference_config_handler.trained_model_path - model_dir_name = self.inference_config_handler - model_name = self.training_conf.training.model_params.model_name - id_name = self.training_conf.training.id_name + model_path = self.config_handler.trained_model_path + model_dir = self.config_handler.model_subdir + model_name = self.training_config.training.model_params.model_name + id_name = self.training_config.training.id_name weights_path_pattern = os.path.join( model_path, - model_dir_name, - (f"{model_dir_name}*_{model_name}" f"*{id_name}_cycle{self.cycle}*"), + model_dir, + f"{model_dir}*_{model_name}*{id_name}_cycle{self.cycle}*" ) + + # Load the trained PSF model return load_trained_psf_model( - self.training_conf, - self.data_conf, + self.training_config, + self.data_config, weights_path_pattern, ) - def set_source_parameters(self): - """Set the input source parameters for inferring the PSF. - - Note - ---- - The input source parameters are expected to be in the WaveDiff format. See the simulated data - format for more details. - - Parameters - ---------- - x_field : array-like - X coordinates of the sources in WaveDiff format. - y_field : array-like - Y coordinates of the sources in WaveDiff format. - seds : list or array-like - A list or array of raw SEDs, where each SED is typically a vector of flux values - or coefficients. These will be processed using the PSF simulator. - It assumes the standard WaveDiff SED format. - - """ - # Positions array is of shape (n_sources, 2) - self.positions = tf.convert_to_tensor( + @property + def n_bins_lambda(self): + if self._n_bins_lambda is None: + self._n_bins_lambda = self.inference_config.inference.model_params.n_bins_lda + return self._n_bins_lambda + + @property + def batch_size(self): + if self._batch_size is None: + self._batch_size = self.inference_config.inference.batch_size + assert self._batch_size > 0, "Batch size must be greater than 0." + return self._batch_size + + @property + def cycle(self): + if self._cycle is None: + self._cycle = self.inference_config.inference.cycle + return self._cycle + + @property + def output_dim(self): + if self._output_dim is None: + self._output_dim = self.inference_config.inference.model_params.output_dim + return self._output_dim + + def _prepare_positions_and_seds(self): + """Preprocess and return tensors for positions and SEDs.""" + positions = tf.convert_to_tensor( np.array([self.x_field, self.y_field]).T, dtype=tf.float32 ) - # Process SED data self.data_handler.process_sed_data(self.seds) - self.sed_data = self.data_handler.sed_data + sed_data = self.data_handler.sed_data + return positions, sed_data - def compute_psfs(self): - """Compute the PSFs for the input source parameters.""" + def run_inference(self): + """Run PSF inference and return the full PSF array.""" + positions, sed_data = self._prepare_positions_and_seds() - # Check if source parameters are set - if self.positions is None or self.sed_data is None: - raise ValueError( - "Source parameters not set. Call set_source_parameters first." - ) + self.engine = PSFInferenceEngine( + trained_model=self.trained_psf_model, + batch_size=self.batch_size, + output_dim=self.output_dim, + ) + return self.engine.compute_psfs(positions, sed_data) + + def _ensure_psf_inference_completed(self): + if self.engine is None or self.engine.inferred_psfs is None: + self.run_inference() + + def get_psfs(self): + self._ensure_psf_inference_completed() + return self.engine.get_psfs() + + def get_psf(self, index): + self._ensure_psf_inference_completed() + return self.engine.get_psf(index) - # Get the number of samples - n_samples = self.positions.shape[0] +class PSFInferenceEngine: + def __init__(self, trained_model, batch_size: int, output_dim: int): + self.trained_model = trained_model + self.batch_size = batch_size + self.output_dim = output_dim + self._inferred_psfs = None + + @property + def inferred_psfs(self) -> np.ndarray: + """Access the cached inferred PSFs, if available.""" + return self._inferred_psfs + + def compute_psfs(self, positions: tf.Tensor, sed_data: tf.Tensor) -> np.ndarray: + """Compute and cache PSFs for the input source parameters.""" + n_samples = positions.shape[0] + self._inferred_psfs = np.zeros((n_samples, self.output_dim, self.output_dim), dtype=np.float32) # Initialize counter counter = 0 - - # Initialize PSF array - self.inferred_psfs = np.zeros((n_samples, self.output_dim, self.output_dim)) - while counter < n_samples: # Calculate the batch end element - if counter + self.batch_size <= n_samples: - end_sample = counter + self.batch_size - else: - end_sample = n_samples + end = min(counter + self.batch_size, n_samples) # Define the batch positions - batch_pos = self.positions[counter:end_sample, :] - batch_seds = self.sed_data[counter:end_sample, :, :] - - # Generate PSFs for the current batch + batch_pos = positions[counter:end_sample, :] + batch_seds = sed_data[counter:end_sample, :, :] batch_inputs = [batch_pos, batch_seds] - batch_poly_psfs = self.trained_psf_model(batch_inputs, training=False) - - # Append to the PSF array - self.inferred_psfs[counter:end_sample, :, :] = batch_poly_psfs.numpy() + + # Generate PSFs for the current batch + batch_psfs = self.trained_model(batch_inputs, training=False) + self.inferred_psfs[counter:end, :, :] = batch_psfs.numpy() # Update the counter - counter += self.batch_size + counter = end + + return self._inferred_psfs def get_psfs(self) -> np.ndarray: - """Get all the generated PSFs. + """Get all the generated PSFs.""" + if self._inferred_psfs is None: + raise ValueError("PSFs not yet computed. Call compute_psfs() first.") + return self._inferred_psfs + + def get_psf(self, index: int) -> np.ndarray: + """Get the PSF at a specific index.""" + if self._inferred_psfs is None: + raise ValueError("PSFs not yet computed. Call compute_psfs() first.") + return self._inferred_psfs[index] + - Returns - ------- - np.ndarray - The generated PSFs for the input source parameters. - Shape is (n_samples, output_dim, output_dim). - """ - if self.inferred_psfs is None: - self.compute_psfs() - return self.inferred_psfs - - def get_psf(self, index) -> np.ndarray: - """Generate the generated PSF at a specific index. - - Returns - ------- - np.ndarray - The generated PSFs for the input source parameters. - Shape is (output_dim, output_dim). - """ - if self.inferred_psfs is None: - self.compute_psfs() - return self.inferred_psfs[index] From d1a4fb5878e1942beeb715d6892b07194cfc85ee Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Sun, 8 Jun 2025 18:26:26 +0200 Subject: [PATCH 052/135] Add unit tests for psf_inference --- .../test_inference/psf_inference_test.py | 255 ++++++++++++++++++ 1 file changed, 255 insertions(+) diff --git a/src/wf_psf/tests/test_inference/psf_inference_test.py b/src/wf_psf/tests/test_inference/psf_inference_test.py index e69de29b..5709ed6d 100644 --- a/src/wf_psf/tests/test_inference/psf_inference_test.py +++ b/src/wf_psf/tests/test_inference/psf_inference_test.py @@ -0,0 +1,255 @@ +"""UNIT TESTS FOR PACKAGE MODULE: PSF Inference. + +This module contains unit tests for the wf_psf.inference.psf_inference module. + +:Author: Jennifer Pollack + +""" + +import numpy as np +import os +from pathlib import Path +import pytest +import tensorflow as tf +from types import SimpleNamespace +from unittest.mock import MagicMock, patch, PropertyMock +from wf_psf.inference.psf_inference import ( + InferenceConfigHandler, + PSFInference, + PSFInferenceEngine +) + +from wf_psf.utils.read_config import RecursiveNamespace + +@pytest.fixture +def mock_training_config(): + training_config = RecursiveNamespace( + training=RecursiveNamespace( + id_name="mock_id", + model_params=RecursiveNamespace( + model_name="mock_model", + output_Q=2, + output_dim=32 + ) + ) + ) + return training_config + +@pytest.fixture +def mock_inference_config(): + inference_config = RecursiveNamespace( + inference=RecursiveNamespace( + batch_size=16, + cycle=2, + configs=RecursiveNamespace( + trained_model_path='/path/to/trained/model', + model_subdir='psf_model', + trained_model_config_path='config/training_config.yaml', + data_config_path=None + ), + model_params=RecursiveNamespace( + output_Q=1, + output_dim=64 + ) + ) + ) + return inference_config + + +@pytest.fixture +def psf_test_setup(mock_inference_config): + num_sources = 2 + num_bins = 10 + output_dim = 32 + + mock_positions = tf.convert_to_tensor([[0.1, 0.1], [0.2, 0.2]], dtype=tf.float32) + mock_seds = tf.convert_to_tensor(np.random.rand(num_sources, 2, num_bins), dtype=tf.float32) + expected_psfs = np.random.rand(num_sources, output_dim, output_dim).astype(np.float32) + + inference = PSFInference( + "dummy_path.yaml", + x_field=[0.1, 0.2], + y_field=[0.1, 0.2], + seds=np.random.rand(num_sources, num_bins) + ) + inference._config_handler = MagicMock() + inference._config_handler.inference_config = mock_inference_config + inference._trained_psf_model = MagicMock() + + return { + "inference": inference, + "mock_positions": mock_positions, + "mock_seds": mock_seds, + "expected_psfs": expected_psfs, + "num_sources": num_sources, + "num_bins": num_bins, + "output_dim": output_dim + } + + +def test_set_config_paths(mock_inference_config): + """Test setting configuration paths.""" + + # Initialize handler and inject mock config + config_handler = InferenceConfigHandler("fake/path") + config_handler.inference_config = mock_inference_config + + # Call the method under test + config_handler.set_config_paths() + + # Assertions + assert config_handler.trained_model_path == Path("/path/to/trained/model") + assert config_handler.model_subdir == "psf_model" + assert config_handler.trained_model_config_path == Path("/path/to/trained/model/config/training_config.yaml") + assert config_handler.data_config_path == None + + +def test_overwrite_model_params(mock_training_config, mock_inference_config): + """Test that model_params can be overwritten.""" + # Mock the model_params object with some initial values + training_config = mock_training_config + inference_config = mock_inference_config + + InferenceConfigHandler.overwrite_model_params( + training_config, inference_config + ) + + # Assert that the model_params were overwritten correctly + assert training_config.training.model_params.output_Q == 1, "output_Q should be overwritten" + assert training_config.training.model_params.output_dim == 64, "output_dim should be overwritten" + + assert training_config.training.id_name == "mock_id", "id_name should not be overwritten" + + +def test_prepare_configs(mock_training_config, mock_inference_config): + """Test preparing configurations for inference.""" + # Mock the model_params object with some initial values + training_config = mock_training_config + inference_config = mock_inference_config + + # Make copy of the original training config model_params + original_model_params = mock_training_config.training.model_params + + # Instantiate PSFInference + psf_inf = PSFInference('/dummy/path.yaml') + + # Mock the config handler attribute with a mock InferenceConfigHandler + mock_config_handler = MagicMock(spec=InferenceConfigHandler) + mock_config_handler.training_config = training_config + mock_config_handler.inference_config = inference_config + + # Patch the overwrite_model_params to use the real static method + mock_config_handler.overwrite_model_params.side_effect = InferenceConfigHandler.overwrite_model_params + + psf_inf._config_handler = mock_config_handler + + # Run prepare_configs + psf_inf.prepare_configs() + + # Assert that the training model_params were updated + assert original_model_params.output_Q == 1 + assert original_model_params.output_dim == 64 + + +def test_config_handler_lazy_load(monkeypatch): + inference = PSFInference("dummy_path.yaml") + + called = {} + + class DummyHandler: + def load_configs(self): + called['load'] = True + self.inference_config = {} + self.training_config = {} + self.data_config = {} + def overwrite_model_params(self, *args): pass + + monkeypatch.setattr("wf_psf.inference.psf_inference.InferenceConfigHandler", lambda path: DummyHandler()) + + inference.prepare_configs() + + assert 'load' in called # Confirm lazy load happened + +def test_batch_size_positive(): + inference = PSFInference("dummy_path.yaml") + inference._config_handler = MagicMock() + inference._config_handler.inference_config = SimpleNamespace( + inference=SimpleNamespace(batch_size=4, model_params=SimpleNamespace(output_dim=32)) + ) + assert inference.batch_size == 4 + + +@patch.object(PSFInference, 'prepare_configs') +@patch('wf_psf.inference.psf_inference.load_trained_psf_model') +def test_load_inference_model(mock_load_trained_psf_model, mock_prepare_configs, mock_training_config, mock_inference_config): + + data_config = MagicMock() + mock_config_handler = MagicMock(spec=InferenceConfigHandler) + mock_config_handler.trained_model_path = "mock/path/to/model" + mock_config_handler.training_config = mock_training_config + mock_config_handler.inference_config = mock_inference_config + mock_config_handler.model_subdir = "psf_model" + mock_config_handler.data_config = data_config + + psf_inf = PSFInference("dummy_path.yaml") + psf_inf._config_handler = mock_config_handler + + psf_inf.load_inference_model() + + weights_path_pattern = os.path.join( + mock_config_handler.trained_model_path, + mock_config_handler.model_subdir, + f"{mock_config_handler.model_subdir}*_{mock_config_handler.training_config.training.model_params.model_name}*{mock_config_handler.training_config.training.id_name}_cycle{mock_config_handler.inference_config.inference.cycle}*" + ) + + # Assert calls to the mocked methods + mock_prepare_configs.assert_called_once() + mock_load_trained_psf_model.assert_called_once_with( + mock_config_handler.training_config, + mock_config_handler.data_config, + weights_path_pattern + ) + + +@patch.object(PSFInference, '_prepare_positions_and_seds') +@patch.object(PSFInferenceEngine, 'compute_psfs') +def test_run_inference(mock_compute_psfs, mock_prepare_positions_and_seds, psf_test_setup): + inference = psf_test_setup["inference"] + mock_positions = psf_test_setup["mock_positions"] + mock_seds = psf_test_setup["mock_seds"] + expected_psfs = psf_test_setup["expected_psfs"] + + mock_prepare_positions_and_seds.return_value = (mock_positions, mock_seds) + mock_compute_psfs.return_value = expected_psfs + + psfs = inference.run_inference() + + assert isinstance(psfs, np.ndarray) + assert psfs.shape == expected_psfs.shape + mock_prepare_positions_and_seds.assert_called_once() + mock_compute_psfs.assert_called_once_with(mock_positions, mock_seds) + + +@patch.object(PSFInference, '_prepare_positions_and_seds') +@patch.object(PSFInferenceEngine, 'compute_psfs') +def test_get_psfs_runs_inference(mock_compute_psfs, mock_prepare_positions_and_seds, psf_test_setup): + inference = psf_test_setup["inference"] + mock_positions = psf_test_setup["mock_positions"] + mock_seds = psf_test_setup["mock_seds"] + expected_psfs = psf_test_setup["expected_psfs"] + + mock_prepare_positions_and_seds.return_value = (mock_positions, mock_seds) + + def fake_compute_psfs(positions, seds): + inference.engine._inferred_psfs = expected_psfs + return expected_psfs + + mock_compute_psfs.side_effect = fake_compute_psfs + + psfs_1 = inference.get_psfs() + assert np.all(psfs_1 == expected_psfs) + + psfs_2 = inference.get_psfs() + assert np.all(psfs_2 == expected_psfs) + + assert mock_compute_psfs.call_count == 1 From 5ea8bcb5efaa49f91bb6564a6c2e19f927f94ff9 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Thu, 12 Jun 2025 13:22:51 +0200 Subject: [PATCH 053/135] Bugfix: Ensure updated training_config.training.model_params are passed to simPSF Move call to prepare_configs() into run_inference() to ensure model_params are overwritten before preparing SEDs. --- src/wf_psf/inference/psf_inference.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index 9012e3fc..01e7cf1a 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -147,7 +147,7 @@ def data_config(self): @property def simPSF(self): if self._simPSF is None: - self._simPSF = psf_models.simPSF(self.model_params) + self._simPSF = psf_models.simPSF(self.training_config.training.model_params) return self._simPSF @property @@ -171,9 +171,7 @@ def trained_psf_model(self): return self._trained_psf_model def load_inference_model(self): - # Prepare the configuration for inference - self.prepare_configs() - + """Load the trained PSF model based on the inference configuration.""" model_path = self.config_handler.trained_model_path model_dir = self.config_handler.model_subdir model_name = self.training_config.training.model_params.model_name @@ -228,6 +226,10 @@ def _prepare_positions_and_seds(self): def run_inference(self): """Run PSF inference and return the full PSF array.""" + # Prepare the configuration for inference + self.prepare_configs() + + # Prepare positions and SEDs for inference positions, sed_data = self._prepare_positions_and_seds() self.engine = PSFInferenceEngine( From c82836b7af8fbbf36230148bbb537436226717d7 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Thu, 12 Jun 2025 13:25:28 +0200 Subject: [PATCH 054/135] test(simPSF): add unit test to verify updated model_params are passed Also moves prepare_configs() assertion from test load_inference() to unit test for run_inference(). --- .../test_inference/psf_inference_test.py | 47 ++++++++++++++++--- 1 file changed, 40 insertions(+), 7 deletions(-) diff --git a/src/wf_psf/tests/test_inference/psf_inference_test.py b/src/wf_psf/tests/test_inference/psf_inference_test.py index 5709ed6d..4add460b 100644 --- a/src/wf_psf/tests/test_inference/psf_inference_test.py +++ b/src/wf_psf/tests/test_inference/psf_inference_test.py @@ -179,17 +179,15 @@ def test_batch_size_positive(): assert inference.batch_size == 4 -@patch.object(PSFInference, 'prepare_configs') @patch('wf_psf.inference.psf_inference.load_trained_psf_model') -def test_load_inference_model(mock_load_trained_psf_model, mock_prepare_configs, mock_training_config, mock_inference_config): +def test_load_inference_model(mock_load_trained_psf_model, mock_training_config, mock_inference_config): - data_config = MagicMock() mock_config_handler = MagicMock(spec=InferenceConfigHandler) mock_config_handler.trained_model_path = "mock/path/to/model" mock_config_handler.training_config = mock_training_config mock_config_handler.inference_config = mock_inference_config mock_config_handler.model_subdir = "psf_model" - mock_config_handler.data_config = data_config + mock_config_handler.data_config = MagicMock() psf_inf = PSFInference("dummy_path.yaml") psf_inf._config_handler = mock_config_handler @@ -203,17 +201,16 @@ def test_load_inference_model(mock_load_trained_psf_model, mock_prepare_configs, ) # Assert calls to the mocked methods - mock_prepare_configs.assert_called_once() mock_load_trained_psf_model.assert_called_once_with( mock_config_handler.training_config, mock_config_handler.data_config, weights_path_pattern ) - +@patch.object(PSFInference, 'prepare_configs') @patch.object(PSFInference, '_prepare_positions_and_seds') @patch.object(PSFInferenceEngine, 'compute_psfs') -def test_run_inference(mock_compute_psfs, mock_prepare_positions_and_seds, psf_test_setup): +def test_run_inference(mock_compute_psfs, mock_prepare_positions_and_seds, mock_prepare_configs, psf_test_setup): inference = psf_test_setup["inference"] mock_positions = psf_test_setup["mock_positions"] mock_seds = psf_test_setup["mock_seds"] @@ -228,6 +225,42 @@ def test_run_inference(mock_compute_psfs, mock_prepare_positions_and_seds, psf_t assert psfs.shape == expected_psfs.shape mock_prepare_positions_and_seds.assert_called_once() mock_compute_psfs.assert_called_once_with(mock_positions, mock_seds) + mock_prepare_configs.assert_called_once() + +@patch("wf_psf.inference.psf_inference.psf_models.simPSF") +def test_simpsf_uses_updated_model_params(mock_simpsf, mock_training_config, mock_inference_config): + """Test that simPSF uses the updated model parameters.""" + training_config = mock_training_config + inference_config = mock_inference_config + + # Set the expected output_Q + expected_output_Q = inference_config.inference.model_params.output_Q + training_config.training.model_params.output_Q = expected_output_Q + + # Create fake psf instance + fake_psf_instance = MagicMock() + fake_psf_instance.output_Q = expected_output_Q + mock_simpsf.return_value = fake_psf_instance + + mock_config_handler = MagicMock(spec=InferenceConfigHandler) + mock_config_handler.trained_model_path = "mock/path/to/model" + mock_config_handler.training_config = training_config + mock_config_handler.inference_config = inference_config + mock_config_handler.model_subdir = "psf_model" + mock_config_handler.data_config = MagicMock() + + modeller = PSFInference("dummy_path.yaml") + modeller._config_handler = mock_config_handler + + modeller.prepare_configs() + result = modeller.simPSF + + # Confirm simPSF was called once with the updated model_params + mock_simpsf.assert_called_once() + called_args, _ = mock_simpsf.call_args + model_params_passed = called_args[0] + assert model_params_passed.output_Q == expected_output_Q + assert result.output_Q == expected_output_Q @patch.object(PSFInference, '_prepare_positions_and_seds') From c99b1e8e0a45e2d4d76f7edbac484bdbf518be07 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Thu, 12 Jun 2025 16:34:20 +0100 Subject: [PATCH 055/135] Bug: replace self.data_conf with self.data_config --- src/wf_psf/inference/psf_inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index 01e7cf1a..6f798245 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -36,7 +36,7 @@ def load_configs(self): if self.data_config_path is not None: # Load the data configuration - self.data_conf = read_conf(self.data_config_path) + self.data_config = read_conf(self.data_config_path) def set_config_paths(self): From cf0e8e1050781f0e2afe3bd953290ef41a057933 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Tue, 5 Aug 2025 19:59:51 +0200 Subject: [PATCH 056/135] Change logger.warnings to ValueErrors for missing fields in datasets & remove redundant check --- src/wf_psf/data/data_handler.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/wf_psf/data/data_handler.py b/src/wf_psf/data/data_handler.py index 84af1eba..f341c28d 100644 --- a/src/wf_psf/data/data_handler.py +++ b/src/wf_psf/data/data_handler.py @@ -161,16 +161,16 @@ def _validate_dataset_structure(self): if "positions" not in self.dataset: raise ValueError("Dataset missing required field: 'positions'") - if self.dataset_type == "train": + if self.dataset_type == "training": if "noisy_stars" not in self.dataset: - logger.warning("Missing 'noisy_stars' in 'train' dataset.") + raise ValueError(f"Missing required field 'noisy_stars' in {self.dataset_type} dataset.") elif self.dataset_type == "test": if "stars" not in self.dataset: - logger.warning("Missing 'stars' in 'test' dataset.") + raise ValueError(f"Missing required field 'stars' in {self.dataset_type} dataset.") elif self.dataset_type == "inference": pass else: - logger.warning(f"Unrecognized dataset_type: {self.dataset_type}") + raise ValueError(f"Unrecognized dataset_type: {self.dataset_type}") def _convert_dataset_to_tensorflow(self): """Convert dataset to TensorFlow tensors.""" @@ -179,12 +179,10 @@ def _convert_dataset_to_tensorflow(self): self.dataset["positions"], dtype=tf.float32 ) if self.dataset_type == "training": - if "noisy_stars" in self.dataset: self.dataset["noisy_stars"] = tf.convert_to_tensor( self.dataset["noisy_stars"], dtype=tf.float32 ) - else: - logger.warning(f"Missing 'noisy_stars' in {self.dataset_type} dataset.") + elif self.dataset_type == "test": if "stars" in self.dataset: self.dataset["stars"] = tf.convert_to_tensor( From 814d047658463909019c588dc49e28a97fd38a20 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Tue, 5 Aug 2025 20:00:21 +0200 Subject: [PATCH 057/135] Update unit tests with changes to data_handler.py --- .../tests/test_data/data_handler_test.py | 78 +++++-------------- 1 file changed, 19 insertions(+), 59 deletions(-) diff --git a/src/wf_psf/tests/test_data/data_handler_test.py b/src/wf_psf/tests/test_data/data_handler_test.py index 42e19a71..8d1b5421 100644 --- a/src/wf_psf/tests/test_data/data_handler_test.py +++ b/src/wf_psf/tests/test_data/data_handler_test.py @@ -16,22 +16,16 @@ def mock_sed(): return np.linspace(0.1, 1.0, 50) -def test_process_sed_data(data_params, simPSF): - # Test processing SED data without initialization - data_handler = DataHandler( - "training", data_params, simPSF, n_bins_lambda=10, load_data=False - ) - assert data_handler.sed_data is None # SED data should not be processed - - def test_process_sed_data_auto_load(data_params, simPSF): # load_data=True → dataset is used and SEDs processed automatically data_handler = DataHandler( - "training", data_params, simPSF, n_bins_lambda=10, load_data=True + "training", data_params.training, simPSF, n_bins_lambda=10, load_data=True ) + assert data_handler.sed_data is not None + assert data_handler.sed_data.shape[1] == 10 # n_bins_lambda -def test_load_train_dataset(tmp_path, data_params, simPSF): +def test_load_train_dataset(tmp_path, simPSF): # Create a temporary directory and a temporary data file data_dir = tmp_path / "data" data_dir.mkdir() @@ -48,9 +42,7 @@ def test_load_train_dataset(tmp_path, data_params, simPSF): np.save(temp_data_dir, mock_dataset) # Initialize DataHandler instance - data_params = RecursiveNamespace( - training=RecursiveNamespace(data_dir=str(data_dir), file="train_data.npy") - ) + data_params = RecursiveNamespace(data_dir=str(data_dir), file="train_data.npy") n_bins_lambda = 10 data_handler = DataHandler( @@ -68,7 +60,7 @@ def test_load_train_dataset(tmp_path, data_params, simPSF): assert np.array_equal(data_handler.dataset["SEDs"], mock_dataset["SEDs"]) -def test_load_test_dataset(tmp_path, data_params, simPSF): +def test_load_test_dataset(tmp_path, simPSF): # Create a temporary directory and a temporary data file data_dir = tmp_path / "data" data_dir.mkdir() @@ -85,14 +77,12 @@ def test_load_test_dataset(tmp_path, data_params, simPSF): np.save(temp_data_dir, mock_dataset) # Initialize DataHandler instance - data_params = RecursiveNamespace( - test=RecursiveNamespace(data_dir=str(data_dir), file="test_data.npy") - ) + data_params = RecursiveNamespace(data_dir=str(data_dir), file="test_data.npy") n_bins_lambda = 10 data_handler = DataHandler( dataset_type="test", - data_params=data_params.test, + data_params=data_params, simPSF=simPSF, n_bins_lambda=n_bins_lambda, load_data=False, @@ -120,21 +110,21 @@ def test_validate_train_dataset_missing_noisy_stars_raises(tmp_path, simPSF): np.save(temp_data_file, mock_dataset) - data_params = RecursiveNamespace( - training=RecursiveNamespace(data_dir=str(data_dir), file="train_data.npy") - ) + data_params = RecursiveNamespace(data_dir=str(data_dir), file="train_data.npy") n_bins_lambda = 10 data_handler = DataHandler( "training", data_params, simPSF, n_bins_lambda, load_data=False ) - with patch("wf_psf.data.data_handler.logger.warning") as mock_warning: + with pytest.raises( + ValueError, match="Missing required field 'noisy_stars' in training dataset." + ): data_handler.load_dataset() - mock_warning.assert_called_with("Missing 'noisy_stars' in training dataset.") + data_handler.validate_and_process_dataset() -def test_load_test_dataset_missing_stars(tmp_path, data_params, simPSF): +def test_load_test_dataset_missing_stars(tmp_path, simPSF): """Test that a warning is raised if 'stars' is missing in test data.""" data_dir = tmp_path / "data" data_dir.mkdir() @@ -147,48 +137,18 @@ def test_load_test_dataset_missing_stars(tmp_path, data_params, simPSF): np.save(temp_data_file, mock_dataset) - data_params = RecursiveNamespace( - test=RecursiveNamespace(data_dir=str(data_dir), file="test_data.npy") - ) + data_params = RecursiveNamespace(data_dir=str(data_dir), file="test_data.npy") n_bins_lambda = 10 data_handler = DataHandler( "test", data_params, simPSF, n_bins_lambda, load_data=False ) - with patch("wf_psf.data.data_handler.logger.warning") as mock_warning: + with pytest.raises( + ValueError, match="Missing required field 'stars' in test dataset." + ): data_handler.load_dataset() - mock_warning.assert_called_with("Missing 'stars' in test dataset.") - - -def test_process_sed_data(data_params, simPSF): - mock_dataset = { - "positions": np.array([[1, 2], [3, 4]]), - "SEDs": np.array([[[0.1, 0.2], [0.3, 0.4]], [[0.5, 0.6], [0.7, 0.8]]]), - # Missing 'noisy_stars' - } - # Initialize DataHandler instance - n_bins_lambda = 4 - data_handler = DataHandler("training", data_params, simPSF, n_bins_lambda, False) - - np.save(temp_data_file, mock_dataset) - - data_params = RecursiveNamespace(data_dir=str(data_dir), file="train_data.npy") - - data_handler = DataHandler( - dataset_type="train", - data_params=data_params, - simPSF=simPSF, - n_bins_lambda=10, - load_data=False, - ) - - data_handler.load_dataset() - data_handler.process_sed_data(mock_dataset["SEDs"]) - - with patch("wf_psf.data.data_handler.logger.warning") as mock_warning: - data_handler._validate_dataset_structure() - mock_warning.assert_called_with("Missing 'noisy_stars' in 'train' dataset.") + data_handler.validate_and_process_dataset() def test_get_obs_positions(mock_data): From dfae6ced67740b2c88d17ea10011e0223a96c54a Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 16 May 2025 14:41:17 +0200 Subject: [PATCH 058/135] Refactor: reorganise modules, relocate utility functions, rename modules, update import statements and unit tests --- src/wf_psf/data/data_handler.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/wf_psf/data/data_handler.py b/src/wf_psf/data/data_handler.py index f341c28d..6ec6ec56 100644 --- a/src/wf_psf/data/data_handler.py +++ b/src/wf_psf/data/data_handler.py @@ -163,10 +163,14 @@ def _validate_dataset_structure(self): if self.dataset_type == "training": if "noisy_stars" not in self.dataset: - raise ValueError(f"Missing required field 'noisy_stars' in {self.dataset_type} dataset.") + raise ValueError( + f"Missing required field 'noisy_stars' in {self.dataset_type} dataset." + ) elif self.dataset_type == "test": if "stars" not in self.dataset: - raise ValueError(f"Missing required field 'stars' in {self.dataset_type} dataset.") + raise ValueError( + f"Missing required field 'stars' in {self.dataset_type} dataset." + ) elif self.dataset_type == "inference": pass else: @@ -179,10 +183,10 @@ def _convert_dataset_to_tensorflow(self): self.dataset["positions"], dtype=tf.float32 ) if self.dataset_type == "training": - self.dataset["noisy_stars"] = tf.convert_to_tensor( - self.dataset["noisy_stars"], dtype=tf.float32 - ) - + self.dataset["noisy_stars"] = tf.convert_to_tensor( + self.dataset["noisy_stars"], dtype=tf.float32 + ) + elif self.dataset_type == "test": if "stars" in self.dataset: self.dataset["stars"] = tf.convert_to_tensor( From c49ebc7b3f6d80150617d800a999f035c6b880b2 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Sun, 18 May 2025 12:22:27 +0200 Subject: [PATCH 059/135] Refactor data_handler with new utility functions to validate and process datasets and update docstrings --- src/wf_psf/data/data_handler.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/wf_psf/data/data_handler.py b/src/wf_psf/data/data_handler.py index 6ec6ec56..7ec53b59 100644 --- a/src/wf_psf/data/data_handler.py +++ b/src/wf_psf/data/data_handler.py @@ -188,10 +188,9 @@ def _convert_dataset_to_tensorflow(self): ) elif self.dataset_type == "test": - if "stars" in self.dataset: - self.dataset["stars"] = tf.convert_to_tensor( - self.dataset["stars"], dtype=tf.float32 - ) + self.dataset["stars"] = tf.convert_to_tensor( + self.dataset["stars"], dtype=tf.float32 + ) def process_sed_data(self, sed_data): """ From 49039287defcdf4cfbeedc4b11498066e0a2325e Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Sun, 18 May 2025 12:23:27 +0200 Subject: [PATCH 060/135] Update unit tests associated to changes in data_handler.py --- src/wf_psf/data/data_handler.py | 42 +++++-------------- .../tests/test_data/data_handler_test.py | 15 +++++++ 2 files changed, 25 insertions(+), 32 deletions(-) diff --git a/src/wf_psf/data/data_handler.py b/src/wf_psf/data/data_handler.py index 7ec53b59..7e7b3881 100644 --- a/src/wf_psf/data/data_handler.py +++ b/src/wf_psf/data/data_handler.py @@ -188,43 +188,21 @@ def _convert_dataset_to_tensorflow(self): ) elif self.dataset_type == "test": - self.dataset["stars"] = tf.convert_to_tensor( - self.dataset["stars"], dtype=tf.float32 - ) + if "stars" in self.dataset: + self.dataset["stars"] = tf.convert_to_tensor( + self.dataset["stars"], dtype=tf.float32 + ) + else: + logger.warning(f"Missing 'stars' in {self.dataset_type} dataset.") + elif "inference" == self.dataset_type: + pass def process_sed_data(self, sed_data): - """ - Generate and process SED (Spectral Energy Distribution) data. - - This method transforms raw SED inputs into TensorFlow tensors suitable for model input. - It generates wavelength-binned SED elements using the PSF simulator, converts the result - into a tensor, and transposes it to match the expected shape for training or inference. - - Parameters - ---------- - sed_data : list or array-like - A list or array of raw SEDs, where each SED is typically a vector of flux values - or coefficients. These will be processed using the PSF simulator. + """Process SED Data. - Raises - ------ - ValueError - If `sed_data` is None. + A method to generate and process SED data. - Notes - ----- - The resulting tensor is stored in `self.sed_data` and has shape - `(num_samples, n_bins_lambda, n_components)`, where: - - `num_samples` is the number of SEDs, - - `n_bins_lambda` is the number of wavelength bins, - - `n_components` is the number of components per SED (e.g., filters or basis terms). - - The intermediate tensor is created with `tf.float64` for precision during generation, - but is converted to `tf.float32` after processing for use in training. """ - if sed_data is None: - raise ValueError("SED data must be provided explicitly or via dataset.") - self.sed_data = [ utils.generate_SED_elems_in_tensorflow( _sed, self.simPSF, n_bins=self.n_bins_lambda, tf_dtype=tf.float64 diff --git a/src/wf_psf/tests/test_data/data_handler_test.py b/src/wf_psf/tests/test_data/data_handler_test.py index 8d1b5421..6545964a 100644 --- a/src/wf_psf/tests/test_data/data_handler_test.py +++ b/src/wf_psf/tests/test_data/data_handler_test.py @@ -150,6 +150,21 @@ def test_load_test_dataset_missing_stars(tmp_path, simPSF): data_handler.load_dataset() data_handler.validate_and_process_dataset() + data_handler = DataHandler( + dataset_type="train", + data_params=data_params, + simPSF=simPSF, + n_bins_lambda=10, + load_data=False + ) + + data_handler.load_dataset() + data_handler.process_sed_data(mock_dataset["SEDs"]) + + with patch("wf_psf.data.data_handler.logger.warning") as mock_warning: + data_handler._validate_dataset_structure() + mock_warning.assert_called_with("Missing 'noisy_stars' in 'train' dataset.") + def test_get_obs_positions(mock_data): observed_positions = get_obs_positions(mock_data) From 152b0d83e700597a8b9b6d4d891063b212da59a8 Mon Sep 17 00:00:00 2001 From: Tobias Liaudat Date: Mon, 19 May 2025 14:45:24 +0200 Subject: [PATCH 061/135] automatic formatting --- src/wf_psf/data/data_handler.py | 31 +++++++++++++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/src/wf_psf/data/data_handler.py b/src/wf_psf/data/data_handler.py index 7e7b3881..b25881a2 100644 --- a/src/wf_psf/data/data_handler.py +++ b/src/wf_psf/data/data_handler.py @@ -198,11 +198,38 @@ def _convert_dataset_to_tensorflow(self): pass def process_sed_data(self, sed_data): - """Process SED Data. + """ + Generate and process SED (Spectral Energy Distribution) data. + + This method transforms raw SED inputs into TensorFlow tensors suitable for model input. + It generates wavelength-binned SED elements using the PSF simulator, converts the result + into a tensor, and transposes it to match the expected shape for training or inference. - A method to generate and process SED data. + Parameters + ---------- + sed_data : list or array-like + A list or array of raw SEDs, where each SED is typically a vector of flux values + or coefficients. These will be processed using the PSF simulator. + Raises + ------ + ValueError + If `sed_data` is None. + + Notes + ----- + The resulting tensor is stored in `self.sed_data` and has shape + `(num_samples, n_bins_lambda, n_components)`, where: + - `num_samples` is the number of SEDs, + - `n_bins_lambda` is the number of wavelength bins, + - `n_components` is the number of components per SED (e.g., filters or basis terms). + + The intermediate tensor is created with `tf.float64` for precision during generation, + but is converted to `tf.float32` after processing for use in training. """ + if sed_data is None: + raise ValueError("SED data must be provided explicitly or via dataset.") + self.sed_data = [ utils.generate_SED_elems_in_tensorflow( _sed, self.simPSF, n_bins=self.n_bins_lambda, tf_dtype=tf.float64 From de0892cd25a1e412e93aa0aaaf160a357ce70e87 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Wed, 18 Jun 2025 23:16:15 +0200 Subject: [PATCH 062/135] Refactor: add ZernikeInputs dataclass, ZernikeInputsFactory, helper methods for assembling zernike contributions according to run_type mode: training, simulation, or inference --- src/wf_psf/data/data_zernike_utils.py | 207 +++++++++++++++++++------- 1 file changed, 152 insertions(+), 55 deletions(-) diff --git a/src/wf_psf/data/data_zernike_utils.py b/src/wf_psf/data/data_zernike_utils.py index 760adb11..648b260c 100644 --- a/src/wf_psf/data/data_zernike_utils.py +++ b/src/wf_psf/data/data_zernike_utils.py @@ -11,14 +11,89 @@ """ +from dataclasses import dataclass +from typing import Optional, Union import numpy as np import tensorflow as tf -from typing import Optional +from wf_psf.utils.read_config import RecursiveNamespace import logging logger = logging.getLogger(__name__) +@dataclass +class ZernikeInputs: + zernike_prior: Optional[np.ndarray] # true prior, if provided (e.g. from PDC) + centroid_dataset: Optional[Union[dict, 'RecursiveNamespace']] # only used in training/simulation + misalignment_positions: Optional[np.ndarray] # needed for CCD corrections + batch_size: int + + +class ZernikeInputsFactory: + @staticmethod + def build(data, run_type: str, model_params, prior: Optional[np.ndarray] = None) -> ZernikeInputs: + """Builds a ZernikeInputs dataclass instance based on run type and data. + + Parameters + ---------- + data : Union[dict, DataConfigHandler] + Dataset object containing star positions, priors, and optionally pixel data. + run_type : str + One of 'training', 'simulation', or 'inference'. + model_params : RecursiveNamespace + Model parameters, including flags for prior/corrections. + prior : Optional[np.ndarray] + An explicitly passed prior (overrides any inferred one if provided). + + Returns + ------- + ZernikeInputs + """ + centroid_dataset = None + positions = None + + if run_type in {"training", "simulation"}: + centroid_dataset = data # Assuming data is a DataConfigHandler or similar object containing train and test datasets + positions = np.concatenate( + [ + data.training_dataset["positions"], + data.test_dataset["positions"] + ], + axis=0, + ) + + if model_params.use_prior: + if prior is not None: + logger.warning( + "Zernike prior explicitly provided; ignoring dataset-based prior despite use_prior=True." + ) + else: + prior = get_np_zernike_prior(data) + + elif run_type == "inference": + centroid_dataset = None + positions = data["positions"] + + if model_params.use_prior: + # Try to extract prior from `data`, if present + prior = getattr(data, "zernike_prior", None) if not isinstance(data, dict) else data.get("zernike_prior") + + if prior is None: + logger.warning( + "model_params.use_prior=True but no prior found in inference data. Proceeding with None." + ) + + else: + raise ValueError(f"Unsupported run_type: {run_type}") + + return ZernikeInputs( + zernike_prior=prior, + centroid_dataset=centroid_dataset, + misalignment_positions=positions, + batch_size=model_params.batch_size, + ) + + def get_np_zernike_prior(data): """Get the zernike prior from the provided dataset. @@ -45,80 +120,102 @@ def get_np_zernike_prior(data): return zernike_prior - -def get_zernike_prior(model_params, data, batch_size: int=16): - """Get Zernike priors from the provided dataset. - - This method concatenates the Zernike priors from both the training - and test datasets. +def pad_contribution_to_order(contribution: np.ndarray, max_order: int) -> np.ndarray: + """Pad a Zernike contribution array to the max Zernike order.""" + current_order = contribution.shape[1] + pad_width = ((0, 0), (0, max_order - current_order)) + return np.pad(contribution, pad_width=pad_width, mode="constant", constant_values=0) + +def combine_zernike_contributions(contributions: list[np.ndarray]) -> np.ndarray: + """Combine multiple Zernike contributions, padding each to the max order before summing.""" + if not contributions: + raise ValueError("No contributions provided.") + + max_order = max(contrib.shape[1] for contrib in contributions) + n_samples = contributions[0].shape[0] + + combined = np.zeros((n_samples, max_order), dtype=np.float32) + for contrib in contributions: + padded = pad_contribution_to_order(contrib, max_order) + combined += padded + + return combined + +def assemble_zernike_contributions( + model_params, + zernike_prior=None, + centroid_dataset=None, + positions=None, + batch_size=16, +): + """ + Assemble the total Zernike contribution map by combining the prior, + centroid correction, and CCD misalignment correction. Parameters ---------- model_params : RecursiveNamespace - Object containing parameters for this PSF model class. - data : DataConfigHandler - Object containing training and test datasets. - batch_size : int, optional - The batch size to use when processing the stars. Default is 16. + Parameters controlling which contributions to apply. + zernike_prior : Optional[np.ndarray or tf.Tensor] + The precomputed Zernike prior (e.g., from PDC or another model). + centroid_dataset : Optional[object] + Dataset used to compute centroid correction. Must have both training and test sets. + positions : Optional[np.ndarray or tf.Tensor] + Positions used for computing CCD misalignment. Must be available in inference mode. + batch_size : int + Batch size for centroid correction. Returns ------- tf.Tensor - Tensor containing the observed positions of the stars. - - Notes - ----- - The Zernike prior are obtained by concatenating the Zernike priors - from both the training and test datasets along the 0th axis. - + A tensor representing the full Zernike contribution map. """ - # List of zernike contribution + zernike_contribution_list = [] - if model_params.use_prior: - logger.info("Reading in Zernike prior into Zernike contribution list...") - zernike_contribution_list.append(get_np_zernike_prior(data)) + # Prior + if model_params.use_prior and zernike_prior is not None: + logger.info("Adding Zernike prior...") + if isinstance(zernike_prior, np.ndarray): + zernike_prior = tf.convert_to_tensor(zernike_prior, dtype=tf.float32) + zernike_contribution_list.append(zernike_prior) + else: + logger.info("Skipping Zernike prior (not used or not provided).") - if model_params.correct_centroids: - logger.info("Adding centroid correction to Zernike contribution list...") + # Centroid correction (tip/tilt) + if model_params.correct_centroids and centroid_dataset is not None: + logger.info("Computing centroid correction...") + centroid_correction = compute_centroid_correction( + model_params, centroid_dataset, batch_size=batch_size + ) zernike_contribution_list.append( - compute_centroid_correction(model_params, data, batch_size) + tf.convert_to_tensor(centroid_correction, dtype=tf.float32) ) - - if model_params.add_ccd_misalignments: - logger.info("Adding CCD mis-alignments to Zernike contribution list...") - zernike_contribution_list.append(compute_ccd_misalignment(model_params, data)) - - if len(zernike_contribution_list) == 1: - zernike_contribution = zernike_contribution_list[0] else: - # Get max zk order - max_zk_order = np.max( - np.array( - [ - zk_contribution.shape[1] - for zk_contribution in zernike_contribution_list - ] - ) - ) + logger.info("Skipping centroid correction (not enabled or no dataset).") - zernike_contribution = np.zeros( - (zernike_contribution_list[0].shape[0], max_zk_order) + # CCD misalignment (focus term) + if model_params.add_ccd_misalignments and positions is not None: + logger.info("Computing CCD misalignment correction...") + ccd_misalignment = compute_ccd_misalignment(model_params, positions) + zernike_contribution_list.append( + tf.convert_to_tensor(ccd_misalignment, dtype=tf.float32) ) + else: + logger.info("Skipping CCD misalignment correction (not enabled or no positions).") - # Pad arrays to get the same length and add the final contribution - for it in range(len(zernike_contribution_list)): - current_zk_order = zernike_contribution_list[it].shape[1] - current_zernike_contribution = np.pad( - zernike_contribution_list[it], - pad_width=[(0, 0), (0, int(max_zk_order - current_zk_order))], - mode="constant", - constant_values=0, - ) + # If no contributions, return zeros tensor to avoid crashes + if not zernike_contribution_list: + logger.warning("No Zernike contributions found. Returning zero tensor.") + # Infer batch size and zernike order from model_params + n_samples = 1 + n_zks = getattr(model_params.param_hparams, "n_zernikes", 10) + return tf.zeros((n_samples, n_zks), dtype=tf.float32) - zernike_contribution += current_zernike_contribution + combined_zernike_prior = combine_zernike_contributions(zernike_contribution_list) - return tf.convert_to_tensor(zernike_contribution, dtype=tf.float32) + return tf.convert_to_tensor(combined_zernike_prior, dtype=tf.float32) + def shift_x_y_to_zk1_2_wavediff(dxy, tel_focal_length=24.5, tel_diameter=1.2): From 8db8de74094704e1fbc6ecc534b9062dda2c580f Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Wed, 18 Jun 2025 23:18:57 +0200 Subject: [PATCH 063/135] Update docstring describing data_conf types permitted --- src/wf_psf/psf_models/psf_model_loader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/wf_psf/psf_models/psf_model_loader.py b/src/wf_psf/psf_models/psf_model_loader.py index 1d2e267f..797be8fc 100644 --- a/src/wf_psf/psf_models/psf_model_loader.py +++ b/src/wf_psf/psf_models/psf_model_loader.py @@ -25,8 +25,8 @@ def load_trained_psf_model(training_conf, data_conf, weights_path_pattern): training_conf : RecursiveNamespace Configuration object containing model parameters and training hyperparameters. Supports attribute-style access to nested fields. - data_conf : RecursiveNamespace - Configuration object containing data-related parameters. + data_conf : RecursiveNamespace or dict + Configuration RecursiveNamespace object or a dictionary containing data parameters (e.g. pixel data, positions, masks, etc). weights_path_pattern : str Glob-style pattern used to locate the model weights file. From 9f7f19e20df01907aec98a8a84289a5ec00ed9f0 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Sat, 21 Jun 2025 18:45:41 +0200 Subject: [PATCH 064/135] Move imports to method to avoid circular imports --- src/wf_psf/data/centroids.py | 2 +- src/wf_psf/instrument/ccd_misalignments.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/wf_psf/data/centroids.py b/src/wf_psf/data/centroids.py index 83ccf8c7..97cb069f 100644 --- a/src/wf_psf/data/centroids.py +++ b/src/wf_psf/data/centroids.py @@ -10,7 +10,6 @@ import scipy.signal as scisig from wf_psf.data.data_handler import extract_star_data from fractions import Fraction -from wf_psf.data.data_zernike_utils import shift_x_y_to_zk1_2_wavediff import tensorflow as tf from typing import Optional @@ -127,6 +126,7 @@ def compute_zernike_tip_tilt( - This function processes all images at once using vectorized operations. - The Zernike coefficients are computed in the WaveDiff convention. """ + from wf_psf.data.data_zernike_utils import shift_x_y_to_zk1_2_wavediff # Vectorize the centroid computation centroid_estimator = CentroidEstimator( im=star_images, mask=star_masks, sigma_init=sigma_init, n_iter=n_iter diff --git a/src/wf_psf/instrument/ccd_misalignments.py b/src/wf_psf/instrument/ccd_misalignments.py index 35343886..76386f20 100644 --- a/src/wf_psf/instrument/ccd_misalignments.py +++ b/src/wf_psf/instrument/ccd_misalignments.py @@ -10,7 +10,7 @@ import numpy as np import matplotlib.path as mpltPath from scipy.spatial import KDTree -from wf_psf.data.data_preprocessing import defocus_to_zk4_wavediff +from wf_psf.data.data_handler import get_np_obs_positions def compute_ccd_misalignment(model_params, data): @@ -386,6 +386,8 @@ def get_zk4_from_position(self, pos): Zernike 4 value in wavediff convention corresponding to the delta z of the given input position `pos`. """ + from wf_psf.data.data_zernike_utils import defocus_to_zk4_wavediff + dz = self.get_dz_from_position(pos) return defocus_to_zk4_wavediff(dz, self.tel_focal_length, self.tel_diameter) From 3c84c1f4dcdd4912cc785d045b56fb7ccd80bae6 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Sat, 21 Jun 2025 18:47:35 +0200 Subject: [PATCH 065/135] Remove batch_size arg from ZernikeInputsFactory ; raise ValueError to check Zernike contributions have the same number of samples --- src/wf_psf/data/data_zernike_utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/wf_psf/data/data_zernike_utils.py b/src/wf_psf/data/data_zernike_utils.py index 648b260c..b03ff400 100644 --- a/src/wf_psf/data/data_zernike_utils.py +++ b/src/wf_psf/data/data_zernike_utils.py @@ -15,6 +15,8 @@ from typing import Optional, Union import numpy as np import tensorflow as tf +from wf_psf.data.centroids import compute_centroid_correction +from wf_psf.instrument.ccd_misalignments import compute_ccd_misalignment from wf_psf.utils.read_config import RecursiveNamespace import logging @@ -26,7 +28,6 @@ class ZernikeInputs: zernike_prior: Optional[np.ndarray] # true prior, if provided (e.g. from PDC) centroid_dataset: Optional[Union[dict, 'RecursiveNamespace']] # only used in training/simulation misalignment_positions: Optional[np.ndarray] # needed for CCD corrections - batch_size: int class ZernikeInputsFactory: @@ -89,8 +90,7 @@ def build(data, run_type: str, model_params, prior: Optional[np.ndarray] = None) return ZernikeInputs( zernike_prior=prior, centroid_dataset=centroid_dataset, - misalignment_positions=positions, - batch_size=model_params.batch_size, + misalignment_positions=positions ) @@ -133,6 +133,8 @@ def combine_zernike_contributions(contributions: list[np.ndarray]) -> np.ndarray max_order = max(contrib.shape[1] for contrib in contributions) n_samples = contributions[0].shape[0] + if any(c.shape[0] != n_samples for c in contributions): + raise ValueError("All contributions must have the same number of samples.") combined = np.zeros((n_samples, max_order), dtype=np.float32) for contrib in contributions: From 362ed76396c78b9ce19c939a3d69ab1a8930894a Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Sat, 21 Jun 2025 18:49:32 +0200 Subject: [PATCH 066/135] Add and set run_type attribute to DataConfigHandler object in TrainingConfigHandler constructor --- src/wf_psf/utils/configs_handler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/wf_psf/utils/configs_handler.py b/src/wf_psf/utils/configs_handler.py index a7e42ceb..f59fdc17 100644 --- a/src/wf_psf/utils/configs_handler.py +++ b/src/wf_psf/utils/configs_handler.py @@ -188,6 +188,7 @@ def __init__(self, training_conf, file_handler): self.training_conf.training.training_hparams.batch_size, self.training_conf.training.load_data_on_init, ) + self.data_conf.run_type = "training" self.file_handler.copy_conffile_to_output_dir( self.training_conf.training.data_config ) From c8f060ec44e7480f57acc1ee3549363156ed144b Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Sat, 21 Jun 2025 18:51:04 +0200 Subject: [PATCH 067/135] Add and set run_type attribute ; Replace var name end with end_sample in PSFInferenceEngine --- src/wf_psf/inference/psf_inference.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index 6f798245..5be49343 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -162,6 +162,7 @@ def data_handler(self): load_data=False, dataset=None, ) + self._data_handler.run_type = "inference" return self._data_handler @property @@ -272,7 +273,7 @@ def compute_psfs(self, positions: tf.Tensor, sed_data: tf.Tensor) -> np.ndarray: counter = 0 while counter < n_samples: # Calculate the batch end element - end = min(counter + self.batch_size, n_samples) + end_sample = min(counter + self.batch_size, n_samples) # Define the batch positions batch_pos = positions[counter:end_sample, :] @@ -281,10 +282,10 @@ def compute_psfs(self, positions: tf.Tensor, sed_data: tf.Tensor) -> np.ndarray: # Generate PSFs for the current batch batch_psfs = self.trained_model(batch_inputs, training=False) - self.inferred_psfs[counter:end, :, :] = batch_psfs.numpy() + self.inferred_psfs[counter:end_sample, :, :] = batch_psfs.numpy() # Update the counter - counter = end + counter = end_sample return self._inferred_psfs From 8458a9ad73bc0d862ac270755e8ba900712ad7c7 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Sat, 21 Jun 2025 18:52:26 +0200 Subject: [PATCH 068/135] Refactor TFPhysicalPolychromaticField to lazy load property objects and attributes dynamically at run-time according to the run_type: training or inference --- .../psf_model_physical_polychromatic.py | 339 ++++++++---------- 1 file changed, 148 insertions(+), 191 deletions(-) diff --git a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py index baec8923..971dbb63 100644 --- a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py +++ b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py @@ -11,7 +11,7 @@ import tensorflow as tf from tensorflow.python.keras.engine import data_adapter from wf_psf.data.data_handler import get_obs_positions -from wf_psf.data.data_zernike_utils import get_zernike_prior +from wf_psf.data.data_zernike_utils import ZernikeInputsFactory, assemble_zernike_contributions from wf_psf.psf_models import psf_models as psfm from wf_psf.psf_models.tf_modules.tf_layers import ( TFPolynomialZernikeField, @@ -98,8 +98,8 @@ def __init__(self, model_params, training_params, data, coeff_mat=None): A Recursive Namespace object containing parameters for this PSF model class. training_params: Recursive Namespace A Recursive Namespace object containing training hyperparameters for this PSF model class. - data: DataConfigHandler - A DataConfigHandler object that provides access to training and test datasets, as well as prior knowledge like Zernike coefficients. + data: DataConfigHandler or dict + A DataConfigHandler object or dict that provides access to single or multiple datasets (e.g. train and test), as well as prior knowledge like Zernike coefficients. coeff_mat: Tensor or None, optional Coefficient matrix defining the parametric PSF field model. @@ -109,204 +109,151 @@ def __init__(self, model_params, training_params, data, coeff_mat=None): Initialized instance of the TFPhysicalPolychromaticField class. """ super().__init__(model_params, training_params, coeff_mat) - self._initialize_parameters_and_layers( - model_params, training_params, data, coeff_mat - ) + self.model_params = model_params + self.training_params = training_params + self.data = data + self.run_type = data.run_type - def _initialize_parameters_and_layers( - self, - model_params: RecursiveNamespace, - training_params: RecursiveNamespace, - data: DataConfigHandler, - coeff_mat: Optional[tf.Tensor] = None, - ): - """Initialize Parameters of the PSF model. - - This method sets up the PSF model parameters, observational positions, - Zernike coefficients, and components required for the automatically - differentiable optical forward model. - - Parameters - ---------- - model_params: Recursive Namespace - A Recursive Namespace object containing parameters for this PSF model class. - training_params: Recursive Namespace - A Recursive Namespace object containing training hyperparameters for this PSF model class. - data: DataConfigHandler object - A DataConfigHandler object providing access to training and tests datasets, as well as prior knowledge like Zernike coefficients. - coeff_mat: Tensor or None, optional - Initialization of the coefficient matrix defining the parametric psf field model. - - Notes - ----- - - Initializes Zernike parameters based on dataset priors. - - Configures the PSF model layers according to `model_params`. - - If `coeff_mat` is provided, the model coefficients are updated accordingly. - """ + # Initialize the model parameters and layers self.output_Q = model_params.output_Q - self.obs_pos = get_obs_positions(data) self.l2_param = model_params.param_hparams.l2_param - # Inputs: Save optimiser history Parametric model features - self.save_optim_history_param = ( - model_params.param_hparams.save_optim_history_param - ) - # Inputs: Save optimiser history NonParameteric model features - self.save_optim_history_nonparam = ( - model_params.nonparam_hparams.save_optim_history_nonparam - ) - self._initialize_zernike_parameters(model_params, data) - self._initialize_layers(model_params, training_params) + self.output_dim = model_params.output_dim + + # Initialise lazy loading of external Zernike prior + self._external_prior = None # Initialize the model parameters with non-default value if coeff_mat is not None: self.assign_coeff_matrix(coeff_mat) - def _initialize_zernike_parameters(self, model_params, data): - """Initialize the Zernike parameters. - - Parameters - ---------- - model_params: Recursive Namespace - A Recursive Namespace object containing parameters for this PSF model class. - data: DataConfigHandler object - A DataConfigHandler object providing access to training and tests datasets, as well as prior knowledge like Zernike coefficients. - """ - self.zks_prior = get_zernike_prior(model_params, data, data.batch_size) - self.n_zks_total = max( - model_params.param_hparams.n_zernikes, - tf.cast(tf.shape(self.zks_prior)[1], tf.int32), + def _assemble_zernike_contributions(self): + zks_inputs = ZernikeInputsFactory.build( + data=self.data, + run_type=self.run_type, + model_params=self.model_params, + prior=self._external_prior if hasattr(self, "_external_prior") else None, ) - self.zernike_maps = psfm.generate_zernike_maps_3d( - self.n_zks_total, model_params.pupil_diameter + return assemble_zernike_contributions( + model_params=self.model_params, + zernike_prior=zks_inputs.zernike_prior, + centroid_dataset=zks_inputs.centroid_dataset, + positions=zks_inputs.misalignment_positions, + batch_size=self.training_params.batch_size, ) - def _initialize_layers(self, model_params, training_params): - """Initialize the layers of the PSF model. - - This method initializes the layers of the PSF model, including the physical layer, polynomial Zernike field, batch polychromatic layer, and non-parametric OPD layer. - - Parameters - ---------- - model_params: Recursive Namespace - A Recursive Namespace object containing parameters for this PSF model class. - training_params: Recursive Namespace - A Recursive Namespace object containing training hyperparameters for this PSF model class. - coeff_mat: Tensor or None, optional - Initialization of the coefficient matrix defining the parametric psf field model. - - """ - self._initialize_physical_layer(model_params) - self._initialize_polynomial_Z_field(model_params) - self._initialize_Zernike_OPD(model_params) - self._initialize_batch_polychromatic_layer(model_params, training_params) - self._initialize_nonparametric_opd_layer(model_params, training_params) - - def _initialize_physical_layer(self, model_params): - """Initialize the physical layer of the PSF model. - - This method initializes the physical layer of the PSF model using parameters - specified in the `model_params` object. - - Parameters - ---------- - model_params: Recursive Namespace - A Recursive Namespace object containing parameters for this PSF model class. - """ - self.tf_physical_layer = TFPhysicalLayer( - self.obs_pos, - self.zks_prior, - interpolation_type=model_params.interpolation_type, - interpolation_args=model_params.interpolation_args, + @property + def save_param_history(self) -> bool: + """Check if the model should save the optimization history for parametric features.""" + return getattr(self.model_params.param_hparams, "save_optim_history_param", False) + + @property + def save_nonparam_history(self) -> bool: + """Check if the model should save the optimization history for non-parametric features.""" + return getattr(self.model_params.nonparam_hparams, "save_optim_history_nonparam", False) + + + # === Lazy properties ===. + @property + def obs_pos(self): + """Lazy loading of the observation positions.""" + if not hasattr(self, "_obs_pos"): + if self.run_type == "training" or self.run_type == "simulation": + # Get the observation positions from the data handler + self._obs_pos = get_obs_positions(self.data) + elif self.run_type == "inference": + # For inference, we might not have a data handler, so we use the model parameters + self._obs_pos = self.data.dataset["positions"] + return self._obs_pos + + @property + def zks_total_contribution(self): + """Lazily load all Zernike contributions, including prior and corrections.""" + if not hasattr(self, "_zks_total_contribution"): + self._zks_total_contribution = self._assemble_zernike_contributions() + return self._zks_total_contribution + + @property + def n_zks_total(self): + """Get the total number of Zernike coefficients.""" + if not hasattr(self, "_n_zks_total"): + self._n_zks_total = max( + self.model_params.param_hparams.n_zernikes, + tf.cast(tf.shape(self.zks_total_contribution)[1], tf.int32), ) - - def _initialize_polynomial_Z_field(self, model_params): - """Initialize the polynomial Zernike field of the PSF model. - - This method initializes the polynomial Zernike field of the PSF model using - parameters specified in the `model_params` object. - - Parameters - ---------- - model_params: Recursive Namespace - A Recursive Namespace object containing parameters for this PSF model class. - - """ - self.tf_poly_Z_field = TFPolynomialZernikeField( - x_lims=model_params.x_lims, - y_lims=model_params.y_lims, - random_seed=model_params.param_hparams.random_seed, - n_zernikes=model_params.param_hparams.n_zernikes, - d_max=model_params.param_hparams.d_max, + return self._n_zks_total + + @property + def zernike_maps(self): + """Lazy loading of the Zernike maps.""" + if not hasattr(self, "_zernike_maps"): + self._zernike_maps = psfm.generate_zernike_maps_3d( + self.n_zks_total, self.model_params.pupil_diameter ) + return self._zernike_maps + + @property + def tf_poly_Z_field(self): + """Lazy loading of the polynomial Zernike field layer.""" + if not hasattr(self, "_tf_poly_Z_field"): + self._tf_poly_Z_field = TFPolynomialZernikeField( + x_lims=self.model_params.x_lims, + y_lims=self.model_params.y_lims, + random_seed=self.model_params.param_hparams.random_seed, + n_zernikes=self.model_params.param_hparams.n_zernikes, + d_max=self.model_params.param_hparams.d_max, + ) + return self._tf_poly_Z_field - def _initialize_Zernike_OPD(self, model_params): - """Initialize the Zernike OPD field of the PSF model. - - This method initializes the Zernike Optical Path Difference - field of the PSF model using parameters specified in the `model_params` object. - - Parameters - ---------- - model_params: Recursive Namespace - A Recursive Namespace object containing parameters for this PSF model class. - - """ - # Initialize the zernike to OPD layer - self.tf_zernike_OPD = TFZernikeOPD(zernike_maps=self.zernike_maps) - - def _initialize_batch_polychromatic_layer(self, model_params, training_params): - """Initialize the batch polychromatic PSF layer. - - This method initializes the batch opd to batch polychromatic PSF layer - using the provided `model_params` and `training_params`. - - Parameters - ---------- - model_params: Recursive Namespace - A Recursive Namespace object containing parameters for this PSF model class. - training_params: Recursive Namespace - A Recursive Namespace object containing training hyperparameters for this PSF model class. - - - """ - self.batch_size = training_params.batch_size - self.obscurations = psfm.tf_obscurations( - pupil_diam=model_params.pupil_diameter, - N_filter=model_params.LP_filter_length, - rotation_angle=model_params.obscuration_rotation_angle, - ) - self.output_dim = model_params.output_dim + @tf_poly_Z_field.deleter + def tf_poly_Z_field(self): + del self._tf_poly_Z_field - self.tf_batch_poly_PSF = TFBatchPolychromaticPSF( - obscurations=self.obscurations, - output_Q=self.output_Q, - output_dim=self.output_dim, + @property + def tf_physical_layer(self): + """Lazy loading of the physical layer of the PSF model.""" + if not hasattr(self, "_tf_physical_layer"): + self._tf_physical_layer = TFPhysicalLayer( + self.obs_pos, + self.zks_total_contribution, + interpolation_type=self.model_params.interpolation_type, + interpolation_args=self.model_params.interpolation_args, ) + + @property + def tf_zernike_OPD(self): + """Lazy loading of the Zernike Optical Path Difference (OPD) layer.""" + if not hasattr(self, "_tf_zernike_OPD"): + self._tf_zernike_OPD = TFZernikeOPD(zernike_maps=self.zernike_maps) + return self._tf_zernike_OPD + + @property + def tf_batch_poly_PSF(self): + """Lazily initialize the batch polychromatic PSF layer.""" + if not hasattr(self, "_tf_batch_poly_PSF"): + obscurations = psfm.tf_obscurations( + pupil_diam=self.model_params.pupil_diameter, + N_filter=self.model_params.LP_filter_length, + rotation_angle=self.model_params.obscuration_rotation_angle, + ) - def _initialize_nonparametric_opd_layer(self, model_params, training_params): - """Initialize the non-parametric OPD layer. - - This method initializes the non-parametric OPD layer using the provided - `model_params` and `training_params`. - - Parameters - ---------- - model_params: Recursive Namespace - A Recursive Namespace object containing parameters for this PSF model class. - training_params: Recursive Namespace - A Recursive Namespace object containing training hyperparameters for this PSF model class. - - """ - # self.d_max_nonparam = model_params.nonparam_hparams.d_max_nonparam - # self.opd_dim = tf.shape(self.zernike_maps)[1].numpy() - - self.tf_np_poly_opd = TFNonParametricPolynomialVariationsOPD( - x_lims=model_params.x_lims, - y_lims=model_params.y_lims, - random_seed=model_params.param_hparams.random_seed, - d_max=model_params.nonparam_hparams.d_max_nonparam, - opd_dim=tf.shape(self.zernike_maps)[1].numpy(), - ) + self._tf_batch_poly_PSF = TFBatchPolychromaticPSF( + obscurations=obscurations, + output_Q=self.output_Q, + output_dim=self.output_dim, + ) + return self._tf_batch_poly_PSF + + @property + def tf_np_poly_opd(self): + """Lazy loading of the non-parametric polynomial variations OPD layer.""" + if not hasattr(self, "_tf_np_poly_opd"): + self._tf_np_poly_opd = TFNonParametricPolynomialVariationsOPD( + x_lims=self.model_params.x_lims, + y_lims=self.model_params.y_lims, + d_max=self.model_params.nonparam_hparams.d_max_nonparam, + opd_dim=tf.shape(self.zernike_maps)[1].numpy(), + ) + return self._tf_np_poly_opd def get_coeff_matrix(self): """Get coefficient matrix.""" @@ -331,23 +278,21 @@ def assign_coeff_matrix(self, coeff_mat: Optional[tf.Tensor]) -> None: """ self.tf_poly_Z_field.assign_coeff_matrix(coeff_mat) + def set_output_Q(self, output_Q: float, output_dim: Optional[int] = None) -> None: """Set the output sampling rate (output_Q) for PSF generation. This method updates the `output_Q` parameter, which defines the resampling factor for generating PSFs at different resolutions - relative to the telescope's native sampling. It also allows optionally - updating `output_dim`, which sets the output resolution of the PSF model. + relative to the telescope's native sampling. It also allows optionally updating `output_dim`, which sets the output resolution of the PSF model. If `output_dim` is provided, the PSF model's output resolution is updated. - The method then reinitializes the batch polychromatic PSF generator - to reflect the updated parameters. + The method then reinitializes the batch polychromatic PSF generator to reflect the updated parameters. Parameters ---------- output_Q : float - The resampling factor that determines the output PSF resolution - relative to the telescope's native sampling. + The resampling factor that determines the output PSF resolution relative to the telescope's native sampling. output_dim : Optional[int], default=None The new output dimension for the PSF model. If `None`, the output dimension remains unchanged. @@ -359,6 +304,7 @@ def set_output_Q(self, output_Q: float, output_dim: Optional[int] = None) -> Non self.output_Q = output_Q if output_dim is not None: self.output_dim = output_dim + # Reinitialize the PSF batch poly generator self.tf_batch_poly_PSF = TFBatchPolychromaticPSF( obscurations=self.obscurations, @@ -472,12 +418,16 @@ def predict_step(self, data, evaluate_step=False): # Compute zernikes from parametric model and physical layer zks_coeffs = self.predict_zernikes(input_positions) + # Propagate to obtain the OPD param_opd_maps = self.tf_zernike_OPD(zks_coeffs) + # Calculate the non parametric part nonparam_opd_maps = self.tf_np_poly_opd(input_positions) + # Add the estimations opd_maps = tf.math.add(param_opd_maps, nonparam_opd_maps) + # Compute the polychromatic PSFs poly_psfs = self.tf_batch_poly_PSF([opd_maps, packed_SEDs]) @@ -520,10 +470,13 @@ def predict_mono_psfs(self, input_positions, lambda_obs, phase_N): # Predict zernikes from parametric model and physical layer zks_coeffs = self.predict_zernikes(input_positions) + # Propagate to obtain the OPD param_opd_maps = self.tf_zernike_OPD(zks_coeffs) + # Calculate the non parametric part nonparam_opd_maps = self.tf_np_poly_opd(input_positions) + # Add the estimations opd_maps = tf.math.add(param_opd_maps, nonparam_opd_maps) @@ -548,10 +501,13 @@ def predict_opd(self, input_positions): """ # Predict zernikes from parametric model and physical layer zks_coeffs = self.predict_zernikes(input_positions) + # Propagate to obtain the OPD param_opd_maps = self.tf_zernike_OPD(zks_coeffs) + # Calculate the non parametric part nonparam_opd_maps = self.tf_np_poly_opd(input_positions) + # Add the estimations opd_maps = tf.math.add(param_opd_maps, nonparam_opd_maps) @@ -589,6 +545,7 @@ def compute_zernikes(self, input_positions): padded_zernike_params, padded_zernike_prior = self.pad_zernikes( zernike_params, zernike_prior ) + zernike_coeffs = tf.math.add(padded_zernike_params, padded_zernike_prior) return zernike_coeffs @@ -685,7 +642,7 @@ def call(self, inputs, training=True): packed_SEDs = inputs[1] # For the training - if training: + if training: # Compute zernikes from parametric model and physical layer zks_coeffs = self.compute_zernikes(input_positions) From 61fb8dc9035036b2d53282807cbd908d48b02d0a Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Sat, 21 Jun 2025 18:53:46 +0200 Subject: [PATCH 069/135] Update/Add unit tests to test refactoring changes to psf_model_physical_polychromatic.py and data_zernike_utils.py --- .../test_data/data_zernike_utils_test.py | 296 ++++++++++++++++-- .../psf_model_physical_polychromatic_test.py | 202 ++---------- 2 files changed, 310 insertions(+), 188 deletions(-) diff --git a/src/wf_psf/tests/test_data/data_zernike_utils_test.py b/src/wf_psf/tests/test_data/data_zernike_utils_test.py index 692624be..90994dde 100644 --- a/src/wf_psf/tests/test_data/data_zernike_utils_test.py +++ b/src/wf_psf/tests/test_data/data_zernike_utils_test.py @@ -1,40 +1,296 @@ import pytest import numpy as np +from unittest.mock import MagicMock, patch import tensorflow as tf from wf_psf.data.data_zernike_utils import ( - get_zernike_prior, + ZernikeInputs, + ZernikeInputsFactory, + get_np_zernike_prior, + pad_contribution_to_order, + combine_zernike_contributions, + assemble_zernike_contributions, compute_zernike_tip_tilt, ) from wf_psf.tests.test_data.test_data_utils import MockData, MockDataset +from types import SimpleNamespace as RecursiveNamespace -def test_get_zernike_prior(model_params, mock_data): - zernike_priors = get_zernike_prior(model_params, mock_data) - expected_shape = ( - 4, - 2, - ) # Assuming 2 Zernike priors for each dataset (training and test) - assert zernike_priors.shape == expected_shape +@pytest.fixture +def mock_model_params(): + return RecursiveNamespace( + use_prior=True, + correct_centroids=True, + add_ccd_misalignments=True, + param_hparams=RecursiveNamespace(n_zernikes=6), + ) + +@pytest.fixture +def dummy_prior(): + return np.ones((4, 6), dtype=np.float32) + +@pytest.fixture +def dummy_positions(): + return np.random.rand(4, 2).astype(np.float32) + +@pytest.fixture +def dummy_centroid_dataset(): + return {"training": "dummy_train", "test": "dummy_test"} + +def test_training_without_prior(mock_model_params): + mock_model_params.use_prior = False + data = MagicMock() + data.training_dataset = {"positions": np.ones((2, 2))} + data.test_dataset = {"positions": np.zeros((3, 2))} + + zinputs = ZernikeInputsFactory.build(data=data, run_type="training", model_params=mock_model_params) + + assert zinputs.centroid_dataset is data + assert zinputs.zernike_prior is None + np.testing.assert_array_equal( + zinputs.misalignment_positions, + np.concatenate([data.training_dataset["positions"], data.test_dataset["positions"]]) + ) + +@patch("wf_psf.data.data_zernike_utils.get_np_zernike_prior") +def test_training_with_dataset_prior(mock_get_prior, mock_model_params): + mock_model_params.use_prior = True + data = MagicMock() + data.training_dataset = {"positions": np.ones((2, 2))} + data.test_dataset = {"positions": np.zeros((2, 2))} + mock_get_prior.return_value = np.array([1.0, 2.0, 3.0]) + + zinputs = ZernikeInputsFactory.build(data=data, run_type="training", model_params=mock_model_params) + + assert zinputs.zernike_prior.tolist() == [1.0, 2.0, 3.0] + mock_get_prior.assert_called_once_with(data) + +def test_training_with_explicit_prior(mock_model_params, caplog): + mock_model_params.use_prior = True + data = MagicMock() + data.training_dataset = {"positions": np.ones((1, 2))} + data.test_dataset = {"positions": np.zeros((1, 2))} + + explicit_prior = np.array([9.0, 9.0, 9.0]) + + with caplog.at_level("WARNING"): + zinputs = ZernikeInputsFactory.build(data, "training", mock_model_params, prior=explicit_prior) + + assert "Zernike prior explicitly provided" in caplog.text + assert (zinputs.zernike_prior == explicit_prior).all() + + +def test_inference_with_dict_and_prior(mock_model_params): + mock_model_params.use_prior = True + data = { + "positions": np.ones((5, 2)), + "zernike_prior": np.array([42.0, 0.0]) + } + + zinputs = ZernikeInputsFactory.build(data, "inference", mock_model_params) + + assert zinputs.centroid_dataset is None + assert (zinputs.zernike_prior == data["zernike_prior"]).all() + np.testing.assert_array_equal(zinputs.misalignment_positions, data["positions"]) + + +def test_invalid_run_type(mock_model_params): + data = {"positions": np.ones((2, 2))} + with pytest.raises(ValueError, match="Unsupported run_type"): + ZernikeInputsFactory.build(data, "invalid_mode", mock_model_params) -def test_get_zernike_prior_dtype(model_params, mock_data): - zernike_priors = get_zernike_prior(model_params, mock_data) - assert zernike_priors.dtype == np.float32 +def test_get_np_zernike_prior(): + # Mock training and test data + training_prior = np.array([[1, 2, 3], [4, 5, 6]]) + test_prior = np.array([[7, 8, 9]]) -def test_get_zernike_prior_concatenation(model_params, mock_data): - zernike_priors = get_zernike_prior(model_params, mock_data) - expected_zernike_priors = tf.convert_to_tensor( - np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6], [0.7, 0.8]]), dtype=tf.float32 + # Construct fake DataConfigHandler structure using RecursiveNamespace + data = RecursiveNamespace( + training_data=RecursiveNamespace(dataset={"zernike_prior": training_prior}), + test_data=RecursiveNamespace(dataset={"zernike_prior": test_prior}) ) - assert np.array_equal(zernike_priors, expected_zernike_priors) + expected_prior = np.concatenate((training_prior, test_prior), axis=0) + result = get_np_zernike_prior(data) -def test_get_zernike_prior_empty_data(model_params): - empty_data = MockData(np.array([]), np.array([]), np.array([]), np.array([])) - zernike_priors = get_zernike_prior(model_params, empty_data) - assert zernike_priors.shape == tf.TensorShape([0]) # Check for empty array shape + # Assert shape and values match expected + np.testing.assert_array_equal(result, expected_prior) + +def test_pad_contribution_to_order(): + # Input: batch of 2 samples, each with 3 Zernike coefficients + input_contribution = np.array([ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + ]) + + max_order = 5 # Target size: pad to 5 coefficients + + expected_output = np.array([ + [1.0, 2.0, 3.0, 0.0, 0.0], + [4.0, 5.0, 6.0, 0.0, 0.0], + ]) + + padded = pad_contribution_to_order(input_contribution, max_order) + + assert padded.shape == (2, 5), "Output shape should match padded shape" + np.testing.assert_array_equal(padded, expected_output) + + +def test_no_padding_needed(): + """If current order equals max_order, return should be unchanged.""" + input_contribution = np.array([[1, 2, 3], [4, 5, 6]]) + max_order = 3 + output = pad_contribution_to_order(input_contribution, max_order) + assert output.shape == input_contribution.shape + np.testing.assert_array_equal(output, input_contribution) + +def test_padding_to_much_higher_order(): + """Pad from order 2 to order 10.""" + input_contribution = np.array([[1, 2], [3, 4]]) + max_order = 10 + expected_output = np.hstack([input_contribution, np.zeros((2, 8))]) + output = pad_contribution_to_order(input_contribution, max_order) + assert output.shape == (2, 10) + np.testing.assert_array_equal(output, expected_output) + +def test_empty_contribution(): + """Test behavior with empty input array (0 features).""" + input_contribution = np.empty((3, 0)) # 3 samples, 0 coefficients + max_order = 4 + expected_output = np.zeros((3, 4)) + output = pad_contribution_to_order(input_contribution, max_order) + assert output.shape == (3, 4) + np.testing.assert_array_equal(output, expected_output) + + +def test_zero_samples(): + """Test with zero samples (empty batch).""" + input_contribution = np.empty((0, 3)) # 0 samples, 3 coefficients + max_order = 5 + expected_output = np.empty((0, 5)) + output = pad_contribution_to_order(input_contribution, max_order) + assert output.shape == (0, 5) + np.testing.assert_array_equal(output, expected_output) + + +def test_combine_zernike_contributions_basic_case(): + """Combine two contributions with matching sample count and varying order.""" + contrib1 = np.array([[1, 2], [3, 4]]) # shape (2, 2) + contrib2 = np.array([[5], [6]]) # shape (2, 1) + expected = np.array([ + [1 + 5, 2 + 0], + [3 + 6, 4 + 0] + ]) # padded contrib2 to (2, 2) + result = combine_zernike_contributions([contrib1, contrib2]) + np.testing.assert_array_equal(result, expected) + +def test_combine_multiple_contributions(): + """Combine three contributions.""" + c1 = np.array([[1, 2, 3]]) # shape (1, 3) + c2 = np.array([[4, 5]]) # shape (1, 2) + c3 = np.array([[6]]) # shape (1, 1) + expected = np.array([[1+4+6, 2+5+0, 3+0+0]]) # shape (1, 3) + result = combine_zernike_contributions([c1, c2, c3]) + np.testing.assert_array_equal(result, expected) + +def test_empty_input_list(): + """Raise ValueError when input list is empty.""" + with pytest.raises(ValueError, match="No contributions provided."): + combine_zernike_contributions([]) + +def test_inconsistent_sample_count(): + """Raise error or produce incorrect shape if contributions have inconsistent sample counts.""" + c1 = np.array([[1, 2], [3, 4]]) # shape (2, 2) + c2 = np.array([[5, 6]]) # shape (1, 2) + with pytest.raises(ValueError): + combine_zernike_contributions([c1, c2]) + +def test_single_contribution(): + """Combining a single contribution should return the same array (no-op).""" + contrib = np.array([[7, 8, 9], [10, 11, 12]]) + result = combine_zernike_contributions([contrib]) + np.testing.assert_array_equal(result, contrib) + +def test_zero_order_contributions(): + """Contributions with 0 Zernike coefficients.""" + contrib1 = np.empty((2, 0)) # 2 samples, 0 coefficients + contrib2 = np.empty((2, 0)) + expected = np.empty((2, 0)) + result = combine_zernike_contributions([contrib1, contrib2]) + assert result.shape == (2, 0) + np.testing.assert_array_equal(result, expected) + +@patch("wf_psf.data.data_zernike_utils.compute_centroid_correction") +@patch("wf_psf.data.data_zernike_utils.compute_ccd_misalignment") +def test_full_contribution_combination(mock_ccd, mock_centroid, mock_model_params, dummy_prior, dummy_centroid_dataset, dummy_positions): + mock_centroid.return_value = np.full((4, 6), 2.0) + mock_ccd.return_value = np.full((4, 6), 3.0) + + result = assemble_zernike_contributions( + model_params=mock_model_params, + zernike_prior=dummy_prior, + centroid_dataset=dummy_centroid_dataset, + positions=dummy_positions + ) + + expected = dummy_prior + 2.0 + 3.0 + np.testing.assert_allclose(result.numpy(), expected) + +def test_prior_only(mock_model_params, dummy_prior): + mock_model_params.correct_centroids = False + mock_model_params.add_ccd_misalignments = False + + result = assemble_zernike_contributions( + model_params=mock_model_params, + zernike_prior=dummy_prior, + centroid_dataset=None, + positions=None + ) + + np.testing.assert_array_equal(result.numpy(), dummy_prior) + +def test_no_contributions_returns_zeros(): + model_params = RecursiveNamespace( + use_prior=False, + correct_centroids=False, + add_ccd_misalignments=False, + param_hparams=RecursiveNamespace(n_zernikes=8), + ) + + result = assemble_zernike_contributions(model_params) + + assert isinstance(result, tf.Tensor) + assert result.shape == (1, 8) + np.testing.assert_array_equal(result.numpy(), np.zeros((1, 8))) + +def test_prior_as_tensor(mock_model_params): + tensor_prior = tf.ones((4, 6), dtype=tf.float32) + + mock_model_params.correct_centroids = False + mock_model_params.add_ccd_misalignments = False + + result = assemble_zernike_contributions( + model_params=mock_model_params, + zernike_prior=tensor_prior + ) + + assert isinstance(result, tf.Tensor) + np.testing.assert_array_equal(result.numpy(), np.ones((4, 6))) + +@patch("wf_psf.data.data_zernike_utils.compute_centroid_correction") +def test_inconsistent_shapes_raises_error(mock_centroid, mock_model_params, dummy_prior, dummy_centroid_dataset): + mock_model_params.add_ccd_misalignments = False + mock_centroid.return_value = np.ones((5, 6)) # 5 samples instead of 4 + + with pytest.raises(ValueError, match="All contributions must have the same number of samples"): + assemble_zernike_contributions( + model_params=mock_model_params, + zernike_prior=dummy_prior, + centroid_dataset=dummy_centroid_dataset, + positions=None + ) def test_compute_zernike_tip_tilt_single_batch(mocker, simple_image, identity_mask): diff --git a/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py b/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py index 7e967465..13d78667 100644 --- a/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py +++ b/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py @@ -9,6 +9,7 @@ import pytest import numpy as np import tensorflow as tf +from unittest.mock import PropertyMock from wf_psf.psf_models.models.psf_model_physical_polychromatic import ( TFPhysicalPolychromaticField, ) @@ -31,6 +32,7 @@ def zks_prior(): def mock_data(mocker): mock_instance = mocker.Mock(spec=DataConfigHandler) # Configure the mock data object to have the necessary attributes + mock_instance.run_type = "training" mock_instance.training_data = mocker.Mock() mock_instance.training_data.dataset = {"positions": np.array([[1, 2], [3, 4]])} mock_instance.test_data = mocker.Mock() @@ -46,145 +48,11 @@ def mock_model_params(mocker): model_params_mock.pupil_diameter = 256 return model_params_mock - -def test_initialize_parameters(mocker, mock_data, mock_model_params, zks_prior): - # Create mock objects for model_params, training_params - # model_params_mock = mocker.MagicMock() - mock_training_params = mocker.Mock() - - # Mock internal methods called during initialization - mocker.patch( - "wf_psf.psf_models.models.psf_model_physical_polychromatic.get_zernike_prior", - return_value=zks_prior, - ) - - mocker.patch( - "wf_psf.data.data_handler.get_obs_positions", return_value=True - ) - - # Create TFPhysicalPolychromaticField instance - field_instance = TFPhysicalPolychromaticField( - mock_model_params, mock_training_params, mock_data - ) - - mocker.patch.object(field_instance, "_initialize_zernike_parameters") - mocker.patch.object(field_instance, "_initialize_layers") - mocker.patch.object(field_instance, "assign_coeff_matrix") - - # Call the method being tested - field_instance._initialize_parameters_and_layers( - mock_model_params, mock_training_params, mock_data - ) - - # Check if internal methods were called with the correct arguments - field_instance._initialize_zernike_parameters.assert_called_once_with( - mock_model_params, mock_data - ) - field_instance._initialize_layers.assert_called_once_with( - mock_model_params, mock_training_params - ) - field_instance.assign_coeff_matrix.assert_not_called() # Because coeff_mat is None in this test - - -def test_initialize_zernike_parameters(mocker, mock_model_params, mock_data, zks_prior): - # Create training params mock object - mock_training_params = mocker.Mock() - - # Mock internal methods called during initialization - mocker.patch( - "wf_psf.psf_models.models.psf_model_physical_polychromatic.get_zernike_prior", - return_value=zks_prior, - ) - - # Create TFPhysicalPolychromaticField instance - field_instance = TFPhysicalPolychromaticField( - mock_model_params, mock_training_params, mock_data - ) - - # Assert that the attributes are set correctly - # assert field_instance.n_zernikes == mock_model_params.param_hparams.n_zernikes - assert np.array_equal(field_instance.zks_prior.numpy(), zks_prior.numpy()) - assert field_instance.n_zks_total == mock_model_params.param_hparams.n_zernikes - assert isinstance( - field_instance.zernike_maps, tf.Tensor - ) # Check if the returned value is a TensorFlow tensor - assert ( - field_instance.zernike_maps.dtype == tf.float32 - ) # Check if the data type of the tensor is float32 - - # Expected shape of the tensor based on the input parameters - expected_shape = ( - field_instance.n_zks_total, - mock_model_params.pupil_diameter, - mock_model_params.pupil_diameter, - ) - assert field_instance.zernike_maps.shape == expected_shape - - # Modify model_params to simulate zks_prior > n_zernikes - mock_model_params.param_hparams.n_zernikes = 2 - - # Call the method again to initialize the parameters - field_instance._initialize_zernike_parameters(mock_model_params, mock_data) - - assert field_instance.n_zks_total == tf.cast( - tf.shape(field_instance.zks_prior)[1], tf.int32 - ) - # Expected shape of the tensor based on the input parameters - expected_shape = ( - field_instance.n_zks_total, - mock_model_params.pupil_diameter, - mock_model_params.pupil_diameter, - ) - assert field_instance.zernike_maps.shape == expected_shape - - -def test_initialize_physical_layer_mocking( - mocker, mock_model_params, mock_data, zks_prior -): - # Create training params mock object - mock_training_params = mocker.Mock() - - # Mock internal methods called during initialization - mocker.patch( - "wf_psf.psf_models.models.psf_model_physical_polychromatic.get_zernike_prior", - return_value=zks_prior, - ) - - # Create a mock for the TFPhysicalLayer class - mock_physical_layer_class = mocker.patch( - "wf_psf.psf_models.models.psf_model_physical_polychromatic.TFPhysicalLayer" - ) - - # Create TFPhysicalPolychromaticField instance - field_instance = TFPhysicalPolychromaticField( - mock_model_params, mock_training_params, mock_data - ) - - # Assert that the TFPhysicalLayer class was called with the expected arguments - mock_physical_layer_class.assert_called_once_with( - field_instance.obs_pos, - field_instance.zks_prior, - interpolation_type=mock_model_params.interpolation_type, - interpolation_args=mock_model_params.interpolation_args, - ) - - @pytest.fixture def physical_layer_instance(mocker, mock_model_params, mock_data, zks_prior): # Create training params mock object mock_training_params = mocker.Mock() - # Mock internal methods called during initialization - mocker.patch( - "wf_psf.psf_models.models.psf_model_physical_polychromatic.get_zernike_prior", - return_value=zks_prior, - ) - - # Create a mock for the TFPhysicalLayer class - mocker.patch( - "wf_psf.psf_models.models.psf_model_physical_polychromatic.TFPhysicalLayer" - ) - # Create TFPhysicalPolychromaticField instance psf_field_instance = TFPhysicalPolychromaticField( mock_model_params, mock_training_params, mock_data @@ -202,8 +70,8 @@ def test_pad_zernikes_num_of_zernikes_equal(physical_layer_instance): zk_prior = tf.reshape(zk_prior, (1, 2, 1, 1)) # Reshaping tensor2 to (1, 2, 1, 1) # Reset n_zks_total attribute - physical_layer_instance.n_zks_total = max( - tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy() + physical_layer_instance._n_zks_total = max( + tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy() ) # Call the method under test padded_zk_param, padded_zk_prior = physical_layer_instance.pad_zernikes( @@ -224,7 +92,7 @@ def test_pad_zernikes_prior_greater_than_param(physical_layer_instance): zk_prior = tf.reshape(zk_prior, (1, 5, 1, 1)) # Reshaping tensor2 to (1, 5, 1, 1) # Reset n_zks_total attribute - physical_layer_instance.n_zks_total = max( + physical_layer_instance._n_zks_total = max( tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy() ) @@ -247,7 +115,7 @@ def test_pad_zernikes_param_greater_than_prior(physical_layer_instance): zk_prior = tf.reshape(zk_prior, (1, 2, 1, 1)) # Reshaping tensor2 to (1, 4, 1, 1) # Reset n_zks_total attribute - physical_layer_instance.n_zks_total = max( + physical_layer_instance._n_zks_total = max( tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy() ) @@ -262,43 +130,41 @@ def test_pad_zernikes_param_greater_than_prior(physical_layer_instance): def test_compute_zernikes(mocker, physical_layer_instance): - # Mock padded tensors - padded_zk_param = tf.constant( - [[[[10]], [[20]], [[30]], [[40]]]] - ) # Shape: (1, 4, 1, 1) - padded_zk_prior = tf.constant([[[[1]], [[2]], [[0]], [[0]]]]) # Shape: (1, 4, 1, 1) - - # Reset n_zks_total attribute - physical_layer_instance.n_zks_total = 4 # Assuming a specific value for simplicity + # Expected output of mock components + padded_zernike_param = tf.constant([[[[10]], [[20]], [[30]], [[40]]]], dtype=tf.float32) + padded_zernike_prior = tf.constant([[[[1]], [[2]], [[0]], [[0]]]], dtype=tf.float32) + expected_values = tf.constant([[[[11]], [[22]], [[30]], [[40]]]], dtype=tf.float32) - # Define the mock return values for tf_poly_Z_field and tf_physical_layer.call - padded_zernike_param = tf.constant( - [[[[10]], [[20]], [[30]], [[40]]]] - ) # Shape: (1, 4, 1, 1) - padded_zernike_prior = tf.constant( - [[[[1]], [[2]], [[0]], [[0]]]] - ) # Shape: (1, 4, 1, 1) + # Patch tf_poly_Z_field property + mock_tf_poly_Z_field = mocker.Mock(return_value=padded_zernike_param) + mocker.patch.object( + TFPhysicalPolychromaticField, + 'tf_poly_Z_field', + new_callable=PropertyMock, + return_value=mock_tf_poly_Z_field + ) + # Patch tf_physical_layer property + mock_tf_physical_layer = mocker.Mock() + mock_tf_physical_layer.call.return_value = padded_zernike_prior mocker.patch.object( - physical_layer_instance, "tf_poly_Z_field", return_value=padded_zk_param + TFPhysicalPolychromaticField, + 'tf_physical_layer', + new_callable=PropertyMock, + return_value=mock_tf_physical_layer ) - mocker.patch.object(physical_layer_instance, "call", return_value=padded_zk_prior) + + # Patch pad_zernikes instance method directly (this one isn't a property) mocker.patch.object( physical_layer_instance, - "pad_zernikes", - return_value=(padded_zernike_param, padded_zernike_prior), + 'pad_zernikes', + return_value=(padded_zernike_param, padded_zernike_prior) ) - # Call the method under test - zernike_coeffs = physical_layer_instance.compute_zernikes(tf.constant([[0.0, 0.0]])) - # Define the expected values - expected_values = tf.constant( - [[[[11]], [[22]], [[30]], [[40]]]] - ) # Shape: (1, 4, 1, 1) - - # Assert that the shapes are equal - assert zernike_coeffs.shape == expected_values.shape + # Run the test + zernike_coeffs = physical_layer_instance.compute_zernikes(tf.constant([[0.0, 0.0]])) - # Assert that the tensor values are equal - assert tf.reduce_all(tf.equal(zernike_coeffs, expected_values)) + # Assertions + tf.debugging.assert_equal(zernike_coeffs, expected_values) + assert zernike_coeffs.shape == expected_values.shape \ No newline at end of file From 60dfe56fa479320d36313271edcab524373b911b Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Sun, 22 Jun 2025 00:08:51 +0200 Subject: [PATCH 070/135] Replace arg: data in compute_ccd_misalignment with positions --- src/wf_psf/instrument/ccd_misalignments.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/wf_psf/instrument/ccd_misalignments.py b/src/wf_psf/instrument/ccd_misalignments.py index 76386f20..f739c4fc 100644 --- a/src/wf_psf/instrument/ccd_misalignments.py +++ b/src/wf_psf/instrument/ccd_misalignments.py @@ -13,22 +13,23 @@ from wf_psf.data.data_handler import get_np_obs_positions -def compute_ccd_misalignment(model_params, data): +def compute_ccd_misalignment(model_params, positions: np.ndarray) -> np.ndarray: """Compute CCD misalignment. Parameters ---------- model_params : RecursiveNamespace Object containing parameters for this PSF model class. - data : DataConfigHandler - Object containing training and test datasets. + positions : np.ndarray + Numpy array containing the positions of the stars in the focal plane. + Shape: (n_stars, 2), where n_stars is the number of stars and 2 corresponds to x and y coordinates. Returns ------- zernike_ccd_misalignment_array : np.ndarray Numpy array containing the Zernike contributions to model the CCD chip misalignments. """ - obs_positions = get_np_obs_positions(data) + obs_positions = positions ccd_misalignment_calculator = CCDMisalignmentCalculator( tiles_path=model_params.ccd_misalignments_input_path, From 29f9fc59ebd169346b9891f6c666211965a38857 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Sun, 22 Jun 2025 00:15:06 +0200 Subject: [PATCH 071/135] Correct object attributes for DataConfigHandler in ZernikeInputsFactory --- src/wf_psf/data/data_zernike_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/wf_psf/data/data_zernike_utils.py b/src/wf_psf/data/data_zernike_utils.py index b03ff400..e50381c3 100644 --- a/src/wf_psf/data/data_zernike_utils.py +++ b/src/wf_psf/data/data_zernike_utils.py @@ -57,8 +57,8 @@ def build(data, run_type: str, model_params, prior: Optional[np.ndarray] = None) centroid_dataset = data # Assuming data is a DataConfigHandler or similar object containing train and test datasets positions = np.concatenate( [ - data.training_dataset["positions"], - data.test_dataset["positions"] + data.training_data.dataset["positions"], + data.test_data.dataset["positions"] ], axis=0, ) From a943dbe2d35c677d682c2e089436c1d12e222bd7 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Sun, 22 Jun 2025 00:29:06 +0200 Subject: [PATCH 072/135] Add missing return for tf_physical_layer property --- src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py index 971dbb63..7f19171e 100644 --- a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py +++ b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py @@ -218,6 +218,7 @@ def tf_physical_layer(self): interpolation_type=self.model_params.interpolation_type, interpolation_args=self.model_params.interpolation_args, ) + return self._tf_physical_layer @property def tf_zernike_OPD(self): From e15444f603653a6d7386cfd38f3ff89579fefa60 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Sun, 22 Jun 2025 18:02:17 +0200 Subject: [PATCH 073/135] Add tf_utils.py module to tf_modules subpackage --- src/wf_psf/psf_models/tf_modules/tf_utils.py | 45 ++++++++++++++++++++ 1 file changed, 45 insertions(+) create mode 100644 src/wf_psf/psf_models/tf_modules/tf_utils.py diff --git a/src/wf_psf/psf_models/tf_modules/tf_utils.py b/src/wf_psf/psf_models/tf_modules/tf_utils.py new file mode 100644 index 00000000..a4795f89 --- /dev/null +++ b/src/wf_psf/psf_models/tf_modules/tf_utils.py @@ -0,0 +1,45 @@ +"""TensorFlow Utilities Module. + +Provides lightweight utility functions for safely converting and managing data types +within TensorFlow-based workflows. + +Includes: +- `ensure_tensor`: ensures inputs are TensorFlow tensors with specified dtype + +These tools are designed to support PSF model components, including lazy property evaluation, +data input validation, and type normalization. + +This module is intended for internal use in model layers and inference components to enforce +TensorFlow-compatible inputs. + +Authors: Jennifer Pollack +""" + +import tensorflow as tf +import numpy as np + +def ensure_tensor(input_array, dtype=tf.float32): + """ + Ensure the input is a TensorFlow tensor of the specified dtype. + + Parameters + ---------- + input_array : array-like, tf.Tensor, or np.ndarray + The input to convert. + dtype : tf.DType, optional + The desired TensorFlow dtype (default: tf.float32). + + Returns + ------- + tf.Tensor + A TensorFlow tensor with the specified dtype. + """ + if tf.is_tensor(input_array): + # If already a tensor, optionally cast dtype if different + if input_array.dtype != dtype: + return tf.cast(input_array, dtype) + return input_array + else: + # Convert numpy arrays or other types to tensor + return tf.convert_to_tensor(input_array, dtype=dtype) + From 00b07f768bc0a08619d9cbbc8dc9fe40dda90e9d Mon Sep 17 00:00:00 2001 From: jeipollack Date: Sun, 22 Jun 2025 18:03:50 +0200 Subject: [PATCH 074/135] Use ensure_tensor method from tf_utils.py to check/convert to tensorflow type; Remove get_obs_positions and replace with ensure_tensor method; add property tf_positions --- src/wf_psf/data/data_handler.py | 26 +++++++------------------- 1 file changed, 7 insertions(+), 19 deletions(-) diff --git a/src/wf_psf/data/data_handler.py b/src/wf_psf/data/data_handler.py index b25881a2..2419b730 100644 --- a/src/wf_psf/data/data_handler.py +++ b/src/wf_psf/data/data_handler.py @@ -15,6 +15,7 @@ import os import numpy as np import wf_psf.utils.utils as utils +from wf_psf.psf_models.tf_modules.tf_utils import ensure_tensor import tensorflow as tf from fractions import Fraction from typing import Optional, Union @@ -137,6 +138,10 @@ def __init__( self.dataset = None self.sed_data = None + @property + def tf_positions(self): + return ensure_tensor(self.dataset["positions"]) + def load_dataset(self): """Load dataset. @@ -236,7 +241,8 @@ def process_sed_data(self, sed_data): ) for _sed in sed_data ] - self.sed_data = tf.convert_to_tensor(self.sed_data, dtype=tf.float32) + # Convert list of generated SED tensors to a single TensorFlow tensor of float32 dtype + self.sed_data = ensure_tensor(self.sed_data, dtype=tf.float32) self.sed_data = tf.transpose(self.sed_data, perm=[0, 2, 1]) @@ -272,24 +278,6 @@ def get_np_obs_positions(data): return obs_positions -def get_obs_positions(data): - """Get observed positions from the provided dataset. - - Parameters - ---------- - data : DataConfigHandler - Object containing training and test datasets. - - Returns - ------- - tf.Tensor - Tensor containing the observed positions of the stars. - """ - obs_positions = get_np_obs_positions(data) - - return tf.convert_to_tensor(obs_positions, dtype=tf.float32) - - def extract_star_data(data, train_key: str, test_key: str) -> np.ndarray: """Extract specific star-related data from training and test datasets. From 3446280a3938019a0aa30b363210febd2ee7a751 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Sun, 22 Jun 2025 18:08:05 +0200 Subject: [PATCH 075/135] Refactor: Add eager-mode helpers and avoid lazy-loading obscurations in graph mode --- .../psf_model_physical_polychromatic.py | 76 +++++++++++-------- 1 file changed, 46 insertions(+), 30 deletions(-) diff --git a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py index 7f19171e..0c5d3d06 100644 --- a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py +++ b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py @@ -10,7 +10,7 @@ from typing import Optional import tensorflow as tf from tensorflow.python.keras.engine import data_adapter -from wf_psf.data.data_handler import get_obs_positions +from wf_psf.data.data_handler import get_np_obs_positions from wf_psf.data.data_zernike_utils import ZernikeInputsFactory, assemble_zernike_contributions from wf_psf.psf_models import psf_models as psfm from wf_psf.psf_models.tf_modules.tf_layers import ( @@ -21,6 +21,7 @@ TFNonParametricPolynomialVariationsOPD, TFPhysicalLayer, ) +from wf_psf.psf_models.tf_modules.tf_utils import ensure_tensor from wf_psf.utils.read_config import RecursiveNamespace from wf_psf.utils.configs_handler import DataConfigHandler import logging @@ -112,7 +113,8 @@ def __init__(self, model_params, training_params, data, coeff_mat=None): self.model_params = model_params self.training_params = training_params self.data = data - self.run_type = data.run_type + self.run_type = self._get_run_type(data) + self.obs_pos = self.get_obs_pos() # Initialize the model parameters and layers self.output_Q = model_params.output_Q @@ -126,6 +128,22 @@ def __init__(self, model_params, training_params, data, coeff_mat=None): if coeff_mat is not None: self.assign_coeff_matrix(coeff_mat) + # Eagerly initialise tf_batch_poly_PSF + self.tf_batch_poly_PSF = self._build_tf_batch_poly_PSF() + + + def _get_run_type(self, data): + if hasattr(data, 'run_type'): + run_type = data.run_type + elif isinstance(data, dict) and 'run_type' in data: + run_type = data['run_type'] + else: + raise ValueError("data must have a 'run_type' attribute or key") + + if run_type not in {"training", "simulation", "inference"}: + raise ValueError(f"Unknown run_type: {run_type}") + return run_type + def _assemble_zernike_contributions(self): zks_inputs = ZernikeInputsFactory.build( data=self.data, @@ -150,21 +168,20 @@ def save_param_history(self) -> bool: def save_nonparam_history(self) -> bool: """Check if the model should save the optimization history for non-parametric features.""" return getattr(self.model_params.nonparam_hparams, "save_optim_history_nonparam", False) - - # === Lazy properties ===. - @property - def obs_pos(self): - """Lazy loading of the observation positions.""" - if not hasattr(self, "_obs_pos"): - if self.run_type == "training" or self.run_type == "simulation": - # Get the observation positions from the data handler - self._obs_pos = get_obs_positions(self.data) - elif self.run_type == "inference": - # For inference, we might not have a data handler, so we use the model parameters - self._obs_pos = self.data.dataset["positions"] - return self._obs_pos + def get_obs_pos(self): + assert self.run_type in {"training", "simulation", "inference"}, f"Unknown run_type: {self.run_type}" + + if self.run_type in {"training", "simulation"}: + raw_pos = get_np_obs_positions(self.data) + else: + raw_pos = self.data.dataset["positions"] + obs_pos = ensure_tensor(raw_pos, dtype=tf.float32) + + return obs_pos + + # === Lazy properties ===. @property def zks_total_contribution(self): """Lazily load all Zernike contributions, including prior and corrections.""" @@ -227,22 +244,20 @@ def tf_zernike_OPD(self): self._tf_zernike_OPD = TFZernikeOPD(zernike_maps=self.zernike_maps) return self._tf_zernike_OPD - @property - def tf_batch_poly_PSF(self): - """Lazily initialize the batch polychromatic PSF layer.""" - if not hasattr(self, "_tf_batch_poly_PSF"): - obscurations = psfm.tf_obscurations( + def _build_tf_batch_poly_PSF(self): + """Eagerly build the TFBatchPolychromaticPSF layer with numpy-based obscurations.""" + + obscurations = psfm.tf_obscurations( pupil_diam=self.model_params.pupil_diameter, N_filter=self.model_params.LP_filter_length, rotation_angle=self.model_params.obscuration_rotation_angle, ) - self._tf_batch_poly_PSF = TFBatchPolychromaticPSF( + return TFBatchPolychromaticPSF( obscurations=obscurations, output_Q=self.output_Q, output_dim=self.output_dim, ) - return self._tf_batch_poly_PSF @property def tf_np_poly_opd(self): @@ -646,23 +661,24 @@ def call(self, inputs, training=True): if training: # Compute zernikes from parametric model and physical layer zks_coeffs = self.compute_zernikes(input_positions) - - # Propagate to obtain the OPD + + # Parametric OPD maps from Zernikes param_opd_maps = self.tf_zernike_OPD(zks_coeffs) - # Add l2 loss on the parametric OPD + # Add L2 regularization loss on parametric OPD maps self.add_loss( - self.l2_param * tf.math.reduce_sum(tf.math.square(param_opd_maps)) + self.l2_param * tf.reduce_sum(tf.square(param_opd_maps)) ) - # Calculate the non parametric part + # Non-parametric correction nonparam_opd_maps = self.tf_np_poly_opd(input_positions) - # Add the estimations - opd_maps = tf.math.add(param_opd_maps, nonparam_opd_maps) - + # Combine both contributions + opd_maps = tf.add(param_opd_maps, nonparam_opd_maps) + # Compute the polychromatic PSFs poly_psfs = self.tf_batch_poly_PSF([opd_maps, packed_SEDs]) + # For the inference else: # Compute predictions From 278479f712a1202be077207265fbfcb8df9b1318 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Sun, 22 Jun 2025 18:09:01 +0200 Subject: [PATCH 076/135] Replace deprecated get_obs_positions with get_np_obs_positions and apply ensure_tensor to convert obs_pos to tensorflow float32 --- src/wf_psf/psf_models/tf_modules/tf_psf_field.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/wf_psf/psf_models/tf_modules/tf_psf_field.py b/src/wf_psf/psf_models/tf_modules/tf_psf_field.py index 23a19b8b..b2d1b53e 100644 --- a/src/wf_psf/psf_models/tf_modules/tf_psf_field.py +++ b/src/wf_psf/psf_models/tf_modules/tf_psf_field.py @@ -16,8 +16,9 @@ TFPhysicalLayer, ) from wf_psf.psf_models.models.psf_model_semiparametric import TFSemiParametricField -from wf_psf.data.data_handler import get_obs_positions +from wf_psf.data.data_handler import get_np_obs_positions from wf_psf.psf_models import psf_models as psfm +from wf_psf.psf_models.tf_modules.tf_utils import ensure_tensor import logging logger = logging.getLogger(__name__) @@ -221,7 +222,7 @@ def __init__(self, model_params, training_params, data, coeff_mat): self.output_Q = model_params.output_Q # Inputs: TF_physical_layer - self.obs_pos = get_obs_positions(data) + self.obs_pos = ensure_tensor(get_np_obs_positions(data), dtype=tf.float32) self.zks_prior = get_ground_truth_zernike(data) self.n_zks_prior = tf.shape(self.zks_prior)[1].numpy() From 0eacd98f495a7a8ab1756472d20668a4dcacc322 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Mon, 23 Jun 2025 18:06:04 +0200 Subject: [PATCH 077/135] Remove tf.convert_to_tensor from all Zernike list contributors --- src/wf_psf/data/data_zernike_utils.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/src/wf_psf/data/data_zernike_utils.py b/src/wf_psf/data/data_zernike_utils.py index e50381c3..309732d8 100644 --- a/src/wf_psf/data/data_zernike_utils.py +++ b/src/wf_psf/data/data_zernike_utils.py @@ -53,7 +53,7 @@ def build(data, run_type: str, model_params, prior: Optional[np.ndarray] = None) centroid_dataset = None positions = None - if run_type in {"training", "simulation"}: + if run_type in {"training", "simulation", "metrics"}: centroid_dataset = data # Assuming data is a DataConfigHandler or similar object containing train and test datasets positions = np.concatenate( [ @@ -178,9 +178,9 @@ def assemble_zernike_contributions( # Prior if model_params.use_prior and zernike_prior is not None: logger.info("Adding Zernike prior...") - if isinstance(zernike_prior, np.ndarray): - zernike_prior = tf.convert_to_tensor(zernike_prior, dtype=tf.float32) - zernike_contribution_list.append(zernike_prior) + if isinstance(zernike_prior, tf.Tensor): + zernike_prior = zernike_prior.numpy() + zernike_contribution_list.append(zernike_prior) else: logger.info("Skipping Zernike prior (not used or not provided).") @@ -190,9 +190,7 @@ def assemble_zernike_contributions( centroid_correction = compute_centroid_correction( model_params, centroid_dataset, batch_size=batch_size ) - zernike_contribution_list.append( - tf.convert_to_tensor(centroid_correction, dtype=tf.float32) - ) + zernike_contribution_list.append(centroid_correction) else: logger.info("Skipping centroid correction (not enabled or no dataset).") @@ -200,9 +198,7 @@ def assemble_zernike_contributions( if model_params.add_ccd_misalignments and positions is not None: logger.info("Computing CCD misalignment correction...") ccd_misalignment = compute_ccd_misalignment(model_params, positions) - zernike_contribution_list.append( - tf.convert_to_tensor(ccd_misalignment, dtype=tf.float32) - ) + zernike_contribution_list.append(ccd_misalignment) else: logger.info("Skipping CCD misalignment correction (not enabled or no positions).") From 98fde06e4d559750fa7466cc51969d9b72809030 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Mon, 23 Jun 2025 18:07:20 +0200 Subject: [PATCH 078/135] Add and set self.data_conf.run_type value to 'metrics' in MetricsConfigHandler --- src/wf_psf/utils/configs_handler.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/wf_psf/utils/configs_handler.py b/src/wf_psf/utils/configs_handler.py index f59fdc17..13ceb6de 100644 --- a/src/wf_psf/utils/configs_handler.py +++ b/src/wf_psf/utils/configs_handler.py @@ -134,7 +134,7 @@ def __init__(self, data_conf, training_model_params, batch_size=16, load_data=Tr exit() self.simPSF = psf_models.simPSF(training_model_params) - + # Extract sub-configs early train_params = self.data_conf.data.training test_params = self.data_conf.data.test @@ -153,7 +153,7 @@ def __init__(self, data_conf, training_model_params, batch_size=16, load_data=Tr n_bins_lambda=training_model_params.n_bins_lda, load_data=load_data, ) - + self.batch_size = batch_size @@ -262,6 +262,7 @@ def __init__(self, metrics_conf, file_handler, training_conf=None): self._file_handler = file_handler self.training_conf = training_conf self.data_conf = self._load_data_conf() + self.data_conf.run_type = "metrics" self.metrics_dir = self._file_handler.get_metrics_dir( self._file_handler._run_output_dir ) From 5c3b585e13929e7dd9b62d890329ca07788ea360 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Mon, 23 Jun 2025 18:19:32 +0200 Subject: [PATCH 079/135] Eagerly precompute Zernike components; add support for 'metrics' run_type - Precomputed `zks_total_contribution` outside the TensorFlow graph and converted it to tf.float32. - Calculated `n_zks_total` from the contribution shape and param config. - Eagerly generated Zernike maps and stored as tf.float32. - Derived OPD dimension (`opd_dim`) from Zernike map shape. - Generated obscurations via `tf_obscurations` and stored as tf.complex64. - Added 'metrics' as a valid `run_type` alongside 'training', 'simulation', and 'inference'. - Adjusted `get_obs_pos()` logic to treat 'metrics' like 'training' and 'simulation' for position loading. These changes avoid runtime `.numpy()` calls inside `@tf.function` contexts, improve robustness across run modes, and ensure compatibility with training and evaluation pipelines. --- .../psf_model_physical_polychromatic.py | 68 ++++++++++++------- 1 file changed, 42 insertions(+), 26 deletions(-) diff --git a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py index 0c5d3d06..08c5ce71 100644 --- a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py +++ b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py @@ -128,6 +128,32 @@ def __init__(self, model_params, training_params, data, coeff_mat=None): if coeff_mat is not None: self.assign_coeff_matrix(coeff_mat) + # Compute contributions once eagerly (outside graph) + zks_total_contribution_np = self._assemble_zernike_contributions().numpy() + self._zks_total_contribution = tf.convert_to_tensor(zks_total_contribution_np, dtype=tf.float32) + + # Compute n_zks_total as int + self._n_zks_total = max( + self.model_params.param_hparams.n_zernikes, + zks_total_contribution_np.shape[1] + ) + + # Precompute zernike maps as tf.float32 + self._zernike_maps = psfm.generate_zernike_maps_3d( + n_zernikes=self._n_zks_total, + pupil_diam=self.model_params.pupil_diameter + ) + + # Precompute OPD dimension + self._opd_dim = self._zernike_maps.shape[1] + + # Precompute obscurations as tf.complex64 + self._obscurations = psfm.tf_obscurations( + pupil_diam=self.model_params.pupil_diameter, + N_filter=self.model_params.LP_filter_length, + rotation_angle=self.model_params.obscuration_rotation_angle, + ) + # Eagerly initialise tf_batch_poly_PSF self.tf_batch_poly_PSF = self._build_tf_batch_poly_PSF() @@ -140,7 +166,7 @@ def _get_run_type(self, data): else: raise ValueError("data must have a 'run_type' attribute or key") - if run_type not in {"training", "simulation", "inference"}: + if run_type not in {"training", "simulation", "metrics", "inference"}: raise ValueError(f"Unknown run_type: {run_type}") return run_type @@ -170,9 +196,9 @@ def save_nonparam_history(self) -> bool: return getattr(self.model_params.nonparam_hparams, "save_optim_history_nonparam", False) def get_obs_pos(self): - assert self.run_type in {"training", "simulation", "inference"}, f"Unknown run_type: {self.run_type}" + assert self.run_type in {"training", "simulation", "metrics", "inference"}, f"Unknown run_type: {self.run_type}" - if self.run_type in {"training", "simulation"}: + if self.run_type in {"training", "simulation", "metrics"}: raw_pos = get_np_obs_positions(self.data) else: raw_pos = self.data.dataset["positions"] @@ -184,30 +210,26 @@ def get_obs_pos(self): # === Lazy properties ===. @property def zks_total_contribution(self): - """Lazily load all Zernike contributions, including prior and corrections.""" - if not hasattr(self, "_zks_total_contribution"): - self._zks_total_contribution = self._assemble_zernike_contributions() return self._zks_total_contribution - + @property def n_zks_total(self): """Get the total number of Zernike coefficients.""" - if not hasattr(self, "_n_zks_total"): - self._n_zks_total = max( - self.model_params.param_hparams.n_zernikes, - tf.cast(tf.shape(self.zks_total_contribution)[1], tf.int32), - ) return self._n_zks_total @property def zernike_maps(self): - """Lazy loading of the Zernike maps.""" - if not hasattr(self, "_zernike_maps"): - self._zernike_maps = psfm.generate_zernike_maps_3d( - self.n_zks_total, self.model_params.pupil_diameter - ) + """Get Zernike maps.""" return self._zernike_maps - + + @property + def opd_dim(self): + return self._opd_dim + + @property + def obscurations(self): + return self._obscurations + @property def tf_poly_Z_field(self): """Lazy loading of the polynomial Zernike field layer.""" @@ -246,15 +268,9 @@ def tf_zernike_OPD(self): def _build_tf_batch_poly_PSF(self): """Eagerly build the TFBatchPolychromaticPSF layer with numpy-based obscurations.""" - - obscurations = psfm.tf_obscurations( - pupil_diam=self.model_params.pupil_diameter, - N_filter=self.model_params.LP_filter_length, - rotation_angle=self.model_params.obscuration_rotation_angle, - ) return TFBatchPolychromaticPSF( - obscurations=obscurations, + obscurations=self.obscurations, output_Q=self.output_Q, output_dim=self.output_dim, ) @@ -267,7 +283,7 @@ def tf_np_poly_opd(self): x_lims=self.model_params.x_lims, y_lims=self.model_params.y_lims, d_max=self.model_params.nonparam_hparams.d_max_nonparam, - opd_dim=tf.shape(self.zernike_maps)[1].numpy(), + opd_dim=self.opd_dim, ) return self._tf_np_poly_opd From e3aea260dd09c43fe319ff08e8b0a80c00f89c4f Mon Sep 17 00:00:00 2001 From: jeipollack Date: Wed, 25 Jun 2025 11:13:33 +0200 Subject: [PATCH 080/135] Correct value error: train in dataset_type with training --- src/wf_psf/data/data_handler.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/src/wf_psf/data/data_handler.py b/src/wf_psf/data/data_handler.py index 2419b730..40611c71 100644 --- a/src/wf_psf/data/data_handler.py +++ b/src/wf_psf/data/data_handler.py @@ -184,17 +184,13 @@ def _validate_dataset_structure(self): def _convert_dataset_to_tensorflow(self): """Convert dataset to TensorFlow tensors.""" - self.dataset["positions"] = tf.convert_to_tensor( - self.dataset["positions"], dtype=tf.float32 - ) - if self.dataset_type == "training": - self.dataset["noisy_stars"] = tf.convert_to_tensor( - self.dataset["noisy_stars"], dtype=tf.float32 - ) - + self.dataset["possitions"] = ensure_tensor(self.dataset["positions"], dtype=tf.float32) + + if self.dataset_type == "train": + self.dataset["noisy_stars"] = ensure_tensor(self.dataset["noisy_stars"], dtype=tf.float32) elif self.dataset_type == "test": if "stars" in self.dataset: - self.dataset["stars"] = tf.convert_to_tensor( + self.dataset["stars"] = ensure_tensor( self.dataset["stars"], dtype=tf.float32 ) else: From ee0f8e03d51da98db8c9da7aec1cb19cf579c954 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Wed, 25 Jun 2025 22:17:36 +0200 Subject: [PATCH 081/135] fix: pass random seed to TFNonParametricPolynomialVariationsOPD constructor in psf_model_physical_polychromatic.py --- src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py index 08c5ce71..9979bcf1 100644 --- a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py +++ b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py @@ -282,6 +282,7 @@ def tf_np_poly_opd(self): self._tf_np_poly_opd = TFNonParametricPolynomialVariationsOPD( x_lims=self.model_params.x_lims, y_lims=self.model_params.y_lims, + random_seed=self.model_params.param_hparams.random_seed, d_max=self.model_params.nonparam_hparams.d_max_nonparam, opd_dim=self.opd_dim, ) From 776c019799f46789958f2f791fbb01ecf4533bed Mon Sep 17 00:00:00 2001 From: jeipollack Date: Thu, 26 Jun 2025 14:09:58 +0200 Subject: [PATCH 082/135] Refactor to suppress TensorFlow debug msgs: replace lambda in call method with a proper function: find_position_indices enable batch processing for better graph optimization --- src/wf_psf/psf_models/tf_modules/tf_layers.py | 13 ++--- src/wf_psf/psf_models/tf_modules/tf_utils.py | 55 +++++++++++++++++++ 2 files changed, 61 insertions(+), 7 deletions(-) diff --git a/src/wf_psf/psf_models/tf_modules/tf_layers.py b/src/wf_psf/psf_models/tf_modules/tf_layers.py index fcb5e8f9..5bb5e2df 100644 --- a/src/wf_psf/psf_models/tf_modules/tf_layers.py +++ b/src/wf_psf/psf_models/tf_modules/tf_layers.py @@ -9,6 +9,7 @@ import tensorflow as tf import tensorflow_addons as tfa from wf_psf.psf_models.tf_modules.tf_modules import TFMonochromaticPSF +from wf_psf.psf_models.tf_modules.tf_utils import find_position_indices from wf_psf.utils.utils import calc_poly_position_mat import wf_psf.utils.utils as utils from wf_psf.utils.interpolation import tfa_interpolate_spline_rbf @@ -16,7 +17,6 @@ logger = logging.getLogger(__name__) - class TFPolynomialZernikeField(tf.keras.layers.Layer): """Calculate the zernike coefficients for a given position. @@ -964,6 +964,7 @@ def interpolate_independent_Zk(self, positions): return interp_zks[:, :, tf.newaxis, tf.newaxis] + def call(self, positions): """Calculate the prior Zernike coefficients for a batch of positions. @@ -999,12 +1000,10 @@ def call(self, positions): """ - def calc_index(idx_pos): - return tf.where(tf.equal(self.obs_pos, idx_pos))[0, 0] + # Find indices for all positions in one batch operation + idx = find_position_indices(self.obs_pos, positions) - # Calculate the indices of the input batch - indices = tf.map_fn(calc_index, positions, fn_output_signature=tf.int64) - # Recover the prior zernikes from the batch indexes - batch_zks = tf.gather(self.zks_prior, indices=indices, axis=0, batch_dims=0) + # Gather the corresponding Zernike coefficients + batch_zks = tf.gather(self.zks_prior, idx, axis=0) return batch_zks[:, :, tf.newaxis, tf.newaxis] diff --git a/src/wf_psf/psf_models/tf_modules/tf_utils.py b/src/wf_psf/psf_models/tf_modules/tf_utils.py index a4795f89..d0a2002c 100644 --- a/src/wf_psf/psf_models/tf_modules/tf_utils.py +++ b/src/wf_psf/psf_models/tf_modules/tf_utils.py @@ -18,6 +18,61 @@ import tensorflow as tf import numpy as np + +@tf.function +def find_position_indices(obs_pos, batch_positions): + """Find indices of batch positions within observed positions using vectorized operations. + + This function locates the indices of multiple query positions within a + reference set of observed positions using broadcasting and vectorized operations. + Each position in the batch must have an exact match in the observed positions. + + Parameters + ---------- + obs_pos : tf.Tensor + Reference positions tensor of shape (n_obs, 2), where n_obs is the number of + observed positions. Each row contains [x, y] coordinates. + batch_positions : tf.Tensor + Query positions tensor of shape (batch_size, 2), where batch_size is the number + of positions to look up. Each row contains [x, y] coordinates. + + Returns + ------- + indices : tf.Tensor + Tensor of shape (batch_size,) containing the indices of each batch position + within obs_pos. The dtype is tf.int64. + + Raises + ------ + tf.errors.InvalidArgumentError + If any position in batch_positions is not found in obs_pos. + + Notes + ----- + Uses exact equality matching - positions must match exactly. More efficient than + iterative lookups for multiple positions due to vectorized operations. + """ + # Shape: obs_pos (n_obs, 2), batch_positions (batch_size, 2) + # Expand for broadcasting: (1, n_obs, 2) and (batch_size, 1, 2) + obs_expanded = tf.expand_dims(obs_pos, 0) + pos_expanded = tf.expand_dims(batch_positions, 1) + + # Compare all positions at once: (batch_size, n_obs) + matches = tf.reduce_all(tf.equal(obs_expanded, pos_expanded), axis=2) + + # Find the index of the matching position for each batch item + # argmax returns the first True value's index along axis=1 + indices = tf.argmax(tf.cast(matches, tf.int32), axis=1) + + # Verify all positions were found + tf.debugging.assert_equal( + tf.reduce_all(tf.reduce_any(matches, axis=1)), + True, + message="Some positions not found in obs_pos" + ) + + return indices + def ensure_tensor(input_array, dtype=tf.float32): """ Ensure the input is a TensorFlow tensor of the specified dtype. From ab8856e8c3d58f9952c5e41e3f05c7eaffbebffc Mon Sep 17 00:00:00 2001 From: jeipollack Date: Thu, 26 Jun 2025 17:24:28 +0200 Subject: [PATCH 083/135] Match old behaviour with conditional and float64 accumulation --- src/wf_psf/data/data_zernike_utils.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/wf_psf/data/data_zernike_utils.py b/src/wf_psf/data/data_zernike_utils.py index 309732d8..6ef2c93c 100644 --- a/src/wf_psf/data/data_zernike_utils.py +++ b/src/wf_psf/data/data_zernike_utils.py @@ -131,12 +131,17 @@ def combine_zernike_contributions(contributions: list[np.ndarray]) -> np.ndarray if not contributions: raise ValueError("No contributions provided.") + if len(contributions) == 1: + return contributions[0] + max_order = max(contrib.shape[1] for contrib in contributions) n_samples = contributions[0].shape[0] + if any(c.shape[0] != n_samples for c in contributions): raise ValueError("All contributions must have the same number of samples.") - combined = np.zeros((n_samples, max_order), dtype=np.float32) + combined = np.zeros((n_samples, max_order)) + for contrib in contributions: padded = pad_contribution_to_order(contrib, max_order) combined += padded From ea854776db88f36b9060bb0dceeb716038b9ba05 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Tue, 8 Jul 2025 18:26:02 +0200 Subject: [PATCH 084/135] Add helper to stack x/y field coordinates into (N, 2) positions array - Introduced get_positions() to convert x_field and y_field into a stacked (N, 2) array. - Updated data handler to pass positions and sed_data explicitly. - Includes validation for shape mismatches and None inputs. --- src/wf_psf/inference/psf_inference.py | 36 +++++++++++++++++++++++++-- 1 file changed, 34 insertions(+), 2 deletions(-) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index 5be49343..f3e876c3 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -160,7 +160,8 @@ def data_handler(self): simPSF=self.simPSF, n_bins_lambda=self.n_bins_lambda, load_data=False, - dataset=None, + dataset={"positions": self.get_positions()}, + sed_data = self.seds, ) self._data_handler.run_type = "inference" return self._data_handler @@ -171,6 +172,37 @@ def trained_psf_model(self): self._trained_psf_model = self.load_inference_model() return self._trained_psf_model + def get_positions(self): + """ + Combine x_field and y_field into position pairs. + + Returns + ------- + numpy.ndarray + Array of shape (num_positions, 2) where each row contains [x, y] coordinates. + Returns None if either x_field or y_field is None. + + Raises + ------ + ValueError + If x_field and y_field have different lengths. + """ + if self.x_field is None or self.y_field is None: + return None + + x_arr = np.asarray(self.x_field) + y_arr = np.asarray(self.y_field) + + if x_arr.size != y_arr.size: + raise ValueError(f"x_field and y_field must have the same length. " + f"Got {x_arr.size} and {y_arr.size}") + + # Flatten arrays to handle any input shape, then stack + x_flat = x_arr.flatten() + y_flat = y_arr.flatten() + + return np.column_stack((x_flat, y_flat)) + def load_inference_model(self): """Load the trained PSF model based on the inference configuration.""" model_path = self.config_handler.trained_model_path @@ -187,7 +219,7 @@ def load_inference_model(self): # Load the trained PSF model return load_trained_psf_model( self.training_config, - self.data_config, + self.data_handler, weights_path_pattern, ) From 7b24c45a3b13a4eb89c5da65d0302cd92c10d07d Mon Sep 17 00:00:00 2001 From: jeipollack Date: Thu, 10 Jul 2025 10:35:39 +0200 Subject: [PATCH 085/135] Add helper method to prepare dataset for inference & handle empty/None fields --- src/wf_psf/inference/psf_inference.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index f3e876c3..6bba2282 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -150,9 +150,17 @@ def simPSF(self): self._simPSF = psf_models.simPSF(self.training_config.training.model_params) return self._simPSF + + def _prepare_dataset_for_inference(self): + """Prepare dataset dictionary for inference, returning None if positions are invalid.""" + positions = self.get_positions() + if positions is None: + return None + return {"positions": positions} + @property def data_handler(self): - if self._data_handler is None: + if self._data_handler is None: # Instantiate the data handler self._data_handler = DataHandler( dataset_type="inference", @@ -160,7 +168,7 @@ def data_handler(self): simPSF=self.simPSF, n_bins_lambda=self.n_bins_lambda, load_data=False, - dataset={"positions": self.get_positions()}, + dataset=self._prepare_dataset_for_inference(), sed_data = self.seds, ) self._data_handler.run_type = "inference" @@ -193,6 +201,9 @@ def get_positions(self): x_arr = np.asarray(self.x_field) y_arr = np.asarray(self.y_field) + if x_arr.size == 0 or y_arr.size == 0: + return None + if x_arr.size != y_arr.size: raise ValueError(f"x_field and y_field must have the same length. " f"Got {x_arr.size} and {y_arr.size}") From aeeafa1bfea1a22bbe4360e9e76d36c19f78fe34 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Tue, 22 Jul 2025 10:09:13 +0200 Subject: [PATCH 086/135] Update data_handler_test replacing "get_obs_positions" (deprecation) with "get_np_obs_positions" --- src/wf_psf/tests/test_data/data_handler_test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/wf_psf/tests/test_data/data_handler_test.py b/src/wf_psf/tests/test_data/data_handler_test.py index 6545964a..fdf5a436 100644 --- a/src/wf_psf/tests/test_data/data_handler_test.py +++ b/src/wf_psf/tests/test_data/data_handler_test.py @@ -3,7 +3,7 @@ import tensorflow as tf from wf_psf.data.data_handler import ( DataHandler, - get_obs_positions, + get_np_obs_positions, extract_star_data, ) from wf_psf.utils.read_config import RecursiveNamespace @@ -166,8 +166,8 @@ def test_load_test_dataset_missing_stars(tmp_path, simPSF): mock_warning.assert_called_with("Missing 'noisy_stars' in 'train' dataset.") -def test_get_obs_positions(mock_data): - observed_positions = get_obs_positions(mock_data) +def test_get_np_obs_positions(mock_data): + observed_positions = get_np_obs_positions(mock_data) expected_positions = np.array([[1, 2], [3, 4], [5, 6], [7, 8]]) assert tf.reduce_all(tf.equal(observed_positions, expected_positions)) From 66b6e6b3a164b14e4c60625273eb76dca0116545 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Wed, 6 Aug 2025 17:58:48 +0200 Subject: [PATCH 087/135] Remove deprecated code from rebase --- src/wf_psf/tests/test_data/data_handler_test.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/src/wf_psf/tests/test_data/data_handler_test.py b/src/wf_psf/tests/test_data/data_handler_test.py index fdf5a436..c1310d21 100644 --- a/src/wf_psf/tests/test_data/data_handler_test.py +++ b/src/wf_psf/tests/test_data/data_handler_test.py @@ -150,21 +150,6 @@ def test_load_test_dataset_missing_stars(tmp_path, simPSF): data_handler.load_dataset() data_handler.validate_and_process_dataset() - data_handler = DataHandler( - dataset_type="train", - data_params=data_params, - simPSF=simPSF, - n_bins_lambda=10, - load_data=False - ) - - data_handler.load_dataset() - data_handler.process_sed_data(mock_dataset["SEDs"]) - - with patch("wf_psf.data.data_handler.logger.warning") as mock_warning: - data_handler._validate_dataset_structure() - mock_warning.assert_called_with("Missing 'noisy_stars' in 'train' dataset.") - def test_get_np_obs_positions(mock_data): observed_positions = get_np_obs_positions(mock_data) From 29970095a01672fcfa92cf8dbe7451a723aadd2a Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Wed, 6 Aug 2025 17:59:32 +0200 Subject: [PATCH 088/135] Remove duplicated checks on arg existance --- src/wf_psf/data/data_handler.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/src/wf_psf/data/data_handler.py b/src/wf_psf/data/data_handler.py index 40611c71..120974eb 100644 --- a/src/wf_psf/data/data_handler.py +++ b/src/wf_psf/data/data_handler.py @@ -189,14 +189,7 @@ def _convert_dataset_to_tensorflow(self): if self.dataset_type == "train": self.dataset["noisy_stars"] = ensure_tensor(self.dataset["noisy_stars"], dtype=tf.float32) elif self.dataset_type == "test": - if "stars" in self.dataset: - self.dataset["stars"] = ensure_tensor( - self.dataset["stars"], dtype=tf.float32 - ) - else: - logger.warning(f"Missing 'stars' in {self.dataset_type} dataset.") - elif "inference" == self.dataset_type: - pass + self.dataset["stars"] = ensure_tensor(self.dataset["stars"], dtype=tf.float32) def process_sed_data(self, sed_data): """ From f4adcd453a5ef19981eb1f0f9afd459e0fe13025 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 8 Aug 2025 20:16:57 +0200 Subject: [PATCH 089/135] Improve Zernike prior handling in assemble_zernike_contributions - Support both NumPy arrays and TensorFlow tensors as valid inputs for zernike_prior. - Added an eager execution check before calling `.numpy()` to ensure safe conversion of tensors. - Raise informative errors for unsupported types or if eager mode is disabled. - Updated function docstring to reflect accepted types and behavior. - Removed extraneous whitespace and added clarifying comments in related code paths. --- src/wf_psf/data/data_zernike_utils.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/src/wf_psf/data/data_zernike_utils.py b/src/wf_psf/data/data_zernike_utils.py index 6ef2c93c..8dd74e4d 100644 --- a/src/wf_psf/data/data_zernike_utils.py +++ b/src/wf_psf/data/data_zernike_utils.py @@ -62,7 +62,6 @@ def build(data, run_type: str, model_params, prior: Optional[np.ndarray] = None) ], axis=0, ) - if model_params.use_prior: if prior is not None: logger.warning( @@ -141,7 +140,7 @@ def combine_zernike_contributions(contributions: list[np.ndarray]) -> np.ndarray raise ValueError("All contributions must have the same number of samples.") combined = np.zeros((n_samples, max_order)) - + # Pad each contribution to the max order and sum them for contrib in contributions: padded = pad_contribution_to_order(contrib, max_order) combined += padded @@ -164,7 +163,8 @@ def assemble_zernike_contributions( model_params : RecursiveNamespace Parameters controlling which contributions to apply. zernike_prior : Optional[np.ndarray or tf.Tensor] - The precomputed Zernike prior (e.g., from PDC or another model). + The precomputed Zernike prior. Can be either a NumPy array or a TensorFlow tensor. + If a Tensor, will be converted to NumPy in eager mode. centroid_dataset : Optional[object] Dataset used to compute centroid correction. Must have both training and test sets. positions : Optional[np.ndarray or tf.Tensor] @@ -184,8 +184,17 @@ def assemble_zernike_contributions( if model_params.use_prior and zernike_prior is not None: logger.info("Adding Zernike prior...") if isinstance(zernike_prior, tf.Tensor): - zernike_prior = zernike_prior.numpy() + if tf.executing_eagerly(): + zernike_prior = zernike_prior.numpy() + else: + raise RuntimeError( + "Zernike prior is a TensorFlow tensor but eager execution is disabled. " + "Cannot call `.numpy()` outside of eager mode." + ) + elif isinstance(zernike_prior, np.ndarray): zernike_contribution_list.append(zernike_prior) + else: + raise TypeError("Unsupported zernike_prior type. Must be np.ndarray or tf.Tensor.") else: logger.info("Skipping Zernike prior (not used or not provided).") @@ -220,7 +229,6 @@ def assemble_zernike_contributions( return tf.convert_to_tensor(combined_zernike_prior, dtype=tf.float32) - def shift_x_y_to_zk1_2_wavediff(dxy, tel_focal_length=24.5, tel_diameter=1.2): """Compute Zernike 1(2) for a given shifts in x(y) in WaveDifff conventions. From e8f60758356b8c10df8abba19f008bfbbb1b6c37 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 8 Aug 2025 20:25:35 +0200 Subject: [PATCH 090/135] Fix bug where Tensor zernike_prior was not appended after eager conversion --- src/wf_psf/data/data_zernike_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/wf_psf/data/data_zernike_utils.py b/src/wf_psf/data/data_zernike_utils.py index 8dd74e4d..1157be84 100644 --- a/src/wf_psf/data/data_zernike_utils.py +++ b/src/wf_psf/data/data_zernike_utils.py @@ -191,10 +191,10 @@ def assemble_zernike_contributions( "Zernike prior is a TensorFlow tensor but eager execution is disabled. " "Cannot call `.numpy()` outside of eager mode." ) - elif isinstance(zernike_prior, np.ndarray): - zernike_contribution_list.append(zernike_prior) - else: + + elif not isinstance(zernike_prior, np.ndarray): raise TypeError("Unsupported zernike_prior type. Must be np.ndarray or tf.Tensor.") + zernike_contribution_list.append(zernike_prior) else: logger.info("Skipping Zernike prior (not used or not provided).") From f8ba288f6d530eab2a59132583777aafa00d7ffd Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 8 Aug 2025 20:28:46 +0200 Subject: [PATCH 091/135] Update unit tests with latest changes to fixtures and data_zernike_utils.py --- .../test_data/data_zernike_utils_test.py | 67 ++++++++++--------- 1 file changed, 35 insertions(+), 32 deletions(-) diff --git a/src/wf_psf/tests/test_data/data_zernike_utils_test.py b/src/wf_psf/tests/test_data/data_zernike_utils_test.py index 90994dde..cc6e22de 100644 --- a/src/wf_psf/tests/test_data/data_zernike_utils_test.py +++ b/src/wf_psf/tests/test_data/data_zernike_utils_test.py @@ -12,10 +12,9 @@ assemble_zernike_contributions, compute_zernike_tip_tilt, ) -from wf_psf.tests.test_data.test_data_utils import MockData, MockDataset +from wf_psf.tests.test_data.test_data_utils import MockData from types import SimpleNamespace as RecursiveNamespace - @pytest.fixture def mock_model_params(): return RecursiveNamespace( @@ -29,41 +28,45 @@ def mock_model_params(): def dummy_prior(): return np.ones((4, 6), dtype=np.float32) -@pytest.fixture -def dummy_positions(): - return np.random.rand(4, 2).astype(np.float32) @pytest.fixture def dummy_centroid_dataset(): return {"training": "dummy_train", "test": "dummy_test"} -def test_training_without_prior(mock_model_params): + +def test_training_without_prior(mock_model_params, mock_data): mock_model_params.use_prior = False - data = MagicMock() - data.training_dataset = {"positions": np.ones((2, 2))} - data.test_dataset = {"positions": np.zeros((3, 2))} - zinputs = ZernikeInputsFactory.build(data=data, run_type="training", model_params=mock_model_params) + # Clear priors to simulate no prior being used + mock_data.training_data.dataset.pop("zernike_prior", None) + mock_data.test_data.dataset.pop("zernike_prior", None) - assert zinputs.centroid_dataset is data + zinputs = ZernikeInputsFactory.build(data=mock_data, run_type="training", model_params=mock_model_params) + + assert zinputs.centroid_dataset is mock_data assert zinputs.zernike_prior is None - np.testing.assert_array_equal( - zinputs.misalignment_positions, - np.concatenate([data.training_dataset["positions"], data.test_dataset["positions"]]) - ) -@patch("wf_psf.data.data_zernike_utils.get_np_zernike_prior") -def test_training_with_dataset_prior(mock_get_prior, mock_model_params): + expected_positions = np.concatenate([ + mock_data.training_data.dataset["positions"], + mock_data.test_data.dataset["positions"] + ]) + np.testing.assert_array_equal(zinputs.misalignment_positions, expected_positions) + + +def test_training_with_dataset_prior(mock_model_params, mock_data): mock_model_params.use_prior = True - data = MagicMock() - data.training_dataset = {"positions": np.ones((2, 2))} - data.test_dataset = {"positions": np.zeros((2, 2))} - mock_get_prior.return_value = np.array([1.0, 2.0, 3.0]) - zinputs = ZernikeInputsFactory.build(data=data, run_type="training", model_params=mock_model_params) + zinputs = ZernikeInputsFactory.build(data=mock_data, run_type="training", model_params=mock_model_params) + + expected_priors = np.concatenate( + ( + mock_data.training_data.dataset["zernike_prior"], + mock_data.test_data.dataset["zernike_prior"], + ), + axis=0, + ) + np.testing.assert_array_equal(zinputs.zernike_prior, expected_priors) - assert zinputs.zernike_prior.tolist() == [1.0, 2.0, 3.0] - mock_get_prior.assert_called_once_with(data) def test_training_with_explicit_prior(mock_model_params, caplog): mock_model_params.use_prior = True @@ -224,17 +227,18 @@ def test_zero_order_contributions(): @patch("wf_psf.data.data_zernike_utils.compute_centroid_correction") @patch("wf_psf.data.data_zernike_utils.compute_ccd_misalignment") -def test_full_contribution_combination(mock_ccd, mock_centroid, mock_model_params, dummy_prior, dummy_centroid_dataset, dummy_positions): +def test_full_contribution_combination(mock_ccd, mock_centroid, mock_model_params, dummy_prior, dummy_centroid_dataset): mock_centroid.return_value = np.full((4, 6), 2.0) mock_ccd.return_value = np.full((4, 6), 3.0) + dummy_positions = np.full((4, 6), 1.0) result = assemble_zernike_contributions( model_params=mock_model_params, zernike_prior=dummy_prior, centroid_dataset=dummy_centroid_dataset, - positions=dummy_positions + positions = dummy_positions ) - + expected = dummy_prior + 2.0 + 3.0 np.testing.assert_allclose(result.numpy(), expected) @@ -275,7 +279,7 @@ def test_prior_as_tensor(mock_model_params): model_params=mock_model_params, zernike_prior=tensor_prior ) - + assert tf.executing_eagerly(), "TensorFlow must be in eager mode for this test" assert isinstance(result, tf.Tensor) np.testing.assert_array_equal(result.numpy(), np.ones((4, 6))) @@ -294,7 +298,7 @@ def test_inconsistent_shapes_raises_error(mock_centroid, mock_model_params, dumm def test_compute_zernike_tip_tilt_single_batch(mocker, simple_image, identity_mask): - """Test compute_zernike_tip_tilt with single batch input and mocks.""" + """Test compute_zernike_tip_tilt handling with single batch input and mocks.""" # Mock the CentroidEstimator class mock_centroid_calc = mocker.patch("wf_psf.data.centroids.CentroidEstimator", autospec=True) @@ -332,11 +336,11 @@ def test_compute_zernike_tip_tilt_single_batch(mocker, simple_image, identity_ma np.testing.assert_allclose(args[0][1], expected_dy * pixel_sampling, rtol=1e-7, atol=0) # Expected values based on mock side_effect (0.5 * shift) - np.testing.assert_allclose(zernike_corrections[0, 0], expected_dx * pixel_sampling * 0.5) # Zk1 + np.testing.assert_allclose(zernike_corrections[0, 0], expected_dx * pixel_sampling * 0.5) # Zk1 np.testing.assert_allclose(zernike_corrections[0, 1], expected_dy * pixel_sampling * 0.5) # Zk2 def test_compute_zernike_tip_tilt_batch(mocker, multiple_images): - """Test compute_zernike_tip_tilt with batch input and mocks.""" + """Test compute_zernike_tip_tilt batch handling of multiple inputs.""" # Mock the CentroidEstimator class mock_centroid_calc = mocker.patch("wf_psf.data.centroids.CentroidEstimator", autospec=True) @@ -377,7 +381,6 @@ def test_compute_zernike_tip_tilt_batch(mocker, multiple_images): # Process the displacements and expected values for each image in the batch expected_dx = reference_shifts[1] - mock_instance.get_intra_pixel_shifts.return_value[:, 1] # Expected x-axis shift in meters - expected_dy = reference_shifts[0] - mock_instance.get_intra_pixel_shifts.return_value[:, 0] # Expected y-axis shift in meters # Compare expected values with the actual arguments passed to the mock function From a97cc66a5f812ed8c2c7f7ec97edd6e114a7cebe Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 8 Aug 2025 20:30:00 +0200 Subject: [PATCH 092/135] Set mock Zernike priors to None in test_data_utils.py helper module --- src/wf_psf/tests/test_data/test_data_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/wf_psf/tests/test_data/test_data_utils.py b/src/wf_psf/tests/test_data/test_data_utils.py index a5ead298..1ebc00cb 100644 --- a/src/wf_psf/tests/test_data/test_data_utils.py +++ b/src/wf_psf/tests/test_data/test_data_utils.py @@ -8,8 +8,8 @@ def __init__( self, training_positions, test_positions, - training_zernike_priors, - test_zernike_priors, + training_zernike_priors=None, + test_zernike_priors=None, noisy_stars=None, noisy_masks=None, stars=None, From 88e49bcc35fc7d4a38e5a59364d5b46df8bce520 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 8 Aug 2025 20:31:12 +0200 Subject: [PATCH 093/135] Remove -1.0 multiplicative factor applied to Zernike tip and tilt values --- src/wf_psf/data/centroids.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/wf_psf/data/centroids.py b/src/wf_psf/data/centroids.py index 97cb069f..b01af2dd 100644 --- a/src/wf_psf/data/centroids.py +++ b/src/wf_psf/data/centroids.py @@ -32,12 +32,11 @@ def compute_centroid_correction(model_params, data, batch_size: int=1) -> np.nda batch_size : int, optional The batch size to use when processing the stars. Default is 16. - Returns ------- zernike_centroid_array : np.ndarray A 2D NumPy array of shape `(n_stars, 3)`, where `n_stars` is the number of - observed stars. The array contains the computed Zernike contributions, + observed stars. The array contains the computed Zernike (Z1, Z2) contributions, with zero padding applied to the first column to ensure a consistent shape. """ star_postage_stamps = extract_star_data(data=data, train_key="noisy_stars", test_key="stars") @@ -70,7 +69,7 @@ def compute_centroid_correction(model_params, data, batch_size: int=1) -> np.nda batch_masks = star_masks[i:i + batch_size] if star_masks is not None else None # Compute Zernike 1 and Zernike 2 for the batch - zk1_2_batch = -1.0 * compute_zernike_tip_tilt( + zk1_2_batch = compute_zernike_tip_tilt( batch_postage_stamps, batch_masks, pix_sampling, reference_shifts ) From a1d055b323df69b47d74d399c2625b09f1cdfbe5 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 8 Aug 2025 22:42:28 +0200 Subject: [PATCH 094/135] Move TFPhysicalPolychromaticField.pad_zernikes to helper method pad_tf_zernikes in data_zernike_utils.py - Import pad_tf_zernikes into psf_model_physical_polychromatic.py - Replace calls to pad_zernikes with pad_tf_zernikes - Move padding zernike unit tests to data_zernike_utils_test.py - Update/Remove unit tests in psf_model_physical_polychromatic_test.py --- src/wf_psf/data/data_zernike_utils.py | 40 +++++ .../psf_model_physical_polychromatic.py | 14 +- .../test_data/data_zernike_utils_test.py | 75 ++++++++- .../psf_model_physical_polychromatic_test.py | 144 +++++------------- 4 files changed, 164 insertions(+), 109 deletions(-) diff --git a/src/wf_psf/data/data_zernike_utils.py b/src/wf_psf/data/data_zernike_utils.py index 1157be84..3a7fa90b 100644 --- a/src/wf_psf/data/data_zernike_utils.py +++ b/src/wf_psf/data/data_zernike_utils.py @@ -147,6 +147,46 @@ def combine_zernike_contributions(contributions: list[np.ndarray]) -> np.ndarray return combined + +def pad_tf_zernikes(zk_param: tf.Tensor, zk_prior: tf.Tensor, n_zks_total: int): + """ + Pad the Zernike coefficient tensors to match the specified total number of Zernikes. + + Parameters + ---------- + zk_param : tf.Tensor + Zernike coefficients for the parametric part. Shape [batch, n_zks_param, 1, 1]. + zk_prior : tf.Tensor + Zernike coefficients for the prior part. Shape [batch, n_zks_prior, 1, 1]. + n_zks_total : int + Total number of Zernikes to pad to. + + Returns + ------- + padded_zk_param : tf.Tensor + Padded Zernike coefficients for the parametric part. Shape [batch, n_zks_total, 1, 1]. + padded_zk_prior : tf.Tensor + Padded Zernike coefficients for the prior part. Shape [batch, n_zks_total, 1, 1]. + """ + + pad_num_param = n_zks_total - tf.shape(zk_param)[1] + pad_num_prior = n_zks_total - tf.shape(zk_prior)[1] + + padded_zk_param = tf.cond( + tf.not_equal(pad_num_param, 0), + lambda: tf.pad(zk_param, [(0, 0), (0, pad_num_param), (0, 0), (0, 0)]), + lambda: zk_param, + ) + + padded_zk_prior = tf.cond( + tf.not_equal(pad_num_prior, 0), + lambda: tf.pad(zk_prior, [(0, 0), (0, pad_num_prior), (0, 0), (0, 0)]), + lambda: zk_prior, + ) + + return padded_zk_param, padded_zk_prior + + def assemble_zernike_contributions( model_params, zernike_prior=None, diff --git a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py index 9979bcf1..918ffc4b 100644 --- a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py +++ b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py @@ -11,7 +11,11 @@ import tensorflow as tf from tensorflow.python.keras.engine import data_adapter from wf_psf.data.data_handler import get_np_obs_positions -from wf_psf.data.data_zernike_utils import ZernikeInputsFactory, assemble_zernike_contributions +from wf_psf.data.data_zernike_utils import ( + ZernikeInputsFactory, + assemble_zernike_contributions, + pad_tf_zernikes +) from wf_psf.psf_models import psf_models as psfm from wf_psf.psf_models.tf_modules.tf_layers import ( TFPolynomialZernikeField, @@ -575,8 +579,8 @@ def compute_zernikes(self, input_positions): zernike_prior = self.tf_physical_layer.call(input_positions) # Pad and sum the zernike coefficients - padded_zernike_params, padded_zernike_prior = self.pad_zernikes( - zernike_params, zernike_prior + padded_zernike_params, padded_zernike_prior = pad_tf_zernikes( + zernike_params, zernike_prior, self.n_zks_total ) zernike_coeffs = tf.math.add(padded_zernike_params, padded_zernike_prior) @@ -613,8 +617,8 @@ def predict_zernikes(self, input_positions): physical_layer_prediction = self.tf_physical_layer.predict(input_positions) # Pad and sum the Zernike coefficients - padded_zernike_params, padded_physical_layer_prediction = self.pad_zernikes( - zernike_params, physical_layer_prediction + padded_zernike_params, padded_physical_layer_prediction = pad_tf_zernikes( + zernike_params, physical_layer_prediction, self.n_zks_total ) zernike_coeffs = tf.math.add( padded_zernike_params, padded_physical_layer_prediction diff --git a/src/wf_psf/tests/test_data/data_zernike_utils_test.py b/src/wf_psf/tests/test_data/data_zernike_utils_test.py index cc6e22de..afafc1db 100644 --- a/src/wf_psf/tests/test_data/data_zernike_utils_test.py +++ b/src/wf_psf/tests/test_data/data_zernike_utils_test.py @@ -4,13 +4,13 @@ from unittest.mock import MagicMock, patch import tensorflow as tf from wf_psf.data.data_zernike_utils import ( - ZernikeInputs, ZernikeInputsFactory, get_np_zernike_prior, pad_contribution_to_order, combine_zernike_contributions, assemble_zernike_contributions, compute_zernike_tip_tilt, + pad_tf_zernikes ) from wf_psf.tests.test_data.test_data_utils import MockData from types import SimpleNamespace as RecursiveNamespace @@ -297,6 +297,79 @@ def test_inconsistent_shapes_raises_error(mock_centroid, mock_model_params, dumm ) +def test_pad_zernikes_num_of_zernikes_equal(): + # Prepare your test tensors + zk_param = tf.constant([[[[1.0]]], [[[2.0]]]]) # Shape (2, 1, 1, 1) + zk_prior = tf.constant([[[[1.0]]], [[[2.0]]]]) # Same shape + + # Reshape to (1, 2, 1, 1) + zk_param = tf.reshape(zk_param, (1, 2, 1, 1)) + zk_prior = tf.reshape(zk_prior, (1, 2, 1, 1)) + + # Reset _n_zks_total to max number of zernikes (2 here) + n_zks_total = max( + tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy() + ) + + # Call pad_zernikes method + padded_zk_param, padded_zk_prior = pad_tf_zernikes( + zk_param, zk_prior, n_zks_total + ) + + # Assert shapes are equal and correct + assert padded_zk_param.shape[1] == n_zks_total + assert padded_zk_prior.shape[1] == n_zks_total + + # If num zernikes already equal, output should be unchanged + np.testing.assert_array_equal(padded_zk_param.numpy(), zk_param.numpy()) + np.testing.assert_array_equal(padded_zk_prior.numpy(), zk_prior.numpy()) + +def test_pad_zernikes_prior_greater_than_param(): + zk_param = tf.constant([[[[1]]], [[[2]]]]) # Shape: (2, 1, 1, 1) + zk_prior = tf.constant([[[[1]], [[2]], [[3]], [[4]], [[5]]]]) # Shape: (5, 1, 1, 1) + + # Reshape the tensors to have the desired shapes + zk_param = tf.reshape(zk_param, (1, 2, 1, 1)) # Reshaping tensor1 to (1, 2, 1, 1) + zk_prior = tf.reshape(zk_prior, (1, 5, 1, 1)) # Reshaping tensor2 to (1, 5, 1, 1) + + # Reset n_zks_total attribute + n_zks_total = max( + tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy() + ) + + # Call the method under test + padded_zk_param, padded_zk_prior = pad_tf_zernikes( + zk_param, zk_prior, n_zks_total + ) + + # Assert that the padded tensors have the correct shapes + assert padded_zk_param.shape == (1, 5, 1, 1) + assert padded_zk_prior.shape == (1, 5, 1, 1) + + +def test_pad_zernikes_param_greater_than_prior(): + zk_param = tf.constant([[[[10]], [[20]], [[30]], [[40]]]]) # Shape: (4, 1, 1, 1) + zk_prior = tf.constant([[[[1]]], [[[2]]]]) # Shape: (2, 1, 1, 1) + + # Reshape the tensors to have the desired shapes + zk_param = tf.reshape(zk_param, (1, 4, 1, 1)) # Reshaping tensor1 to (1, 2, 1, 1) + zk_prior = tf.reshape(zk_prior, (1, 2, 1, 1)) # Reshaping tensor2 to (1, 4, 1, 1) + + # Reset n_zks_total attribute + n_zks_total = max( + tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy() + ) + + # Call the method under test + padded_zk_param, padded_zk_prior = pad_tf_zernikes( + zk_param, zk_prior, n_zks_total + ) + + # Assert that the padded tensors have the correct shapes + assert padded_zk_param.shape == (1, 4, 1, 1) + assert padded_zk_prior.shape == (1, 4, 1, 1) + + def test_compute_zernike_tip_tilt_single_batch(mocker, simple_image, identity_mask): """Test compute_zernike_tip_tilt handling with single batch input and mocks.""" diff --git a/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py b/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py index 13d78667..95ad6287 100644 --- a/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py +++ b/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py @@ -9,7 +9,7 @@ import pytest import numpy as np import tensorflow as tf -from unittest.mock import PropertyMock +from unittest.mock import patch from wf_psf.psf_models.models.psf_model_physical_polychromatic import ( TFPhysicalPolychromaticField, ) @@ -29,15 +29,27 @@ def zks_prior(): @pytest.fixture -def mock_data(mocker): +def mock_data(mocker, zks_prior): mock_instance = mocker.Mock(spec=DataConfigHandler) - # Configure the mock data object to have the necessary attributes mock_instance.run_type = "training" + + training_dataset = { + "positions": np.array([[1, 2], [3, 4]]), + "zernike_prior": zks_prior, + "noisy_stars": np.zeros((2, 1, 1, 1)), + } + test_dataset = { + "positions": np.array([[5, 6], [7, 8]]), + "zernike_prior": zks_prior, + "stars": np.zeros((2, 1, 1, 1)), + } + mock_instance.training_data = mocker.Mock() - mock_instance.training_data.dataset = {"positions": np.array([[1, 2], [3, 4]])} + mock_instance.training_data.dataset = training_dataset mock_instance.test_data = mocker.Mock() - mock_instance.test_data.dataset = {"positions": np.array([[5, 6], [7, 8]])} - mock_instance.batch_size = 32 + mock_instance.test_data.dataset = test_dataset + mock_instance.batch_size = 16 + return mock_instance @@ -49,122 +61,48 @@ def mock_model_params(mocker): return model_params_mock @pytest.fixture -def physical_layer_instance(mocker, mock_model_params, mock_data, zks_prior): - # Create training params mock object - mock_training_params = mocker.Mock() - - # Create TFPhysicalPolychromaticField instance - psf_field_instance = TFPhysicalPolychromaticField( - mock_model_params, mock_training_params, mock_data - ) - return psf_field_instance - - -def test_pad_zernikes_num_of_zernikes_equal(physical_layer_instance): - # Define input tensors with same length and num of Zernikes - zk_param = tf.constant([[[[1]]], [[[2]]]]) # Shape: (2, 1, 1, 1) - zk_prior = tf.constant([[[[1]]], [[[2]]]]) # Shape: (2, 1, 1, 1) - - # Reshape the tensors to have the desired shapes - zk_param = tf.reshape(zk_param, (1, 2, 1, 1)) # Reshaping tensor1 to (1, 2, 1, 1) - zk_prior = tf.reshape(zk_prior, (1, 2, 1, 1)) # Reshaping tensor2 to (1, 2, 1, 1) - - # Reset n_zks_total attribute - physical_layer_instance._n_zks_total = max( - tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy() - ) - # Call the method under test - padded_zk_param, padded_zk_prior = physical_layer_instance.pad_zernikes( - zk_param, zk_prior - ) - - # Assert that the padded tensors have the correct shapes - assert padded_zk_param.shape == (1, 2, 1, 1) - assert padded_zk_prior.shape == (1, 2, 1, 1) - - -def test_pad_zernikes_prior_greater_than_param(physical_layer_instance): - zk_param = tf.constant([[[[1]]], [[[2]]]]) # Shape: (2, 1, 1, 1) - zk_prior = tf.constant([[[[1]], [[2]], [[3]], [[4]], [[5]]]]) # Shape: (5, 1, 1, 1) - - # Reshape the tensors to have the desired shapes - zk_param = tf.reshape(zk_param, (1, 2, 1, 1)) # Reshaping tensor1 to (1, 2, 1, 1) - zk_prior = tf.reshape(zk_prior, (1, 5, 1, 1)) # Reshaping tensor2 to (1, 5, 1, 1) - - # Reset n_zks_total attribute - physical_layer_instance._n_zks_total = max( - tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy() - ) - - # Call the method under test - padded_zk_param, padded_zk_prior = physical_layer_instance.pad_zernikes( - zk_param, zk_prior - ) - - # Assert that the padded tensors have the correct shapes - assert padded_zk_param.shape == (1, 5, 1, 1) - assert padded_zk_prior.shape == (1, 5, 1, 1) - - -def test_pad_zernikes_param_greater_than_prior(physical_layer_instance): - zk_param = tf.constant([[[[10]], [[20]], [[30]], [[40]]]]) # Shape: (4, 1, 1, 1) - zk_prior = tf.constant([[[[1]]], [[[2]]]]) # Shape: (2, 1, 1, 1) - - # Reshape the tensors to have the desired shapes - zk_param = tf.reshape(zk_param, (1, 4, 1, 1)) # Reshaping tensor1 to (1, 2, 1, 1) - zk_prior = tf.reshape(zk_prior, (1, 2, 1, 1)) # Reshaping tensor2 to (1, 4, 1, 1) - - # Reset n_zks_total attribute - physical_layer_instance._n_zks_total = max( - tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy() - ) - - # Call the method under test - padded_zk_param, padded_zk_prior = physical_layer_instance.pad_zernikes( - zk_param, zk_prior - ) - - # Assert that the padded tensors have the correct shapes - assert padded_zk_param.shape == (1, 4, 1, 1) - assert padded_zk_prior.shape == (1, 4, 1, 1) - +def physical_layer_instance(mocker, mock_model_params, mock_data): + # Patch expensive methods during construction to avoid errors + with patch("wf_psf.psf_models.models.psf_model_physical_polychromatic.TFPhysicalPolychromaticField._assemble_zernike_contributions", return_value=tf.constant([[[[1.0]]], [[[2.0]]]])): + from wf_psf.psf_models.models.psf_model_physical_polychromatic import TFPhysicalPolychromaticField + instance = TFPhysicalPolychromaticField(mock_model_params, mocker.Mock(), mock_data) + return instance def test_compute_zernikes(mocker, physical_layer_instance): - # Expected output of mock components + # Expected output of mock components padded_zernike_param = tf.constant([[[[10]], [[20]], [[30]], [[40]]]], dtype=tf.float32) padded_zernike_prior = tf.constant([[[[1]], [[2]], [[0]], [[0]]]], dtype=tf.float32) - expected_values = tf.constant([[[[11]], [[22]], [[30]], [[40]]]], dtype=tf.float32) - - # Patch tf_poly_Z_field property - mock_tf_poly_Z_field = mocker.Mock(return_value=padded_zernike_param) + n_zks_total = physical_layer_instance.n_zks_total + expected_values_list = [11, 22, 30, 40] + [0] * (n_zks_total - 4) + expected_values = tf.constant( + [[[[v]] for v in expected_values_list]], + dtype=tf.float32 +) + # Patch tf_poly_Z_field method mocker.patch.object( TFPhysicalPolychromaticField, - 'tf_poly_Z_field', - new_callable=PropertyMock, - return_value=mock_tf_poly_Z_field + "tf_poly_Z_field", + return_value=padded_zernike_param ) - # Patch tf_physical_layer property + # Patch tf_physical_layer.call method mock_tf_physical_layer = mocker.Mock() mock_tf_physical_layer.call.return_value = padded_zernike_prior mocker.patch.object( TFPhysicalPolychromaticField, - 'tf_physical_layer', - new_callable=PropertyMock, - return_value=mock_tf_physical_layer + "tf_physical_layer", + mock_tf_physical_layer ) - # Patch pad_zernikes instance method directly (this one isn't a property) - mocker.patch.object( - physical_layer_instance, - 'pad_zernikes', + # Patch pad_tf_zernikes function + mocker.patch( + "wf_psf.data.data_zernike_utils.pad_tf_zernikes", return_value=(padded_zernike_param, padded_zernike_prior) ) - # Run the test zernike_coeffs = physical_layer_instance.compute_zernikes(tf.constant([[0.0, 0.0]])) # Assertions tf.debugging.assert_equal(zernike_coeffs, expected_values) - assert zernike_coeffs.shape == expected_values.shape \ No newline at end of file + assert zernike_coeffs.shape == expected_values.shape From b721f2fffc698361eddeefd243210f85de7d1e7c Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 8 Aug 2025 22:59:07 +0200 Subject: [PATCH 095/135] Correct bug in test_load_inference_model - Add missing fixture arguments to mock_training_config and mock_inference_config - Add patch for DataHandler - Remove unused import --- .../test_inference/psf_inference_test.py | 38 +++++++++++++++---- 1 file changed, 30 insertions(+), 8 deletions(-) diff --git a/src/wf_psf/tests/test_inference/psf_inference_test.py b/src/wf_psf/tests/test_inference/psf_inference_test.py index 4add460b..91328300 100644 --- a/src/wf_psf/tests/test_inference/psf_inference_test.py +++ b/src/wf_psf/tests/test_inference/psf_inference_test.py @@ -12,11 +12,11 @@ import pytest import tensorflow as tf from types import SimpleNamespace -from unittest.mock import MagicMock, patch, PropertyMock +from unittest.mock import MagicMock, patch from wf_psf.inference.psf_inference import ( InferenceConfigHandler, PSFInference, - PSFInferenceEngine + PSFInferenceEngine ) from wf_psf.utils.read_config import RecursiveNamespace @@ -29,7 +29,26 @@ def mock_training_config(): model_params=RecursiveNamespace( model_name="mock_model", output_Q=2, - output_dim=32 + output_dim=32, + pupil_diameter=256, + oversampling_rate=3, + interpolation_type=None, + interpolation_args=None, + sed_interp_pts_per_bin=0, + sed_extrapolate=True, + sed_interp_kind="linear", + sed_sigma=0, + x_lims=[0.0, 1000.0], + y_lims=[0.0, 1000.0], + pix_sampling=12, + tel_diameter=1.2, + tel_focal_length=24.5, + euclid_obsc=True, + LP_filter_length=3, + param_hparams=RecursiveNamespace( + n_zernikes=10, + + ) ) ) ) @@ -48,9 +67,10 @@ def mock_inference_config(): data_config_path=None ), model_params=RecursiveNamespace( + n_bins_lda=8, output_Q=1, output_dim=64 - ) + ), ) ) return inference_config @@ -179,9 +199,11 @@ def test_batch_size_positive(): assert inference.batch_size == 4 +@patch('wf_psf.inference.psf_inference.DataHandler') @patch('wf_psf.inference.psf_inference.load_trained_psf_model') -def test_load_inference_model(mock_load_trained_psf_model, mock_training_config, mock_inference_config): - +def test_load_inference_model(mock_load_trained_psf_model, mock_data_handler, mock_training_config, mock_inference_config): + mock_data_config = MagicMock() + mock_data_handler.return_value = mock_data_config mock_config_handler = MagicMock(spec=InferenceConfigHandler) mock_config_handler.trained_model_path = "mock/path/to/model" mock_config_handler.training_config = mock_training_config @@ -202,8 +224,8 @@ def test_load_inference_model(mock_load_trained_psf_model, mock_training_config, # Assert calls to the mocked methods mock_load_trained_psf_model.assert_called_once_with( - mock_config_handler.training_config, - mock_config_handler.data_config, + mock_training_config, + mock_data_config, weights_path_pattern ) From 98d92a0372a11cce388291c8827fa7c0c05f7293 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Mon, 18 Aug 2025 15:33:26 +0200 Subject: [PATCH 096/135] Revert sign change applied to compute_centroid_correction --- src/wf_psf/data/centroids.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/wf_psf/data/centroids.py b/src/wf_psf/data/centroids.py index b01af2dd..7052f833 100644 --- a/src/wf_psf/data/centroids.py +++ b/src/wf_psf/data/centroids.py @@ -69,7 +69,7 @@ def compute_centroid_correction(model_params, data, batch_size: int=1) -> np.nda batch_masks = star_masks[i:i + batch_size] if star_masks is not None else None # Compute Zernike 1 and Zernike 2 for the batch - zk1_2_batch = compute_zernike_tip_tilt( + zk1_2_batch = -1.0 * compute_zernike_tip_tilt( batch_postage_stamps, batch_masks, pix_sampling, reference_shifts ) From 2e644d7918e63c7bbfd985e15322dc9c560d1489 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Tue, 19 Aug 2025 23:37:52 +0200 Subject: [PATCH 097/135] Refactor _prepare_positions_and_seds to enforce shape consistency and add validation - Ensure x_field and y_field are at least 1D and broadcast to positions of shape (n_samples, 2) - Validate that SEDs batch size matches the number of positions; raise ValueError if not - Validate that SEDs last dimension is 2 (flux, wavelength) - Process SEDs via data_handler as before - Removes need for broadcasting SEDs silently, avoiding hidden shape mismatches - Supports single-star, multi-star, and scalar inputs - Improves error messages for easier debugging in unit tests Also updates unit tests to cover: - Single-star input shapes - Mismatched x/y fields or SED batch size (ValueError cases) --- src/wf_psf/inference/psf_inference.py | 60 ++++++-- .../test_inference/psf_inference_test.py | 137 +++++++++++++++++- src/wf_psf/utils/utils.py | 41 ++++++ 3 files changed, 226 insertions(+), 12 deletions(-) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index 6bba2282..76619f5a 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -13,6 +13,7 @@ import numpy as np from wf_psf.data.data_handler import DataHandler from wf_psf.utils.read_config import read_conf +from wf_psf.utils.utils import ensure_batch from wf_psf.psf_models import psf_models from wf_psf.psf_models.psf_model_loader import load_trained_psf_model import tensorflow as tf @@ -260,13 +261,40 @@ def output_dim(self): return self._output_dim def _prepare_positions_and_seds(self): - """Preprocess and return tensors for positions and SEDs.""" - positions = tf.convert_to_tensor( - np.array([self.x_field, self.y_field]).T, dtype=tf.float32 - ) - self.data_handler.process_sed_data(self.seds) - sed_data = self.data_handler.sed_data - return positions, sed_data + """ + Preprocess and return tensors for positions and SEDs with consistent shapes. + + Handles single-star, multi-star, and even scalar inputs, ensuring: + - positions: shape (n_samples, 2) + - sed_data: shape (n_samples, n_bins, 2) + """ + # Ensure x_field and y_field are at least 1D + x_arr = np.atleast_1d(self.x_field) + y_arr = np.atleast_1d(self.y_field) + + if x_arr.size != y_arr.size: + raise ValueError(f"x_field and y_field must have the same length. " + f"Got {x_arr.size} and {y_arr.size}") + + # Combine into positions array (n_samples, 2) + positions = np.column_stack((x_arr, y_arr)) + positions = tf.convert_to_tensor(positions, dtype=tf.float32) + + # Ensure SEDs have shape (n_samples, n_bins, 2) + sed_data = ensure_batch(self.seds) + + if sed_data.shape[0] != positions.shape[0]: + raise ValueError(f"SEDs batch size {sed_data.shape[0]} does not match number of positions {positions.shape[0]}") + + if sed_data.shape[2] != 2: + raise ValueError(f"SEDs last dimension must be 2 (flux, wavelength). Got {sed_data.shape}") + + # Process SEDs through the data handler + self.data_handler.process_sed_data(sed_data) + sed_data_tensor = self.data_handler.sed_data + + return positions, sed_data_tensor + def run_inference(self): """Run PSF inference and return the full PSF array.""" @@ -291,9 +319,23 @@ def get_psfs(self): self._ensure_psf_inference_completed() return self.engine.get_psfs() - def get_psf(self, index): + def get_psf(self, index: int = 0) -> np.ndarray: + """ + Get the PSF at a specific index. + + If only a single star was passed during instantiation, the index defaults to 0. + """ self._ensure_psf_inference_completed() - return self.engine.get_psf(index) + + inferred_psfs = self.engine.get_psfs() + + # If a single-star batch, ignore index bounds + if inferred_psfs.shape[0] == 1: + return inferred_psfs[0] + + # Otherwise, return the PSF at the requested index + return inferred_psfs[index] + class PSFInferenceEngine: def __init__(self, trained_model, batch_size: int, output_dim: int): diff --git a/src/wf_psf/tests/test_inference/psf_inference_test.py b/src/wf_psf/tests/test_inference/psf_inference_test.py index 91328300..ec9f0495 100644 --- a/src/wf_psf/tests/test_inference/psf_inference_test.py +++ b/src/wf_psf/tests/test_inference/psf_inference_test.py @@ -12,7 +12,7 @@ import pytest import tensorflow as tf from types import SimpleNamespace -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, patch, PropertyMock from wf_psf.inference.psf_inference import ( InferenceConfigHandler, PSFInference, @@ -21,6 +21,19 @@ from wf_psf.utils.read_config import RecursiveNamespace +def _patch_data_handler(): + """Helper for patching data_handler to avoid full PSF logic.""" + patcher = patch.object(PSFInference, "data_handler", new_callable=PropertyMock) + mock_data_handler = patcher.start() + mock_instance = MagicMock() + mock_data_handler.return_value = mock_instance + + def fake_process(x): + mock_instance.sed_data = tf.convert_to_tensor(x) + + mock_instance.process_sed_data.side_effect = fake_process + return patcher, mock_instance + @pytest.fixture def mock_training_config(): training_config = RecursiveNamespace( @@ -83,14 +96,46 @@ def psf_test_setup(mock_inference_config): output_dim = 32 mock_positions = tf.convert_to_tensor([[0.1, 0.1], [0.2, 0.2]], dtype=tf.float32) - mock_seds = tf.convert_to_tensor(np.random.rand(num_sources, 2, num_bins), dtype=tf.float32) + mock_seds = tf.convert_to_tensor(np.random.rand(num_sources, num_bins, 2), dtype=tf.float32) expected_psfs = np.random.rand(num_sources, output_dim, output_dim).astype(np.float32) inference = PSFInference( "dummy_path.yaml", x_field=[0.1, 0.2], y_field=[0.1, 0.2], - seds=np.random.rand(num_sources, num_bins) + seds=np.random.rand(num_sources, num_bins, 2) + ) + inference._config_handler = MagicMock() + inference._config_handler.inference_config = mock_inference_config + inference._trained_psf_model = MagicMock() + + return { + "inference": inference, + "mock_positions": mock_positions, + "mock_seds": mock_seds, + "expected_psfs": expected_psfs, + "num_sources": num_sources, + "num_bins": num_bins, + "output_dim": output_dim + } + +@pytest.fixture +def psf_single_star_setup(mock_inference_config): + num_sources = 1 + num_bins = 10 + output_dim = 32 + + # Single position + mock_positions = tf.convert_to_tensor([[0.1, 0.1]], dtype=tf.float32) + # Shape (1, 2, num_bins) + mock_seds = tf.convert_to_tensor(np.random.rand(num_sources, 2, num_bins), dtype=tf.float32) + expected_psfs = np.random.rand(num_sources, output_dim, output_dim).astype(np.float32) + + inference = PSFInference( + "dummy_path.yaml", + x_field=0.1, # scalar for single star + y_field=0.1, + seds=np.random.rand(num_bins, 2) # shape (num_bins, 2) before batching ) inference._config_handler = MagicMock() inference._config_handler.inference_config = mock_inference_config @@ -308,3 +353,89 @@ def fake_compute_psfs(positions, seds): assert np.all(psfs_2 == expected_psfs) assert mock_compute_psfs.call_count == 1 + + + +def test_single_star_inference_shape(psf_single_star_setup): + setup = psf_single_star_setup + + _, mock_instance = _patch_data_handler() + + # Run the method under test + positions, sed_data = setup["inference"]._prepare_positions_and_seds() + + # Check shapes + assert sed_data.shape == (1, setup["num_bins"], 2) + assert positions.shape == (1, 2) + + # Verify the call happened + mock_instance.process_sed_data.assert_called_once() + args, _ = mock_instance.process_sed_data.call_args + input_array = args[0] + + # Check input SED had the right shape before being tensorized + assert input_array.shape == (1, setup["num_bins"], 2), \ + "process_sed_data should have been called with shape (1, num_bins, 2)" + + + +def test_multiple_star_inference_shape(psf_test_setup): + """Test that _prepare_positions_and_seds returns correct shapes for multiple stars.""" + setup = psf_test_setup + + _, mock_instance = _patch_data_handler() + + # Run the method under test + positions, sed_data = setup["inference"]._prepare_positions_and_seds() + + # Check shapes + assert sed_data.shape == (2, setup["num_bins"], 2) + assert positions.shape == (2, 2) + + # Verify the call happened + mock_instance.process_sed_data.assert_called_once() + args, _ = mock_instance.process_sed_data.call_args + input_array = args[0] + + # Check input SED had the right shape before being tensorized + assert input_array.shape == (2, setup["num_bins"], 2), \ + "process_sed_data should have been called with shape (2, num_bins, 2)" + + +def test_valueerror_on_mismatched_batches(psf_single_star_setup): + """Raise if sed_data batch size != positions batch size and sed_data != 1.""" + setup = psf_single_star_setup + inference = setup["inference"] + + patcher, _ = _patch_data_handler() + try: + # Force sed_data to have 2 sources while positions has 1 + bad_sed = np.ones((2, setup["num_bins"], 2), dtype=np.float32) + + # Replace fixture's sed_data with mismatched one + inference.seds = bad_sed + inference.positions = np.ones((1, 2), dtype=np.float32) + + with pytest.raises(ValueError, match="SEDs batch size 2 does not match number of positions 1"): + inference._prepare_positions_and_seds() + finally: + patcher.stop() + + +def test_valueerror_on_mismatched_positions(psf_single_star_setup): + """Raise if positions batch size != sed_data batch size (opposite mismatch).""" + setup = psf_single_star_setup + inference = setup["inference"] + + patcher, _ = _patch_data_handler() + try: + # Force positions to have 3 entries while sed_data has 2 + bad_sed = np.ones((2, setup["num_bins"], 2), dtype=np.float32) + inference.seds = bad_sed + inference.x_field = np.ones((3, 1), dtype=np.float32) + inference.y_field = np.ones((3, 1), dtype=np.float32) + + with pytest.raises(ValueError, match="SEDs batch size 2 does not match number of positions 3"): + inference._prepare_positions_and_seds() + finally: + patcher.stop() \ No newline at end of file diff --git a/src/wf_psf/utils/utils.py b/src/wf_psf/utils/utils.py index 1b1f2d6d..17219ad9 100644 --- a/src/wf_psf/utils/utils.py +++ b/src/wf_psf/utils/utils.py @@ -26,6 +26,47 @@ pass +def scale_to_range(input_array, old_range, new_range): + # Scale to [0,1] + input_array = (input_array - old_range[0]) / (old_range[1] - old_range[0]) + # Scale to new_range + input_array = input_array * (new_range[1] - new_range[0]) + new_range[0] + return input_array + + +def ensure_batch(arr): + """ + Ensure array/tensor has a batch dimension. Converts shape (M, N) → (1, M, N). + + Parameters + ---------- + arr : np.ndarray or tf.Tensor + Input 2D or 3D array/tensor. + + Returns + ------- + np.ndarray or tf.Tensor + With batch dimension prepended if needed. + """ + if isinstance(arr, np.ndarray): + return arr if arr.ndim == 3 else np.expand_dims(arr, axis=0) + elif isinstance(arr, tf.Tensor): + return arr if arr.ndim == 3 else tf.expand_dims(arr, axis=0) + else: + raise TypeError(f"Expected np.ndarray or tf.Tensor, got {type(arr)}") + + +def calc_wfe(zernike_basis, zks): + wfe = np.einsum("ijk,ijk->jk", zernike_basis, zks.reshape(-1, 1, 1)) + return wfe + + +def calc_wfe_rms(zernike_basis, zks, pupil_mask): + wfe = calc_wfe(zernike_basis, zks) + wfe_rms = np.sqrt(np.mean((wfe[pupil_mask] - np.mean(wfe[pupil_mask])) ** 2)) + return wfe_rms + + def generalised_sigmoid(x, max_val=1, power_k=1): """ Apply a generalized sigmoid function to the input. From dd99fe27b08a673a9a546cccfb0220f7e7b56544 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Thu, 21 Aug 2025 16:43:44 +0200 Subject: [PATCH 098/135] Fix tensor handling in ZernikeInputsFactory Explicitly convert positions to NumPy arrays with `.numpy()` and adjust inference path to read from `data.dataset` for positions and priors. --- src/wf_psf/data/data_zernike_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/wf_psf/data/data_zernike_utils.py b/src/wf_psf/data/data_zernike_utils.py index 3a7fa90b..4149d79e 100644 --- a/src/wf_psf/data/data_zernike_utils.py +++ b/src/wf_psf/data/data_zernike_utils.py @@ -57,8 +57,8 @@ def build(data, run_type: str, model_params, prior: Optional[np.ndarray] = None) centroid_dataset = data # Assuming data is a DataConfigHandler or similar object containing train and test datasets positions = np.concatenate( [ - data.training_data.dataset["positions"], - data.test_data.dataset["positions"] + data.training_data.dataset["positions"].numpy(), + data.test_data.dataset["positions"].numpy() ], axis=0, ) @@ -72,11 +72,11 @@ def build(data, run_type: str, model_params, prior: Optional[np.ndarray] = None) elif run_type == "inference": centroid_dataset = None - positions = data["positions"] + positions = data.dataset["positions"].numpy() if model_params.use_prior: # Try to extract prior from `data`, if present - prior = getattr(data, "zernike_prior", None) if not isinstance(data, dict) else data.get("zernike_prior") + prior = getattr(data.dataset, "zernike_prior", None) if not isinstance(data, dict) else data.dataset.get("zernike_prior") if prior is None: logger.warning( From 818dc73806dad8d350b8b48ddbdf68d1f8f6395c Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Thu, 21 Aug 2025 11:50:45 +0200 Subject: [PATCH 099/135] Reformat and remove unused import --- src/wf_psf/data/data_handler.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/wf_psf/data/data_handler.py b/src/wf_psf/data/data_handler.py index 120974eb..fe660940 100644 --- a/src/wf_psf/data/data_handler.py +++ b/src/wf_psf/data/data_handler.py @@ -17,7 +17,6 @@ import wf_psf.utils.utils as utils from wf_psf.psf_models.tf_modules.tf_utils import ensure_tensor import tensorflow as tf -from fractions import Fraction from typing import Optional, Union import logging @@ -184,12 +183,18 @@ def _validate_dataset_structure(self): def _convert_dataset_to_tensorflow(self): """Convert dataset to TensorFlow tensors.""" - self.dataset["possitions"] = ensure_tensor(self.dataset["positions"], dtype=tf.float32) - + self.dataset["positions"] = ensure_tensor( + self.dataset["positions"], dtype=tf.float32 + ) + if self.dataset_type == "train": - self.dataset["noisy_stars"] = ensure_tensor(self.dataset["noisy_stars"], dtype=tf.float32) + self.dataset["noisy_stars"] = ensure_tensor( + self.dataset["noisy_stars"], dtype=tf.float32 + ) elif self.dataset_type == "test": - self.dataset["stars"] = ensure_tensor(self.dataset["stars"], dtype=tf.float32) + self.dataset["stars"] = ensure_tensor( + self.dataset["stars"], dtype=tf.float32 + ) def process_sed_data(self, sed_data): """ From f3d4f165b6b549c0eab21c348b45bddfe133dac7 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 22 Aug 2025 09:39:40 +0200 Subject: [PATCH 100/135] Correct zernike_prior extraction when dataset is a dict, reformat file --- src/wf_psf/data/data_zernike_utils.py | 76 ++++++++++++++++----------- 1 file changed, 45 insertions(+), 31 deletions(-) diff --git a/src/wf_psf/data/data_zernike_utils.py b/src/wf_psf/data/data_zernike_utils.py index 4149d79e..f7653540 100644 --- a/src/wf_psf/data/data_zernike_utils.py +++ b/src/wf_psf/data/data_zernike_utils.py @@ -26,13 +26,17 @@ @dataclass class ZernikeInputs: zernike_prior: Optional[np.ndarray] # true prior, if provided (e.g. from PDC) - centroid_dataset: Optional[Union[dict, 'RecursiveNamespace']] # only used in training/simulation + centroid_dataset: Optional[ + Union[dict, "RecursiveNamespace"] + ] # only used in training/simulation misalignment_positions: Optional[np.ndarray] # needed for CCD corrections class ZernikeInputsFactory: @staticmethod - def build(data, run_type: str, model_params, prior: Optional[np.ndarray] = None) -> ZernikeInputs: + def build( + data, run_type: str, model_params, prior: Optional[np.ndarray] = None + ) -> ZernikeInputs: """Builds a ZernikeInputs dataclass instance based on run type and data. Parameters @@ -58,7 +62,7 @@ def build(data, run_type: str, model_params, prior: Optional[np.ndarray] = None) positions = np.concatenate( [ data.training_data.dataset["positions"].numpy(), - data.test_data.dataset["positions"].numpy() + data.test_data.dataset["positions"].numpy(), ], axis=0, ) @@ -76,7 +80,11 @@ def build(data, run_type: str, model_params, prior: Optional[np.ndarray] = None) if model_params.use_prior: # Try to extract prior from `data`, if present - prior = getattr(data.dataset, "zernike_prior", None) if not isinstance(data, dict) else data.dataset.get("zernike_prior") + prior = ( + getattr(data.dataset, "zernike_prior", None) + if not isinstance(data.dataset, dict) + else data.dataset.get("zernike_prior") + ) if prior is None: logger.warning( @@ -89,7 +97,7 @@ def build(data, run_type: str, model_params, prior: Optional[np.ndarray] = None) return ZernikeInputs( zernike_prior=prior, centroid_dataset=centroid_dataset, - misalignment_positions=positions + misalignment_positions=positions, ) @@ -119,12 +127,14 @@ def get_np_zernike_prior(data): return zernike_prior + def pad_contribution_to_order(contribution: np.ndarray, max_order: int) -> np.ndarray: """Pad a Zernike contribution array to the max Zernike order.""" current_order = contribution.shape[1] pad_width = ((0, 0), (0, max_order - current_order)) return np.pad(contribution, pad_width=pad_width, mode="constant", constant_values=0) + def combine_zernike_contributions(contributions: list[np.ndarray]) -> np.ndarray: """Combine multiple Zernike contributions, padding each to the max order before summing.""" if not contributions: @@ -228,12 +238,14 @@ def assemble_zernike_contributions( zernike_prior = zernike_prior.numpy() else: raise RuntimeError( - "Zernike prior is a TensorFlow tensor but eager execution is disabled. " - "Cannot call `.numpy()` outside of eager mode." + "Zernike prior is a TensorFlow tensor but eager execution is disabled. " + "Cannot call `.numpy()` outside of eager mode." ) - + elif not isinstance(zernike_prior, np.ndarray): - raise TypeError("Unsupported zernike_prior type. Must be np.ndarray or tf.Tensor.") + raise TypeError( + "Unsupported zernike_prior type. Must be np.ndarray or tf.Tensor." + ) zernike_contribution_list.append(zernike_prior) else: logger.info("Skipping Zernike prior (not used or not provided).") @@ -254,7 +266,9 @@ def assemble_zernike_contributions( ccd_misalignment = compute_ccd_misalignment(model_params, positions) zernike_contribution_list.append(ccd_misalignment) else: - logger.info("Skipping CCD misalignment correction (not enabled or no positions).") + logger.info( + "Skipping CCD misalignment correction (not enabled or no positions)." + ) # If no contributions, return zeros tensor to avoid crashes if not zernike_contribution_list: @@ -267,7 +281,7 @@ def assemble_zernike_contributions( combined_zernike_prior = combine_zernike_contributions(zernike_contribution_list) return tf.convert_to_tensor(combined_zernike_prior, dtype=tf.float32) - + def shift_x_y_to_zk1_2_wavediff(dxy, tel_focal_length=24.5, tel_diameter=1.2): """Compute Zernike 1(2) for a given shifts in x(y) in WaveDifff conventions. @@ -303,18 +317,19 @@ def shift_x_y_to_zk1_2_wavediff(dxy, tel_focal_length=24.5, tel_diameter=1.2): * 3.0 ) + def compute_zernike_tip_tilt( star_images: np.ndarray, star_masks: Optional[np.ndarray] = None, pixel_sampling: float = 12e-6, - reference_shifts: list[float] = [-1/3, -1/3], + reference_shifts: list[float] = [-1 / 3, -1 / 3], sigma_init: float = 2.5, n_iter: int = 20, ) -> np.ndarray: """ Compute Zernike tip-tilt corrections for a batch of PSF images. - This function estimates the centroid shifts of multiple PSFs and computes + This function estimates the centroid shifts of multiple PSFs and computes the corresponding Zernike tip-tilt corrections to align them with a reference. Parameters @@ -330,7 +345,7 @@ def compute_zernike_tip_tilt( pixel_sampling : float, optional The pixel size in meters. Defaults to `12e-6 m` (12 microns). reference_shifts : list[float], optional - The target centroid shifts in pixels, specified as `[dy, dx]`. + The target centroid shifts in pixels, specified as `[dy, dx]`. Defaults to `[-1/3, -1/3]` (nominal Euclid conditions). sigma_init : float, optional Initial standard deviation for centroid estimation. Default is `2.5`. @@ -343,21 +358,18 @@ def compute_zernike_tip_tilt( An array of shape `(num_images, 2)`, where: - Column 0 contains `Zk1` (tip) values. - Column 1 contains `Zk2` (tilt) values. - + Notes ----- - This function processes all images at once using vectorized operations. - The Zernike coefficients are computed in the WaveDiff convention. """ from wf_psf.data.centroids import CentroidEstimator - + # Vectorize the centroid computation centroid_estimator = CentroidEstimator( - im=star_images, - mask=star_masks, - sigma_init=sigma_init, - n_iter=n_iter - ) + im=star_images, mask=star_masks, sigma_init=sigma_init, n_iter=n_iter + ) shifts = centroid_estimator.get_intra_pixel_shifts() @@ -365,23 +377,25 @@ def compute_zernike_tip_tilt( reference_shifts = np.array(reference_shifts) # Reshape to ensure it's a column vector (1, 2) - reference_shifts = reference_shifts[None,:] - + reference_shifts = reference_shifts[None, :] + # Broadcast reference_shifts to match the shape of shifts - reference_shifts = np.broadcast_to(reference_shifts, shifts.shape) - + reference_shifts = np.broadcast_to(reference_shifts, shifts.shape) + # Compute displacements - displacements = (reference_shifts - shifts) # - + displacements = reference_shifts - shifts # + # Ensure the correct axis order for displacements (x-axis, then y-axis) - displacements_swapped = displacements[:, [1, 0]] # Adjust axis order if necessary + displacements_swapped = displacements[:, [1, 0]] # Adjust axis order if necessary # Call shift_x_y_to_zk1_2_wavediff directly on the vector of displacements - zk1_2_array = shift_x_y_to_zk1_2_wavediff(displacements_swapped.flatten() * pixel_sampling ) # vectorized call - + zk1_2_array = shift_x_y_to_zk1_2_wavediff( + displacements_swapped.flatten() * pixel_sampling + ) # vectorized call + # Reshape the result back to the original shape of displacements zk1_2_array = zk1_2_array.reshape(displacements.shape) - + return zk1_2_array From 829136a24b396c728bb3544f3b6f78077dd2b385 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 22 Aug 2025 09:41:46 +0200 Subject: [PATCH 101/135] Replace np.array with Tensorflow tensors in unit test and fixtures, reformat files --- src/wf_psf/tests/test_data/conftest.py | 50 ++-- .../test_data/data_zernike_utils_test.py | 242 +++++++++++------- 2 files changed, 176 insertions(+), 116 deletions(-) diff --git a/src/wf_psf/tests/test_data/conftest.py b/src/wf_psf/tests/test_data/conftest.py index 04a56893..47eed929 100644 --- a/src/wf_psf/tests/test_data/conftest.py +++ b/src/wf_psf/tests/test_data/conftest.py @@ -100,38 +100,35 @@ def mock_data(scope="module"): """Fixture to provide mock data for testing.""" # Mock positions and Zernike priors - training_positions = np.array([[1, 2], [3, 4]]) - test_positions = np.array([[5, 6], [7, 8]]) - training_zernike_priors = np.array([[0.1, 0.2], [0.3, 0.4]]) - test_zernike_priors = np.array([[0.5, 0.6], [0.7, 0.8]]) + training_positions = tf.constant([[1, 2], [3, 4]]) + test_positions = tf.constant([[5, 6], [7, 8]]) + training_zernike_priors = tf.constant([[0.1, 0.2], [0.3, 0.4]]) + test_zernike_priors = tf.constant([[0.5, 0.6], [0.7, 0.8]]) # Define dummy 5x5 image patches for stars (mock star images) # Define varied values for 5x5 star images - noisy_stars = tf.constant([ - np.arange(25).reshape(5, 5), - np.arange(25, 50).reshape(5, 5) - ], dtype=tf.float32) - - noisy_masks = tf.constant([ - np.eye(5), - np.ones((5, 5)) - ], dtype=tf.float32) - - stars = tf.constant([ - np.full((5, 5), 100), - np.full((5, 5), 200) - ], dtype=tf.float32) - - masks = tf.constant([ - np.zeros((5, 5)), - np.tri(5) - ], dtype=tf.float32) + noisy_stars = tf.constant( + [np.arange(25).reshape(5, 5), np.arange(25, 50).reshape(5, 5)], dtype=tf.float32 + ) + + noisy_masks = tf.constant([np.eye(5), np.ones((5, 5))], dtype=tf.float32) + + stars = tf.constant([np.full((5, 5), 100), np.full((5, 5), 200)], dtype=tf.float32) + + masks = tf.constant([np.zeros((5, 5)), np.tri(5)], dtype=tf.float32) return MockData( - training_positions, test_positions, training_zernike_priors, - test_zernike_priors, noisy_stars, noisy_masks, stars, masks + training_positions, + test_positions, + training_zernike_priors, + test_zernike_priors, + noisy_stars, + noisy_masks, + stars, + masks, ) + @pytest.fixture def simple_image(scope="module"): """Fixture for a simple star image.""" @@ -140,11 +137,13 @@ def simple_image(scope="module"): image[:, 2, 2] = 1 # Place the star at the center for each image return image + @pytest.fixture def identity_mask(scope="module"): """Creates a mask where all pixels are fully considered.""" return np.ones((5, 5)) + @pytest.fixture def multiple_images(scope="module"): """Fixture for a batch of images with stars at different positions.""" @@ -154,6 +153,7 @@ def multiple_images(scope="module"): images[2, 3, 1] = 1 # Star at (3, 1) in image 2 return images + @pytest.fixture(scope="module", params=[data]) def data_params(): return data diff --git a/src/wf_psf/tests/test_data/data_zernike_utils_test.py b/src/wf_psf/tests/test_data/data_zernike_utils_test.py index afafc1db..390d10f9 100644 --- a/src/wf_psf/tests/test_data/data_zernike_utils_test.py +++ b/src/wf_psf/tests/test_data/data_zernike_utils_test.py @@ -1,4 +1,3 @@ - import pytest import numpy as np from unittest.mock import MagicMock, patch @@ -10,11 +9,12 @@ combine_zernike_contributions, assemble_zernike_contributions, compute_zernike_tip_tilt, - pad_tf_zernikes + pad_tf_zernikes, ) from wf_psf.tests.test_data.test_data_utils import MockData from types import SimpleNamespace as RecursiveNamespace + @pytest.fixture def mock_model_params(): return RecursiveNamespace( @@ -24,6 +24,7 @@ def mock_model_params(): param_hparams=RecursiveNamespace(n_zernikes=6), ) + @pytest.fixture def dummy_prior(): return np.ones((4, 6), dtype=np.float32) @@ -41,22 +42,28 @@ def test_training_without_prior(mock_model_params, mock_data): mock_data.training_data.dataset.pop("zernike_prior", None) mock_data.test_data.dataset.pop("zernike_prior", None) - zinputs = ZernikeInputsFactory.build(data=mock_data, run_type="training", model_params=mock_model_params) + zinputs = ZernikeInputsFactory.build( + data=mock_data, run_type="training", model_params=mock_model_params + ) assert zinputs.centroid_dataset is mock_data assert zinputs.zernike_prior is None - expected_positions = np.concatenate([ - mock_data.training_data.dataset["positions"], - mock_data.test_data.dataset["positions"] - ]) + expected_positions = np.concatenate( + [ + mock_data.training_data.dataset["positions"], + mock_data.test_data.dataset["positions"], + ] + ) np.testing.assert_array_equal(zinputs.misalignment_positions, expected_positions) def test_training_with_dataset_prior(mock_model_params, mock_data): mock_model_params.use_prior = True - zinputs = ZernikeInputsFactory.build(data=mock_data, run_type="training", model_params=mock_model_params) + zinputs = ZernikeInputsFactory.build( + data=mock_data, run_type="training", model_params=mock_model_params + ) expected_priors = np.concatenate( ( @@ -77,7 +84,9 @@ def test_training_with_explicit_prior(mock_model_params, caplog): explicit_prior = np.array([9.0, 9.0, 9.0]) with caplog.at_level("WARNING"): - zinputs = ZernikeInputsFactory.build(data, "training", mock_model_params, prior=explicit_prior) + zinputs = ZernikeInputsFactory.build( + data, "training", mock_model_params, prior=explicit_prior + ) assert "Zernike prior explicitly provided" in caplog.text assert (zinputs.zernike_prior == explicit_prior).all() @@ -85,16 +94,24 @@ def test_training_with_explicit_prior(mock_model_params, caplog): def test_inference_with_dict_and_prior(mock_model_params): mock_model_params.use_prior = True - data = { - "positions": np.ones((5, 2)), - "zernike_prior": np.array([42.0, 0.0]) - } + data = RecursiveNamespace( + dataset={ + "positions": tf.ones((5, 2)), + "zernike_prior": tf.constant([42.0, 0.0]), + } + ) zinputs = ZernikeInputsFactory.build(data, "inference", mock_model_params) assert zinputs.centroid_dataset is None - assert (zinputs.zernike_prior == data["zernike_prior"]).all() - np.testing.assert_array_equal(zinputs.misalignment_positions, data["positions"]) + + # NumPy array comparison + np.testing.assert_array_equal( + zinputs.misalignment_positions, data.dataset["positions"].numpy() + ) + + # TensorFlow tensor comparison + tf.debugging.assert_equal(zinputs.zernike_prior, data.dataset["zernike_prior"]) def test_invalid_run_type(mock_model_params): @@ -111,7 +128,7 @@ def test_get_np_zernike_prior(): # Construct fake DataConfigHandler structure using RecursiveNamespace data = RecursiveNamespace( training_data=RecursiveNamespace(dataset={"zernike_prior": training_prior}), - test_data=RecursiveNamespace(dataset={"zernike_prior": test_prior}) + test_data=RecursiveNamespace(dataset={"zernike_prior": test_prior}), ) expected_prior = np.concatenate((training_prior, test_prior), axis=0) @@ -121,19 +138,24 @@ def test_get_np_zernike_prior(): # Assert shape and values match expected np.testing.assert_array_equal(result, expected_prior) + def test_pad_contribution_to_order(): # Input: batch of 2 samples, each with 3 Zernike coefficients - input_contribution = np.array([ - [1.0, 2.0, 3.0], - [4.0, 5.0, 6.0], - ]) - + input_contribution = np.array( + [ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + ] + ) + max_order = 5 # Target size: pad to 5 coefficients - expected_output = np.array([ - [1.0, 2.0, 3.0, 0.0, 0.0], - [4.0, 5.0, 6.0, 0.0, 0.0], - ]) + expected_output = np.array( + [ + [1.0, 2.0, 3.0, 0.0, 0.0], + [4.0, 5.0, 6.0, 0.0, 0.0], + ] + ) padded = pad_contribution_to_order(input_contribution, max_order) @@ -149,6 +171,7 @@ def test_no_padding_needed(): assert output.shape == input_contribution.shape np.testing.assert_array_equal(output, input_contribution) + def test_padding_to_much_higher_order(): """Pad from order 2 to order 10.""" input_contribution = np.array([[1, 2], [3, 4]]) @@ -158,6 +181,7 @@ def test_padding_to_much_higher_order(): assert output.shape == (2, 10) np.testing.assert_array_equal(output, expected_output) + def test_empty_contribution(): """Test behavior with empty input array (0 features).""" input_contribution = np.empty((3, 0)) # 3 samples, 0 coefficients @@ -180,42 +204,44 @@ def test_zero_samples(): def test_combine_zernike_contributions_basic_case(): """Combine two contributions with matching sample count and varying order.""" - contrib1 = np.array([[1, 2], [3, 4]]) # shape (2, 2) - contrib2 = np.array([[5], [6]]) # shape (2, 1) - expected = np.array([ - [1 + 5, 2 + 0], - [3 + 6, 4 + 0] - ]) # padded contrib2 to (2, 2) + contrib1 = np.array([[1, 2], [3, 4]]) # shape (2, 2) + contrib2 = np.array([[5], [6]]) # shape (2, 1) + expected = np.array([[1 + 5, 2 + 0], [3 + 6, 4 + 0]]) # padded contrib2 to (2, 2) result = combine_zernike_contributions([contrib1, contrib2]) np.testing.assert_array_equal(result, expected) + def test_combine_multiple_contributions(): """Combine three contributions.""" - c1 = np.array([[1, 2, 3]]) # shape (1, 3) - c2 = np.array([[4, 5]]) # shape (1, 2) - c3 = np.array([[6]]) # shape (1, 1) - expected = np.array([[1+4+6, 2+5+0, 3+0+0]]) # shape (1, 3) + c1 = np.array([[1, 2, 3]]) # shape (1, 3) + c2 = np.array([[4, 5]]) # shape (1, 2) + c3 = np.array([[6]]) # shape (1, 1) + expected = np.array([[1 + 4 + 6, 2 + 5 + 0, 3 + 0 + 0]]) # shape (1, 3) result = combine_zernike_contributions([c1, c2, c3]) np.testing.assert_array_equal(result, expected) + def test_empty_input_list(): """Raise ValueError when input list is empty.""" with pytest.raises(ValueError, match="No contributions provided."): combine_zernike_contributions([]) + def test_inconsistent_sample_count(): """Raise error or produce incorrect shape if contributions have inconsistent sample counts.""" c1 = np.array([[1, 2], [3, 4]]) # shape (2, 2) - c2 = np.array([[5, 6]]) # shape (1, 2) + c2 = np.array([[5, 6]]) # shape (1, 2) with pytest.raises(ValueError): combine_zernike_contributions([c1, c2]) + def test_single_contribution(): """Combining a single contribution should return the same array (no-op).""" contrib = np.array([[7, 8, 9], [10, 11, 12]]) result = combine_zernike_contributions([contrib]) np.testing.assert_array_equal(result, contrib) + def test_zero_order_contributions(): """Contributions with 0 Zernike coefficients.""" contrib1 = np.empty((2, 0)) # 2 samples, 0 coefficients @@ -225,9 +251,12 @@ def test_zero_order_contributions(): assert result.shape == (2, 0) np.testing.assert_array_equal(result, expected) + @patch("wf_psf.data.data_zernike_utils.compute_centroid_correction") @patch("wf_psf.data.data_zernike_utils.compute_ccd_misalignment") -def test_full_contribution_combination(mock_ccd, mock_centroid, mock_model_params, dummy_prior, dummy_centroid_dataset): +def test_full_contribution_combination( + mock_ccd, mock_centroid, mock_model_params, dummy_prior, dummy_centroid_dataset +): mock_centroid.return_value = np.full((4, 6), 2.0) mock_ccd.return_value = np.full((4, 6), 3.0) dummy_positions = np.full((4, 6), 1.0) @@ -236,12 +265,13 @@ def test_full_contribution_combination(mock_ccd, mock_centroid, mock_model_param model_params=mock_model_params, zernike_prior=dummy_prior, centroid_dataset=dummy_centroid_dataset, - positions = dummy_positions + positions=dummy_positions, ) - + expected = dummy_prior + 2.0 + 3.0 np.testing.assert_allclose(result.numpy(), expected) + def test_prior_only(mock_model_params, dummy_prior): mock_model_params.correct_centroids = False mock_model_params.add_ccd_misalignments = False @@ -250,11 +280,12 @@ def test_prior_only(mock_model_params, dummy_prior): model_params=mock_model_params, zernike_prior=dummy_prior, centroid_dataset=None, - positions=None + positions=None, ) np.testing.assert_array_equal(result.numpy(), dummy_prior) + def test_no_contributions_returns_zeros(): model_params = RecursiveNamespace( use_prior=False, @@ -269,6 +300,7 @@ def test_no_contributions_returns_zeros(): assert result.shape == (1, 8) np.testing.assert_array_equal(result.numpy(), np.zeros((1, 8))) + def test_prior_as_tensor(mock_model_params): tensor_prior = tf.ones((4, 6), dtype=tf.float32) @@ -276,24 +308,28 @@ def test_prior_as_tensor(mock_model_params): mock_model_params.add_ccd_misalignments = False result = assemble_zernike_contributions( - model_params=mock_model_params, - zernike_prior=tensor_prior + model_params=mock_model_params, zernike_prior=tensor_prior ) assert tf.executing_eagerly(), "TensorFlow must be in eager mode for this test" assert isinstance(result, tf.Tensor) np.testing.assert_array_equal(result.numpy(), np.ones((4, 6))) + @patch("wf_psf.data.data_zernike_utils.compute_centroid_correction") -def test_inconsistent_shapes_raises_error(mock_centroid, mock_model_params, dummy_prior, dummy_centroid_dataset): +def test_inconsistent_shapes_raises_error( + mock_centroid, mock_model_params, dummy_prior, dummy_centroid_dataset +): mock_model_params.add_ccd_misalignments = False mock_centroid.return_value = np.ones((5, 6)) # 5 samples instead of 4 - with pytest.raises(ValueError, match="All contributions must have the same number of samples"): + with pytest.raises( + ValueError, match="All contributions must have the same number of samples" + ): assemble_zernike_contributions( model_params=mock_model_params, zernike_prior=dummy_prior, centroid_dataset=dummy_centroid_dataset, - positions=None + positions=None, ) @@ -307,14 +343,10 @@ def test_pad_zernikes_num_of_zernikes_equal(): zk_prior = tf.reshape(zk_prior, (1, 2, 1, 1)) # Reset _n_zks_total to max number of zernikes (2 here) - n_zks_total = max( - tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy() - ) + n_zks_total = max(tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy()) # Call pad_zernikes method - padded_zk_param, padded_zk_prior = pad_tf_zernikes( - zk_param, zk_prior, n_zks_total - ) + padded_zk_param, padded_zk_prior = pad_tf_zernikes(zk_param, zk_prior, n_zks_total) # Assert shapes are equal and correct assert padded_zk_param.shape[1] == n_zks_total @@ -324,6 +356,7 @@ def test_pad_zernikes_num_of_zernikes_equal(): np.testing.assert_array_equal(padded_zk_param.numpy(), zk_param.numpy()) np.testing.assert_array_equal(padded_zk_prior.numpy(), zk_prior.numpy()) + def test_pad_zernikes_prior_greater_than_param(): zk_param = tf.constant([[[[1]]], [[[2]]]]) # Shape: (2, 1, 1, 1) zk_prior = tf.constant([[[[1]], [[2]], [[3]], [[4]], [[5]]]]) # Shape: (5, 1, 1, 1) @@ -333,14 +366,10 @@ def test_pad_zernikes_prior_greater_than_param(): zk_prior = tf.reshape(zk_prior, (1, 5, 1, 1)) # Reshaping tensor2 to (1, 5, 1, 1) # Reset n_zks_total attribute - n_zks_total = max( - tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy() - ) + n_zks_total = max(tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy()) # Call the method under test - padded_zk_param, padded_zk_prior = pad_tf_zernikes( - zk_param, zk_prior, n_zks_total - ) + padded_zk_param, padded_zk_prior = pad_tf_zernikes(zk_param, zk_prior, n_zks_total) # Assert that the padded tensors have the correct shapes assert padded_zk_param.shape == (1, 5, 1, 1) @@ -356,14 +385,10 @@ def test_pad_zernikes_param_greater_than_prior(): zk_prior = tf.reshape(zk_prior, (1, 2, 1, 1)) # Reshaping tensor2 to (1, 4, 1, 1) # Reset n_zks_total attribute - n_zks_total = max( - tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy() - ) + n_zks_total = max(tf.shape(zk_param)[1].numpy(), tf.shape(zk_prior)[1].numpy()) # Call the method under test - padded_zk_param, padded_zk_prior = pad_tf_zernikes( - zk_param, zk_prior, n_zks_total - ) + padded_zk_param, padded_zk_prior = pad_tf_zernikes(zk_param, zk_prior, n_zks_total) # Assert that the padded tensors have the correct shapes assert padded_zk_param.shape == (1, 4, 1, 1) @@ -374,16 +399,20 @@ def test_compute_zernike_tip_tilt_single_batch(mocker, simple_image, identity_ma """Test compute_zernike_tip_tilt handling with single batch input and mocks.""" # Mock the CentroidEstimator class - mock_centroid_calc = mocker.patch("wf_psf.data.centroids.CentroidEstimator", autospec=True) + mock_centroid_calc = mocker.patch( + "wf_psf.data.centroids.CentroidEstimator", autospec=True + ) # Create a mock instance and configure get_intra_pixel_shifts() mock_instance = mock_centroid_calc.return_value - mock_instance.get_intra_pixel_shifts.return_value = np.array([[0.05, -0.02]]) # Shape (1, 2) + mock_instance.get_intra_pixel_shifts.return_value = np.array( + [[0.05, -0.02]] + ) # Shape (1, 2) # Mock shift_x_y_to_zk1_2_wavediff to return predictable values mock_shift_fn = mocker.patch( "wf_psf.data.data_zernike_utils.shift_x_y_to_zk1_2_wavediff", - side_effect=lambda shift: shift * 0.5 # Mocked conversion for test + side_effect=lambda shift: shift * 0.5, # Mocked conversion for test ) # Define test inputs (batch of 1 image) @@ -391,41 +420,58 @@ def test_compute_zernike_tip_tilt_single_batch(mocker, simple_image, identity_ma reference_shifts = [-1 / 3, -1 / 3] # Default Euclid conditions # Run the function - zernike_corrections = compute_zernike_tip_tilt(simple_image, identity_mask, pixel_sampling, reference_shifts) - zernike_corrections = compute_zernike_tip_tilt(simple_image, identity_mask, pixel_sampling, reference_shifts) + zernike_corrections = compute_zernike_tip_tilt( + simple_image, identity_mask, pixel_sampling, reference_shifts + ) + zernike_corrections = compute_zernike_tip_tilt( + simple_image, identity_mask, pixel_sampling, reference_shifts + ) # Expected shifts based on centroid calculation - expected_dx = (reference_shifts[1] - (-0.02)) # Expected x-axis shift in meters - expected_dy = (reference_shifts[0] - 0.05) # Expected y-axis shift in meters + expected_dx = reference_shifts[1] - (-0.02) # Expected x-axis shift in meters + expected_dy = reference_shifts[0] - 0.05 # Expected y-axis shift in meters # Expected calls to the mocked function # Extract the arguments passed to mock_shift_fn - args, _ = mock_shift_fn.call_args_list[0] # Get the first call args + args, _ = mock_shift_fn.call_args_list[0] # Get the first call args # Compare expected values with the actual arguments passed to the mock function - np.testing.assert_allclose(args[0][0], expected_dx * pixel_sampling, rtol=1e-7, atol=0) + np.testing.assert_allclose( + args[0][0], expected_dx * pixel_sampling, rtol=1e-7, atol=0 + ) # Check dy values similarly - np.testing.assert_allclose(args[0][1], expected_dy * pixel_sampling, rtol=1e-7, atol=0) + np.testing.assert_allclose( + args[0][1], expected_dy * pixel_sampling, rtol=1e-7, atol=0 + ) # Expected values based on mock side_effect (0.5 * shift) - np.testing.assert_allclose(zernike_corrections[0, 0], expected_dx * pixel_sampling * 0.5) # Zk1 - np.testing.assert_allclose(zernike_corrections[0, 1], expected_dy * pixel_sampling * 0.5) # Zk2 + np.testing.assert_allclose( + zernike_corrections[0, 0], expected_dx * pixel_sampling * 0.5 + ) # Zk1 + np.testing.assert_allclose( + zernike_corrections[0, 1], expected_dy * pixel_sampling * 0.5 + ) # Zk2 + def test_compute_zernike_tip_tilt_batch(mocker, multiple_images): """Test compute_zernike_tip_tilt batch handling of multiple inputs.""" - + # Mock the CentroidEstimator class - mock_centroid_calc = mocker.patch("wf_psf.data.centroids.CentroidEstimator", autospec=True) + mock_centroid_calc = mocker.patch( + "wf_psf.data.centroids.CentroidEstimator", autospec=True + ) # Create a mock instance and configure get_intra_pixel_shifts() mock_instance = mock_centroid_calc.return_value - mock_instance.get_intra_pixel_shifts.return_value = np.array([[0.05, -0.02], [0.04, -0.01], [0.06, -0.03]]) # Shape (3, 2) + mock_instance.get_intra_pixel_shifts.return_value = np.array( + [[0.05, -0.02], [0.04, -0.01], [0.06, -0.03]] + ) # Shape (3, 2) # Mock shift_x_y_to_zk1_2_wavediff to return predictable values mock_shift_fn = mocker.patch( "wf_psf.data.data_zernike_utils.shift_x_y_to_zk1_2_wavediff", - side_effect=lambda shift: shift * 0.5 # Mocked conversion for test + side_effect=lambda shift: shift * 0.5, # Mocked conversion for test ) # Define test inputs (batch of 3 images) @@ -434,16 +480,18 @@ def test_compute_zernike_tip_tilt_batch(mocker, multiple_images): # Run the function zernike_corrections = compute_zernike_tip_tilt( - star_images=multiple_images, - pixel_sampling=pixel_sampling, - reference_shifts=reference_shifts - ) + star_images=multiple_images, + pixel_sampling=pixel_sampling, + reference_shifts=reference_shifts, + ) # Check if the mock function was called once with the full batch - assert len(mock_shift_fn.call_args_list) == 1, f"Expected 1 call, but got {len(mock_shift_fn.call_args_list)}" + assert ( + len(mock_shift_fn.call_args_list) == 1 + ), f"Expected 1 call, but got {len(mock_shift_fn.call_args_list)}" # Get the arguments passed to the mock function for the batch of images - args, _ = mock_shift_fn.call_args_list[0] + args, _ = mock_shift_fn.call_args_list[0] print("Shape of args[0]:", args[0].shape) print("Contents of args[0]:", args[0]) @@ -453,13 +501,25 @@ def test_compute_zernike_tip_tilt_batch(mocker, multiple_images): args_array = np.array(args[0]).reshape(-1, 2) # Process the displacements and expected values for each image in the batch - expected_dx = reference_shifts[1] - mock_instance.get_intra_pixel_shifts.return_value[:, 1] # Expected x-axis shift in meters - expected_dy = reference_shifts[0] - mock_instance.get_intra_pixel_shifts.return_value[:, 0] # Expected y-axis shift in meters + expected_dx = ( + reference_shifts[1] - mock_instance.get_intra_pixel_shifts.return_value[:, 1] + ) # Expected x-axis shift in meters + expected_dy = ( + reference_shifts[0] - mock_instance.get_intra_pixel_shifts.return_value[:, 0] + ) # Expected y-axis shift in meters # Compare expected values with the actual arguments passed to the mock function - np.testing.assert_allclose(args_array[:, 0], expected_dx * pixel_sampling, rtol=1e-7, atol=0) - np.testing.assert_allclose(args_array[:, 1], expected_dy * pixel_sampling, rtol=1e-7, atol=0) + np.testing.assert_allclose( + args_array[:, 0], expected_dx * pixel_sampling, rtol=1e-7, atol=0 + ) + np.testing.assert_allclose( + args_array[:, 1], expected_dy * pixel_sampling, rtol=1e-7, atol=0 + ) # Expected values based on mock side_effect (0.5 * shift) - np.testing.assert_allclose(zernike_corrections[:, 0], expected_dx * pixel_sampling * 0.5) # Zk1 for each image - np.testing.assert_allclose(zernike_corrections[:, 1], expected_dy * pixel_sampling * 0.5) # Zk2 for each image \ No newline at end of file + np.testing.assert_allclose( + zernike_corrections[:, 0], expected_dx * pixel_sampling * 0.5 + ) # Zk1 for each image + np.testing.assert_allclose( + zernike_corrections[:, 1], expected_dy * pixel_sampling * 0.5 + ) # Zk2 for each image From 96460a5600e748c6c9809812f406d8f11a60bda8 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Wed, 27 Aug 2025 11:19:53 +0200 Subject: [PATCH 102/135] Eagerly initialise trainable layers in physical poly model constructor required for evaluation/inference --- .../models/psf_model_physical_polychromatic.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py index 918ffc4b..e2302da5 100644 --- a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py +++ b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py @@ -120,15 +120,15 @@ def __init__(self, model_params, training_params, data, coeff_mat=None): self.run_type = self._get_run_type(data) self.obs_pos = self.get_obs_pos() - # Initialize the model parameters and layers + # Initialize the model parameters self.output_Q = model_params.output_Q self.l2_param = model_params.param_hparams.l2_param self.output_dim = model_params.output_dim - + # Initialise lazy loading of external Zernike prior self._external_prior = None - # Initialize the model parameters with non-default value + # Set Zernike Polynomial Coefficient Matrix if not None if coeff_mat is not None: self.assign_coeff_matrix(coeff_mat) @@ -158,9 +158,10 @@ def __init__(self, model_params, training_params, data, coeff_mat=None): rotation_angle=self.model_params.obscuration_rotation_angle, ) - # Eagerly initialise tf_batch_poly_PSF + # Eagerly initialise model layers self.tf_batch_poly_PSF = self._build_tf_batch_poly_PSF() - + _ = self.tf_poly_Z_field + _ = self.tf_np_poly_opd def _get_run_type(self, data): if hasattr(data, 'run_type'): From 06d715b6e91d2eb2e4b8db63bc87cb5c8cbe4812 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Wed, 27 Aug 2025 11:21:59 +0200 Subject: [PATCH 103/135] fix: use expect_partial() when loading model weights for evaluation - Add status handling to model.load_weights() call - Use expect_partial() to suppress warnings about unused optimizer state - Allows successful weight loading for metrics evaluation when checkpoint contains training artifacts --- src/wf_psf/psf_models/psf_model_loader.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/wf_psf/psf_models/psf_model_loader.py b/src/wf_psf/psf_models/psf_model_loader.py index 797be8fc..c30d31ad 100644 --- a/src/wf_psf/psf_models/psf_model_loader.py +++ b/src/wf_psf/psf_models/psf_model_loader.py @@ -12,7 +12,7 @@ get_psf_model, get_psf_model_weights_filepath ) - +import tensorflow as tf logger = logging.getLogger(__name__) @@ -48,7 +48,9 @@ def load_trained_psf_model(training_conf, data_conf, weights_path_pattern): try: logger.info(f"Loading PSF model weights from {weights_path}") - model.load_weights(weights_path) + status = model.load_weights(weights_path) + status.expect_partial() + except Exception as e: logger.exception("Failed to load model weights.") raise RuntimeError("Model weight loading failed.") from e From 5b3dee35adc843f352e44df86f898edad7215827 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Wed, 27 Aug 2025 11:24:43 +0200 Subject: [PATCH 104/135] Add memory cleanup after training completion - Delete model reference and run garbage collection - Clear TensorFlow session to free GPU memory - Prevents OOM issues in subsequent operations or multiple training runs --- src/wf_psf/training/train.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/wf_psf/training/train.py b/src/wf_psf/training/train.py index bb0e3df9..f2366ca7 100644 --- a/src/wf_psf/training/train.py +++ b/src/wf_psf/training/train.py @@ -7,6 +7,7 @@ """ +import gc import numpy as np import time import tensorflow as tf @@ -538,3 +539,8 @@ def train( final_time = time.time() logger.info("\nTotal elapsed time: %f" % (final_time - starting_time)) logger.info("\n Training complete..") + + # Clean up memory + del psf_model + gc.collect() + tf.keras.backend.clear_session() From 6a2d24dbb5b312744b9dc6eb4a85b51a9e40c26c Mon Sep 17 00:00:00 2001 From: jeipollack Date: Thu, 4 Sep 2025 01:59:17 +0200 Subject: [PATCH 105/135] refactor: centralise PSF data extraction in data_handler - Introduce unified `get_data_array` for training/metrics/inference access - Add helpers `extract_star_data` and `_get_inference_data` - Remove redundant `get_np_obs_positions` - Move data handling logic out of `compute_centroid_corrections` - Standardise `centroid_dataset` as dict (stamps + optional masks) - Support optional keys (masks, priors) via `allow_missing` - Improve and unify docstrings for data extraction utilities - Add optional "sources" and "masks" attributes to `PSFInference - Add `correct_centroids` and `add_ccd_misalignments` as options to inference_config.yaml --- src/wf_psf/data/centroids.py | 35 ++- src/wf_psf/data/data_handler.py | 200 ++++++++++++++---- src/wf_psf/data/data_zernike_utils.py | 27 +-- src/wf_psf/inference/psf_inference.py | 6 +- .../psf_model_physical_polychromatic.py | 7 +- 5 files changed, 193 insertions(+), 82 deletions(-) diff --git a/src/wf_psf/data/centroids.py b/src/wf_psf/data/centroids.py index 7052f833..fb69c8dd 100644 --- a/src/wf_psf/data/centroids.py +++ b/src/wf_psf/data/centroids.py @@ -14,7 +14,7 @@ from typing import Optional -def compute_centroid_correction(model_params, data, batch_size: int=1) -> np.ndarray: +def compute_centroid_correction(model_params, centroid_dataset, batch_size: int=1) -> np.ndarray: """Compute centroid corrections using Zernike polynomials. This function calculates the Zernike contributions required to match the centroid @@ -25,10 +25,13 @@ def compute_centroid_correction(model_params, data, batch_size: int=1) -> np.nda model_params : RecursiveNamespace An object containing parameters for the PSF model, including pixel sampling and initial centroid window parameters. - data : DataConfigHandler - An object containing training and test datasets, including observed PSFs - and optional star masks. - + centroid_dataset : dict + Dictionary containing star data needed for centroiding: + - "stamps" : np.ndarray + Array of star postage stamps (required). + - "masks" : Optional[np.ndarray] + Array of star masks (optional, can be None). + batch_size : int, optional The batch size to use when processing the stars. Default is 16. @@ -39,24 +42,14 @@ def compute_centroid_correction(model_params, data, batch_size: int=1) -> np.nda observed stars. The array contains the computed Zernike (Z1, Z2) contributions, with zero padding applied to the first column to ensure a consistent shape. """ - star_postage_stamps = extract_star_data(data=data, train_key="noisy_stars", test_key="stars") - - # Get star mask catalogue only if "masks" exist in both training and test datasets - star_masks = ( - extract_star_data(data=data, train_key="masks", test_key="masks") - if ( - data.training_data.dataset.get("masks") is not None - and data.test_data.dataset.get("masks") is not None - and tf.size(data.training_data.dataset["masks"]) > 0 - and tf.size(data.test_data.dataset["masks"]) > 0 - ) - else None - ) + # Retrieve stamps and masks from centroid_dataset + star_postage_stamps = centroid_dataset.get("stamps") + star_masks = centroid_dataset.get("masks") # may be None - pix_sampling = model_params.pix_sampling * 1e-6 # Change units from [um] to [m] + if star_postage_stamps is None: + raise ValueError("centroid_dataset must contain 'stamps'") - # Ensure star_masks is properly handled - star_masks = star_masks if star_masks is not None else None + pix_sampling = model_params.pix_sampling * 1e-6 # Change units from [um] to [m] reference_shifts = [float(Fraction(value)) for value in model_params.reference_shifts] diff --git a/src/wf_psf/data/data_handler.py b/src/wf_psf/data/data_handler.py index fe660940..bd802763 100644 --- a/src/wf_psf/data/data_handler.py +++ b/src/wf_psf/data/data_handler.py @@ -240,67 +240,43 @@ def process_sed_data(self, sed_data): self.sed_data = tf.transpose(self.sed_data, perm=[0, 2, 1]) -def get_np_obs_positions(data): - """Get observed positions in numpy from the provided dataset. - - This method concatenates the positions of the stars from both the training - and test datasets to obtain the observed positions. - - Parameters - ---------- - data : DataConfigHandler - Object containing training and test datasets. - - Returns - ------- - np.ndarray - Numpy array containing the observed positions of the stars. - - Notes - ----- - The observed positions are obtained by concatenating the positions of stars - from both the training and test datasets along the 0th axis. - """ - obs_positions = np.concatenate( - ( - data.training_data.dataset["positions"], - data.test_data.dataset["positions"], - ), - axis=0, - ) - - return obs_positions - - def extract_star_data(data, train_key: str, test_key: str) -> np.ndarray: - """Extract specific star-related data from training and test datasets. + """ + Extract and concatenate star-related data from training and test datasets. - This function retrieves and concatenates specific star-related data (e.g., stars, masks) from the - star training and test datasets such as star stamps or masks, based on the provided keys. + This function retrieves arrays (e.g., postage stamps, masks, positions) from + both the training and test datasets using the specified keys, converts them + to NumPy if necessary, and concatenates them along the first axis. Parameters ---------- data : DataConfigHandler Object containing training and test datasets. train_key : str - The key to retrieve data from the training dataset (e.g., 'noisy_stars', 'masks'). + Key to retrieve data from the training dataset + (e.g., 'noisy_stars', 'masks'). test_key : str - The key to retrieve data from the test dataset (e.g., 'stars', 'masks'). + Key to retrieve data from the test dataset + (e.g., 'stars', 'masks'). Returns ------- np.ndarray - A NumPy array containing the concatenated data for the given keys. + Concatenated NumPy array containing the selected data from both + training and test sets. Raises ------ KeyError - If the specified keys do not exist in the training or test datasets. + If either the training or test dataset does not contain the + requested key. Notes ----- - - If the dataset contains TensorFlow tensors, they will be converted to NumPy arrays. - - Ensure that eager execution is enabled when calling this function. + - Designed for datasets with separate train/test splits, such as when + evaluating metrics on held-out data. + - TensorFlow tensors are automatically converted to NumPy arrays. + - Requires eager execution if TensorFlow tensors are present. """ # Ensure the requested keys exist in both training and test datasets missing_keys = [ @@ -327,3 +303,145 @@ def extract_star_data(data, train_key: str, test_key: str) -> np.ndarray: # Concatenate and return return np.concatenate((train_data, test_data), axis=0) + +def get_data_array( + data, + run_type: str, + key: str = None, + train_key: str = None, + test_key: str = None, + allow_missing: bool = False, +) -> np.ndarray | None: + """ + Retrieve data from dataset depending on run type. + + This function provides a unified interface for accessing data across different + execution contexts (training, simulation, metrics, inference). It handles + key resolution with sensible fallbacks and optional missing data tolerance. + + Parameters + ---------- + data : DataConfigHandler + Dataset object containing training, test, or inference data. + Expected to have methods compatible with the specified run_type. + run_type : {"training", "simulation", "metrics", "inference"} + Execution context that determines how data is retrieved: + - "training", "simulation", "metrics": Uses extract_star_data function + - "inference": Retrieves data directly from dataset using key lookup + key : str, optional + Primary key for data lookup. Used directly for inference run_type. + If None, falls back to train_key value. Default is None. + train_key : str, optional + Key for training dataset access. If None and key is provided, + defaults to key value. Default is None. + test_key : str, optional + Key for test dataset access. If None, defaults to the resolved + train_key value. Default is None. + allow_missing : bool, default False + Control behavior when data is missing or keys are not found: + - True: Return None instead of raising exceptions + - False: Raise appropriate exceptions (KeyError, ValueError) + + Returns + ------- + np.ndarray or None + Retrieved data as NumPy array. Returns None only when allow_missing=True + and the requested data is not available. + + Raises + ------ + ValueError + If run_type is not one of the supported values, or if no key can be + resolved for the operation and allow_missing=False. + KeyError + If the specified key is not found in the dataset and allow_missing=False. + + Notes + ----- + Key resolution follows this priority order: + 1. train_key = train_key or key + 2. test_key = test_key or resolved_train_key + 3. key = key or resolved_train_key (for inference fallback) + + For TensorFlow tensors, the .numpy() method is called to convert to NumPy. + Other data types are converted using np.asarray(). + + Examples + -------- + >>> # Training data retrieval + >>> train_data = get_data_array(data, "training", train_key="noisy_stars") + + >>> # Inference with fallback handling + >>> inference_data = get_data_array(data, "inference", key="positions", + ... allow_missing=True) + >>> if inference_data is None: + ... print("No inference data available") + + >>> # Using key parameter for both train and inference + >>> result = get_data_array(data, "inference", key="positions") + """ + # Validate run_type early + valid_run_types = {"training", "simulation", "metrics", "inference"} + if run_type not in valid_run_types: + raise ValueError(f"run_type must be one of {valid_run_types}, got '{run_type}'") + + # Simplify key resolution with clear precedence + effective_train_key = train_key or key + effective_test_key = test_key or effective_train_key + effective_key = key or effective_train_key + + try: + if run_type in {"simulation", "training", "metrics"}: + return extract_star_data(data, effective_train_key, effective_test_key) + else: # inference + return _get_inference_data(data, effective_key, allow_missing) + except Exception as e: + if allow_missing: + return None + raise + + +def _get_inference_data(data, key: str, allow_missing: bool) -> np.ndarray | None: + """ + Extract inference data with proper error handling and type conversion. + + Parameters + ---------- + data : DataConfigHandler + Dataset object with a .dataset attribute that supports .get() method. + key : str or None + Key to lookup in the dataset. If None, behavior depends on allow_missing. + allow_missing : bool + If True, return None for missing keys/data instead of raising exceptions. + + Returns + ------- + np.ndarray or None + Data converted to NumPy array, or None if allow_missing=True and + data is unavailable. + + Raises + ------ + ValueError + If key is None and allow_missing=False. + KeyError + If key is not found in dataset and allow_missing=False. + + Notes + ----- + Conversion logic: + - TensorFlow tensors: Converted using .numpy() method + - Other types: Converted using np.asarray() + """ + if key is None: + if allow_missing: + return None + raise ValueError("No key provided for inference data") + + value = data.dataset.get(key, None) + if value is None: + if allow_missing: + return None + raise KeyError(f"Key '{key}' not found in inference dataset") + + return value.numpy() if tf.is_tensor(value) else np.asarray(value) diff --git a/src/wf_psf/data/data_zernike_utils.py b/src/wf_psf/data/data_zernike_utils.py index f7653540..9bf9d1fd 100644 --- a/src/wf_psf/data/data_zernike_utils.py +++ b/src/wf_psf/data/data_zernike_utils.py @@ -16,6 +16,7 @@ import numpy as np import tensorflow as tf from wf_psf.data.centroids import compute_centroid_correction +from wf_psf.data.data_handler import get_data_array from wf_psf.instrument.ccd_misalignments import compute_ccd_misalignment from wf_psf.utils.read_config import RecursiveNamespace import logging @@ -54,29 +55,29 @@ def build( ------- ZernikeInputs """ - centroid_dataset = None - positions = None + centroid_dataset, positions = None, None if run_type in {"training", "simulation", "metrics"}: - centroid_dataset = data # Assuming data is a DataConfigHandler or similar object containing train and test datasets - positions = np.concatenate( - [ - data.training_data.dataset["positions"].numpy(), - data.test_data.dataset["positions"].numpy(), - ], - axis=0, - ) + stamps = get_data_array(data, run_type, train_key="noisy_stars", test_key="stars") + masks = get_data_array(data, run_type, key="masks", allow_missing=True) + centroid_dataset = {"stamps": stamps, "masks": masks} + + positions = get_data_array(data=data, run_type=run_type, key="positions") + if model_params.use_prior: if prior is not None: logger.warning( - "Zernike prior explicitly provided; ignoring dataset-based prior despite use_prior=True." + "Explicit prior provided; ignoring dataset-based prior." ) else: prior = get_np_zernike_prior(data) elif run_type == "inference": - centroid_dataset = None - positions = data.dataset["positions"].numpy() + stamps = get_data_array(data=data, run_type=run_type, key="sources") + masks = get_data_array(data, run_type, key="masks", allow_missing=True) + centroid_dataset = {"stamps": stamps, "masks": masks} + + positions = get_data_array(data=data, run_type=run_type, key="positions") if model_params.use_prior: # Try to extract prior from `data`, if present diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index 76619f5a..a37e6ba7 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -97,7 +97,7 @@ class PSFInference: Spectral energy distributions (SEDs). """ - def __init__(self, inference_config_path: str, x_field=None, y_field=None, seds=None): + def __init__(self, inference_config_path: str, x_field=None, y_field=None, seds=None, sources=None, masks=None): self.inference_config_path = inference_config_path @@ -105,6 +105,8 @@ def __init__(self, inference_config_path: str, x_field=None, y_field=None, seds= self.x_field = x_field self.y_field = y_field self.seds = seds + self.sources = sources + self.masks = masks # Internal caches for lazy-loading self._config_handler = None @@ -157,7 +159,7 @@ def _prepare_dataset_for_inference(self): positions = self.get_positions() if positions is None: return None - return {"positions": positions} + return {"positions": positions, "sources": self.sources, "masks": self.masks} @property def data_handler(self): diff --git a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py index e2302da5..c8086cb7 100644 --- a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py +++ b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py @@ -10,7 +10,7 @@ from typing import Optional import tensorflow as tf from tensorflow.python.keras.engine import data_adapter -from wf_psf.data.data_handler import get_np_obs_positions +from wf_psf.data.data_handler import get_data_array from wf_psf.data.data_zernike_utils import ( ZernikeInputsFactory, assemble_zernike_contributions, @@ -203,10 +203,7 @@ def save_nonparam_history(self) -> bool: def get_obs_pos(self): assert self.run_type in {"training", "simulation", "metrics", "inference"}, f"Unknown run_type: {self.run_type}" - if self.run_type in {"training", "simulation", "metrics"}: - raw_pos = get_np_obs_positions(self.data) - else: - raw_pos = self.data.dataset["positions"] + raw_pos = get_data_array(data=self.data, run_type=self.run_type, key="positions") obs_pos = ensure_tensor(raw_pos, dtype=tf.float32) From 8376fe1de78bdf31e23ed796ad6c4564d7dd17b2 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Thu, 4 Sep 2025 02:41:53 +0200 Subject: [PATCH 106/135] Add and options to inference_config.yaml (forgot to stage with previous commit) --- config/inference_conf.yaml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/config/inference_conf.yaml b/config/inference_conf.yaml index c9d29cb8..927723c7 100644 --- a/config/inference_conf.yaml +++ b/config/inference_conf.yaml @@ -30,3 +30,8 @@ inference: # Dimension of the pixel PSF postage stamp output_dim: 64 + # Flag to perform centroid error correction + correct_centroids: False + + # Flag to perform CCD misalignment error correction + add_ccd_misalignments: True From 77434be4fa9eb4d43f8ccafdcb4424e7cd72292b Mon Sep 17 00:00:00 2001 From: jeipollack Date: Thu, 4 Sep 2025 04:35:23 +0200 Subject: [PATCH 107/135] Update PSFInference doc string with new optional attributes --- src/wf_psf/inference/psf_inference.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index a37e6ba7..5c27fdbf 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -95,6 +95,10 @@ class PSFInference: y coordinates in SHE convention. seds : array-like, optional Spectral energy distributions (SEDs). + sources : array-like, optional + Postage stamps of sources, e.g. star images (shape: [n_stars, h, w]). + masks : array-like, optional + Corresponding masks for the sources (same shape as sources). Defaults to None. """ def __init__(self, inference_config_path: str, x_field=None, y_field=None, seds=None, sources=None, masks=None): From 76a449c009cfedce003afd755881467a230033f6 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Thu, 4 Sep 2025 05:44:34 +0200 Subject: [PATCH 108/135] Rename _get_inference_data to _get_direct_data --- src/wf_psf/data/data_handler.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/wf_psf/data/data_handler.py b/src/wf_psf/data/data_handler.py index bd802763..c0ddf7b0 100644 --- a/src/wf_psf/data/data_handler.py +++ b/src/wf_psf/data/data_handler.py @@ -394,16 +394,16 @@ def get_data_array( if run_type in {"simulation", "training", "metrics"}: return extract_star_data(data, effective_train_key, effective_test_key) else: # inference - return _get_inference_data(data, effective_key, allow_missing) + return _get_direct_data(data, effective_key, allow_missing) except Exception as e: if allow_missing: return None raise -def _get_inference_data(data, key: str, allow_missing: bool) -> np.ndarray | None: +def _get_direct_data(data, key: str, allow_missing: bool) -> np.ndarray | None: """ - Extract inference data with proper error handling and type conversion. + Extract data directly with proper error handling and type conversion. Parameters ---------- From b0874e76622948cd5dc3ff8499caeb445dbc11af Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 5 Sep 2025 10:37:27 -0400 Subject: [PATCH 109/135] Reformat with black --- src/wf_psf/__init__.py | 6 +- src/wf_psf/data/centroids.py | 29 ++- src/wf_psf/data/data_zernike_utils.py | 6 +- src/wf_psf/inference/psf_inference.py | 84 ++++---- .../psf_model_physical_polychromatic.py | 125 ++++++------ src/wf_psf/psf_models/psf_model_loader.py | 16 +- src/wf_psf/psf_models/tf_modules/tf_layers.py | 2 +- src/wf_psf/psf_models/tf_modules/tf_utils.py | 42 ++-- .../masked_loss/results/plot_results.ipynb | 164 ++++++++++++---- src/wf_psf/tests/test_data/test_data_utils.py | 22 ++- .../test_inference/psf_inference_test.py | 183 +++++++++++------- .../psf_model_physical_polychromatic_test.py | 39 ++-- .../tests/test_psf_models/psf_models_test.py | 2 +- src/wf_psf/tests/test_utils/utils_test.py | 42 ++-- src/wf_psf/training/train.py | 17 +- 15 files changed, 480 insertions(+), 299 deletions(-) diff --git a/src/wf_psf/__init__.py b/src/wf_psf/__init__.py index d4394f09..988b02fe 100644 --- a/src/wf_psf/__init__.py +++ b/src/wf_psf/__init__.py @@ -2,6 +2,6 @@ # Dynamically import modules to trigger side effects when wf_psf is imported importlib.import_module("wf_psf.psf_models.psf_models") -importlib.import_module("wf_psf.psf_models.psf_model_semiparametric") -importlib.import_module("wf_psf.psf_models.psf_model_physical_polychromatic") -importlib.import_module("wf_psf.psf_models.tf_psf_field") +importlib.import_module("wf_psf.psf_models.models.psf_model_semiparametric") +importlib.import_module("wf_psf.psf_models.models.psf_model_physical_polychromatic") +importlib.import_module("wf_psf.psf_models.tf_modules.tf_psf_field") diff --git a/src/wf_psf/data/centroids.py b/src/wf_psf/data/centroids.py index fb69c8dd..20391793 100644 --- a/src/wf_psf/data/centroids.py +++ b/src/wf_psf/data/centroids.py @@ -14,7 +14,9 @@ from typing import Optional -def compute_centroid_correction(model_params, centroid_dataset, batch_size: int=1) -> np.ndarray: +def compute_centroid_correction( + model_params, centroid_dataset, batch_size: int = 1 +) -> np.ndarray: """Compute centroid corrections using Zernike polynomials. This function calculates the Zernike contributions required to match the centroid @@ -31,15 +33,15 @@ def compute_centroid_correction(model_params, centroid_dataset, batch_size: int= Array of star postage stamps (required). - "masks" : Optional[np.ndarray] Array of star masks (optional, can be None). - + batch_size : int, optional The batch size to use when processing the stars. Default is 16. Returns ------- zernike_centroid_array : np.ndarray - A 2D NumPy array of shape `(n_stars, 3)`, where `n_stars` is the number of - observed stars. The array contains the computed Zernike (Z1, Z2) contributions, + A 2D NumPy array of shape `(n_stars, 3)`, where `n_stars` is the number of + observed stars. The array contains the computed Zernike (Z1, Z2) contributions, with zero padding applied to the first column to ensure a consistent shape. """ # Retrieve stamps and masks from centroid_dataset @@ -51,15 +53,17 @@ def compute_centroid_correction(model_params, centroid_dataset, batch_size: int= pix_sampling = model_params.pix_sampling * 1e-6 # Change units from [um] to [m] - reference_shifts = [float(Fraction(value)) for value in model_params.reference_shifts] + reference_shifts = [ + float(Fraction(value)) for value in model_params.reference_shifts + ] n_stars = len(star_postage_stamps) zernike_centroid_array = [] # Batch process the stars for i in range(0, n_stars, batch_size): - batch_postage_stamps = star_postage_stamps[i:i + batch_size] - batch_masks = star_masks[i:i + batch_size] if star_masks is not None else None + batch_postage_stamps = star_postage_stamps[i : i + batch_size] + batch_masks = star_masks[i : i + batch_size] if star_masks is not None else None # Compute Zernike 1 and Zernike 2 for the batch zk1_2_batch = -1.0 * compute_zernike_tip_tilt( @@ -67,11 +71,19 @@ def compute_centroid_correction(model_params, centroid_dataset, batch_size: int= ) # Zero pad array for each batch and append - zernike_centroid_array.append(np.pad(zk1_2_batch, pad_width=[(0, 0), (1, 0)], mode="constant", constant_values=0)) + zernike_centroid_array.append( + np.pad( + zk1_2_batch, + pad_width=[(0, 0), (1, 0)], + mode="constant", + constant_values=0, + ) + ) # Combine all batches into a single array return np.concatenate(zernike_centroid_array, axis=0) + def compute_zernike_tip_tilt( star_images: np.ndarray, star_masks: Optional[np.ndarray] = None, @@ -119,6 +131,7 @@ def compute_zernike_tip_tilt( - The Zernike coefficients are computed in the WaveDiff convention. """ from wf_psf.data.data_zernike_utils import shift_x_y_to_zk1_2_wavediff + # Vectorize the centroid computation centroid_estimator = CentroidEstimator( im=star_images, mask=star_masks, sigma_init=sigma_init, n_iter=n_iter diff --git a/src/wf_psf/data/data_zernike_utils.py b/src/wf_psf/data/data_zernike_utils.py index 9bf9d1fd..399ee9ef 100644 --- a/src/wf_psf/data/data_zernike_utils.py +++ b/src/wf_psf/data/data_zernike_utils.py @@ -58,12 +58,14 @@ def build( centroid_dataset, positions = None, None if run_type in {"training", "simulation", "metrics"}: - stamps = get_data_array(data, run_type, train_key="noisy_stars", test_key="stars") + stamps = get_data_array( + data, run_type, train_key="noisy_stars", test_key="stars" + ) masks = get_data_array(data, run_type, key="masks", allow_missing=True) centroid_dataset = {"stamps": stamps, "masks": masks} positions = get_data_array(data=data, run_type=run_type, key="positions") - + if model_params.use_prior: if prior is not None: logger.warning( diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index 5c27fdbf..c7d73249 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -28,7 +28,6 @@ def __init__(self, inference_config_path: str): self.training_config = None self.data_config = None - def load_configs(self): """Load configuration files based on the inference config.""" self.inference_config = read_conf(self.inference_config_path) @@ -39,7 +38,6 @@ def load_configs(self): # Load the data configuration self.data_config = read_conf(self.data_config_path) - def set_config_paths(self): """Extract and set the configuration paths.""" # Set config paths @@ -47,10 +45,11 @@ def set_config_paths(self): self.trained_model_path = Path(config_paths.trained_model_path) self.model_subdir = config_paths.model_subdir - self.trained_model_config_path = self.trained_model_path / config_paths.trained_model_config_path + self.trained_model_config_path = ( + self.trained_model_path / config_paths.trained_model_config_path + ) self.data_config_path = config_paths.data_config_path - @staticmethod def overwrite_model_params(training_config=None, inference_config=None): """ @@ -75,7 +74,6 @@ def overwrite_model_params(training_config=None, inference_config=None): if hasattr(model_params, key): setattr(model_params, key, value) - class PSFInference: """ @@ -101,7 +99,15 @@ class PSFInference: Corresponding masks for the sources (same shape as sources). Defaults to None. """ - def __init__(self, inference_config_path: str, x_field=None, y_field=None, seds=None, sources=None, masks=None): + def __init__( + self, + inference_config_path: str, + x_field=None, + y_field=None, + seds=None, + sources=None, + masks=None, + ): self.inference_config_path = inference_config_path @@ -111,7 +117,7 @@ def __init__(self, inference_config_path: str, x_field=None, y_field=None, seds= self.seds = seds self.sources = sources self.masks = masks - + # Internal caches for lazy-loading self._config_handler = None self._simPSF = None @@ -123,7 +129,7 @@ def __init__(self, inference_config_path: str, x_field=None, y_field=None, seds= self._output_dim = None # Initialise PSF Inference engine - self.engine = None + self.engine = None @property def config_handler(self): @@ -157,7 +163,6 @@ def simPSF(self): self._simPSF = psf_models.simPSF(self.training_config.training.model_params) return self._simPSF - def _prepare_dataset_for_inference(self): """Prepare dataset dictionary for inference, returning None if positions are invalid.""" positions = self.get_positions() @@ -167,7 +172,7 @@ def _prepare_dataset_for_inference(self): @property def data_handler(self): - if self._data_handler is None: + if self._data_handler is None: # Instantiate the data handler self._data_handler = DataHandler( dataset_type="inference", @@ -176,7 +181,7 @@ def data_handler(self): n_bins_lambda=self.n_bins_lambda, load_data=False, dataset=self._prepare_dataset_for_inference(), - sed_data = self.seds, + sed_data=self.seds, ) self._data_handler.run_type = "inference" return self._data_handler @@ -190,13 +195,13 @@ def trained_psf_model(self): def get_positions(self): """ Combine x_field and y_field into position pairs. - + Returns ------- numpy.ndarray Array of shape (num_positions, 2) where each row contains [x, y] coordinates. Returns None if either x_field or y_field is None. - + Raises ------ ValueError @@ -204,25 +209,27 @@ def get_positions(self): """ if self.x_field is None or self.y_field is None: return None - + x_arr = np.asarray(self.x_field) y_arr = np.asarray(self.y_field) - + if x_arr.size == 0 or y_arr.size == 0: return None if x_arr.size != y_arr.size: - raise ValueError(f"x_field and y_field must have the same length. " - f"Got {x_arr.size} and {y_arr.size}") - + raise ValueError( + f"x_field and y_field must have the same length. " + f"Got {x_arr.size} and {y_arr.size}" + ) + # Flatten arrays to handle any input shape, then stack x_flat = x_arr.flatten() y_flat = y_arr.flatten() - + return np.column_stack((x_flat, y_flat)) def load_inference_model(self): - """Load the trained PSF model based on the inference configuration.""" + """Load the trained PSF model based on the inference configuration.""" model_path = self.config_handler.trained_model_path model_dir = self.config_handler.model_subdir model_name = self.training_config.training.model_params.model_name @@ -231,9 +238,9 @@ def load_inference_model(self): weights_path_pattern = os.path.join( model_path, model_dir, - f"{model_dir}*_{model_name}*{id_name}_cycle{self.cycle}*" + f"{model_dir}*_{model_name}*{id_name}_cycle{self.cycle}*", ) - + # Load the trained PSF model return load_trained_psf_model( self.training_config, @@ -244,7 +251,9 @@ def load_inference_model(self): @property def n_bins_lambda(self): if self._n_bins_lambda is None: - self._n_bins_lambda = self.inference_config.inference.model_params.n_bins_lda + self._n_bins_lambda = ( + self.inference_config.inference.model_params.n_bins_lda + ) return self._n_bins_lambda @property @@ -279,8 +288,10 @@ def _prepare_positions_and_seds(self): y_arr = np.atleast_1d(self.y_field) if x_arr.size != y_arr.size: - raise ValueError(f"x_field and y_field must have the same length. " - f"Got {x_arr.size} and {y_arr.size}") + raise ValueError( + f"x_field and y_field must have the same length. " + f"Got {x_arr.size} and {y_arr.size}" + ) # Combine into positions array (n_samples, 2) positions = np.column_stack((x_arr, y_arr)) @@ -288,12 +299,16 @@ def _prepare_positions_and_seds(self): # Ensure SEDs have shape (n_samples, n_bins, 2) sed_data = ensure_batch(self.seds) - + if sed_data.shape[0] != positions.shape[0]: - raise ValueError(f"SEDs batch size {sed_data.shape[0]} does not match number of positions {positions.shape[0]}") + raise ValueError( + f"SEDs batch size {sed_data.shape[0]} does not match number of positions {positions.shape[0]}" + ) if sed_data.shape[2] != 2: - raise ValueError(f"SEDs last dimension must be 2 (flux, wavelength). Got {sed_data.shape}") + raise ValueError( + f"SEDs last dimension must be 2 (flux, wavelength). Got {sed_data.shape}" + ) # Process SEDs through the data handler self.data_handler.process_sed_data(sed_data) @@ -301,7 +316,6 @@ def _prepare_positions_and_seds(self): return positions, sed_data_tensor - def run_inference(self): """Run PSF inference and return the full PSF array.""" # Prepare the configuration for inference @@ -332,7 +346,7 @@ def get_psf(self, index: int = 0) -> np.ndarray: If only a single star was passed during instantiation, the index defaults to 0. """ self._ensure_psf_inference_completed() - + inferred_psfs = self.engine.get_psfs() # If a single-star batch, ignore index bounds @@ -358,7 +372,9 @@ def inferred_psfs(self) -> np.ndarray: def compute_psfs(self, positions: tf.Tensor, sed_data: tf.Tensor) -> np.ndarray: """Compute and cache PSFs for the input source parameters.""" n_samples = positions.shape[0] - self._inferred_psfs = np.zeros((n_samples, self.output_dim, self.output_dim), dtype=np.float32) + self._inferred_psfs = np.zeros( + (n_samples, self.output_dim, self.output_dim), dtype=np.float32 + ) # Initialize counter counter = 0 @@ -370,14 +386,14 @@ def compute_psfs(self, positions: tf.Tensor, sed_data: tf.Tensor) -> np.ndarray: batch_pos = positions[counter:end_sample, :] batch_seds = sed_data[counter:end_sample, :, :] batch_inputs = [batch_pos, batch_seds] - + # Generate PSFs for the current batch batch_psfs = self.trained_model(batch_inputs, training=False) self.inferred_psfs[counter:end_sample, :, :] = batch_psfs.numpy() # Update the counter counter = end_sample - + return self._inferred_psfs def get_psfs(self) -> np.ndarray: @@ -391,5 +407,3 @@ def get_psf(self, index: int) -> np.ndarray: if self._inferred_psfs is None: raise ValueError("PSFs not yet computed. Call compute_psfs() first.") return self._inferred_psfs[index] - - diff --git a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py index c8086cb7..e2c84868 100644 --- a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py +++ b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py @@ -12,9 +12,9 @@ from tensorflow.python.keras.engine import data_adapter from wf_psf.data.data_handler import get_data_array from wf_psf.data.data_zernike_utils import ( - ZernikeInputsFactory, + ZernikeInputsFactory, assemble_zernike_contributions, - pad_tf_zernikes + pad_tf_zernikes, ) from wf_psf.psf_models import psf_models as psfm from wf_psf.psf_models.tf_modules.tf_layers import ( @@ -124,7 +124,7 @@ def __init__(self, model_params, training_params, data, coeff_mat=None): self.output_Q = model_params.output_Q self.l2_param = model_params.param_hparams.l2_param self.output_dim = model_params.output_dim - + # Initialise lazy loading of external Zernike prior self._external_prior = None @@ -134,25 +134,26 @@ def __init__(self, model_params, training_params, data, coeff_mat=None): # Compute contributions once eagerly (outside graph) zks_total_contribution_np = self._assemble_zernike_contributions().numpy() - self._zks_total_contribution = tf.convert_to_tensor(zks_total_contribution_np, dtype=tf.float32) - + self._zks_total_contribution = tf.convert_to_tensor( + zks_total_contribution_np, dtype=tf.float32 + ) + # Compute n_zks_total as int self._n_zks_total = max( self.model_params.param_hparams.n_zernikes, - zks_total_contribution_np.shape[1] + zks_total_contribution_np.shape[1], ) - - # Precompute zernike maps as tf.float32 + + # Precompute zernike maps as tf.float32 self._zernike_maps = psfm.generate_zernike_maps_3d( - n_zernikes=self._n_zks_total, - pupil_diam=self.model_params.pupil_diameter - ) + n_zernikes=self._n_zks_total, pupil_diam=self.model_params.pupil_diameter + ) - # Precompute OPD dimension + # Precompute OPD dimension self._opd_dim = self._zernike_maps.shape[1] # Precompute obscurations as tf.complex64 - self._obscurations = psfm.tf_obscurations( + self._obscurations = psfm.tf_obscurations( pupil_diam=self.model_params.pupil_diameter, N_filter=self.model_params.LP_filter_length, rotation_angle=self.model_params.obscuration_rotation_angle, @@ -164,10 +165,10 @@ def __init__(self, model_params, training_params, data, coeff_mat=None): _ = self.tf_np_poly_opd def _get_run_type(self, data): - if hasattr(data, 'run_type'): + if hasattr(data, "run_type"): run_type = data.run_type - elif isinstance(data, dict) and 'run_type' in data: - run_type = data['run_type'] + elif isinstance(data, dict) and "run_type" in data: + run_type = data["run_type"] else: raise ValueError("data must have a 'run_type' attribute or key") @@ -193,17 +194,28 @@ def _assemble_zernike_contributions(self): @property def save_param_history(self) -> bool: """Check if the model should save the optimization history for parametric features.""" - return getattr(self.model_params.param_hparams, "save_optim_history_param", False) - + return getattr( + self.model_params.param_hparams, "save_optim_history_param", False + ) + @property def save_nonparam_history(self) -> bool: """Check if the model should save the optimization history for non-parametric features.""" - return getattr(self.model_params.nonparam_hparams, "save_optim_history_nonparam", False) + return getattr( + self.model_params.nonparam_hparams, "save_optim_history_nonparam", False + ) def get_obs_pos(self): - assert self.run_type in {"training", "simulation", "metrics", "inference"}, f"Unknown run_type: {self.run_type}" - - raw_pos = get_data_array(data=self.data, run_type=self.run_type, key="positions") + assert self.run_type in { + "training", + "simulation", + "metrics", + "inference", + }, f"Unknown run_type: {self.run_type}" + + raw_pos = get_data_array( + data=self.data, run_type=self.run_type, key="positions" + ) obs_pos = ensure_tensor(raw_pos, dtype=tf.float32) @@ -213,7 +225,7 @@ def get_obs_pos(self): @property def zks_total_contribution(self): return self._zks_total_contribution - + @property def n_zks_total(self): """Get the total number of Zernike coefficients.""" @@ -254,40 +266,40 @@ def tf_physical_layer(self): """Lazy loading of the physical layer of the PSF model.""" if not hasattr(self, "_tf_physical_layer"): self._tf_physical_layer = TFPhysicalLayer( - self.obs_pos, - self.zks_total_contribution, - interpolation_type=self.model_params.interpolation_type, - interpolation_args=self.model_params.interpolation_args, - ) + self.obs_pos, + self.zks_total_contribution, + interpolation_type=self.model_params.interpolation_type, + interpolation_args=self.model_params.interpolation_args, + ) return self._tf_physical_layer - + @property def tf_zernike_OPD(self): """Lazy loading of the Zernike Optical Path Difference (OPD) layer.""" if not hasattr(self, "_tf_zernike_OPD"): self._tf_zernike_OPD = TFZernikeOPD(zernike_maps=self.zernike_maps) return self._tf_zernike_OPD - + def _build_tf_batch_poly_PSF(self): """Eagerly build the TFBatchPolychromaticPSF layer with numpy-based obscurations.""" return TFBatchPolychromaticPSF( - obscurations=self.obscurations, - output_Q=self.output_Q, - output_dim=self.output_dim, - ) + obscurations=self.obscurations, + output_Q=self.output_Q, + output_dim=self.output_dim, + ) @property def tf_np_poly_opd(self): """Lazy loading of the non-parametric polynomial variations OPD layer.""" if not hasattr(self, "_tf_np_poly_opd"): self._tf_np_poly_opd = TFNonParametricPolynomialVariationsOPD( - x_lims=self.model_params.x_lims, - y_lims=self.model_params.y_lims, - random_seed=self.model_params.param_hparams.random_seed, - d_max=self.model_params.nonparam_hparams.d_max_nonparam, - opd_dim=self.opd_dim, - ) + x_lims=self.model_params.x_lims, + y_lims=self.model_params.y_lims, + random_seed=self.model_params.param_hparams.random_seed, + d_max=self.model_params.nonparam_hparams.d_max_nonparam, + opd_dim=self.opd_dim, + ) return self._tf_np_poly_opd def get_coeff_matrix(self): @@ -313,7 +325,6 @@ def assign_coeff_matrix(self, coeff_mat: Optional[tf.Tensor]) -> None: """ self.tf_poly_Z_field.assign_coeff_matrix(coeff_mat) - def set_output_Q(self, output_Q: float, output_dim: Optional[int] = None) -> None: """Set the output sampling rate (output_Q) for PSF generation. @@ -453,16 +464,16 @@ def predict_step(self, data, evaluate_step=False): # Compute zernikes from parametric model and physical layer zks_coeffs = self.predict_zernikes(input_positions) - + # Propagate to obtain the OPD param_opd_maps = self.tf_zernike_OPD(zks_coeffs) - + # Calculate the non parametric part nonparam_opd_maps = self.tf_np_poly_opd(input_positions) - + # Add the estimations opd_maps = tf.math.add(param_opd_maps, nonparam_opd_maps) - + # Compute the polychromatic PSFs poly_psfs = self.tf_batch_poly_PSF([opd_maps, packed_SEDs]) @@ -505,13 +516,13 @@ def predict_mono_psfs(self, input_positions, lambda_obs, phase_N): # Predict zernikes from parametric model and physical layer zks_coeffs = self.predict_zernikes(input_positions) - + # Propagate to obtain the OPD param_opd_maps = self.tf_zernike_OPD(zks_coeffs) - + # Calculate the non parametric part nonparam_opd_maps = self.tf_np_poly_opd(input_positions) - + # Add the estimations opd_maps = tf.math.add(param_opd_maps, nonparam_opd_maps) @@ -536,13 +547,13 @@ def predict_opd(self, input_positions): """ # Predict zernikes from parametric model and physical layer zks_coeffs = self.predict_zernikes(input_positions) - + # Propagate to obtain the OPD param_opd_maps = self.tf_zernike_OPD(zks_coeffs) - + # Calculate the non parametric part nonparam_opd_maps = self.tf_np_poly_opd(input_positions) - + # Add the estimations opd_maps = tf.math.add(param_opd_maps, nonparam_opd_maps) @@ -677,27 +688,25 @@ def call(self, inputs, training=True): packed_SEDs = inputs[1] # For the training - if training: + if training: # Compute zernikes from parametric model and physical layer zks_coeffs = self.compute_zernikes(input_positions) - + # Parametric OPD maps from Zernikes param_opd_maps = self.tf_zernike_OPD(zks_coeffs) # Add L2 regularization loss on parametric OPD maps - self.add_loss( - self.l2_param * tf.reduce_sum(tf.square(param_opd_maps)) - ) + self.add_loss(self.l2_param * tf.reduce_sum(tf.square(param_opd_maps))) # Non-parametric correction nonparam_opd_maps = self.tf_np_poly_opd(input_positions) # Combine both contributions opd_maps = tf.add(param_opd_maps, nonparam_opd_maps) - + # Compute the polychromatic PSFs poly_psfs = self.tf_batch_poly_PSF([opd_maps, packed_SEDs]) - + # For the inference else: # Compute predictions diff --git a/src/wf_psf/psf_models/psf_model_loader.py b/src/wf_psf/psf_models/psf_model_loader.py index c30d31ad..e445f7af 100644 --- a/src/wf_psf/psf_models/psf_model_loader.py +++ b/src/wf_psf/psf_models/psf_model_loader.py @@ -7,15 +7,14 @@ Author: Jennifer Pollack """ + import logging -from wf_psf.psf_models.psf_models import ( - get_psf_model, - get_psf_model_weights_filepath -) +from wf_psf.psf_models.psf_models import get_psf_model, get_psf_model_weights_filepath import tensorflow as tf logger = logging.getLogger(__name__) + def load_trained_psf_model(training_conf, data_conf, weights_path_pattern): """ Loads a trained PSF model and applies saved weights. @@ -40,9 +39,11 @@ def load_trained_psf_model(training_conf, data_conf, weights_path_pattern): RuntimeError If loading the model weights fails for any reason. """ - model = get_psf_model(training_conf.training.model_params, - training_conf.training.training_hparams, - data_conf) + model = get_psf_model( + training_conf.training.model_params, + training_conf.training.training_hparams, + data_conf, + ) weights_path = get_psf_model_weights_filepath(weights_path_pattern) @@ -55,4 +56,3 @@ def load_trained_psf_model(training_conf, data_conf, weights_path_pattern): logger.exception("Failed to load model weights.") raise RuntimeError("Model weight loading failed.") from e return model - diff --git a/src/wf_psf/psf_models/tf_modules/tf_layers.py b/src/wf_psf/psf_models/tf_modules/tf_layers.py index 5bb5e2df..7b22b5e8 100644 --- a/src/wf_psf/psf_models/tf_modules/tf_layers.py +++ b/src/wf_psf/psf_models/tf_modules/tf_layers.py @@ -17,6 +17,7 @@ logger = logging.getLogger(__name__) + class TFPolynomialZernikeField(tf.keras.layers.Layer): """Calculate the zernike coefficients for a given position. @@ -964,7 +965,6 @@ def interpolate_independent_Zk(self, positions): return interp_zks[:, :, tf.newaxis, tf.newaxis] - def call(self, positions): """Calculate the prior Zernike coefficients for a batch of positions. diff --git a/src/wf_psf/psf_models/tf_modules/tf_utils.py b/src/wf_psf/psf_models/tf_modules/tf_utils.py index d0a2002c..09540e60 100644 --- a/src/wf_psf/psf_models/tf_modules/tf_utils.py +++ b/src/wf_psf/psf_models/tf_modules/tf_utils.py @@ -1,15 +1,15 @@ """TensorFlow Utilities Module. -Provides lightweight utility functions for safely converting and managing data types +Provides lightweight utility functions for safely converting and managing data types within TensorFlow-based workflows. Includes: - `ensure_tensor`: ensures inputs are TensorFlow tensors with specified dtype -These tools are designed to support PSF model components, including lazy property evaluation, +These tools are designed to support PSF model components, including lazy property evaluation, data input validation, and type normalization. -This module is intended for internal use in model layers and inference components to enforce +This module is intended for internal use in model layers and inference components to enforce TensorFlow-compatible inputs. Authors: Jennifer Pollack @@ -22,68 +22,69 @@ @tf.function def find_position_indices(obs_pos, batch_positions): """Find indices of batch positions within observed positions using vectorized operations. - - This function locates the indices of multiple query positions within a + + This function locates the indices of multiple query positions within a reference set of observed positions using broadcasting and vectorized operations. Each position in the batch must have an exact match in the observed positions. Parameters ---------- obs_pos : tf.Tensor - Reference positions tensor of shape (n_obs, 2), where n_obs is the number of + Reference positions tensor of shape (n_obs, 2), where n_obs is the number of observed positions. Each row contains [x, y] coordinates. batch_positions : tf.Tensor - Query positions tensor of shape (batch_size, 2), where batch_size is the number + Query positions tensor of shape (batch_size, 2), where batch_size is the number of positions to look up. Each row contains [x, y] coordinates. - + Returns ------- indices : tf.Tensor - Tensor of shape (batch_size,) containing the indices of each batch position + Tensor of shape (batch_size,) containing the indices of each batch position within obs_pos. The dtype is tf.int64. - + Raises ------ tf.errors.InvalidArgumentError If any position in batch_positions is not found in obs_pos. - + Notes ----- - Uses exact equality matching - positions must match exactly. More efficient than + Uses exact equality matching - positions must match exactly. More efficient than iterative lookups for multiple positions due to vectorized operations. """ # Shape: obs_pos (n_obs, 2), batch_positions (batch_size, 2) # Expand for broadcasting: (1, n_obs, 2) and (batch_size, 1, 2) obs_expanded = tf.expand_dims(obs_pos, 0) pos_expanded = tf.expand_dims(batch_positions, 1) - + # Compare all positions at once: (batch_size, n_obs) matches = tf.reduce_all(tf.equal(obs_expanded, pos_expanded), axis=2) - + # Find the index of the matching position for each batch item - # argmax returns the first True value's index along axis=1 + # argmax returns the first True value's index along axis=1 indices = tf.argmax(tf.cast(matches, tf.int32), axis=1) - + # Verify all positions were found tf.debugging.assert_equal( tf.reduce_all(tf.reduce_any(matches, axis=1)), True, - message="Some positions not found in obs_pos" + message="Some positions not found in obs_pos", ) - + return indices + def ensure_tensor(input_array, dtype=tf.float32): """ Ensure the input is a TensorFlow tensor of the specified dtype. - + Parameters ---------- input_array : array-like, tf.Tensor, or np.ndarray The input to convert. dtype : tf.DType, optional The desired TensorFlow dtype (default: tf.float32). - + Returns ------- tf.Tensor @@ -97,4 +98,3 @@ def ensure_tensor(input_array, dtype=tf.float32): else: # Convert numpy arrays or other types to tensor return tf.convert_to_tensor(input_array, dtype=dtype) - diff --git a/src/wf_psf/tests/data/validation/masked_loss/results/plot_results.ipynb b/src/wf_psf/tests/data/validation/masked_loss/results/plot_results.ipynb index bec994a4..c36a9e31 100644 --- a/src/wf_psf/tests/data/validation/masked_loss/results/plot_results.ipynb +++ b/src/wf_psf/tests/data/validation/masked_loss/results/plot_results.ipynb @@ -17,19 +17,19 @@ "outputs": [], "source": [ "# Trained on masked data, tested on masked data\n", - "metrics_path_mm = '../runs/masked_train_masked_test/wf-outputs/wf-outputs-202503131718/metrics/metrics-polymask_train_mask_test.npy'\n", + "metrics_path_mm = \"../runs/masked_train_masked_test/wf-outputs/wf-outputs-202503131718/metrics/metrics-polymask_train_mask_test.npy\"\n", "mask_train_mask_test = np.load(metrics_path_mm, allow_pickle=True)[()]\n", "\n", "# Trained on masked data, tested on unmasked data\n", - "metrics_path_mu = '../runs/masked_train_unit_mask_test/wf-outputs/wf-outputs-202503131720/metrics/metrics-polymasked_train_unit_mask_test.npy'\n", + "metrics_path_mu = \"../runs/masked_train_unit_mask_test/wf-outputs/wf-outputs-202503131720/metrics/metrics-polymasked_train_unit_mask_test.npy\"\n", "mask_train_nomask_test = np.load(metrics_path_mu, allow_pickle=True)[()]\n", "\n", "# Trained on unmasked data, tested on unmasked data\n", - "metrics_path_c = '../runs/control_train/wf-outputs/wf-outputs-202503131716/metrics/metrics-polycontrol_train.npy'\n", + "metrics_path_c = \"../runs/control_train/wf-outputs/wf-outputs-202503131716/metrics/metrics-polycontrol_train.npy\"\n", "control_train = np.load(metrics_path_c, allow_pickle=True)[()]\n", "\n", "# Trained and tested with unitary masks\n", - "metrics_path_u = '../runs/unit_masked_train/wf-outputs/wf-outputs-202503131721/metrics/metrics-polyunit_masked_train.npy'\n", + "metrics_path_u = \"../runs/unit_masked_train/wf-outputs/wf-outputs-202503131721/metrics/metrics-polyunit_masked_train.npy\"\n", "unitary = np.load(metrics_path_u, allow_pickle=True)[()]" ] }, @@ -50,8 +50,8 @@ ], "source": [ "print(mask_train_mask_test.keys())\n", - "print(mask_train_mask_test['test_metrics'].keys())\n", - "print(mask_train_mask_test['test_metrics']['poly_metric'].keys())" + "print(mask_train_mask_test[\"test_metrics\"].keys())\n", + "print(mask_train_mask_test[\"test_metrics\"][\"poly_metric\"].keys())" ] }, { @@ -60,17 +60,25 @@ "metadata": {}, "outputs": [], "source": [ - "mask_test_mask_test_rel_rmse = mask_train_mask_test['test_metrics']['poly_metric']['rel_rmse']\n", - "mask_test_mask_test_std_rel_rmse = mask_train_mask_test['test_metrics']['poly_metric']['std_rel_rmse']\n", + "mask_test_mask_test_rel_rmse = mask_train_mask_test[\"test_metrics\"][\"poly_metric\"][\n", + " \"rel_rmse\"\n", + "]\n", + "mask_test_mask_test_std_rel_rmse = mask_train_mask_test[\"test_metrics\"][\"poly_metric\"][\n", + " \"std_rel_rmse\"\n", + "]\n", "\n", - "mask_test_nomask_test_rel_rmse = mask_train_nomask_test['test_metrics']['poly_metric']['rel_rmse']\n", - "mask_test_nomask_test_std_rel_rmse = mask_train_nomask_test['test_metrics']['poly_metric']['std_rel_rmse']\n", + "mask_test_nomask_test_rel_rmse = mask_train_nomask_test[\"test_metrics\"][\"poly_metric\"][\n", + " \"rel_rmse\"\n", + "]\n", + "mask_test_nomask_test_std_rel_rmse = mask_train_nomask_test[\"test_metrics\"][\n", + " \"poly_metric\"\n", + "][\"std_rel_rmse\"]\n", "\n", - "control_test_rel_rmse = control_train['test_metrics']['poly_metric']['rel_rmse']\n", - "control_test_std_rel_rmse = control_train['test_metrics']['poly_metric']['std_rel_rmse']\n", + "control_test_rel_rmse = control_train[\"test_metrics\"][\"poly_metric\"][\"rel_rmse\"]\n", + "control_test_std_rel_rmse = control_train[\"test_metrics\"][\"poly_metric\"][\"std_rel_rmse\"]\n", "\n", - "unitary_test_rel_rmse = unitary['test_metrics']['poly_metric']['rel_rmse']\n", - "unitary_test_std_rel_rmse = unitary['test_metrics']['poly_metric']['std_rel_rmse']" + "unitary_test_rel_rmse = unitary[\"test_metrics\"][\"poly_metric\"][\"rel_rmse\"]\n", + "unitary_test_std_rel_rmse = unitary[\"test_metrics\"][\"poly_metric\"][\"std_rel_rmse\"]" ] }, { @@ -92,12 +100,29 @@ "source": [ "# Plot the results\n", "fig, ax = plt.subplots(figsize=(8, 4))\n", - "plt.title('Relative RMSE 1x - Test dataset')\n", - "ax.errorbar([0, 1, 2, 3], [control_test_rel_rmse, mask_test_mask_test_rel_rmse, mask_test_nomask_test_rel_rmse, unitary_test_rel_rmse], yerr=[control_test_std_rel_rmse, mask_test_mask_test_std_rel_rmse, mask_test_nomask_test_std_rel_rmse, unitary_test_std_rel_rmse], fmt='o')\n", + "plt.title(\"Relative RMSE 1x - Test dataset\")\n", + "ax.errorbar(\n", + " [0, 1, 2, 3],\n", + " [\n", + " control_test_rel_rmse,\n", + " mask_test_mask_test_rel_rmse,\n", + " mask_test_nomask_test_rel_rmse,\n", + " unitary_test_rel_rmse,\n", + " ],\n", + " yerr=[\n", + " control_test_std_rel_rmse,\n", + " mask_test_mask_test_std_rel_rmse,\n", + " mask_test_nomask_test_std_rel_rmse,\n", + " unitary_test_std_rel_rmse,\n", + " ],\n", + " fmt=\"o\",\n", + ")\n", "ax.set_xticks([0, 1, 2, 3])\n", - "ax.set_xticklabels(['Control Train', 'Mask Train Mask Test', 'Mask Train Unit Mask Test', 'Unitary'])\n", - "ax.grid('minor')\n", - "ax.set_ylabel('Relative RMSE')\n", + "ax.set_xticklabels(\n", + " [\"Control Train\", \"Mask Train Mask Test\", \"Mask Train Unit Mask Test\", \"Unitary\"]\n", + ")\n", + "ax.grid(\"minor\")\n", + "ax.set_ylabel(\"Relative RMSE\")\n", "plt.show()" ] }, @@ -107,17 +132,27 @@ "metadata": {}, "outputs": [], "source": [ - "mask_train_mask_test_rel_rmse = mask_train_mask_test['train_metrics']['poly_metric']['rel_rmse']\n", - "mask_train_mask_test_std_rel_rmse = mask_train_mask_test['train_metrics']['poly_metric']['std_rel_rmse']\n", + "mask_train_mask_test_rel_rmse = mask_train_mask_test[\"train_metrics\"][\"poly_metric\"][\n", + " \"rel_rmse\"\n", + "]\n", + "mask_train_mask_test_std_rel_rmse = mask_train_mask_test[\"train_metrics\"][\n", + " \"poly_metric\"\n", + "][\"std_rel_rmse\"]\n", "\n", - "mask_train_nomask_test_rel_rmse = mask_train_nomask_test['train_metrics']['poly_metric']['rel_rmse']\n", - "mask_train_nomask_test_std_rel_rmse = mask_train_nomask_test['train_metrics']['poly_metric']['std_rel_rmse']\n", + "mask_train_nomask_test_rel_rmse = mask_train_nomask_test[\"train_metrics\"][\n", + " \"poly_metric\"\n", + "][\"rel_rmse\"]\n", + "mask_train_nomask_test_std_rel_rmse = mask_train_nomask_test[\"train_metrics\"][\n", + " \"poly_metric\"\n", + "][\"std_rel_rmse\"]\n", "\n", - "control_train_rel_rmse = control_train['train_metrics']['poly_metric']['rel_rmse']\n", - "control_train_std_rel_rmse = control_train['train_metrics']['poly_metric']['std_rel_rmse']\n", + "control_train_rel_rmse = control_train[\"train_metrics\"][\"poly_metric\"][\"rel_rmse\"]\n", + "control_train_std_rel_rmse = control_train[\"train_metrics\"][\"poly_metric\"][\n", + " \"std_rel_rmse\"\n", + "]\n", "\n", - "unitary_rel_rmse = unitary['train_metrics']['poly_metric']['rel_rmse']\n", - "unitary_std_rel_rmse = unitary['train_metrics']['poly_metric']['std_rel_rmse']" + "unitary_rel_rmse = unitary[\"train_metrics\"][\"poly_metric\"][\"rel_rmse\"]\n", + "unitary_std_rel_rmse = unitary[\"train_metrics\"][\"poly_metric\"][\"std_rel_rmse\"]" ] }, { @@ -139,12 +174,29 @@ "source": [ "# Plot the results\n", "fig, ax = plt.subplots(figsize=(8, 4))\n", - "plt.title('Relative RMSE 1x - Train dataset')\n", - "ax.errorbar([0, 1, 2, 3], [control_train_rel_rmse, mask_train_mask_test_rel_rmse, mask_train_nomask_test_rel_rmse, unitary_rel_rmse], yerr=[control_train_std_rel_rmse, mask_train_mask_test_std_rel_rmse, mask_train_nomask_test_std_rel_rmse, unitary_std_rel_rmse], fmt='o')\n", + "plt.title(\"Relative RMSE 1x - Train dataset\")\n", + "ax.errorbar(\n", + " [0, 1, 2, 3],\n", + " [\n", + " control_train_rel_rmse,\n", + " mask_train_mask_test_rel_rmse,\n", + " mask_train_nomask_test_rel_rmse,\n", + " unitary_rel_rmse,\n", + " ],\n", + " yerr=[\n", + " control_train_std_rel_rmse,\n", + " mask_train_mask_test_std_rel_rmse,\n", + " mask_train_nomask_test_std_rel_rmse,\n", + " unitary_std_rel_rmse,\n", + " ],\n", + " fmt=\"o\",\n", + ")\n", "ax.set_xticks([0, 1, 2, 3])\n", - "ax.set_xticklabels(['Control Train', 'Mask Train Mask Test', 'Mask Train Unit Mask Test', 'Unitary'])\n", - "ax.set_ylabel('Relative RMSE')\n", - "ax.grid('minor')\n", + "ax.set_xticklabels(\n", + " [\"Control Train\", \"Mask Train Mask Test\", \"Mask Train Unit Mask Test\", \"Unitary\"]\n", + ")\n", + "ax.set_ylabel(\"Relative RMSE\")\n", + "ax.grid(\"minor\")\n", "plt.show()" ] }, @@ -167,16 +219,50 @@ "source": [ "# Plot test and train relative RMSE in the same plot\n", "fig, ax = plt.subplots(figsize=(8, 4))\n", - "plt.title('Relative RMSE 1x - Train and Test dataset')\n", - "ax.errorbar([0, 1, 2, 3], [control_train_rel_rmse, mask_train_mask_test_rel_rmse, mask_train_nomask_test_rel_rmse, unitary_rel_rmse], yerr=[control_train_std_rel_rmse, mask_train_mask_test_std_rel_rmse, mask_train_nomask_test_std_rel_rmse, unitary_std_rel_rmse], fmt='o', label='Train')\n", - "ax.errorbar([0.02, 1.02, 2.02, 3.02], [control_test_rel_rmse, mask_test_mask_test_rel_rmse, mask_test_nomask_test_rel_rmse, unitary_test_rel_rmse], yerr=[control_test_std_rel_rmse, mask_test_mask_test_std_rel_rmse, mask_test_nomask_test_std_rel_rmse, unitary_test_std_rel_rmse], fmt='o', label='Test')\n", + "plt.title(\"Relative RMSE 1x - Train and Test dataset\")\n", + "ax.errorbar(\n", + " [0, 1, 2, 3],\n", + " [\n", + " control_train_rel_rmse,\n", + " mask_train_mask_test_rel_rmse,\n", + " mask_train_nomask_test_rel_rmse,\n", + " unitary_rel_rmse,\n", + " ],\n", + " yerr=[\n", + " control_train_std_rel_rmse,\n", + " mask_train_mask_test_std_rel_rmse,\n", + " mask_train_nomask_test_std_rel_rmse,\n", + " unitary_std_rel_rmse,\n", + " ],\n", + " fmt=\"o\",\n", + " label=\"Train\",\n", + ")\n", + "ax.errorbar(\n", + " [0.02, 1.02, 2.02, 3.02],\n", + " [\n", + " control_test_rel_rmse,\n", + " mask_test_mask_test_rel_rmse,\n", + " mask_test_nomask_test_rel_rmse,\n", + " unitary_test_rel_rmse,\n", + " ],\n", + " yerr=[\n", + " control_test_std_rel_rmse,\n", + " mask_test_mask_test_std_rel_rmse,\n", + " mask_test_nomask_test_std_rel_rmse,\n", + " unitary_test_std_rel_rmse,\n", + " ],\n", + " fmt=\"o\",\n", + " label=\"Test\",\n", + ")\n", "ax.set_xticks([0, 1, 2, 3])\n", - "ax.set_xticklabels(['Control Train', 'Mask Train Mask Test', 'Mask Train Unit Mask Test', 'Unitary'])\n", - "ax.set_ylabel('Relative RMSE')\n", - "ax.grid('minor')\n", + "ax.set_xticklabels(\n", + " [\"Control Train\", \"Mask Train Mask Test\", \"Mask Train Unit Mask Test\", \"Unitary\"]\n", + ")\n", + "ax.set_ylabel(\"Relative RMSE\")\n", + "ax.grid(\"minor\")\n", "plt.legend()\n", "# plt.show()\n", - "plt.savefig('masked_loss_validation.pdf')\n" + "plt.savefig(\"masked_loss_validation.pdf\")" ] }, { diff --git a/src/wf_psf/tests/test_data/test_data_utils.py b/src/wf_psf/tests/test_data/test_data_utils.py index 1ebc00cb..de111427 100644 --- a/src/wf_psf/tests/test_data/test_data_utils.py +++ b/src/wf_psf/tests/test_data/test_data_utils.py @@ -1,8 +1,13 @@ - class MockDataset: def __init__(self, positions, zernike_priors, star_type, stars, masks): - self.dataset = {"positions": positions, "zernike_prior": zernike_priors, star_type: stars, "masks": masks} - + self.dataset = { + "positions": positions, + "zernike_prior": zernike_priors, + star_type: stars, + "masks": masks, + } + + class MockData: def __init__( self, @@ -16,15 +21,16 @@ def __init__( masks=None, ): self.training_data = MockDataset( - positions=training_positions, + positions=training_positions, zernike_priors=training_zernike_priors, star_type="noisy_stars", stars=noisy_stars, - masks=noisy_masks) + masks=noisy_masks, + ) self.test_data = MockDataset( - positions=test_positions, + positions=test_positions, zernike_priors=test_zernike_priors, star_type="stars", stars=stars, - masks=masks) - + masks=masks, + ) diff --git a/src/wf_psf/tests/test_inference/psf_inference_test.py b/src/wf_psf/tests/test_inference/psf_inference_test.py index ec9f0495..05728de7 100644 --- a/src/wf_psf/tests/test_inference/psf_inference_test.py +++ b/src/wf_psf/tests/test_inference/psf_inference_test.py @@ -14,13 +14,14 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch, PropertyMock from wf_psf.inference.psf_inference import ( - InferenceConfigHandler, + InferenceConfigHandler, PSFInference, - PSFInferenceEngine + PSFInferenceEngine, ) from wf_psf.utils.read_config import RecursiveNamespace + def _patch_data_handler(): """Helper for patching data_handler to avoid full PSF logic.""" patcher = patch.object(PSFInference, "data_handler", new_callable=PropertyMock) @@ -34,6 +35,7 @@ def fake_process(x): mock_instance.process_sed_data.side_effect = fake_process return patcher, mock_instance + @pytest.fixture def mock_training_config(): training_config = RecursiveNamespace( @@ -60,13 +62,13 @@ def mock_training_config(): LP_filter_length=3, param_hparams=RecursiveNamespace( n_zernikes=10, - - ) - ) - ) + ), + ), + ) ) return training_config + @pytest.fixture def mock_inference_config(): inference_config = RecursiveNamespace( @@ -74,16 +76,12 @@ def mock_inference_config(): batch_size=16, cycle=2, configs=RecursiveNamespace( - trained_model_path='/path/to/trained/model', - model_subdir='psf_model', - trained_model_config_path='config/training_config.yaml', - data_config_path=None - ), - model_params=RecursiveNamespace( - n_bins_lda=8, - output_Q=1, - output_dim=64 + trained_model_path="/path/to/trained/model", + model_subdir="psf_model", + trained_model_config_path="config/training_config.yaml", + data_config_path=None, ), + model_params=RecursiveNamespace(n_bins_lda=8, output_Q=1, output_dim=64), ) ) return inference_config @@ -96,14 +94,18 @@ def psf_test_setup(mock_inference_config): output_dim = 32 mock_positions = tf.convert_to_tensor([[0.1, 0.1], [0.2, 0.2]], dtype=tf.float32) - mock_seds = tf.convert_to_tensor(np.random.rand(num_sources, num_bins, 2), dtype=tf.float32) - expected_psfs = np.random.rand(num_sources, output_dim, output_dim).astype(np.float32) + mock_seds = tf.convert_to_tensor( + np.random.rand(num_sources, num_bins, 2), dtype=tf.float32 + ) + expected_psfs = np.random.rand(num_sources, output_dim, output_dim).astype( + np.float32 + ) inference = PSFInference( "dummy_path.yaml", x_field=[0.1, 0.2], y_field=[0.1, 0.2], - seds=np.random.rand(num_sources, num_bins, 2) + seds=np.random.rand(num_sources, num_bins, 2), ) inference._config_handler = MagicMock() inference._config_handler.inference_config = mock_inference_config @@ -116,9 +118,10 @@ def psf_test_setup(mock_inference_config): "expected_psfs": expected_psfs, "num_sources": num_sources, "num_bins": num_bins, - "output_dim": output_dim + "output_dim": output_dim, } + @pytest.fixture def psf_single_star_setup(mock_inference_config): num_sources = 1 @@ -128,14 +131,18 @@ def psf_single_star_setup(mock_inference_config): # Single position mock_positions = tf.convert_to_tensor([[0.1, 0.1]], dtype=tf.float32) # Shape (1, 2, num_bins) - mock_seds = tf.convert_to_tensor(np.random.rand(num_sources, 2, num_bins), dtype=tf.float32) - expected_psfs = np.random.rand(num_sources, output_dim, output_dim).astype(np.float32) + mock_seds = tf.convert_to_tensor( + np.random.rand(num_sources, 2, num_bins), dtype=tf.float32 + ) + expected_psfs = np.random.rand(num_sources, output_dim, output_dim).astype( + np.float32 + ) inference = PSFInference( "dummy_path.yaml", - x_field=0.1, # scalar for single star + x_field=0.1, # scalar for single star y_field=0.1, - seds=np.random.rand(num_bins, 2) # shape (num_bins, 2) before batching + seds=np.random.rand(num_bins, 2), # shape (num_bins, 2) before batching ) inference._config_handler = MagicMock() inference._config_handler.inference_config = mock_inference_config @@ -148,7 +155,7 @@ def psf_single_star_setup(mock_inference_config): "expected_psfs": expected_psfs, "num_sources": num_sources, "num_bins": num_bins, - "output_dim": output_dim + "output_dim": output_dim, } @@ -165,7 +172,9 @@ def test_set_config_paths(mock_inference_config): # Assertions assert config_handler.trained_model_path == Path("/path/to/trained/model") assert config_handler.model_subdir == "psf_model" - assert config_handler.trained_model_config_path == Path("/path/to/trained/model/config/training_config.yaml") + assert config_handler.trained_model_config_path == Path( + "/path/to/trained/model/config/training_config.yaml" + ) assert config_handler.data_config_path == None @@ -175,15 +184,19 @@ def test_overwrite_model_params(mock_training_config, mock_inference_config): training_config = mock_training_config inference_config = mock_inference_config - InferenceConfigHandler.overwrite_model_params( - training_config, inference_config - ) + InferenceConfigHandler.overwrite_model_params(training_config, inference_config) # Assert that the model_params were overwritten correctly - assert training_config.training.model_params.output_Q == 1, "output_Q should be overwritten" - assert training_config.training.model_params.output_dim == 64, "output_dim should be overwritten" - - assert training_config.training.id_name == "mock_id", "id_name should not be overwritten" + assert ( + training_config.training.model_params.output_Q == 1 + ), "output_Q should be overwritten" + assert ( + training_config.training.model_params.output_dim == 64 + ), "output_dim should be overwritten" + + assert ( + training_config.training.id_name == "mock_id" + ), "id_name should not be overwritten" def test_prepare_configs(mock_training_config, mock_inference_config): @@ -196,7 +209,7 @@ def test_prepare_configs(mock_training_config, mock_inference_config): original_model_params = mock_training_config.training.model_params # Instantiate PSFInference - psf_inf = PSFInference('/dummy/path.yaml') + psf_inf = PSFInference("/dummy/path.yaml") # Mock the config handler attribute with a mock InferenceConfigHandler mock_config_handler = MagicMock(spec=InferenceConfigHandler) @@ -204,7 +217,9 @@ def test_prepare_configs(mock_training_config, mock_inference_config): mock_config_handler.inference_config = inference_config # Patch the overwrite_model_params to use the real static method - mock_config_handler.overwrite_model_params.side_effect = InferenceConfigHandler.overwrite_model_params + mock_config_handler.overwrite_model_params.side_effect = ( + InferenceConfigHandler.overwrite_model_params + ) psf_inf._config_handler = mock_config_handler @@ -223,30 +238,43 @@ def test_config_handler_lazy_load(monkeypatch): class DummyHandler: def load_configs(self): - called['load'] = True + called["load"] = True self.inference_config = {} self.training_config = {} self.data_config = {} - def overwrite_model_params(self, *args): pass - monkeypatch.setattr("wf_psf.inference.psf_inference.InferenceConfigHandler", lambda path: DummyHandler()) + def overwrite_model_params(self, *args): + pass + + monkeypatch.setattr( + "wf_psf.inference.psf_inference.InferenceConfigHandler", + lambda path: DummyHandler(), + ) inference.prepare_configs() - assert 'load' in called # Confirm lazy load happened + assert "load" in called # Confirm lazy load happened + def test_batch_size_positive(): inference = PSFInference("dummy_path.yaml") inference._config_handler = MagicMock() inference._config_handler.inference_config = SimpleNamespace( - inference=SimpleNamespace(batch_size=4, model_params=SimpleNamespace(output_dim=32)) + inference=SimpleNamespace( + batch_size=4, model_params=SimpleNamespace(output_dim=32) + ) ) assert inference.batch_size == 4 -@patch('wf_psf.inference.psf_inference.DataHandler') -@patch('wf_psf.inference.psf_inference.load_trained_psf_model') -def test_load_inference_model(mock_load_trained_psf_model, mock_data_handler, mock_training_config, mock_inference_config): +@patch("wf_psf.inference.psf_inference.DataHandler") +@patch("wf_psf.inference.psf_inference.load_trained_psf_model") +def test_load_inference_model( + mock_load_trained_psf_model, + mock_data_handler, + mock_training_config, + mock_inference_config, +): mock_data_config = MagicMock() mock_data_handler.return_value = mock_data_config mock_config_handler = MagicMock(spec=InferenceConfigHandler) @@ -255,29 +283,33 @@ def test_load_inference_model(mock_load_trained_psf_model, mock_data_handler, mo mock_config_handler.inference_config = mock_inference_config mock_config_handler.model_subdir = "psf_model" mock_config_handler.data_config = MagicMock() - + psf_inf = PSFInference("dummy_path.yaml") psf_inf._config_handler = mock_config_handler psf_inf.load_inference_model() weights_path_pattern = os.path.join( - mock_config_handler.trained_model_path, - mock_config_handler.model_subdir, - f"{mock_config_handler.model_subdir}*_{mock_config_handler.training_config.training.model_params.model_name}*{mock_config_handler.training_config.training.id_name}_cycle{mock_config_handler.inference_config.inference.cycle}*" - ) + mock_config_handler.trained_model_path, + mock_config_handler.model_subdir, + f"{mock_config_handler.model_subdir}*_{mock_config_handler.training_config.training.model_params.model_name}*{mock_config_handler.training_config.training.id_name}_cycle{mock_config_handler.inference_config.inference.cycle}*", + ) # Assert calls to the mocked methods mock_load_trained_psf_model.assert_called_once_with( - mock_training_config, - mock_data_config, - weights_path_pattern + mock_training_config, mock_data_config, weights_path_pattern ) -@patch.object(PSFInference, 'prepare_configs') -@patch.object(PSFInference, '_prepare_positions_and_seds') -@patch.object(PSFInferenceEngine, 'compute_psfs') -def test_run_inference(mock_compute_psfs, mock_prepare_positions_and_seds, mock_prepare_configs, psf_test_setup): + +@patch.object(PSFInference, "prepare_configs") +@patch.object(PSFInference, "_prepare_positions_and_seds") +@patch.object(PSFInferenceEngine, "compute_psfs") +def test_run_inference( + mock_compute_psfs, + mock_prepare_positions_and_seds, + mock_prepare_configs, + psf_test_setup, +): inference = psf_test_setup["inference"] mock_positions = psf_test_setup["mock_positions"] mock_seds = psf_test_setup["mock_seds"] @@ -294,8 +326,11 @@ def test_run_inference(mock_compute_psfs, mock_prepare_positions_and_seds, mock_ mock_compute_psfs.assert_called_once_with(mock_positions, mock_seds) mock_prepare_configs.assert_called_once() + @patch("wf_psf.inference.psf_inference.psf_models.simPSF") -def test_simpsf_uses_updated_model_params(mock_simpsf, mock_training_config, mock_inference_config): +def test_simpsf_uses_updated_model_params( + mock_simpsf, mock_training_config, mock_inference_config +): """Test that simPSF uses the updated model parameters.""" training_config = mock_training_config inference_config = mock_inference_config @@ -315,7 +350,7 @@ def test_simpsf_uses_updated_model_params(mock_simpsf, mock_training_config, moc mock_config_handler.inference_config = inference_config mock_config_handler.model_subdir = "psf_model" mock_config_handler.data_config = MagicMock() - + modeller = PSFInference("dummy_path.yaml") modeller._config_handler = mock_config_handler @@ -330,9 +365,11 @@ def test_simpsf_uses_updated_model_params(mock_simpsf, mock_training_config, moc assert result.output_Q == expected_output_Q -@patch.object(PSFInference, '_prepare_positions_and_seds') -@patch.object(PSFInferenceEngine, 'compute_psfs') -def test_get_psfs_runs_inference(mock_compute_psfs, mock_prepare_positions_and_seds, psf_test_setup): +@patch.object(PSFInference, "_prepare_positions_and_seds") +@patch.object(PSFInferenceEngine, "compute_psfs") +def test_get_psfs_runs_inference( + mock_compute_psfs, mock_prepare_positions_and_seds, psf_test_setup +): inference = psf_test_setup["inference"] mock_positions = psf_test_setup["mock_positions"] mock_seds = psf_test_setup["mock_seds"] @@ -355,7 +392,6 @@ def fake_compute_psfs(positions, seds): assert mock_compute_psfs.call_count == 1 - def test_single_star_inference_shape(psf_single_star_setup): setup = psf_single_star_setup @@ -374,9 +410,11 @@ def test_single_star_inference_shape(psf_single_star_setup): input_array = args[0] # Check input SED had the right shape before being tensorized - assert input_array.shape == (1, setup["num_bins"], 2), \ - "process_sed_data should have been called with shape (1, num_bins, 2)" - + assert input_array.shape == ( + 1, + setup["num_bins"], + 2, + ), "process_sed_data should have been called with shape (1, num_bins, 2)" def test_multiple_star_inference_shape(psf_test_setup): @@ -398,9 +436,12 @@ def test_multiple_star_inference_shape(psf_test_setup): input_array = args[0] # Check input SED had the right shape before being tensorized - assert input_array.shape == (2, setup["num_bins"], 2), \ - "process_sed_data should have been called with shape (2, num_bins, 2)" - + assert input_array.shape == ( + 2, + setup["num_bins"], + 2, + ), "process_sed_data should have been called with shape (2, num_bins, 2)" + def test_valueerror_on_mismatched_batches(psf_single_star_setup): """Raise if sed_data batch size != positions batch size and sed_data != 1.""" @@ -416,7 +457,9 @@ def test_valueerror_on_mismatched_batches(psf_single_star_setup): inference.seds = bad_sed inference.positions = np.ones((1, 2), dtype=np.float32) - with pytest.raises(ValueError, match="SEDs batch size 2 does not match number of positions 1"): + with pytest.raises( + ValueError, match="SEDs batch size 2 does not match number of positions 1" + ): inference._prepare_positions_and_seds() finally: patcher.stop() @@ -435,7 +478,9 @@ def test_valueerror_on_mismatched_positions(psf_single_star_setup): inference.x_field = np.ones((3, 1), dtype=np.float32) inference.y_field = np.ones((3, 1), dtype=np.float32) - with pytest.raises(ValueError, match="SEDs batch size 2 does not match number of positions 3"): + with pytest.raises( + ValueError, match="SEDs batch size 2 does not match number of positions 3" + ): inference._prepare_positions_and_seds() finally: - patcher.stop() \ No newline at end of file + patcher.stop() diff --git a/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py b/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py index 95ad6287..e900a6d3 100644 --- a/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py +++ b/src/wf_psf/tests/test_psf_models/psf_model_physical_polychromatic_test.py @@ -32,16 +32,16 @@ def zks_prior(): def mock_data(mocker, zks_prior): mock_instance = mocker.Mock(spec=DataConfigHandler) mock_instance.run_type = "training" - + training_dataset = { "positions": np.array([[1, 2], [3, 4]]), "zernike_prior": zks_prior, - "noisy_stars": np.zeros((2, 1, 1, 1)), + "noisy_stars": np.zeros((2, 1, 1, 1)), } test_dataset = { "positions": np.array([[5, 6], [7, 8]]), "zernike_prior": zks_prior, - "stars": np.zeros((2, 1, 1, 1)), + "stars": np.zeros((2, 1, 1, 1)), } mock_instance.training_data = mocker.Mock() @@ -60,44 +60,53 @@ def mock_model_params(mocker): model_params_mock.pupil_diameter = 256 return model_params_mock + @pytest.fixture def physical_layer_instance(mocker, mock_model_params, mock_data): # Patch expensive methods during construction to avoid errors - with patch("wf_psf.psf_models.models.psf_model_physical_polychromatic.TFPhysicalPolychromaticField._assemble_zernike_contributions", return_value=tf.constant([[[[1.0]]], [[[2.0]]]])): - from wf_psf.psf_models.models.psf_model_physical_polychromatic import TFPhysicalPolychromaticField - instance = TFPhysicalPolychromaticField(mock_model_params, mocker.Mock(), mock_data) + with patch( + "wf_psf.psf_models.models.psf_model_physical_polychromatic.TFPhysicalPolychromaticField._assemble_zernike_contributions", + return_value=tf.constant([[[[1.0]]], [[[2.0]]]]), + ): + from wf_psf.psf_models.models.psf_model_physical_polychromatic import ( + TFPhysicalPolychromaticField, + ) + + instance = TFPhysicalPolychromaticField( + mock_model_params, mocker.Mock(), mock_data + ) return instance + def test_compute_zernikes(mocker, physical_layer_instance): # Expected output of mock components - padded_zernike_param = tf.constant([[[[10]], [[20]], [[30]], [[40]]]], dtype=tf.float32) + padded_zernike_param = tf.constant( + [[[[10]], [[20]], [[30]], [[40]]]], dtype=tf.float32 + ) padded_zernike_prior = tf.constant([[[[1]], [[2]], [[0]], [[0]]]], dtype=tf.float32) n_zks_total = physical_layer_instance.n_zks_total expected_values_list = [11, 22, 30, 40] + [0] * (n_zks_total - 4) expected_values = tf.constant( - [[[[v]] for v in expected_values_list]], - dtype=tf.float32 -) + [[[[v]] for v in expected_values_list]], dtype=tf.float32 + ) # Patch tf_poly_Z_field method mocker.patch.object( TFPhysicalPolychromaticField, "tf_poly_Z_field", - return_value=padded_zernike_param + return_value=padded_zernike_param, ) # Patch tf_physical_layer.call method mock_tf_physical_layer = mocker.Mock() mock_tf_physical_layer.call.return_value = padded_zernike_prior mocker.patch.object( - TFPhysicalPolychromaticField, - "tf_physical_layer", - mock_tf_physical_layer + TFPhysicalPolychromaticField, "tf_physical_layer", mock_tf_physical_layer ) # Patch pad_tf_zernikes function mocker.patch( "wf_psf.data.data_zernike_utils.pad_tf_zernikes", - return_value=(padded_zernike_param, padded_zernike_prior) + return_value=(padded_zernike_param, padded_zernike_prior), ) # Run the test diff --git a/src/wf_psf/tests/test_psf_models/psf_models_test.py b/src/wf_psf/tests/test_psf_models/psf_models_test.py index b7c906f6..2b907eff 100644 --- a/src/wf_psf/tests/test_psf_models/psf_models_test.py +++ b/src/wf_psf/tests/test_psf_models/psf_models_test.py @@ -10,7 +10,7 @@ from wf_psf.psf_models import psf_models from wf_psf.psf_models.models import ( psf_model_semiparametric, - psf_model_physical_polychromatic + psf_model_physical_polychromatic, ) import tensorflow as tf import numpy as np diff --git a/src/wf_psf/tests/test_utils/utils_test.py b/src/wf_psf/tests/test_utils/utils_test.py index dacf0bdc..cc7f2a2b 100644 --- a/src/wf_psf/tests/test_utils/utils_test.py +++ b/src/wf_psf/tests/test_utils/utils_test.py @@ -20,11 +20,6 @@ from unittest import mock - -def test_sanity(): - assert 1 + 1 == 2 - - def test_downsample_basic(): """Test apply_mask when a zeroed mask is provided.""" img_dim = (10, 10) @@ -37,9 +32,9 @@ def test_downsample_basic(): # The result should be an array of False values, as the mask excludes all pixels expected_result = np.zeros(img_dim, dtype=bool) - assert np.array_equal(result, expected_result), ( - "apply_mask did not handle the zeroed mask correctly." - ) + assert np.array_equal( + result, expected_result + ), "apply_mask did not handle the zeroed mask correctly." def test_initialization(): @@ -121,9 +116,9 @@ def test_apply_mask_with_none_mask(): result = estimator.apply_mask(None) # Pass None as the mask # It should return the window itself when no mask is provided - assert np.array_equal(result, estimator.window), ( - "apply_mask should return the window when mask is None." - ) + assert np.array_equal( + result, estimator.window + ), "apply_mask should return the window when mask is None." def test_apply_mask_with_valid_mask(): @@ -139,9 +134,9 @@ def test_apply_mask_with_valid_mask(): # Check that the mask was applied correctly: pixel (5, 5) should be False, others True expected_result = estimator.window & custom_mask - assert np.array_equal(result, expected_result), ( - "apply_mask did not apply the mask correctly." - ) + assert np.array_equal( + result, expected_result + ), "apply_mask did not apply the mask correctly." def test_apply_mask_with_zeroed_mask(): @@ -156,9 +151,9 @@ def test_apply_mask_with_zeroed_mask(): # The result should be an array of False values, as the mask excludes all pixels expected_result = np.zeros(img_dim, dtype=bool) - assert np.array_equal(result, expected_result), ( - "apply_mask did not handle the zeroed mask correctly." - ) + assert np.array_equal( + result, expected_result + ), "apply_mask did not handle the zeroed mask correctly." def test_unobscured_zernike_projection(): @@ -252,6 +247,7 @@ def test_tf_decompose_obscured_opd_basis(): assert rmse_error < tol + def test_downsample_basic(): """Downsample a small array to a smaller square size.""" arr = np.arange(16).reshape(4, 4).astype(np.float32) @@ -262,9 +258,10 @@ def test_downsample_basic(): assert result.shape == (output_dim, output_dim), "Output shape mismatch" # Values should be averaged/downsampled; simple check - assert np.all(result >= arr.min()) and np.all(result <= arr.max()), \ - "Values outside input range" - + assert np.all(result >= arr.min()) and np.all( + result <= arr.max() + ), "Values outside input range" + def test_downsample_identity(): """Downsample to the same size should return same array (approximately).""" @@ -274,10 +271,12 @@ def test_downsample_identity(): # Since OpenCV / skimage may do minor interpolation, allow small tolerance np.testing.assert_allclose(result, arr, rtol=1e-6, atol=1e-6) + # ---------------------------- # Backend fallback tests # ---------------------------- + @mock.patch("wf_psf.utils.utils._HAS_CV2", False) @mock.patch("wf_psf.utils.utils._HAS_SKIMAGE", False) def test_downsample_no_backend(): @@ -296,10 +295,11 @@ def test_downsample_values_average(): # All output values should be close to input value np.testing.assert_allclose(result, 3.0, rtol=1e-6, atol=1e-6) + @mock.patch("wf_psf.utils.utils._HAS_CV2", True) def test_downsample_non_square_array(): """Check downsampling works for non-square arrays.""" arr = np.arange(12).reshape(3, 4).astype(np.float32) output_dim = 2 result = downsample_im(arr, output_dim) - assert result.shape == (2, 2) \ No newline at end of file + assert result.shape == (2, 2) diff --git a/src/wf_psf/training/train.py b/src/wf_psf/training/train.py index f2366ca7..ab2f0ac1 100644 --- a/src/wf_psf/training/train.py +++ b/src/wf_psf/training/train.py @@ -274,10 +274,7 @@ def _prepare_callbacks( def get_loss_metrics_monitor_and_outputs(training_handler, data_conf): - """Generate factory for loss, metrics, monitor, and outputs. - - A function to generate loss, metrics, monitor, and outputs - for training. + """Factory to return fresh loss, metrics (param & non-param), monitor, and outputs for the current cycle. Parameters ---------- @@ -370,12 +367,12 @@ def train( psf_model_dir : str Directory where the final trained PSF model weights will be saved per cycle. - Notes - ----- - - Utilizes TensorFlow and TensorFlow Addons for model training and optimization. - - Supports masked mean squared error loss for training with masked data. - - Allows for projection of data-driven features onto parametric models between cycles. - - Supports resetting of non-parametric features to initial states. + Returns + ------- + None + + Side Effects + ------------ - Saves model weights to `psf_model_dir` per training cycle (or final one if not all saved) - Saves optimizer histories to `optimizer_dir` - Logs cycle information and time durations From 0e85c0f198417063b1458ffc11dee9e07bd92648 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Fri, 5 Sep 2025 19:31:46 +0100 Subject: [PATCH 110/135] Correct type hint errors --- src/wf_psf/data/data_handler.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/wf_psf/data/data_handler.py b/src/wf_psf/data/data_handler.py index c0ddf7b0..bdcf9a6b 100644 --- a/src/wf_psf/data/data_handler.py +++ b/src/wf_psf/data/data_handler.py @@ -241,7 +241,7 @@ def process_sed_data(self, sed_data): def extract_star_data(data, train_key: str, test_key: str) -> np.ndarray: - """ + """ Extract and concatenate star-related data from training and test datasets. This function retrieves arrays (e.g., postage stamps, masks, positions) from @@ -310,8 +310,8 @@ def get_data_array( key: str = None, train_key: str = None, test_key: str = None, - allow_missing: bool = False, -) -> np.ndarray | None: + allow_missing: bool = True, +) -> Optional[np.ndarray]: """ Retrieve data from dataset depending on run type. @@ -337,7 +337,7 @@ def get_data_array( test_key : str, optional Key for test dataset access. If None, defaults to the resolved train_key value. Default is None. - allow_missing : bool, default False + allow_missing : bool, default True Control behavior when data is missing or keys are not found: - True: Return None instead of raising exceptions - False: Raise appropriate exceptions (KeyError, ValueError) @@ -401,7 +401,7 @@ def get_data_array( raise -def _get_direct_data(data, key: str, allow_missing: bool) -> np.ndarray | None: +def _get_direct_data(data, key: str, allow_missing: bool) -> Optional[np.ndarray]: """ Extract data directly with proper error handling and type conversion. From 2bdd6225f509158dfe6480852948c766b4870b06 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Fri, 5 Sep 2025 19:32:16 +0100 Subject: [PATCH 111/135] Remove unused import --- src/wf_psf/instrument/ccd_misalignments.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/wf_psf/instrument/ccd_misalignments.py b/src/wf_psf/instrument/ccd_misalignments.py index f739c4fc..49bc8180 100644 --- a/src/wf_psf/instrument/ccd_misalignments.py +++ b/src/wf_psf/instrument/ccd_misalignments.py @@ -10,7 +10,6 @@ import numpy as np import matplotlib.path as mpltPath from scipy.spatial import KDTree -from wf_psf.data.data_handler import get_np_obs_positions def compute_ccd_misalignment(model_params, positions: np.ndarray) -> np.ndarray: From 6134d8e7cf7c6ed629b873c7f90ba0ae198dd7d8 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Fri, 5 Sep 2025 19:32:56 +0100 Subject: [PATCH 112/135] Replace call to deprecated get_np_obs_positions with get_data_array --- src/wf_psf/psf_models/tf_modules/tf_psf_field.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/wf_psf/psf_models/tf_modules/tf_psf_field.py b/src/wf_psf/psf_models/tf_modules/tf_psf_field.py index b2d1b53e..21c0f9a4 100644 --- a/src/wf_psf/psf_models/tf_modules/tf_psf_field.py +++ b/src/wf_psf/psf_models/tf_modules/tf_psf_field.py @@ -16,7 +16,7 @@ TFPhysicalLayer, ) from wf_psf.psf_models.models.psf_model_semiparametric import TFSemiParametricField -from wf_psf.data.data_handler import get_np_obs_positions +from wf_psf.data.data_handler import get_data_array from wf_psf.psf_models import psf_models as psfm from wf_psf.psf_models.tf_modules.tf_utils import ensure_tensor import logging @@ -222,7 +222,7 @@ def __init__(self, model_params, training_params, data, coeff_mat): self.output_Q = model_params.output_Q # Inputs: TF_physical_layer - self.obs_pos = ensure_tensor(get_np_obs_positions(data), dtype=tf.float32) + self.obs_pos = ensure_tensor(get_data_array(data, data.run_type, key="positions"), dtype=tf.float32) self.zks_prior = get_ground_truth_zernike(data) self.n_zks_prior = tf.shape(self.zks_prior)[1].numpy() From e309487a2963fde3c83063278ee1d5012d8229fa Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 31 Oct 2025 12:33:40 +0100 Subject: [PATCH 113/135] Remove unused imports and reformat --- src/wf_psf/data/centroids.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/wf_psf/data/centroids.py b/src/wf_psf/data/centroids.py index 20391793..01135428 100644 --- a/src/wf_psf/data/centroids.py +++ b/src/wf_psf/data/centroids.py @@ -8,9 +8,7 @@ import numpy as np import scipy.signal as scisig -from wf_psf.data.data_handler import extract_star_data from fractions import Fraction -import tensorflow as tf from typing import Optional @@ -252,7 +250,6 @@ def __init__( self, im, mask=None, sigma_init=7.5, n_iter=5, auto_run=True, xc=None, yc=None ): """Initialize class attributes.""" - # Convert to np.ndarray if not already im = np.asarray(im) if mask is not None: From 88325849ecdc4f7de49476b134135b9a71a1439c Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 31 Oct 2025 12:35:13 +0100 Subject: [PATCH 114/135] Update fixtures and unit tests --- src/wf_psf/tests/test_data/centroids_test.py | 38 ++--- src/wf_psf/tests/test_data/conftest.py | 16 +- .../tests/test_data/data_handler_test.py | 148 ++++++++++++++++-- .../test_data/data_zernike_utils_test.py | 30 +++- .../test_inference/psf_inference_test.py | 1 - 5 files changed, 192 insertions(+), 41 deletions(-) diff --git a/src/wf_psf/tests/test_data/centroids_test.py b/src/wf_psf/tests/test_data/centroids_test.py index 85719f4f..185da8d7 100644 --- a/src/wf_psf/tests/test_data/centroids_test.py +++ b/src/wf_psf/tests/test_data/centroids_test.py @@ -9,8 +9,6 @@ import numpy as np import pytest from wf_psf.data.centroids import compute_centroid_correction, CentroidEstimator -from wf_psf.data.data_handler import extract_star_data -from wf_psf.data.data_zernike_utils import compute_zernike_tip_tilt from wf_psf.utils.read_config import RecursiveNamespace from unittest.mock import MagicMock, patch @@ -116,24 +114,23 @@ def test_compute_centroid_correction_with_masks(mock_data): reference_shifts=["-1/3", "-1/3"], ) + # Wrap mock_data into a dict to match the function signature + centroid_dataset = { + "stamps": mock_data.training_data.dataset["noisy_stars"], + "masks": mock_data.training_data.dataset["masks"], + } + # Mock the internal function calls: with ( - patch("wf_psf.data.centroids.extract_star_data") as mock_extract_star_data, patch( "wf_psf.data.centroids.compute_zernike_tip_tilt" ) as mock_compute_zernike_tip_tilt, ): - - # Mock the return values of extract_star_data and compute_zernike_tip_tilt - mock_extract_star_data.side_effect = lambda data, train_key, test_key: ( - np.array([[1, 2], [3, 4]]) - if train_key == "noisy_stars" - else np.array([[5, 6], [7, 8]]) - ) + # Mock compute_zernike_tip_tilt to return synthetic Zernike coefficients mock_compute_zernike_tip_tilt.return_value = np.array([[0.1, 0.2], [0.3, 0.4]]) # Call the function under test - result = compute_centroid_correction(model_params, mock_data) + result = compute_centroid_correction(model_params, centroid_dataset) # Ensure the result has the correct shape assert result.shape == (4, 3) # Should be (n_stars, 3 Zernike components) @@ -148,10 +145,6 @@ def test_compute_centroid_correction_with_masks(mock_data): def test_compute_centroid_correction_without_masks(mock_data): """Test compute_centroid_correction function when no masks are provided.""" - # Remove masks from mock_data - mock_data.test_data.dataset["masks"] = None - mock_data.training_data.dataset["masks"] = None - # Define model parameters model_params = RecursiveNamespace( pix_sampling=12e-6, # Example pixel sampling in meters @@ -159,26 +152,23 @@ def test_compute_centroid_correction_without_masks(mock_data): reference_shifts=["-1/3", "-1/3"], ) + # Wrap mock_data into a dict to match the function signature + centroid_dataset = { + "stamps": mock_data.training_data.dataset["noisy_stars"], + } + # Mock internal function calls with ( - patch("wf_psf.data.centroids.extract_star_data") as mock_extract_star_data, patch( "wf_psf.data.centroids.compute_zernike_tip_tilt" ) as mock_compute_zernike_tip_tilt, ): - # Mock extract_star_data to return synthetic star postage stamps - mock_extract_star_data.side_effect = lambda data, train_key, test_key: ( - np.array([[1, 2], [3, 4]]) - if train_key == "noisy_stars" - else np.array([[5, 6], [7, 8]]) - ) - # Mock compute_zernike_tip_tilt assuming no masks mock_compute_zernike_tip_tilt.return_value = np.array([[0.1, 0.2], [0.3, 0.4]]) # Call function under test - result = compute_centroid_correction(model_params, mock_data) + result = compute_centroid_correction(model_params, centroid_dataset) # Validate result shape assert result.shape == (4, 3) # (n_stars, 3 Zernike components) diff --git a/src/wf_psf/tests/test_data/conftest.py b/src/wf_psf/tests/test_data/conftest.py index 47eed929..131922e5 100644 --- a/src/wf_psf/tests/test_data/conftest.py +++ b/src/wf_psf/tests/test_data/conftest.py @@ -11,9 +11,11 @@ import pytest import numpy as np import tensorflow as tf +from types import SimpleNamespace + from wf_psf.utils.read_config import RecursiveNamespace from wf_psf.psf_models import psf_models -from wf_psf.tests.test_data.test_data_utils import MockData, MockDataset +from wf_psf.tests.test_data.test_data_utils import MockData training_config = RecursiveNamespace( id_name="-coherent_euclid_200stars", @@ -129,6 +131,18 @@ def mock_data(scope="module"): ) +@pytest.fixture +def mock_data_inference(): + """Flat dataset for inference path only.""" + return SimpleNamespace( + dataset={ + "positions": np.array([[9, 9], [10, 10]]), + "zernike_prior": np.array([[0.9, 0.9]]), + # no "missing_key" → used to trigger allow_missing behavior + } + ) + + @pytest.fixture def simple_image(scope="module"): """Fixture for a simple star image.""" diff --git a/src/wf_psf/tests/test_data/data_handler_test.py b/src/wf_psf/tests/test_data/data_handler_test.py index c1310d21..d29771a1 100644 --- a/src/wf_psf/tests/test_data/data_handler_test.py +++ b/src/wf_psf/tests/test_data/data_handler_test.py @@ -3,12 +3,10 @@ import tensorflow as tf from wf_psf.data.data_handler import ( DataHandler, - get_np_obs_positions, + get_data_array, extract_star_data, ) from wf_psf.utils.read_config import RecursiveNamespace -import logging -from unittest.mock import patch def mock_sed(): @@ -151,12 +149,6 @@ def test_load_test_dataset_missing_stars(tmp_path, simPSF): data_handler.validate_and_process_dataset() -def test_get_np_obs_positions(mock_data): - observed_positions = get_np_obs_positions(mock_data) - expected_positions = np.array([[1, 2], [3, 4], [5, 6], [7, 8]]) - assert tf.reduce_all(tf.equal(observed_positions, expected_positions)) - - def test_extract_star_data_valid_keys(mock_data): """Test extracting valid data from the dataset.""" result = extract_star_data(mock_data, train_key="noisy_stars", test_key="stars") @@ -229,3 +221,141 @@ def test_reference_shifts_broadcasting(): # Test the result assert displacements.shape == shifts.shape, "Shapes do not match" assert np.all(displacements.shape == (2, 2400)), "Broadcasting failed" + + +@pytest.mark.parametrize( + "run_type,data_fixture,key,train_key,test_key,allow_missing,expect", + [ + # =================================================== + # training/simulation/metrics → extract_star_data path + # =================================================== + ( + "training", + "mock_data", + None, + "positions", + None, + False, + np.array([[1, 2], [3, 4], [5, 6], [7, 8]]), + ), + ( + "simulation", + "mock_data", + "none", + "noisy_stars", + "stars", + True, + # will concatenate noisy_stars from train and stars from test + # expected shape: (4, 5, 5) + # validate shape only, not full content (too large) + "shape:(4, 5, 5)", + ), + ( + "metrics", + "mock_data", + "zernike_prior", + None, + None, + True, + np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6], [0.7, 0.8]]), + ), + # ================= + # inference (success) + # ================= + ( + "inference", + "mock_data_inference", + "positions", + None, + None, + False, + np.array([[9, 9], [10, 10]]), + ), + ( + "inference", + "mock_data_inference", + "zernike_prior", + None, + None, + False, + np.array([[0.9, 0.9]]), + ), + # ============================== + # inference → allow_missing=True + # ============================== + ( + "inference", + "mock_data_inference", + None, + None, + None, + True, + None, + ), + ( + "inference", + "mock_data_inference", + "missing_key", + None, + None, + True, + None, + ), + # ================================= + # inference → allow_missing=False → errors + # ================================= + ( + "inference", + "mock_data_inference", + None, + None, + None, + False, + pytest.raises(ValueError), + ), + ( + "inference", + "mock_data_inference", + "missing_key", + None, + None, + False, + pytest.raises(KeyError), + ), + ], +) +def test_get_data_array_v2( + request, run_type, data_fixture, key, train_key, test_key, allow_missing, expect +): + data = request.getfixturevalue(data_fixture) + + if hasattr(expect, "__enter__") and hasattr(expect, "__exit__"): + with expect: + get_data_array( + data, + run_type, + key=key, + train_key=train_key, + test_key=test_key, + allow_missing=allow_missing, + ) + return + + result = get_data_array( + data, + run_type, + key=key, + train_key=train_key, + test_key=test_key, + allow_missing=allow_missing, + ) + + if expect is None: + assert result is None + elif isinstance(expect, str) and expect.startswith("shape:"): + expected_shape = tuple(eval(expect.replace("shape:", ""))) + assert isinstance(result, np.ndarray) + assert result.shape == expected_shape + else: + assert isinstance(result, np.ndarray) + assert np.allclose(result, expect, rtol=1e-6, atol=1e-8) diff --git a/src/wf_psf/tests/test_data/data_zernike_utils_test.py b/src/wf_psf/tests/test_data/data_zernike_utils_test.py index 390d10f9..66d23309 100644 --- a/src/wf_psf/tests/test_data/data_zernike_utils_test.py +++ b/src/wf_psf/tests/test_data/data_zernike_utils_test.py @@ -11,7 +11,6 @@ compute_zernike_tip_tilt, pad_tf_zernikes, ) -from wf_psf.tests.test_data.test_data_utils import MockData from types import SimpleNamespace as RecursiveNamespace @@ -46,7 +45,27 @@ def test_training_without_prior(mock_model_params, mock_data): data=mock_data, run_type="training", model_params=mock_model_params ) - assert zinputs.centroid_dataset is mock_data + mock_data_stamps = np.concatenate( + [ + mock_data.training_data.dataset["noisy_stars"], + mock_data.test_data.dataset["stars"], + ] + ) + mock_data_masks = np.concatenate( + [ + mock_data.training_data.dataset["masks"], + mock_data.test_data.dataset["masks"], + ] + ) + + assert np.allclose( + zinputs.centroid_dataset["stamps"], mock_data_stamps, rtol=1e-6, atol=1e-8 + ) + + assert np.allclose( + zinputs.centroid_dataset["masks"], mock_data_masks, rtol=1e-6, atol=1e-8 + ) + assert zinputs.zernike_prior is None expected_positions = np.concatenate( @@ -88,7 +107,7 @@ def test_training_with_explicit_prior(mock_model_params, caplog): data, "training", mock_model_params, prior=explicit_prior ) - assert "Zernike prior explicitly provided" in caplog.text + assert "Explicit prior provided; ignoring dataset-based prior." in caplog.text assert (zinputs.zernike_prior == explicit_prior).all() @@ -103,7 +122,8 @@ def test_inference_with_dict_and_prior(mock_model_params): zinputs = ZernikeInputsFactory.build(data, "inference", mock_model_params) - assert zinputs.centroid_dataset is None + for key in ["stamps", "masks"]: + assert zinputs.centroid_dataset[key] is None # NumPy array comparison np.testing.assert_array_equal( @@ -397,7 +417,6 @@ def test_pad_zernikes_param_greater_than_prior(): def test_compute_zernike_tip_tilt_single_batch(mocker, simple_image, identity_mask): """Test compute_zernike_tip_tilt handling with single batch input and mocks.""" - # Mock the CentroidEstimator class mock_centroid_calc = mocker.patch( "wf_psf.data.centroids.CentroidEstimator", autospec=True @@ -456,7 +475,6 @@ def test_compute_zernike_tip_tilt_single_batch(mocker, simple_image, identity_ma def test_compute_zernike_tip_tilt_batch(mocker, multiple_images): """Test compute_zernike_tip_tilt batch handling of multiple inputs.""" - # Mock the CentroidEstimator class mock_centroid_calc = mocker.patch( "wf_psf.data.centroids.CentroidEstimator", autospec=True diff --git a/src/wf_psf/tests/test_inference/psf_inference_test.py b/src/wf_psf/tests/test_inference/psf_inference_test.py index 05728de7..28a2c1af 100644 --- a/src/wf_psf/tests/test_inference/psf_inference_test.py +++ b/src/wf_psf/tests/test_inference/psf_inference_test.py @@ -161,7 +161,6 @@ def psf_single_star_setup(mock_inference_config): def test_set_config_paths(mock_inference_config): """Test setting configuration paths.""" - # Initialize handler and inject mock config config_handler = InferenceConfigHandler("fake/path") config_handler.inference_config = mock_inference_config From 529fcd3c8dcb330d76bfd0e11a0a9d2945fa2555 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 31 Oct 2025 12:36:32 +0100 Subject: [PATCH 115/135] Reformat with black --- src/wf_psf/data/data_handler.py | 41 ++- src/wf_psf/data/data_zernike_utils.py | 2 - src/wf_psf/data/old_zernike_prior.py | 335 ++++++++++++++++++ src/wf_psf/instrument/ccd_misalignments.py | 6 +- .../psf_model_physical_polychromatic.py | 3 - src/wf_psf/psf_models/psf_model_loader.py | 1 - src/wf_psf/psf_models/tf_modules/tf_layers.py | 1 - .../psf_models/tf_modules/tf_psf_field.py | 4 +- src/wf_psf/psf_models/tf_modules/tf_utils.py | 1 - .../test_metrics/metrics_interface_test.py | 2 +- .../tests/test_utils/configs_handler_test.py | 1 - src/wf_psf/utils/utils.py | 1 + 12 files changed, 361 insertions(+), 37 deletions(-) create mode 100644 src/wf_psf/data/old_zernike_prior.py diff --git a/src/wf_psf/data/data_handler.py b/src/wf_psf/data/data_handler.py index bdcf9a6b..052fe730 100644 --- a/src/wf_psf/data/data_handler.py +++ b/src/wf_psf/data/data_handler.py @@ -118,7 +118,6 @@ def __init__( `load_data=True` is used. - TensorFlow conversion is performed at the end of initialization via `convert_dataset_to_tensorflow()`. """ - self.dataset_type = dataset_type self.data_params = data_params self.simPSF = simPSF @@ -182,7 +181,6 @@ def _validate_dataset_structure(self): def _convert_dataset_to_tensorflow(self): """Convert dataset to TensorFlow tensors.""" - self.dataset["positions"] = ensure_tensor( self.dataset["positions"], dtype=tf.float32 ) @@ -244,8 +242,8 @@ def extract_star_data(data, train_key: str, test_key: str) -> np.ndarray: """ Extract and concatenate star-related data from training and test datasets. - This function retrieves arrays (e.g., postage stamps, masks, positions) from - both the training and test datasets using the specified keys, converts them + This function retrieves arrays (e.g., postage stamps, masks, positions) from + both the training and test datasets using the specified keys, converts them to NumPy if necessary, and concatenates them along the first axis. Parameters @@ -253,27 +251,27 @@ def extract_star_data(data, train_key: str, test_key: str) -> np.ndarray: data : DataConfigHandler Object containing training and test datasets. train_key : str - Key to retrieve data from the training dataset + Key to retrieve data from the training dataset (e.g., 'noisy_stars', 'masks'). test_key : str - Key to retrieve data from the test dataset + Key to retrieve data from the test dataset (e.g., 'stars', 'masks'). Returns ------- np.ndarray - Concatenated NumPy array containing the selected data from both + Concatenated NumPy array containing the selected data from both training and test sets. Raises ------ KeyError - If either the training or test dataset does not contain the + If either the training or test dataset does not contain the requested key. Notes ----- - - Designed for datasets with separate train/test splits, such as when + - Designed for datasets with separate train/test splits, such as when evaluating metrics on held-out data. - TensorFlow tensors are automatically converted to NumPy arrays. - Requires eager execution if TensorFlow tensors are present. @@ -304,6 +302,7 @@ def extract_star_data(data, train_key: str, test_key: str) -> np.ndarray: # Concatenate and return return np.concatenate((train_data, test_data), axis=0) + def get_data_array( data, run_type: str, @@ -334,7 +333,7 @@ def get_data_array( train_key : str, optional Key for training dataset access. If None and key is provided, defaults to key value. Default is None. - test_key : str, optional + test_key : str, optional Key for test dataset access. If None, defaults to the resolved train_key value. Default is None. allow_missing : bool, default True @@ -355,12 +354,12 @@ def get_data_array( resolved for the operation and allow_missing=False. KeyError If the specified key is not found in the dataset and allow_missing=False. - + Notes ----- Key resolution follows this priority order: 1. train_key = train_key or key - 2. test_key = test_key or resolved_train_key + 2. test_key = test_key or resolved_train_key 3. key = key or resolved_train_key (for inference fallback) For TensorFlow tensors, the .numpy() method is called to convert to NumPy. @@ -370,13 +369,13 @@ def get_data_array( -------- >>> # Training data retrieval >>> train_data = get_data_array(data, "training", train_key="noisy_stars") - + >>> # Inference with fallback handling - >>> inference_data = get_data_array(data, "inference", key="positions", + >>> inference_data = get_data_array(data, "inference", key="positions", ... allow_missing=True) >>> if inference_data is None: ... print("No inference data available") - + >>> # Using key parameter for both train and inference >>> result = get_data_array(data, "inference", key="positions") """ @@ -384,18 +383,18 @@ def get_data_array( valid_run_types = {"training", "simulation", "metrics", "inference"} if run_type not in valid_run_types: raise ValueError(f"run_type must be one of {valid_run_types}, got '{run_type}'") - + # Simplify key resolution with clear precedence effective_train_key = train_key or key effective_test_key = test_key or effective_train_key effective_key = key or effective_train_key - + try: if run_type in {"simulation", "training", "metrics"}: return extract_star_data(data, effective_train_key, effective_test_key) else: # inference return _get_direct_data(data, effective_key, allow_missing) - except Exception as e: + except Exception: if allow_missing: return None raise @@ -417,7 +416,7 @@ def _get_direct_data(data, key: str, allow_missing: bool) -> Optional[np.ndarray Returns ------- np.ndarray or None - Data converted to NumPy array, or None if allow_missing=True and + Data converted to NumPy array, or None if allow_missing=True and data is unavailable. Raises @@ -437,11 +436,11 @@ def _get_direct_data(data, key: str, allow_missing: bool) -> Optional[np.ndarray if allow_missing: return None raise ValueError("No key provided for inference data") - + value = data.dataset.get(key, None) if value is None: if allow_missing: return None raise KeyError(f"Key '{key}' not found in inference dataset") - + return value.numpy() if tf.is_tensor(value) else np.asarray(value) diff --git a/src/wf_psf/data/data_zernike_utils.py b/src/wf_psf/data/data_zernike_utils.py index 399ee9ef..0fad9c8e 100644 --- a/src/wf_psf/data/data_zernike_utils.py +++ b/src/wf_psf/data/data_zernike_utils.py @@ -181,7 +181,6 @@ def pad_tf_zernikes(zk_param: tf.Tensor, zk_prior: tf.Tensor, n_zks_total: int): padded_zk_prior : tf.Tensor Padded Zernike coefficients for the prior part. Shape [batch, n_zks_total, 1, 1]. """ - pad_num_param = n_zks_total - tf.shape(zk_param)[1] pad_num_prior = n_zks_total - tf.shape(zk_prior)[1] @@ -230,7 +229,6 @@ def assemble_zernike_contributions( tf.Tensor A tensor representing the full Zernike contribution map. """ - zernike_contribution_list = [] # Prior diff --git a/src/wf_psf/data/old_zernike_prior.py b/src/wf_psf/data/old_zernike_prior.py new file mode 100644 index 00000000..0feb3e70 --- /dev/null +++ b/src/wf_psf/data/old_zernike_prior.py @@ -0,0 +1,335 @@ +import numpy as np +import tensorflow as tf +from wf_psf.instrument.ccd_misalignments import CCDMisalignmentCalculator +from wf_psf.data.centroids import compute_zernike_tip_tilt +from fractions import Fraction +import logging + +logger = logging.getLogger(__name__) + + +def get_np_obs_positions(data): + """Get observed positions in numpy from the provided dataset. + + This method concatenates the positions of the stars from both the training + and test datasets to obtain the observed positions. + + Parameters + ---------- + data : DataConfigHandler + Object containing training and test datasets. + + Returns + ------- + np.ndarray + Numpy array containing the observed positions of the stars. + + Notes + ----- + The observed positions are obtained by concatenating the positions of stars + from both the training and test datasets along the 0th axis. + """ + obs_positions = np.concatenate( + ( + data.training_data.dataset["positions"], + data.test_data.dataset["positions"], + ), + axis=0, + ) + + return obs_positions + + +def get_obs_positions(data): + """Get observed positions from the provided dataset. + + Parameters + ---------- + data : DataConfigHandler + Object containing training and test datasets. + + Returns + ------- + tf.Tensor + Tensor containing the observed positions of the stars. + """ + obs_positions = get_np_obs_positions(data) + + return tf.convert_to_tensor(obs_positions, dtype=tf.float32) + + +def extract_star_data(data, train_key: str, test_key: str) -> np.ndarray: + """Extract specific star-related data from training and test datasets. + + This function retrieves and concatenates specific star-related data (e.g., stars, masks) from the + star training and test datasets such as star stamps or masks, based on the provided keys. + + Parameters + ---------- + data : DataConfigHandler + Object containing training and test datasets. + train_key : str + The key to retrieve data from the training dataset (e.g., 'noisy_stars', 'masks'). + test_key : str + The key to retrieve data from the test dataset (e.g., 'stars', 'masks'). + + Returns + ------- + np.ndarray + A NumPy array containing the concatenated data for the given keys. + + Raises + ------ + KeyError + If the specified keys do not exist in the training or test datasets. + + Notes + ----- + - If the dataset contains TensorFlow tensors, they will be converted to NumPy arrays. + - Ensure that eager execution is enabled when calling this function. + """ + # Ensure the requested keys exist in both training and test datasets + missing_keys = [ + key + for key, dataset in [ + (train_key, data.training_data.dataset), + (test_key, data.test_data.dataset), + ] + if key not in dataset + ] + + if missing_keys: + raise KeyError(f"Missing keys in dataset: {missing_keys}") + + # Retrieve data from training and test sets + train_data = data.training_data.dataset[train_key] + test_data = data.test_data.dataset[test_key] + + # Convert to NumPy if necessary + if tf.is_tensor(train_data): + train_data = train_data.numpy() + if tf.is_tensor(test_data): + test_data = test_data.numpy() + + # Concatenate and return + return np.concatenate((train_data, test_data), axis=0) + + +def get_np_zernike_prior(data): + """Get the zernike prior from the provided dataset. + + This method concatenates the stars from both the training + and test datasets to obtain the full prior. + + Parameters + ---------- + data : DataConfigHandler + Object containing training and test datasets. + + Returns + ------- + zernike_prior : np.ndarray + Numpy array containing the full prior. + """ + zernike_prior = np.concatenate( + ( + data.training_data.dataset["zernike_prior"], + data.test_data.dataset["zernike_prior"], + ), + axis=0, + ) + + return zernike_prior + + +def compute_centroid_correction(model_params, data, batch_size: int = 1) -> np.ndarray: + """Compute centroid corrections using Zernike polynomials. + + This function calculates the Zernike contributions required to match the centroid + of the WaveDiff PSF model to the observed star centroids, processing in batches. + + Parameters + ---------- + model_params : RecursiveNamespace + An object containing parameters for the PSF model, including pixel sampling and initial centroid window parameters. + + data : DataConfigHandler + An object containing training and test datasets, including observed PSFs + and optional star masks. + + batch_size : int, optional + The batch size to use when processing the stars. Default is 16. + + + Returns + ------- + zernike_centroid_array : np.ndarray + A 2D NumPy array of shape `(n_stars, 3)`, where `n_stars` is the number of + observed stars. The array contains the computed Zernike contributions, + with zero padding applied to the first column to ensure a consistent shape. + """ + star_postage_stamps = extract_star_data( + data=data, train_key="noisy_stars", test_key="stars" + ) + + # Get star mask catalogue only if "masks" exist in both training and test datasets + star_masks = ( + extract_star_data(data=data, train_key="masks", test_key="masks") + if ( + data.training_data.dataset.get("masks") is not None + and data.test_data.dataset.get("masks") is not None + and tf.size(data.training_data.dataset["masks"]) > 0 + and tf.size(data.test_data.dataset["masks"]) > 0 + ) + else None + ) + + pix_sampling = model_params.pix_sampling * 1e-6 # Change units from [um] to [m] + + # Ensure star_masks is properly handled + star_masks = star_masks if star_masks is not None else None + + reference_shifts = [ + float(Fraction(value)) for value in model_params.reference_shifts + ] + + n_stars = len(star_postage_stamps) + zernike_centroid_array = [] + + # Batch process the stars + for i in range(0, n_stars, batch_size): + batch_postage_stamps = star_postage_stamps[i : i + batch_size] + batch_masks = star_masks[i : i + batch_size] if star_masks is not None else None + + # Compute Zernike 1 and Zernike 2 for the batch + zk1_2_batch = -1.0 * compute_zernike_tip_tilt( + batch_postage_stamps, batch_masks, pix_sampling, reference_shifts + ) + + # Zero pad array for each batch and append + zernike_centroid_array.append( + np.pad( + zk1_2_batch, + pad_width=[(0, 0), (1, 0)], + mode="constant", + constant_values=0, + ) + ) + + # Combine all batches into a single array + return np.concatenate(zernike_centroid_array, axis=0) + + +def compute_ccd_misalignment(model_params, data): + """Compute CCD misalignment. + + Parameters + ---------- + model_params : RecursiveNamespace + Object containing parameters for this PSF model class. + data : DataConfigHandler + Object containing training and test datasets. + + Returns + ------- + zernike_ccd_misalignment_array : np.ndarray + Numpy array containing the Zernike contributions to model the CCD chip misalignments. + """ + obs_positions = get_np_obs_positions(data) + + ccd_misalignment_calculator = CCDMisalignmentCalculator( + tiles_path=model_params.ccd_misalignments_input_path, + x_lims=model_params.x_lims, + y_lims=model_params.y_lims, + tel_focal_length=model_params.tel_focal_length, + tel_diameter=model_params.tel_diameter, + ) + # Compute required zernike 4 for each position + zk4_values = np.array( + [ + ccd_misalignment_calculator.get_zk4_from_position(single_pos) + for single_pos in obs_positions + ] + ).reshape(-1, 1) + + # Zero pad array to get shape (n_stars, n_zernike=4) + zernike_ccd_misalignment_array = np.pad( + zk4_values, pad_width=[(0, 0), (3, 0)], mode="constant", constant_values=0 + ) + + return zernike_ccd_misalignment_array + + +def get_zernike_prior(model_params, data, batch_size: int = 16): + """Get Zernike priors from the provided dataset. + + This method concatenates the Zernike priors from both the training + and test datasets. + + Parameters + ---------- + model_params : RecursiveNamespace + Object containing parameters for this PSF model class. + data : DataConfigHandler + Object containing training and test datasets. + batch_size : int, optional + The batch size to use when processing the stars. Default is 16. + + Returns + ------- + tf.Tensor + Tensor containing the observed positions of the stars. + + Notes + ----- + The Zernike prior are obtained by concatenating the Zernike priors + from both the training and test datasets along the 0th axis. + + """ + # List of zernike contribution + zernike_contribution_list = [] + + if model_params.use_prior: + logger.info("Reading in Zernike prior into Zernike contribution list...") + zernike_contribution_list.append(get_np_zernike_prior(data)) + + if model_params.correct_centroids: + logger.info("Adding centroid correction to Zernike contribution list...") + zernike_contribution_list.append( + compute_centroid_correction(model_params, data, batch_size) + ) + + if model_params.add_ccd_misalignments: + logger.info("Adding CCD mis-alignments to Zernike contribution list...") + zernike_contribution_list.append(compute_ccd_misalignment(model_params, data)) + + if len(zernike_contribution_list) == 1: + zernike_contribution = zernike_contribution_list[0] + else: + # Get max zk order + max_zk_order = np.max( + np.array( + [ + zk_contribution.shape[1] + for zk_contribution in zernike_contribution_list + ] + ) + ) + + zernike_contribution = np.zeros( + (zernike_contribution_list[0].shape[0], max_zk_order) + ) + + # Pad arrays to get the same length and add the final contribution + for it in range(len(zernike_contribution_list)): + current_zk_order = zernike_contribution_list[it].shape[1] + current_zernike_contribution = np.pad( + zernike_contribution_list[it], + pad_width=[(0, 0), (0, int(max_zk_order - current_zk_order))], + mode="constant", + constant_values=0, + ) + + zernike_contribution += current_zernike_contribution + + return tf.convert_to_tensor(zernike_contribution, dtype=tf.float32) diff --git a/src/wf_psf/instrument/ccd_misalignments.py b/src/wf_psf/instrument/ccd_misalignments.py index 49bc8180..873509e5 100644 --- a/src/wf_psf/instrument/ccd_misalignments.py +++ b/src/wf_psf/instrument/ccd_misalignments.py @@ -161,11 +161,7 @@ def _preprocess_tile_data(self) -> None: self.tiles_z_average = np.mean(self.tiles_z_lims) def _initialize_polygons(self): - """Initialize polygons to look for CCD IDs. - - Each CCD is represented by a polygon defined by its corner points. - - """ + """Initialize polygons to look for CCD IDs""" # Build polygon list corresponding to each CCD self.ccd_polygons = [] diff --git a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py index e2c84868..fb8bc902 100644 --- a/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py +++ b/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py @@ -26,8 +26,6 @@ TFPhysicalLayer, ) from wf_psf.psf_models.tf_modules.tf_utils import ensure_tensor -from wf_psf.utils.read_config import RecursiveNamespace -from wf_psf.utils.configs_handler import DataConfigHandler import logging @@ -282,7 +280,6 @@ def tf_zernike_OPD(self): def _build_tf_batch_poly_PSF(self): """Eagerly build the TFBatchPolychromaticPSF layer with numpy-based obscurations.""" - return TFBatchPolychromaticPSF( obscurations=self.obscurations, output_Q=self.output_Q, diff --git a/src/wf_psf/psf_models/psf_model_loader.py b/src/wf_psf/psf_models/psf_model_loader.py index e445f7af..e41e3536 100644 --- a/src/wf_psf/psf_models/psf_model_loader.py +++ b/src/wf_psf/psf_models/psf_model_loader.py @@ -10,7 +10,6 @@ import logging from wf_psf.psf_models.psf_models import get_psf_model, get_psf_model_weights_filepath -import tensorflow as tf logger = logging.getLogger(__name__) diff --git a/src/wf_psf/psf_models/tf_modules/tf_layers.py b/src/wf_psf/psf_models/tf_modules/tf_layers.py index 7b22b5e8..cf11d9cd 100644 --- a/src/wf_psf/psf_models/tf_modules/tf_layers.py +++ b/src/wf_psf/psf_models/tf_modules/tf_layers.py @@ -999,7 +999,6 @@ def call(self, positions): If the shape of the input `positions` tensor is not compatible. """ - # Find indices for all positions in one batch operation idx = find_position_indices(self.obs_pos, positions) diff --git a/src/wf_psf/psf_models/tf_modules/tf_psf_field.py b/src/wf_psf/psf_models/tf_modules/tf_psf_field.py index 21c0f9a4..07b523d1 100644 --- a/src/wf_psf/psf_models/tf_modules/tf_psf_field.py +++ b/src/wf_psf/psf_models/tf_modules/tf_psf_field.py @@ -222,7 +222,9 @@ def __init__(self, model_params, training_params, data, coeff_mat): self.output_Q = model_params.output_Q # Inputs: TF_physical_layer - self.obs_pos = ensure_tensor(get_data_array(data, data.run_type, key="positions"), dtype=tf.float32) + self.obs_pos = ensure_tensor( + get_data_array(data, data.run_type, key="positions"), dtype=tf.float32 + ) self.zks_prior = get_ground_truth_zernike(data) self.n_zks_prior = tf.shape(self.zks_prior)[1].numpy() diff --git a/src/wf_psf/psf_models/tf_modules/tf_utils.py b/src/wf_psf/psf_models/tf_modules/tf_utils.py index 09540e60..4bd1246a 100644 --- a/src/wf_psf/psf_models/tf_modules/tf_utils.py +++ b/src/wf_psf/psf_models/tf_modules/tf_utils.py @@ -16,7 +16,6 @@ """ import tensorflow as tf -import numpy as np @tf.function diff --git a/src/wf_psf/tests/test_metrics/metrics_interface_test.py b/src/wf_psf/tests/test_metrics/metrics_interface_test.py index bf52f0aa..5672d648 100644 --- a/src/wf_psf/tests/test_metrics/metrics_interface_test.py +++ b/src/wf_psf/tests/test_metrics/metrics_interface_test.py @@ -1,6 +1,6 @@ from unittest.mock import patch, MagicMock import pytest -from wf_psf.metrics.metrics_interface import evaluate_model, MetricsParamsHandler +from wf_psf.metrics.metrics_interface import evaluate_model from wf_psf.data.data_handler import DataHandler diff --git a/src/wf_psf/tests/test_utils/configs_handler_test.py b/src/wf_psf/tests/test_utils/configs_handler_test.py index f8299997..57dfdc8a 100644 --- a/src/wf_psf/tests/test_utils/configs_handler_test.py +++ b/src/wf_psf/tests/test_utils/configs_handler_test.py @@ -13,7 +13,6 @@ from wf_psf.utils.io import FileIOHandler from wf_psf.utils.configs_handler import ( TrainingConfigHandler, - MetricsConfigHandler, DataConfigHandler, ) import os diff --git a/src/wf_psf/utils/utils.py b/src/wf_psf/utils/utils.py index 17219ad9..2027d01c 100644 --- a/src/wf_psf/utils/utils.py +++ b/src/wf_psf/utils/utils.py @@ -5,6 +5,7 @@ """ import numpy as np +from typing import Tuple import tensorflow as tf import PIL import zernike as zk From 6c3e9d099e6256edfc2ff1429691c988548191d9 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Mon, 8 Dec 2025 11:01:46 +0100 Subject: [PATCH 116/135] Remove outdated back module --- src/wf_psf/data/old_zernike_prior.py | 335 --------------------------- 1 file changed, 335 deletions(-) delete mode 100644 src/wf_psf/data/old_zernike_prior.py diff --git a/src/wf_psf/data/old_zernike_prior.py b/src/wf_psf/data/old_zernike_prior.py deleted file mode 100644 index 0feb3e70..00000000 --- a/src/wf_psf/data/old_zernike_prior.py +++ /dev/null @@ -1,335 +0,0 @@ -import numpy as np -import tensorflow as tf -from wf_psf.instrument.ccd_misalignments import CCDMisalignmentCalculator -from wf_psf.data.centroids import compute_zernike_tip_tilt -from fractions import Fraction -import logging - -logger = logging.getLogger(__name__) - - -def get_np_obs_positions(data): - """Get observed positions in numpy from the provided dataset. - - This method concatenates the positions of the stars from both the training - and test datasets to obtain the observed positions. - - Parameters - ---------- - data : DataConfigHandler - Object containing training and test datasets. - - Returns - ------- - np.ndarray - Numpy array containing the observed positions of the stars. - - Notes - ----- - The observed positions are obtained by concatenating the positions of stars - from both the training and test datasets along the 0th axis. - """ - obs_positions = np.concatenate( - ( - data.training_data.dataset["positions"], - data.test_data.dataset["positions"], - ), - axis=0, - ) - - return obs_positions - - -def get_obs_positions(data): - """Get observed positions from the provided dataset. - - Parameters - ---------- - data : DataConfigHandler - Object containing training and test datasets. - - Returns - ------- - tf.Tensor - Tensor containing the observed positions of the stars. - """ - obs_positions = get_np_obs_positions(data) - - return tf.convert_to_tensor(obs_positions, dtype=tf.float32) - - -def extract_star_data(data, train_key: str, test_key: str) -> np.ndarray: - """Extract specific star-related data from training and test datasets. - - This function retrieves and concatenates specific star-related data (e.g., stars, masks) from the - star training and test datasets such as star stamps or masks, based on the provided keys. - - Parameters - ---------- - data : DataConfigHandler - Object containing training and test datasets. - train_key : str - The key to retrieve data from the training dataset (e.g., 'noisy_stars', 'masks'). - test_key : str - The key to retrieve data from the test dataset (e.g., 'stars', 'masks'). - - Returns - ------- - np.ndarray - A NumPy array containing the concatenated data for the given keys. - - Raises - ------ - KeyError - If the specified keys do not exist in the training or test datasets. - - Notes - ----- - - If the dataset contains TensorFlow tensors, they will be converted to NumPy arrays. - - Ensure that eager execution is enabled when calling this function. - """ - # Ensure the requested keys exist in both training and test datasets - missing_keys = [ - key - for key, dataset in [ - (train_key, data.training_data.dataset), - (test_key, data.test_data.dataset), - ] - if key not in dataset - ] - - if missing_keys: - raise KeyError(f"Missing keys in dataset: {missing_keys}") - - # Retrieve data from training and test sets - train_data = data.training_data.dataset[train_key] - test_data = data.test_data.dataset[test_key] - - # Convert to NumPy if necessary - if tf.is_tensor(train_data): - train_data = train_data.numpy() - if tf.is_tensor(test_data): - test_data = test_data.numpy() - - # Concatenate and return - return np.concatenate((train_data, test_data), axis=0) - - -def get_np_zernike_prior(data): - """Get the zernike prior from the provided dataset. - - This method concatenates the stars from both the training - and test datasets to obtain the full prior. - - Parameters - ---------- - data : DataConfigHandler - Object containing training and test datasets. - - Returns - ------- - zernike_prior : np.ndarray - Numpy array containing the full prior. - """ - zernike_prior = np.concatenate( - ( - data.training_data.dataset["zernike_prior"], - data.test_data.dataset["zernike_prior"], - ), - axis=0, - ) - - return zernike_prior - - -def compute_centroid_correction(model_params, data, batch_size: int = 1) -> np.ndarray: - """Compute centroid corrections using Zernike polynomials. - - This function calculates the Zernike contributions required to match the centroid - of the WaveDiff PSF model to the observed star centroids, processing in batches. - - Parameters - ---------- - model_params : RecursiveNamespace - An object containing parameters for the PSF model, including pixel sampling and initial centroid window parameters. - - data : DataConfigHandler - An object containing training and test datasets, including observed PSFs - and optional star masks. - - batch_size : int, optional - The batch size to use when processing the stars. Default is 16. - - - Returns - ------- - zernike_centroid_array : np.ndarray - A 2D NumPy array of shape `(n_stars, 3)`, where `n_stars` is the number of - observed stars. The array contains the computed Zernike contributions, - with zero padding applied to the first column to ensure a consistent shape. - """ - star_postage_stamps = extract_star_data( - data=data, train_key="noisy_stars", test_key="stars" - ) - - # Get star mask catalogue only if "masks" exist in both training and test datasets - star_masks = ( - extract_star_data(data=data, train_key="masks", test_key="masks") - if ( - data.training_data.dataset.get("masks") is not None - and data.test_data.dataset.get("masks") is not None - and tf.size(data.training_data.dataset["masks"]) > 0 - and tf.size(data.test_data.dataset["masks"]) > 0 - ) - else None - ) - - pix_sampling = model_params.pix_sampling * 1e-6 # Change units from [um] to [m] - - # Ensure star_masks is properly handled - star_masks = star_masks if star_masks is not None else None - - reference_shifts = [ - float(Fraction(value)) for value in model_params.reference_shifts - ] - - n_stars = len(star_postage_stamps) - zernike_centroid_array = [] - - # Batch process the stars - for i in range(0, n_stars, batch_size): - batch_postage_stamps = star_postage_stamps[i : i + batch_size] - batch_masks = star_masks[i : i + batch_size] if star_masks is not None else None - - # Compute Zernike 1 and Zernike 2 for the batch - zk1_2_batch = -1.0 * compute_zernike_tip_tilt( - batch_postage_stamps, batch_masks, pix_sampling, reference_shifts - ) - - # Zero pad array for each batch and append - zernike_centroid_array.append( - np.pad( - zk1_2_batch, - pad_width=[(0, 0), (1, 0)], - mode="constant", - constant_values=0, - ) - ) - - # Combine all batches into a single array - return np.concatenate(zernike_centroid_array, axis=0) - - -def compute_ccd_misalignment(model_params, data): - """Compute CCD misalignment. - - Parameters - ---------- - model_params : RecursiveNamespace - Object containing parameters for this PSF model class. - data : DataConfigHandler - Object containing training and test datasets. - - Returns - ------- - zernike_ccd_misalignment_array : np.ndarray - Numpy array containing the Zernike contributions to model the CCD chip misalignments. - """ - obs_positions = get_np_obs_positions(data) - - ccd_misalignment_calculator = CCDMisalignmentCalculator( - tiles_path=model_params.ccd_misalignments_input_path, - x_lims=model_params.x_lims, - y_lims=model_params.y_lims, - tel_focal_length=model_params.tel_focal_length, - tel_diameter=model_params.tel_diameter, - ) - # Compute required zernike 4 for each position - zk4_values = np.array( - [ - ccd_misalignment_calculator.get_zk4_from_position(single_pos) - for single_pos in obs_positions - ] - ).reshape(-1, 1) - - # Zero pad array to get shape (n_stars, n_zernike=4) - zernike_ccd_misalignment_array = np.pad( - zk4_values, pad_width=[(0, 0), (3, 0)], mode="constant", constant_values=0 - ) - - return zernike_ccd_misalignment_array - - -def get_zernike_prior(model_params, data, batch_size: int = 16): - """Get Zernike priors from the provided dataset. - - This method concatenates the Zernike priors from both the training - and test datasets. - - Parameters - ---------- - model_params : RecursiveNamespace - Object containing parameters for this PSF model class. - data : DataConfigHandler - Object containing training and test datasets. - batch_size : int, optional - The batch size to use when processing the stars. Default is 16. - - Returns - ------- - tf.Tensor - Tensor containing the observed positions of the stars. - - Notes - ----- - The Zernike prior are obtained by concatenating the Zernike priors - from both the training and test datasets along the 0th axis. - - """ - # List of zernike contribution - zernike_contribution_list = [] - - if model_params.use_prior: - logger.info("Reading in Zernike prior into Zernike contribution list...") - zernike_contribution_list.append(get_np_zernike_prior(data)) - - if model_params.correct_centroids: - logger.info("Adding centroid correction to Zernike contribution list...") - zernike_contribution_list.append( - compute_centroid_correction(model_params, data, batch_size) - ) - - if model_params.add_ccd_misalignments: - logger.info("Adding CCD mis-alignments to Zernike contribution list...") - zernike_contribution_list.append(compute_ccd_misalignment(model_params, data)) - - if len(zernike_contribution_list) == 1: - zernike_contribution = zernike_contribution_list[0] - else: - # Get max zk order - max_zk_order = np.max( - np.array( - [ - zk_contribution.shape[1] - for zk_contribution in zernike_contribution_list - ] - ) - ) - - zernike_contribution = np.zeros( - (zernike_contribution_list[0].shape[0], max_zk_order) - ) - - # Pad arrays to get the same length and add the final contribution - for it in range(len(zernike_contribution_list)): - current_zk_order = zernike_contribution_list[it].shape[1] - current_zernike_contribution = np.pad( - zernike_contribution_list[it], - pad_width=[(0, 0), (0, int(max_zk_order - current_zk_order))], - mode="constant", - constant_values=0, - ) - - zernike_contribution += current_zernike_contribution - - return tf.convert_to_tensor(zernike_contribution, dtype=tf.float32) From af12a6cf905de372531a084f8cb05b928bbca6d9 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Mon, 8 Dec 2025 13:21:46 +0100 Subject: [PATCH 117/135] Remove unused import left after rebase --- src/wf_psf/utils/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/wf_psf/utils/utils.py b/src/wf_psf/utils/utils.py index 2027d01c..17219ad9 100644 --- a/src/wf_psf/utils/utils.py +++ b/src/wf_psf/utils/utils.py @@ -5,7 +5,6 @@ """ import numpy as np -from typing import Tuple import tensorflow as tf import PIL import zernike as zk From 07e88c7004c0fc3490649e57874d41b19af9f4ac Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 12 Dec 2025 13:17:20 +0100 Subject: [PATCH 118/135] Update doc strings and cache handling - Updated classes and methods with complete doc strings - Added two cache clearing methods to PSFInference and PSFInferenceEngine classes --- src/wf_psf/inference/psf_inference.py | 352 +++++++++++++++++++++++++- 1 file changed, 341 insertions(+), 11 deletions(-) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index c7d73249..dc7b5ac7 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -20,6 +20,37 @@ class InferenceConfigHandler: + """ + Handle configuration loading and management for PSF inference. + + This class manages the loading of inference, training, and data configuration + files required for PSF inference operations. + + Parameters + ---------- + inference_config_path : str + Path to the inference configuration YAML file. + + Attributes + ---------- + inference_config_path : str + Path to the inference configuration file. + inference_config : RecursiveNamespace or None + Loaded inference configuration. + training_config : RecursiveNamespace or None + Loaded training configuration. + data_config : RecursiveNamespace or None + Loaded data configuration. + trained_model_path : Path + Path to the trained model directory. + model_subdir : str + Subdirectory name for model files. + trained_model_config_path : Path + Path to the training configuration file. + data_config_path : str or None + Path to the data configuration file. + """ + ids = ("inference_conf",) def __init__(self, inference_config_path: str): @@ -29,7 +60,20 @@ def __init__(self, inference_config_path: str): self.data_config = None def load_configs(self): - """Load configuration files based on the inference config.""" + """ + Load configuration files based on the inference config. + + Loads the inference configuration first, then uses it to determine and load + the training and data configurations. + + Notes + ----- + Updates the following attributes in-place: + - inference_config + - training_config + - data_config (if data_config_path is specified) + """ + self.inference_config = read_conf(self.inference_config_path) self.set_config_paths() self.training_config = read_conf(self.trained_model_config_path) @@ -39,7 +83,15 @@ def load_configs(self): self.data_config = read_conf(self.data_config_path) def set_config_paths(self): - """Extract and set the configuration paths.""" + """ + Extract and set the configuration paths from the inference config. + + Sets the following attributes: + - trained_model_path + - model_subdir + - trained_model_config_path + - data_config_path + """ # Set config paths config_paths = self.inference_config.inference.configs @@ -97,6 +149,35 @@ class PSFInference: Postage stamps of sources, e.g. star images (shape: [n_stars, h, w]). masks : array-like, optional Corresponding masks for the sources (same shape as sources). Defaults to None. + + + Attributes + ---------- + inference_config_path : str + Path to the inference configuration file. + x_field : array-like or None + x coordinates for PSF positions. + y_field : array-like or None + y coordinates for PSF positions. + seds : array-like or None + Spectral energy distributions. + sources : array-like or None + Source postage stamps. + masks : array-like or None + Source masks. + engine : PSFInferenceEngine or None + The inference engine instance. + + Examples + -------- + >>> psf_inf = PSFInference( + ... inference_config_path="config.yaml", + ... x_field=[100.5, 200.3], + ... y_field=[150.2, 250.8], + ... seds=sed_array + ... ) + >>> psf_inf.run_inference() + >>> psf = psf_inf.get_psf(0) """ def __init__( @@ -133,13 +214,25 @@ def __init__( @property def config_handler(self): + """ + Get or create the configuration handler. + + Returns + ------- + InferenceConfigHandler + The configuration handler instance with loaded configs. + """ if self._config_handler is None: self._config_handler = InferenceConfigHandler(self.inference_config_path) self._config_handler.load_configs() return self._config_handler def prepare_configs(self): - """Prepare the configuration for inference.""" + """ + Prepare the configuration for inference. + + Overwrites training model parameters with inference configuration values. + """ # Overwrite model parameters with inference config self.config_handler.overwrite_model_params( self.training_config, self.inference_config @@ -147,24 +240,63 @@ def prepare_configs(self): @property def inference_config(self): + """ + Get the inference configuration. + + Returns + ------- + RecursiveNamespace + The inference configuration object. + """ return self.config_handler.inference_config @property def training_config(self): + """ + Get the training configuration. + + Returns + ------- + RecursiveNamespace + The training configuration object. + """ return self.config_handler.training_config @property def data_config(self): + """ + Get the data configuration. + + Returns + ------- + RecursiveNamespace or None + The data configuration object, or None if not available. + """ return self.config_handler.data_config @property def simPSF(self): + """ + Get or create the PSF simulator. + + Returns + ------- + simPSF + The PSF simulator instance. + """ if self._simPSF is None: self._simPSF = psf_models.simPSF(self.training_config.training.model_params) return self._simPSF def _prepare_dataset_for_inference(self): - """Prepare dataset dictionary for inference, returning None if positions are invalid.""" + """ + Prepare dataset dictionary for inference. + + Returns + ------- + dict or None + Dictionary containing positions, sources, and masks, or None if positions are invalid. + """ positions = self.get_positions() if positions is None: return None @@ -172,6 +304,14 @@ def _prepare_dataset_for_inference(self): @property def data_handler(self): + """ + Get or create the data handler. + + Returns + ------- + DataHandler + The data handler instance configured for inference. + """ if self._data_handler is None: # Instantiate the data handler self._data_handler = DataHandler( @@ -188,6 +328,14 @@ def data_handler(self): @property def trained_psf_model(self): + """ + Get or load the trained PSF model. + + Returns + ------- + Model + The loaded trained PSF model. + """ if self._trained_psf_model is None: self._trained_psf_model = self.load_inference_model() return self._trained_psf_model @@ -229,7 +377,19 @@ def get_positions(self): return np.column_stack((x_flat, y_flat)) def load_inference_model(self): - """Load the trained PSF model based on the inference configuration.""" + """Load the trained PSF model based on the inference configuration. + + Returns + ------- + Model + The loaded trained PSF model. + + Notes + ----- + Constructs the weights path pattern based on the trained model path, + model subdirectory, model name, id name, and cycle number specified in the + configuration files. + """ model_path = self.config_handler.trained_model_path model_dir = self.config_handler.model_subdir model_name = self.training_config.training.model_params.model_name @@ -250,6 +410,12 @@ def load_inference_model(self): @property def n_bins_lambda(self): + """Get the number of wavelength bins for inference. + + Returns + ------- + int + The number of wavelength bins used during inference.""" if self._n_bins_lambda is None: self._n_bins_lambda = ( self.inference_config.inference.model_params.n_bins_lda @@ -258,6 +424,14 @@ def n_bins_lambda(self): @property def batch_size(self): + """ + Get the batch size for inference. + + Returns + ------- + int + The batch size for processing during inference. + """ if self._batch_size is None: self._batch_size = self.inference_config.inference.batch_size assert self._batch_size > 0, "Batch size must be greater than 0." @@ -265,12 +439,26 @@ def batch_size(self): @property def cycle(self): + """Get the cycle number for inference. + + Returns + ------- + int + The cycle number used for loading the trained model. + """ if self._cycle is None: self._cycle = self.inference_config.inference.cycle return self._cycle @property def output_dim(self): + """Get the output dimension for PSF inference. + + Returns + ------- + int + The output dimension (height and width) of the inferred PSFs. + """ if self._output_dim is None: self._output_dim = self.inference_config.inference.model_params.output_dim return self._output_dim @@ -317,7 +505,18 @@ def _prepare_positions_and_seds(self): return positions, sed_data_tensor def run_inference(self): - """Run PSF inference and return the full PSF array.""" + """Run PSF inference and return the full PSF array. + + Returns + ------- + numpy.ndarray + Array of inferred PSFs with shape (n_samples, output_dim, output_dim). + + Notes + ----- + Prepares configurations and input data, initializes the inference engine, + and computes the PSF for all input positions. + """ # Prepare the configuration for inference self.prepare_configs() @@ -332,10 +531,25 @@ def run_inference(self): return self.engine.compute_psfs(positions, sed_data) def _ensure_psf_inference_completed(self): + """Ensure that PSF inference has been completed. + + Runs inference if it has not been done yet. + """ if self.engine is None or self.engine.inferred_psfs is None: self.run_inference() def get_psfs(self): + """Get all inferred PSFs. + + Returns + ------- + numpy.ndarray + Array of inferred PSFs with shape (n_samples, output_dim, output_dim). + + Notes + ----- + Ensures automatically that inference has been completed before accessing the PSFs. + """ self._ensure_psf_inference_completed() return self.engine.get_psfs() @@ -343,7 +557,21 @@ def get_psf(self, index: int = 0) -> np.ndarray: """ Get the PSF at a specific index. - If only a single star was passed during instantiation, the index defaults to 0. + Parameters + ---------- + index : int, optional + Index of the PSF to retrieve (default is 0). + + Returns + ------- + numpy.ndarray + The inferred PSF at the specified index with shape (output_dim, output_dim). + + Notes + ----- + Ensures automatically that inference has been completed before accessing the PSF. + If only a single star was passed during instantiation, the index defaults to 0 + and bounds checking is relaxed. """ self._ensure_psf_inference_completed() @@ -356,8 +584,60 @@ def get_psf(self, index: int = 0) -> np.ndarray: # Otherwise, return the PSF at the requested index return inferred_psfs[index] + def clear_cache(self): + """ + Clear all cached properties and reset the instance. + + This method resets all lazy-loaded properties, including the config handler, + PSF simulator, data handler, trained model, and inference engine. Useful for + freeing memory or forcing a fresh initialization. + + Notes + ----- + After calling this method, accessing any property will trigger re-initialization. + """ + self._config_handler = None + self._simPSF = None + self._data_handler = None + self._trained_psf_model = None + self._n_bins_lambda = None + self._batch_size = None + self._cycle = None + self._output_dim = None + self.engine = None + class PSFInferenceEngine: + """Engine to perform PSF inference using a trained model. + + This class handles the batch-wise computation of PSFs using a trained PSF model. + It manages the batching of input positions and SEDs, and caches the inferred PSFs for later access. + + Parameters + ---------- + trained_model : Model + The trained PSF model to use for inference. + batch_size : int + The batch size for processing during inference. + output_dim : int + The output dimension (height and width) of the inferred PSFs. + + Attributes + ---------- + trained_model : Model + The trained PSF model used for inference. + batch_size : int + The batch size for processing during inference. + output_dim : int + The output dimension (height and width) of the inferred PSFs. + + Examples + -------- + >>> engine = PSFInferenceEngine(model, batch_size=32, output_dim=64) + >>> psfs = engine.compute_psfs(positions, seds) + >>> single_psf = engine.get_psf(0) + """ + def __init__(self, trained_model, batch_size: int, output_dim: int): self.trained_model = trained_model self.batch_size = batch_size @@ -366,11 +646,35 @@ def __init__(self, trained_model, batch_size: int, output_dim: int): @property def inferred_psfs(self) -> np.ndarray: - """Access the cached inferred PSFs, if available.""" + """Access the cached inferred PSFs, if available. + + Returns + ------- + numpy.ndarray or None + The cached inferred PSFs, or None if not yet computed. + """ return self._inferred_psfs def compute_psfs(self, positions: tf.Tensor, sed_data: tf.Tensor) -> np.ndarray: - """Compute and cache PSFs for the input source parameters.""" + """Compute and cache PSFs for the input source parameters. + + Parameters + ---------- + positions : tf.Tensor + Tensor of shape (n_samples, 2) containing the (x, y) positions + sed_data : tf.Tensor + Tensor of shape (n_samples, n_bins, 2) containing the SEDs + + Returns + ------- + numpy.ndarray + Array of inferred PSFs with shape (n_samples, output_dim, output_dim). + + Notes + ----- + PSFs are computed in batches according to the specified batch_size. + Results are cached internally for subsequent access via get_psfs() or get_psf(). + """ n_samples = positions.shape[0] self._inferred_psfs = np.zeros( (n_samples, self.output_dim, self.output_dim), dtype=np.float32 @@ -397,13 +701,39 @@ def compute_psfs(self, positions: tf.Tensor, sed_data: tf.Tensor) -> np.ndarray: return self._inferred_psfs def get_psfs(self) -> np.ndarray: - """Get all the generated PSFs.""" + """Get all the generated PSFs. + + Returns + ------- + numpy.ndarray + Array of inferred PSFs with shape (n_samples, output_dim, output_dim). + """ if self._inferred_psfs is None: raise ValueError("PSFs not yet computed. Call compute_psfs() first.") return self._inferred_psfs def get_psf(self, index: int) -> np.ndarray: - """Get the PSF at a specific index.""" + """Get the PSF at a specific index. + + Returns + ------- + numpy.ndarray + The inferred PSF at the specified index with shape (output_dim, output_dim). + + Raises + ------ + ValueError + If PSFs have not yet been computed. + """ if self._inferred_psfs is None: raise ValueError("PSFs not yet computed. Call compute_psfs() first.") return self._inferred_psfs[index] + + def clear_cache(self): + """ + Clear cached inferred PSFs. + + Resets the internal PSF cache to free memory. After calling this method, + compute_psfs() must be called again before accessing PSFs. + """ + self._inferred_psfs = None From 43f894ffe4404ac6840dcaf17128538abda101fe Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 12 Dec 2025 13:20:14 +0100 Subject: [PATCH 119/135] Update psf_test_setup fixture for reusability and add unit tests for cache handling --- .../test_inference/psf_inference_test.py | 111 ++++++++++++++++-- 1 file changed, 98 insertions(+), 13 deletions(-) diff --git a/src/wf_psf/tests/test_inference/psf_inference_test.py b/src/wf_psf/tests/test_inference/psf_inference_test.py index 28a2c1af..4cff7a13 100644 --- a/src/wf_psf/tests/test_inference/psf_inference_test.py +++ b/src/wf_psf/tests/test_inference/psf_inference_test.py @@ -159,6 +159,44 @@ def psf_single_star_setup(mock_inference_config): } +@pytest.fixture +def mock_compute_psfs_with_cache(psf_test_setup): + """ + Fixture that patches PSFInferenceEngine.compute_psfs with a side effect + that populates the engine's cache. + + Returns + ------- + dict + Dictionary containing: + - 'mock': The mock object for compute_psfs + - 'inference': The PSFInference instance + - 'positions': Mock positions tensor + - 'seds': Mock SEDs tensor + - 'expected_psfs': Expected PSF array + """ + inference = psf_test_setup["inference"] + mock_positions = psf_test_setup["mock_positions"] + mock_seds = psf_test_setup["mock_seds"] + expected_psfs = psf_test_setup["expected_psfs"] + + with patch.object(PSFInferenceEngine, "compute_psfs") as mock_compute_psfs: + + def fake_compute_psfs(positions, seds): + inference.engine._inferred_psfs = expected_psfs + return expected_psfs + + mock_compute_psfs.side_effect = fake_compute_psfs + + yield { + "mock": mock_compute_psfs, + "inference": inference, + "positions": mock_positions, + "seds": mock_seds, + "expected_psfs": expected_psfs, + } + + def test_set_config_paths(mock_inference_config): """Test setting configuration paths.""" # Initialize handler and inject mock config @@ -365,30 +403,25 @@ def test_simpsf_uses_updated_model_params( @patch.object(PSFInference, "_prepare_positions_and_seds") -@patch.object(PSFInferenceEngine, "compute_psfs") def test_get_psfs_runs_inference( - mock_compute_psfs, mock_prepare_positions_and_seds, psf_test_setup + mock_prepare_positions_and_seds, mock_compute_psfs_with_cache ): - inference = psf_test_setup["inference"] - mock_positions = psf_test_setup["mock_positions"] - mock_seds = psf_test_setup["mock_seds"] - expected_psfs = psf_test_setup["expected_psfs"] + """Test that get_psfs uses cached PSFs after first computation.""" + mock = mock_compute_psfs_with_cache["mock"] + inference = mock_compute_psfs_with_cache["inference"] + mock_positions = mock_compute_psfs_with_cache["positions"] + mock_seds = mock_compute_psfs_with_cache["seds"] + expected_psfs = mock_compute_psfs_with_cache["expected_psfs"] mock_prepare_positions_and_seds.return_value = (mock_positions, mock_seds) - def fake_compute_psfs(positions, seds): - inference.engine._inferred_psfs = expected_psfs - return expected_psfs - - mock_compute_psfs.side_effect = fake_compute_psfs - psfs_1 = inference.get_psfs() assert np.all(psfs_1 == expected_psfs) psfs_2 = inference.get_psfs() assert np.all(psfs_2 == expected_psfs) - assert mock_compute_psfs.call_count == 1 + assert mock.call_count == 1 def test_single_star_inference_shape(psf_single_star_setup): @@ -483,3 +516,55 @@ def test_valueerror_on_mismatched_positions(psf_single_star_setup): inference._prepare_positions_and_seds() finally: patcher.stop() + + +def test_inference_clear_cache(psf_test_setup): + """Test that PSFInference clear_cache resets the instance of PSFInference.""" + inference = psf_test_setup["inference"] + inference._simPSF = MagicMock() + inference._data_handler = MagicMock() + inference._trained_psf_model = MagicMock() + inference._n_bins_lambda = MagicMock() + inference._batch_size = MagicMock() + inference._cycle = MagicMock() + inference._output_dim = MagicMock() + inference.engine = MagicMock() + + # Clear the cache + inference.clear_cache() + + # Check that the internal cache is None + assert ( + inference._config_handler == None, + inference._simPSF == None, + inference._data_handler == None, + inference._trained_psf_model == None, + inference._n_bins_lambda == None, + inference._batch_size == None, + inference._cycle == None, + inference._output_dim == None, + inference.engine == None, + ), "Inference attributes should be cleared to None" # type: ignore + + +def test_engine_clear_cache(psf_test_setup): + """Test that clear_cache resets the internal PSF cache.""" + inference = psf_test_setup["inference"] + expected_psfs = psf_test_setup["expected_psfs"] + + # Create the engine and compute PSFs + inference.engine = PSFInferenceEngine( + trained_model=inference.trained_psf_model, + batch_size=inference.batch_size, + output_dim=inference.output_dim, + ) + + inference.engine._inferred_psfs = expected_psfs + + # Clear the cache + inference.engine.clear_cache() + + # Check that the internal cache is None + assert ( + inference.engine._inferred_psfs is None + ), "PSF cache should be cleared to None" From a88c455357a540ddc0ce07271e70995bea387c85 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 12 Dec 2025 13:21:50 +0100 Subject: [PATCH 120/135] Remove unneeded mock of logger in evaluate_model --- src/wf_psf/tests/test_metrics/metrics_interface_test.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/wf_psf/tests/test_metrics/metrics_interface_test.py b/src/wf_psf/tests/test_metrics/metrics_interface_test.py index 5672d648..2d6b6553 100644 --- a/src/wf_psf/tests/test_metrics/metrics_interface_test.py +++ b/src/wf_psf/tests/test_metrics/metrics_interface_test.py @@ -200,8 +200,6 @@ def test_evaluate_model( ) as mock_evaluate_shape_results_dict, patch("numpy.save", new_callable=MagicMock) as mock_np_save, ): - # Mock the logger - _ = mocker.patch("wf_psf.metrics.metrics_interface.logger") # Call evaluate_model evaluate_model( From 6f0e21dcba30a2f3fdff93134e89cf3689a94c76 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 12 Dec 2025 13:22:16 +0100 Subject: [PATCH 121/135] Add wf_psf.inference to list of main packages --- docs/source/api.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/api.rst b/docs/source/api.rst index 3bc4d395..cda87632 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -8,6 +8,7 @@ This section contains the API reference for the main packages in WaveDiff. :recursive: wf_psf.data + wf_psf.inference wf_psf.metrics wf_psf.plotting wf_psf.psf_models From cd715c30a78ef5049df2e9fa16c3eef86e0a68e4 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 12 Dec 2025 13:30:18 +0100 Subject: [PATCH 122/135] Update version to 3.1.0 --- docs/source/conf.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 988df8fe..01edfeb6 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -21,7 +21,7 @@ else: copyright = f"{start_year}, CosmoStat" author = "CosmoStat" -release = "3.0.0" +release = "3.1.0" # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration diff --git a/pyproject.toml b/pyproject.toml index 3107f12c..ec0c8d34 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ dependencies = [ "seaborn", ] -version = "3.0.0" +version = "3.1.0" [project.optional-dependencies] docs = [ From 5d54518e99e16586fba8f2323901b830a19b0915 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 12 Dec 2025 13:34:11 +0100 Subject: [PATCH 123/135] docs: Correct syntax error and code-block formatting in doc string examples --- src/wf_psf/inference/psf_inference.py | 30 ++++++++++++++++----------- 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/src/wf_psf/inference/psf_inference.py b/src/wf_psf/inference/psf_inference.py index dc7b5ac7..70d7a9be 100644 --- a/src/wf_psf/inference/psf_inference.py +++ b/src/wf_psf/inference/psf_inference.py @@ -151,7 +151,7 @@ class PSFInference: Corresponding masks for the sources (same shape as sources). Defaults to None. - Attributes + Attributes ---------- inference_config_path : str Path to the inference configuration file. @@ -170,14 +170,18 @@ class PSFInference: Examples -------- - >>> psf_inf = PSFInference( - ... inference_config_path="config.yaml", - ... x_field=[100.5, 200.3], - ... y_field=[150.2, 250.8], - ... seds=sed_array - ... ) - >>> psf_inf.run_inference() - >>> psf = psf_inf.get_psf(0) + Basic usage with position coordinates and SEDs: + + .. code-block:: python + + psf_inf = PSFInference( + inference_config_path="config.yaml", + x_field=[100.5, 200.3], + y_field=[150.2, 250.8], + seds=sed_array + ) + psf_inf.run_inference() + psf = psf_inf.get_psf(0) """ def __init__( @@ -633,9 +637,11 @@ class PSFInferenceEngine: Examples -------- - >>> engine = PSFInferenceEngine(model, batch_size=32, output_dim=64) - >>> psfs = engine.compute_psfs(positions, seds) - >>> single_psf = engine.get_psf(0) + .. code-block:: python + + >>> engine = PSFInferenceEngine(model, batch_size=32, output_dim=64) + >>> psfs = engine.compute_psfs(positions, seds) + >>> single_psf = engine.get_psf(0) """ def __init__(self, trained_model, batch_size: int, output_dim: int): From b5da83f8d781a64d61cf5898de8d902b8a7283a8 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 23 Jan 2026 13:58:22 +0100 Subject: [PATCH 124/135] Remove weights_path references missed during rebase Rebase dropped commits that removed weights_path from the metrics interface. Cleaning up remaining references in docstrings and tests. --- src/wf_psf/metrics/metrics_interface.py | 2 -- src/wf_psf/tests/test_metrics/metrics_interface_test.py | 3 --- 2 files changed, 5 deletions(-) diff --git a/src/wf_psf/metrics/metrics_interface.py b/src/wf_psf/metrics/metrics_interface.py index bd6249e2..db410f4e 100644 --- a/src/wf_psf/metrics/metrics_interface.py +++ b/src/wf_psf/metrics/metrics_interface.py @@ -328,8 +328,6 @@ def evaluate_model( DataHandler object containing training and test data psf_model: object PSF model object - weights_path: str - Directory location of model weights metrics_output: str Directory location of metrics output diff --git a/src/wf_psf/tests/test_metrics/metrics_interface_test.py b/src/wf_psf/tests/test_metrics/metrics_interface_test.py index 2d6b6553..35f3b9a1 100644 --- a/src/wf_psf/tests/test_metrics/metrics_interface_test.py +++ b/src/wf_psf/tests/test_metrics/metrics_interface_test.py @@ -106,7 +106,6 @@ def test_evaluate_model_flags( trained_model_params=mock_trained_model_params, data=mock_data, psf_model=mock_psf_model, - weights_path="/mock/path", metrics_output="/mock/output", ) @@ -134,7 +133,6 @@ def test_missing_ground_truth_model_raises( trained_model_params=mock_trained_model_params, data=mock_data, psf_model=mock_psf_model, - weights_path="/mock/weights/path", metrics_output="/mock/metrics/output", ) @@ -168,7 +166,6 @@ def test_plotting_config_passed( trained_model_params=mock_trained_model_params, data=mock_data, psf_model=mock_psf_model, - weights_path="/mock/path", metrics_output="/mock/output", ) From 98b66e94ea234153c420ac9db056d7f7ad49db9f Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 23 Jan 2026 16:22:37 +0100 Subject: [PATCH 125/135] Add changelog fragment --- ...llack_159_psf_output_from_trained_model.md | 47 +++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100644 changelog.d/20260123_160331_jennifer.pollack_159_psf_output_from_trained_model.md diff --git a/changelog.d/20260123_160331_jennifer.pollack_159_psf_output_from_trained_model.md b/changelog.d/20260123_160331_jennifer.pollack_159_psf_output_from_trained_model.md new file mode 100644 index 00000000..1206fbcd --- /dev/null +++ b/changelog.d/20260123_160331_jennifer.pollack_159_psf_output_from_trained_model.md @@ -0,0 +1,47 @@ + + + + +### New features + +- Added PSF inference capabilities for generating broadband (polychromatic) PSFs from trained models given star positions and SEDs +- Introduced `PSFInferenceEngine` class to centralize training, simulation, metrics, and inference workflows +- Added `run_type` attribute to `DataHandler` supporting training, simulation, metrics, and inference modes +- Implemented `ZernikeInputsFactory` class for building `ZernikeInputs` instances based on run type +- Added `psf_model_loader.py` module for centralized model weights loading + + + + + +### Internal changes + +- Refactored `TFPhysicalPolychromatic` and related modules to separate training vs. inference behavior +- Enhanced `ZernikeInputs` data class with intelligent assembly based on run type and available data +- Implemented hybrid loading pattern with eager loading in constructors and lazy-loading via property decorators +- Centralized PSF data extraction in `data_handler` module +- Improved code organization with new `tf_utils.py` module in `psf_models` sub-package +- Updated configuration handling to support inference workflows via `inference_config.yaml` +- Fixed incorrect argument name in `DataHandler` that prevented proper TensorFlow data type conversion +- Removed deprecated `get_obs_positions` method +- Updated documentation to include inference package From 33ad4f5ff64b51527fd5ba12db8b35746879a179 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Mon, 26 Jan 2026 12:57:35 +0100 Subject: [PATCH 126/135] Fix logger formatting for relative RMSE metrics Use f-strings instead of %-formatting to properly display percent symbols in metric output. --- src/wf_psf/metrics/metrics.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/wf_psf/metrics/metrics.py b/src/wf_psf/metrics/metrics.py index 0447d596..942a0622 100644 --- a/src/wf_psf/metrics/metrics.py +++ b/src/wf_psf/metrics/metrics.py @@ -152,8 +152,7 @@ def compute_poly_metric( # Print RMSE values logger.info("Absolute RMSE:\t %.4e \t +/- %.4e", rmse, std_rmse) - logger.info("Relative RMSE:\t %.4e % \t +/- %.4e %", rel_rmse, std_rel_rmse) - + logger.info(f"Relative RMSE:\t {rel_rmse:.4e}% \t +/- {std_rel_rmse:.4e}%") return rmse, rel_rmse, std_rmse, std_rel_rmse @@ -364,9 +363,8 @@ def compute_opd_metrics(tf_semiparam_field, gt_tf_semiparam_field, pos, batch_si rel_rmse_std = np.std(rel_rmse_vals) # Print RMSE values - logger.info("Absolute RMSE:\t %.4e % \t +/- %.4e %", rmse, rmse_std) - logger.info("Relative RMSE:\t %.4e % \t +/- %.4e %", rel_rmse, rel_rmse_std) - + logger.info("Absolute RMSE:\t %.4e \t +/- %.4e" % (rmse, rmse_std)) + logger.info(f"Relative RMSE:\t {rel_rmse:.4e}% \t +/- {rel_rmse_std:.4e}%") return rmse, rel_rmse, rmse_std, rel_rmse_std @@ -596,10 +594,10 @@ def compute_shape_metrics( # Print relative shape/size errors logger.info( - f"\nRelative sigma(e1) RMSE =\t {rel_rmse_e1:.4e} % \t +/- {std_rel_rmse_e1:.4e} %" + f"\nRelative sigma(e1) RMSE =\t {rel_rmse_e1:.4e}% \t +/- {std_rel_rmse_e1:.4e}%" ) logger.info( - f"Relative sigma(e2) RMSE =\t {rel_rmse_e2:.4e} % \t +/- {std_rel_rmse_e2:.4e} %" + f"Relative sigma(e2) RMSE =\t {rel_rmse_e2:.4e}% \t +/- {std_rel_rmse_e2:.4e}%" ) # Print number of stars From 63e6acca71b16f9836cda3ad51757ffa647b0cf7 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Mon, 26 Jan 2026 13:05:41 +0100 Subject: [PATCH 127/135] Update changelog with entry under Bug Fixes --- ...0331_jennifer.pollack_159_psf_output_from_trained_model.md | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/changelog.d/20260123_160331_jennifer.pollack_159_psf_output_from_trained_model.md b/changelog.d/20260123_160331_jennifer.pollack_159_psf_output_from_trained_model.md index 1206fbcd..453a504e 100644 --- a/changelog.d/20260123_160331_jennifer.pollack_159_psf_output_from_trained_model.md +++ b/changelog.d/20260123_160331_jennifer.pollack_159_psf_output_from_trained_model.md @@ -21,12 +21,10 @@ For top level release notes, leave all the headers commented out. - Added `psf_model_loader.py` module for centralized model weights loading - + + + + + + +### Internal changes + +- Remove deprecated/optional import tensorflow-addons statement from tf_layers.py + + diff --git a/src/wf_psf/psf_models/tf_modules/tf_layers.py b/src/wf_psf/psf_models/tf_modules/tf_layers.py index cf11d9cd..cdd01e16 100644 --- a/src/wf_psf/psf_models/tf_modules/tf_layers.py +++ b/src/wf_psf/psf_models/tf_modules/tf_layers.py @@ -7,7 +7,6 @@ """ import tensorflow as tf -import tensorflow_addons as tfa from wf_psf.psf_models.tf_modules.tf_modules import TFMonochromaticPSF from wf_psf.psf_models.tf_modules.tf_utils import find_position_indices from wf_psf.utils.utils import calc_poly_position_mat From ae2c4170c93544589fcac662a2c9e879f90759f4 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 13 Feb 2026 14:30:21 +0100 Subject: [PATCH 130/135] Fix Sphinx autosummary imports and improve comments - Remove tensorflow-addons from environment.yml file - Remove simulation parameters from example data_config.yaml file - Naming consistency - rename inference configuration YAML file - Improve inline comments in example metrics, plotting and training configuration files - Update pyproject.toml with setup tools to ensure all packages are found by sphinx autosummary - Remove tensorflow from autodoc_mock_imports list - Add docstring to conf.py - Add scriv fragment describing internal changes --- ...nnifer.pollack_api_documentation_update.md | 40 +++++++++++++++++++ config/data_config.yaml | 37 +---------------- ...erence_conf.yaml => inference_config.yaml} | 4 +- config/metrics_config.yaml | 5 +-- config/plotting_config.yaml | 21 ++++++---- config/training_config.yaml | 10 ++--- docs/source/conf.py | 23 ++++++----- environment.yml | 1 - pyproject.toml | 11 ++++- 9 files changed, 83 insertions(+), 69 deletions(-) create mode 100644 changelog.d/20260213_141819_jennifer.pollack_api_documentation_update.md rename config/{inference_conf.yaml => inference_config.yaml} (97%) diff --git a/changelog.d/20260213_141819_jennifer.pollack_api_documentation_update.md b/changelog.d/20260213_141819_jennifer.pollack_api_documentation_update.md new file mode 100644 index 00000000..0239f903 --- /dev/null +++ b/changelog.d/20260213_141819_jennifer.pollack_api_documentation_update.md @@ -0,0 +1,40 @@ + + + + + + + +### Internal changes + +- Fixed Sphinx autosummary import errors by removing core dependencies (tensorflow) from autodoc_mock_imports in conf.py. +- Updated pyproject.toml to include all wf_psf packages under src/ and include config/yaml files. +- Updated example configuration files with clearer inline comments. + + + diff --git a/config/data_config.yaml b/config/data_config.yaml index 02a15341..ca6939af 100644 --- a/config/data_config.yaml +++ b/config/data_config.yaml @@ -5,43 +5,8 @@ data: data_dir: data/coherent_euclid_dataset/ # Provide name of training dataset file: train_Euclid_res_200_TrainStars_id_001.npy - # if training data set file does not exist, generate a new one by setting values below - stars: null - positions: null - SEDS: null - zernike_coef: null - C_poly: null - params: # - d_max: 2 - max_order: 45 - x_lims: [0, 1000.0] - y_lims: [0, 1000.0] - grid_points: [4, 4] - n_bins: 20 - max_wfe_rms: 0.1 - oversampling_rate: 3.0 - output_Q: 3.0 - output_dim: 32 - LP_filter_length: 2 - pupil_diameter: 256 - euclid_obsc: true - n_stars: 200 test: # Specify directory path to training dataset data_dir: data/coherent_euclid_dataset/ # Provide name of test dataset - file: test_Euclid_res_id_001.npy - # If test data set file not provided produce a new one - stars: null - noisy_stars: null - positions: null - SEDS: null - zernike_coef: null - C_poly: null - parameters: - d_max: 2 - max_order: 45 - x_lims: [0, 1000.0] - y_lims: [0, 1000.0] - grid_points: [4,4] - max_wfe_rms: 0.1 \ No newline at end of file + file: test_Euclid_res_id_001.npy \ No newline at end of file diff --git a/config/inference_conf.yaml b/config/inference_config.yaml similarity index 97% rename from config/inference_conf.yaml rename to config/inference_config.yaml index 927723c7..7f32a235 100644 --- a/config/inference_conf.yaml +++ b/config/inference_config.yaml @@ -12,12 +12,12 @@ inference: # Subdirectory name of the trained model, e.g. psf_model model_subdir: model - + # Relative Path to the training configuration file used to train the model trained_model_config_path: config/training_config.yaml # Path to the data config file (this could contain prior information) - data_config_path: + data_config_path: # The following parameters will overwrite the `model_params` in the training config file. model_params: diff --git a/config/metrics_config.yaml b/config/metrics_config.yaml index 9f543c31..50dbcd75 100644 --- a/config/metrics_config.yaml +++ b/config/metrics_config.yaml @@ -51,9 +51,6 @@ metrics: # Top-hat filter to avoid the aliasing effect in the obscuration mask LP_filter_length: 2 - # Boolean to define if we use sample weights based on the noise standard deviation estimation - use_sample_weights: True - # Flag to use Zernike prior use_prior: False @@ -140,7 +137,7 @@ metrics: metrics_hparams: # Batch size to use for the evaluation. batch_size: 16 - + # Metrics and model evaluation configuration optimizer: name: 'adam' # Only standard Adam used for metrics diff --git a/config/plotting_config.yaml b/config/plotting_config.yaml index fcb17ecd..8abbad89 100644 --- a/config/plotting_config.yaml +++ b/config/plotting_config.yaml @@ -1,13 +1,18 @@ plotting_params: - # Specify path to parent folder containing wf-outputs-xxxxxxxxxxx for all runs, ex: $WORK/wf-outputs/ + # Path to the parent folder containing WaveDiff output directories metrics_output_path: - # List all of the parent output directories (i.e. wf-outputs-xxxxxxxxxxx) that contain metrics results to be included in the plot + + # List of output directories whose metrics should be plotted + # Leave commented/empty if plotting immediately after a metrics run metrics_dir: - # - wf-outputs-xxxxxxxxxxx1 - # - wf-outputs-xxxxxxxxxxx2 - # List of name of metric config file to add to plot (would like to change such that code goes and finds them in the metrics_dir) + # - wf-outputs-xxxxxxxxxxxxxxxxxxx1 + # - wf-outputs-xxxxxxxxxxxxxxxxxxx2 + + # List of metrics config filenames corresponding to each directory + # Leave commented/empty if plotting immediately after a metrics run metrics_config: - # - metrics_config_1.yaml - # - metrics_config_2.yaml - # Show Plots Flag + # - metrics_config_1.yaml + # - metrics_config_2.yaml + + # If True, plots are shown interactively during execution plot_show: False \ No newline at end of file diff --git a/config/training_config.yaml b/config/training_config.yaml index 9fa8a50f..6348c416 100644 --- a/config/training_config.yaml +++ b/config/training_config.yaml @@ -52,7 +52,7 @@ training: ccd_misalignments_input_path: /path/to/ccd_misalignments_file.txt # Boolean to use sample weights based on the noise standard deviation estimation - use_sample_weights: True + use_sample_weights: True # Sample weight generalised sigmoid function sample_weights_sigmoid: @@ -96,7 +96,7 @@ training: # Telescope's focal length in [m]. Default is `24.5`[m] (Euclid-like). tel_focal_length: 24.5 - # Wheter to use Euclid-like obscurations. + # Use Euclid-like obscurations. euclid_obsc: True # Length of one dimension of the Low-Pass (LP) filter to apply to the @@ -154,9 +154,9 @@ training: loss: 'mask_mse' # Optimizer to use during training. Options are: 'adam' or 'rectified_adam'. - optimizer: - name: 'rectified_adam' - + optimizer: + name: 'rectified_adam' + multi_cycle_params: # Number of training cycles to perform. Each cycle may use different learning rates or number of epochs. diff --git a/docs/source/conf.py b/docs/source/conf.py index 01edfeb6..d54e4c8b 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -1,14 +1,15 @@ -# Configuration file for the Sphinx documentation builder. -# -# For the full list of built-in configuration values, see the documentation: -# https://www.sphinx-doc.org/en/master/usage/configuration.html -import sys -import os -from datetime import datetime +""" +Sphinx configuration for the wf-psf documentation. -current_year = datetime.now().year +This file sets up paths, extensions, theme, and other options +for building the HTML docs. +""" + +from datetime import datetime +import os +import sys -sys.path.insert(0, os.path.abspath("src/wf_psf")) +sys.path.insert(0, os.path.abspath("../../src")) # -- Project information ----------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information @@ -103,7 +104,7 @@ bibtex_reference_style = "author_year" # -- Mock imports for documentation ------------------------------------------ -autodoc_mock_imports = [ - "tensorflow", +optional_deps = [ "tensorflow_addons", ] +autodoc_mock_imports = optional_deps diff --git a/environment.yml b/environment.yml index 1b9264e4..66f9873c 100644 --- a/environment.yml +++ b/environment.yml @@ -12,7 +12,6 @@ dependencies: - pip - pip: - numpy>=1.26,<2.0 - - tensorflow-addons - tensorflow-estimator - zernike - opencv-python diff --git a/pyproject.toml b/pyproject.toml index ec0c8d34..7ace733d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ description = 'A software framework to perform Differentiable wavefront-based PS dependencies = [ "numpy>=1.18,<1.24", "scipy", - "tensorflow==2.11.0", +# "tensorflow==2.11.0", "tensorflow-estimator", "zernike", "opencv-python", @@ -88,8 +88,15 @@ quote-style = "double" indent-style = "space" line-ending = "lf" +[tool.setuptools] +package-dir = {"" = "src"} + +[tool.setuptools.packages.find] +where = ["src"] +include = ["wf_psf*"] + [tool.setuptools.package-data] -"wf_psf.config" = ["*.conf"] +"wf_psf.config" = ["*.conf", "*.yaml"] # Set per-file-ignores [tool.ruff.lint.per-file-ignores] From 274bd6ff9330f4be0d48b66717523c4cd9946570 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 13 Feb 2026 14:34:30 +0100 Subject: [PATCH 131/135] Update .gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 1c4865fc..7f3bc838 100644 --- a/.gitignore +++ b/.gitignore @@ -89,6 +89,7 @@ instance/ # Sphinx documentation docs/build/ +docs/_build/ docs/source/wf_psf*.rst docs/source/_static/file.png docs/source/_static/images/logo_colab.png From 4b5fb43b72f02ebbb713eb8342cbd60c4b6aa90e Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 13 Feb 2026 14:43:32 +0100 Subject: [PATCH 132/135] docs: Update API documentation for v3.1.0 release - Add API documentation for new inference and instrument packages - Restructure configuration documentation: * Split workflows into CLI tasks vs standalone components * Add configuration file dependency table * Document inference_config.yaml * Clarify filename flexibility and standardize section structure - Document tensorflow-addons as optional dependency - Fix bullet point rendering and ruff errors in data_handler.py and data_zernike_utils.py docstrings - Update example configuration files with clearer inline comments --- ...nnifer.pollack_api_documentation_update.md | 47 ++ docs/source/api.rst | 1 + docs/source/configuration.md | 467 ++++++++++++------ docs/source/dependencies.md | 25 +- src/wf_psf/data/data_handler.py | 14 +- src/wf_psf/data/data_zernike_utils.py | 31 +- src/wf_psf/inference/__init__.py | 2 + src/wf_psf/instrument/__init__.py | 1 + 8 files changed, 417 insertions(+), 171 deletions(-) create mode 100644 changelog.d/20260213_143539_jennifer.pollack_api_documentation_update.md diff --git a/changelog.d/20260213_143539_jennifer.pollack_api_documentation_update.md b/changelog.d/20260213_143539_jennifer.pollack_api_documentation_update.md new file mode 100644 index 00000000..9b41f68f --- /dev/null +++ b/changelog.d/20260213_143539_jennifer.pollack_api_documentation_update.md @@ -0,0 +1,47 @@ + + + + + + + +### Internal changes + +- API documentation for new `inference` package in `api.rst` +- API documentation for new `instrument` package in `api.rst` +- Inference Configuration section in `configuration.md` documenting `inference_config.yaml` +- Restructured Configuration documentation: + - Split workflows into "CLI Tasks" and "Additional Components" sections + - Added configuration file dependency table showing required vs optional files per task + - Clarified configuration filename flexibility (filenames customizable, internal structure fixed) + - Standardized section titles (Training Configuration, Metrics Configuration, etc.) + - Improved markdown formatting and fixed broken anchor links +- Updated `dependencies.md` to document `tensorflow-addons` as optional dependency with manual installation instructions +- `tensorflow-addons` from core dependencies documentation (now documented as optional) + + diff --git a/docs/source/api.rst b/docs/source/api.rst index cda87632..868493f4 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -9,6 +9,7 @@ This section contains the API reference for the main packages in WaveDiff. wf_psf.data wf_psf.inference + wf_psf.instrument wf_psf.metrics wf_psf.plotting wf_psf.psf_models diff --git a/docs/source/configuration.md b/docs/source/configuration.md index 173e229d..b7e4d638 100644 --- a/docs/source/configuration.md +++ b/docs/source/configuration.md @@ -1,59 +1,84 @@ # Configuration -WaveDiff uses a set of YAML and INI configuration files to control each pipeline task. -This section provides a high-level overview of the configuration system, followed by detailed explanations of each file. +WaveDiff uses a set of YAML and INI configuration files to control each pipeline task. This section provides a high-level overview of the configuration system, followed by detailed explanations of each file. -## Overview of Pipeline Tasks +## Overview of Workflows -WaveDiff consists of four main pipeline tasks: +WaveDiff consists of three CLI tasks, configured by passing a configuration file to the `wavediff` command (e.g., `wavediff -c configs.yaml -o output/`): -| Pipeline | Purpose | -|---------|---------| +| Task | Purpose | +|------|---------| | `training` | Trains a PSF model using the provided dataset and hyperparameters. | -| `metrics` |Evaluates model performance using multiple metrics, optionally comparing against a ground-truth model. | -| `plotting` | Generates figures summarizing the results from the metrics pipeline. | -| `sims` | Simulates stellar PSFs used as training/test datasets.
*(Currently executed via a standalone script, not via `wavediff`.)* | +| `metrics` | Evaluates model performance using multiple metrics, optionally comparing against a ground-truth model. | +| `plotting` | Generates figures summarising the results from the metrics pipeline. | -You configure these tasks by passing a configuration file to the `wavediff` command (e.g., `--config configs.yaml`). +WaveDiff also provides two standalone Python APIs used outside the `wavediff` CLI: + +| Component | Purpose | +|-----------|---------| +| `sims` | Provides classes and methods for simulating monochromatic, polychromatic, and spatially-varying PSFs
for generating custom datasets. | +| `inference` | Provides classes and methods for inferring PSFs as a function of position and SED from a trained PSF
model. | ## Configuration File Structure -WaveDiff expects the following configuration files under the `config/` directory: +WaveDiff expects configuration files under the `config/` directory. **Configuration filenames are flexible** — you can name them as you wish (e.g., `training_euclid_v2.yaml`, `my_metrics.yaml`) as long as you reference them correctly in `configs.yaml` or command-line arguments. The filenames shown below are conventional defaults used in documentation examples. + +The files required depend on which task or component you are running: ``` -config -├── configs.yaml -├── data_config.yaml -├── logging.conf -├── metrics_config.yaml -├── plotting_config.yaml -└── training_config.yaml +config/ +├── configs.yaml # Master configuration (all CLI tasks) +├── data_config.yaml # Dataset paths (training task only) +├── logging.conf # Logging configuration (all CLI tasks) +├── training_config.yaml # Training task +├── metrics_config.yaml # Metrics task +├── plotting_config.yaml # Plotting task +└── inference_config.yaml # Inference API ``` -- All `.yaml` files use standard **YAML** syntax and are loaded as nested dictionaries of key–value pairs. -- `logging.conf` uses standard **INI** syntax and configures logging behavior. -- Users may modify values but should **not rename keys or section names**, as the software depends on them. +| Task / Component | Required | Optional | +|---------------|--------------|---------| +| `training` | `configs.yaml`, `data_config.yaml` ,
`logging.conf`, `training_config.yaml` | `metrics_config.yaml`
(_triggers post-training metrics_) | +| `metrics` | `configs.yaml`, `logging.conf`,
`metrics_config.yaml` | `plotting_config.yaml`
(_triggers post-metrics plotting_) | +| `plotting` | `configs.yaml`, `logging.conf`,
`plotting_config.yaml`| — | +| `inference` | `inference_config.yaml` | `data_config.yaml` | + +**Notes:** + +- **Configuration filenames are flexible.** The names shown above (e.g., `training_config.yaml`) are conventional defaults. You may use any filename as long as you reference it correctly in `configs.yaml` or via command-line arguments. +- **Keys and section names within configuration files must be preserved.** While you can rename files, the internal YAML structure (keys like `model_params`, `training`, etc.) must remain unchanged, as the software depends on them. +- The metrics and plotting tasks retrieve dataset paths from the trained model's configuration and do not require `data_config.yaml`. +- When `metrics_config.yaml` is specified as optional for the `training` task, metrics evaluation runs automatically after training completes. +- When `plotting_config.yaml` is specified as optional for the `metrics` task, plots are generated automatically after metrics evaluation completes. +- `logging.conf` uses standard INI syntax and configures logging behaviour for all CLI tasks. Each of the configuration files is described in detail below. (data_config)= -## `data_config.yaml` — Data Configuration +## Data Configuration ### 1. Purpose -Specifies where WaveDiff loads (or later versions may generate) the training and test datasets. -All training, evaluation, and metrics pipelines depend on this file for consistent dataset paths. +Specifies the training and test datasets used by the training CLI task. + ### 2. Key Fields -- `data.training.data_dir` _(required)_ — directory containing training data -- `data.training.file` _(required)_ — filename of the training dataset -- `data.test.*` — same structure as `training`, for the test dataset -- **Simulation-related fields** — reserved for future releases -**Notes** -- The simulation options are placeholders; WaveDiff v3.x does **not yet** auto-generate datasets. -- The default dataset bundled with WaveDiff can be used by simply pointing to its directory. +Both `data.training` and `data.test` share the same structure: + +| Field | Required | Description | +|-----------|--------------|--------------| +| `data_dir` | Yes | Path to the directory containing the dataset. | +| `file` | Yes | Filename of the dataset (`.npy`). | + + +### 3. Notes + +- The default dataset bundled with WaveDiff can be used by pointing `data_dir` to its installation directory. +- The `metrics` and `plotting` tasks retrieve dataset paths automatically from the trained model's configuration file and do not require this file. +- This file is optional for the `inference` API; see [inference_config.yaml](inference_config) if you need to supply prior information for inference. + +### 4. Example -**Example (minimal)** ```yaml data: training: @@ -65,51 +90,52 @@ data: ``` (training_config)= -## `training_config.yaml` — Training Pipeline Configuration +## Training Configuration ### 1. Purpose -Controls the training pipeline, including model selection, hyperparameters, optional metrics evaluation, and data loading behavior. +Controls the training pipeline, including model selection, hyperparameters, optional post-training metrics evaluation, and data loading behaviour. ### 2. General Notes -- Every field has an inline comment in the YAML file. -- **All required parameters must be specified.** Missing values will prevent the model from being instantiated, as there is currently no default configuration provided. +- **All required parameters must be specified.** There is currently no default configuration — missing values will prevent the model from being instantiated. - **Optional fields:** - - `metrics_config` (run metrics after training) - - `param_hparams`, `nonparam_hparams` - - `multi_cycle_params.save_all_cycles` -- Some parameters are specific to physical or polychromatic PSF models. -- Example training configuration file is provided in the top-level root directory of the repository (`training_config.yaml`). Users can copy and adapt this template for their own runs. -- If any descriptions are unclear, or unexpected behaviour occurs, please open a [GitHub issue](https://github.com/CosmoStat/wf-psf/issues/new). + - `metrics_config` — trigger metrics evaluation after training completes + - `multi_cycle_params.save_all_cycles`— defaults to `False` +- Some parameters are specific to the physical PSF model and may be ignored by simpler model types. +- An example training configuration file is provided in the repository root (`config/training_config.yaml`). Copy and adapt this template for your own runs. +- **Fraction notation**: Fields like `reference_shifts` accept fraction strings (e.g., "`-1/3`") which are automatically converted to floats. You can also use decimal values directly (e.g., `-0.333`). +- Every field in the YAML file includes an inline comment. If any descriptions remain unclear or unexpected behavior occurs, please open a [GitHub issue](https://github.com/CosmoStat/wf-psf/issues/new). -**Note:** The values in the examples shown below correspond to a typical WaveDiff training run. Users should adapt parameters such as `model_name`, telescope dimensions, pixel/field coordinates, and SED settings to match their own instrument or dataset. All required fields must still be specified. +**Note on example values**: The parameter values shown below correspond to a typical Euclid-like WaveDiff training run. Adapt `model_name`, telescope dimensions, pixel/field coordinates, and SED settings to match your instrument and dataset. ### 3. Top-Level Training Parameters +`training` ```yaml training: - # ID name for this run (used in output files) - id_name: run__001 + # ID name for this run (used in output filenames and logs) + id_name: run_001 - # Path to Data Configuration file (required) + # Path to data configuration file data_config: data_config.yaml # Load dataset on initialization (True) or manually later (False) load_data_on_init: True - # Optional: metrics configuration to run after training + # Optional: path to metrics configuration to run after training metrics_config: ``` -### 4. Model Parameters (`model_params`) +### 4. Model Parameters -Controls PSF model type, geometry, oversampling, and preprocessing: +Controls PSF model type, geometry, oversampling, and physical corrections. +`training.model_params` ```yaml model_params: - # Model type. Options: 'poly' and 'physical_poly' + # Model type. Options: 'poly', 'physical_poly' model_name: physical_poly # Number of wavelength bins for polychromatic reconstruction @@ -134,25 +160,22 @@ model_params: # Centroid correction parameters sigma_centroid_window: 2.5 # Std dev of centroiding window - reference_shifts: [-1/3, -1/3] # Euclid-like default shifts + reference_shifts: [-0.333, -0.333] # Reference pixel shifts (Euclid default: -1/3, -1/3) - # Obscuration / geometry - obscuration_rotation_angle: 0 # Degrees (multiple of 90); counterclockwise rotation. + # Obscuration geometry + obscuration_rotation_angle: 0 # Rotation in degrees (multiples of 90); counterclockwise # CCD misalignments input file path ccd_misalignments_input_path: /path/to/ccd_misalignments_file.txt - # Boolean to use sample weights based on the noise standard deviation estimation + # Sample weighting based on noise standard deviation use_sample_weights: True - # Sample weight generalised sigmoid function + # Sample weight sigmoid function parameters sample_weights_sigmoid: - # Boolean to define if we apply the sigmoid function to the sample weights - apply_sigmoid: False - # Maximum value of the sigmoid function and consequently the maximum value of the sample weights - sigmoid_max_val: 5.0 - # Power of the sigmoid function. The higher the value the steeper the sigmoid function. In the limit - sigmoid_power_k: 1.0 + apply_sigmoid: False # Enable sigmoid weighting transform + sigmoid_max_val: 5.0 # Maximum sample weight value + sigmoid_power_k: 1.0 # Sigmoid steepness (higher = steeper) # Interpolation settings for physical-poly model interpolation_type: None @@ -165,30 +188,32 @@ model_params: sed_sigma: 0 # Field and pixel coordinates - x_lims: [0.0, 1.0e3] - y_lims: [0.0, 1.0e3] - pix_sampling: 12 # in [um] + x_lims: [0.0, 1000.0] + y_lims: [0.0, 1000.0] + pix_sampling: 12 # Pixel size in microns # Telescope parameters - tel_diameter: 1.2 # [m] - tel_focal_length: 24.5 # [m] - euclid_obsc: True. # Use Euclid-specific obscuration mask. Set to False for other instruments or custom masks. - LP_filter_length: 3 # Low-pass filter for obscurations + tel_diameter: 1.2 # Aperture diameter in meters + tel_focal_length: 24.5 # Focal length in meters + euclid_obsc: True # Use Euclid-specific obscuration mask (set False for other instruments) + LP_filter_length: 3 # Low-pass filter kernel size for obscurations ``` -### 5. Parametric Model Hyperparameters (`param_hparams`) +### 5. Parametric Model Hyperparameters +`training.model_params.param_hparams` ```yaml param_hparams: random_seed: 3877572 - l2_param: 0.0 # L2 loss for OPD/WFE - n_zernikes: 15 - d_max: 2 # Max polynomial degree - save_optim_history_param: true + l2_param: 0.0 # L2 regularization weight for OPD/WFE + n_zernikes: 15 # Number of Zernike polynomials + d_max: 2 # Maximum polynomial degree + save_optim_history_param: True ``` -### 6. Non-Parametric Model Hyperparameters (`nonparam_hparams`) +### 6. Non-Parametric Model Hyperparameters +`training.model_params.nonparam_hparams` ```yaml nonparam_hparams: d_max_nonparam: 5 @@ -196,37 +221,46 @@ nonparam_hparams: l1_rate: 1.0e-8 project_dd_features: False reset_dd_features: False - save_optim_history_nonparam: true + save_optim_history_nonparam: True ``` -### 7. Training Hyperparameters (`training_hparams`) +### 7. Training Hyperparameters -Controls batches, loss, and multi-cycle learning: +Controls batch size, loss function, optimizer selection, and multi-cycle learning. +`training.training_hparams` ```yaml training_hparams: - batch_size: 32 # Number of samples per batch - loss: 'mask_mse' # Options: 'mask_mse', 'mse' + batch_size: 32 # Number of samples per training batch + loss: 'mask_mse' # Loss function. Options: 'mask_mse', 'mse' + optimizer: + name: 'rectified_adam' # Options: 'adam', 'rectified_adam' multi_cycle_params: total_cycles: 2 - cycle_def: complete # Options: 'parametric', 'non-parametric', 'complete', etc. - save_all_cycles: False - saved_cycle: cycle2 - - learning_rate_params: [1.0e-2, 1.0e-2] - learning_rate_non_params: [1.0e-1, 1.0e-1] - n_epochs_params: [20, 20] - n_epochs_non_params: [100, 120] + cycle_def: complete # Options: 'parametric', 'non-parametric', 'complete' + save_all_cycles: False # If True, saves checkpoints for all cycles; otherwise only saved_cycle + saved_cycle: cycle2 # Which cycle checkpoint to retain + + learning_rate_params: [1.0e-2, 1.0e-2] # Per-cycle learning rate for parametric model + learning_rate_non_params: [1.0e-1, 1.0e-1] # Per-cycle learning rate for non-parametric model + n_epochs_params: [20, 20] # Per-cycle epochs for parametric model + n_epochs_non_params: [100, 120] # Per-cycle epochs for non-parametric model ``` +**Optimizer Notes:** +- `rectified_adam` requires tensorflow-addons to be installed manually. +- If TensorFlow Addons is not installed and `rectified_adam` is requested, WaveDiff will raise a runtime error with installation instructions. +- Standard workflows (`training`, `metrics`, `plotting`) run without TensorFlow Addons. + (metrics_config)= -## `metrics_config.yaml` — Metrics Configuration +## Metrics Configuration ### 1. Purpose Defines how a trained PSF model is evaluated. This configuration specifies which metrics to compute, which model weights to use, and how ground truth stars are obtained. It allows you to: + - Select a fully trained PSF model or a checkpoint for evaluation. -- Specify which training cycle’s weights to evaluate. +- Specify which training cycle's weights to evaluate. - Compute Polychromatic, Monochromatic, OPD, and Weak Lensing Shape metrics. - Use precomputed ground truth stars from the dataset if available, or automatically generate them from the configured ground truth model. - Optionally produce plots of the computed metrics via a plotting configuration file. @@ -235,7 +269,7 @@ Defines how a trained PSF model is evaluated. This configuration specifies which - WaveDiff automatically searches the dataset used for training. If the dataset contains `stars`, `SR_stars`, or `super_res_stars` fields, these are used as the ground truth for metrics evaluation. - If precomputed ground truth stars are not found in the dataset, WaveDiff regenerates them using the `ground_truth_model` parameters. **All required fields in `model_params` must be specified**; leaving them empty will prevent the metrics pipeline from running (see [Ground Truth Model Parameters](section-ground-truth-model) for details). -- The metrics evaluation can be run independently of training by specifying trained_model_path and `trained_model_config`. +- Metrics evaluation can be run independently of training by specifying both `trained_model_path` and `trained_model_config` to point to a previously trained model. - Metrics defined in [Metrics Overview table](metrics-table) are selectively computed according to their boolean flags. The Polychromatic Pixel Reconstruction metric is always computed. - The `plotting_config` parameter triggers plotting of the metrics results if a valid configuration file is provided. If left empty, metrics are computed without generating plots (see [Plotting Configuration](section-plotting-config)). - Batch size and other evaluation hyperparameters can be set under `metrics_hparams` (see [Evaluation Hyperparameters](section-evaluation-hyperparameters)) @@ -258,6 +292,8 @@ Defines how a trained PSF model is evaluated. This configuration specifies which ### 4. Top-Level Configuration Parameters +`metrics` + ```yaml metrics: model_save_path: @@ -271,25 +307,29 @@ metrics: plotting_config: ``` +**Parameter descriptions:** + +- `model_save_path`: Specifies which weights to load. Options: `psf_model` (final trained weights) or `checkpoint` (intermediate checkpoint). +- `saved_training_cycle`: Which training cycle to evaluate (e.g., `1`, `2`, ...). +- `trained_model_path`: Absolute path to the parent directory of a previously trained model. Leave empty if running `training` + `metrics` sequentially in the same workflow. +- `trained_model_config`: Filename of the training configuration (located in `/config/`). +- `eval_mono_metric`: If `True`, computes the monochromatic pixel reconstruction metric. Requires `ground_truth_model` to be configured (see [Ground Truth Model Parameters](section-ground-truth-model)). +- `eval_opd_metric`: If `True`, computes the optical path difference (OPD) metric. Requires `ground_truth_model` to be configured. +- `eval_train_shape_results_dict`: If `True`, computes Weak Lensing Shape metrics on the training dataset. +- `eval_test_shape_results_dict`: If `True`, computes Weak Lensing Shape metrics on the test dataset. +- `plotting_config`: Optional filename of a plotting configuration (e.g., `plotting_config.yaml`) to automatically generate plots after metrics evaluation. Leave empty to skip plotting. + **Notes:** -- `model_save_path`: Load final PSF model weights (`psf_model`) or checkpoint weights (`checkpoint`). -- `saved_training_cycle`: Choose which training cycle to evaluate (1, 2, …). -- `trained_model_path`: Absolute path to parent directory of previously trained model. Leave empty for training + metrics in serial. -- `trained_model_config`: Name of training config file in `trained_model_path/config/`. -- `eval_mono_metric`: If True, computes the monochromatic pixel reconstruction metric. Requires a `ground_truth_model` (see). -- `eval_opd_metric`: If True, computes the optical path difference (OPD) metric. Requires a `ground_truth_model`. -- `eval_train_shape_results_dict` / `eval_test_shape_results_dict`: Compute Weak Lensing Shape metrics on the training and/or test dataset. -- `plotting_config:` Optionally provide a `plotting_config.yaml` file to generate plots after metrics evaluation. -- **Behaviour notes:** - - Metrics controlled by flags (`eval_mono_metric`, `eval_opd_metric`, `eval_train_shape_results_dict`, `eval_test_shape_results_dict`) are only computed if their respective flags are True. - - The Polychromatic Pixel Reconstruction metric is always computed, regardless of flags. - - Future releases may allow optional `ground_truth_model` instantiation if the dataset already contains precomputed stars. + +- The Polychromatic Pixel Reconstruction metric is **always computed** regardless of flag settings. +- All other metrics (`eval_mono_metric`, `eval_opd_metric`, `eval_train_shape_results_dict`, `eval_test_shape_results_dict`) are only computed when their respective flags are set to `True`. (section-ground-truth-model)= ### 5. Ground Truth Model Parameters -Mirrors training parameters for consistency: +Specifies parameters for generating ground truth PSFs when precomputed stars are not available in the dataset. This configuration includes a subset of the training parameters — only those needed to simulate ground truth PSFs for comparison. +`metrics.ground_truth_model` ```yaml ground_truth_model: model_params: @@ -300,11 +340,10 @@ ground_truth_model: output_dim: 32 pupil_diameter: 256 LP_filter_length: 2 - use_sample_weights: True use_prior: False correct_centroids: False sigma_centroid_window: 2.5 - reference_shifts: [-1/3, -1/3] + reference_shifts: [-0.333, -0.333] obscuration_rotation_angle: 0 add_ccd_misalignments: False ccd_misalignments_input_path: @@ -313,11 +352,11 @@ ground_truth_model: sed_extrapolate: True sed_interp_kind: linear sed_sigma: 0 - x_lims: [0.0, 1.0e+3] - y_lims: [0.0, 1.0e+3] + x_lims: [0.0, 1000.0] + y_lims: [0.0, 1000.0] param_hparams: random_seed: 3877572 - l2_param: 0. + l2_param: 0.0 n_zernikes: 45 d_max: 2 save_optim_history_param: True @@ -330,32 +369,48 @@ ground_truth_model: save_optim_history_nonparam: True ``` **Notes:** -- **All fields in `model_params` are required.** Do not leave them empty. Even if the dataset contains precomputed ground truth stars, omitting `model_params` will prevent the metrics pipeline from running. -- Parameters mirror `training_config.yaml` for consistency. +- **All fields shown above are required.** Do not leave them empty. Even if the dataset contains precomputed ground truth stars, omitting these fields will prevent the metrics pipeline from running. +- This configuration uses a subset of the training parameters — telescope geometry (`tel_diameter`, `tel_focal_length`, `pix_sampling`) and sample weighting (`use_sample_weights`, `sample_weights_sigmoid`) are not required for metrics evaluation, as these are only needed during model training. +- Ground truth model parameters should match the simulation settings used to generate your dataset for meaningful comparison. - Future releases may allow optional instantiation of `ground_truth_model` when precomputed stars are available in the dataset. (section-evaluation-hyperparameters)= ### 6. Evaluation Hyperparameters +`metrics.metrics_hparams` ```yaml metrics_hparams: batch_size: 16 opt_stars_rel_pix_rmse: False - l2_param: 0. + l2_param: 0.0 output_Q: 1 output_dim: 64 + + # Optimizer configuration for metrics evaluation + optimizer: + name: 'adam' # Fixed to Adam for metrics evaluation + learning_rate: 1.0e-2 + beta_1: 0.9 + beta_2: 0.999 + epsilon: 1.0e-7 + amsgrad: False ``` -**Parameter explanations:** +**Parameter descriptions:** - `batch_size`: Number of samples processed per batch during evaluation. -- `opt_stars_rel_pix_rmse`: If `True`, saves RMSE for each individual star in addition to mean across FOV. -- `l2_param`: L2 loss weight for OPD. -- `output_Q`: Downsampling rate from high-resolution pixel modeling space. -- `output_dim`: Size of the PSF postage stamp for evaluation. - +- `opt_stars_rel_pix_rmse`: (_optional individual star RMSE_) If `True`, saves the relative pixel RMSE for each individual star in the test dataset in addition to the mean across the field of view. +- `l2_param`: L2 loss weight for the OPD metric. +- `output_Q`: Downsampling rate from the high-resolution pixel modeling space to the resolution at which PSF shapes are measured. Recommended value: `1`. +- `output_dim`: Pixel dimension of the PSF postage stamp. Should be large enough to contain most of the PSF signal. The required size depends on the `output_Q` value used. Recommended value: `64` or higher. +- `optimizer`: Optimizer configuration for metrics evaluation. Unlike training, metrics evaluation always uses the standard Adam optimizer. + - `name`: Fixed to `'adam'` (no other optimizers supported for metrics). + - `learning_rate`: Learning rate for optimizer. + - `beta_1, beta_2`: Exponential decay rates for moment estimates. + - `epsilon`: Small constant for numerical stability. + - `amsgrad`: If `True`, uses AMSGrad variant of Adam. (section-plotting-config)= -## `plotting_config.yaml` — Plot Configuration +## Plotting Configuration The `plotting_config.yaml` file defines how WaveDiff generates diagnostic plots from the metrics produced during model evaluation. While the plotting routines are mostly pre-configured internally, this file allows you to combine and compare metrics from multiple training runs, or simply visualize the results of the most recent `metrics` pipeline execution. @@ -373,52 +428,58 @@ This configuration controls how metric outputs from one or more WaveDiff runs ar - All plotting styles and figure settings are hard-coded and do not require user modification. - If the plotting task is executed immediately after a metrics evaluation run, all fields except `plot_show` may be left empty—the pipeline will automatically locate the outputs of the active run. - When plotting results from multiple runs, the entries in `metrics_dir` and `metrics_config` must appear **row-aligned**, with each position referring to the same run. -- If any descriptions are unclear, or if you encounter unexpected behavior, please open a GitHub issue (). +- If any descriptions are unclear, or if you encounter unexpected behavior, please open a [GitHub issue](). -### 3. Basic Structure +### 3. Configuration Structure -An example `plotting_config.yaml` is shown below: +`plotting_params` ```yaml plotting_params: - # Path to the parent folder containing wf-psf output directories (e.g. $WORK/wf-outputs/) - metrics_output_path: + # Path to the parent folder containing WaveDiff output directories + metrics_output_path: /path/to/wf-outputs/ - # List of output directories (e.g. wf-outputs-xxxxxxxxxxx) whose metrics should be plotted + # List of output directories whose metrics should be plotted + # Leave commented/empty if plotting immediately after a metrics run metrics_dir: - # - wf-outputs-xxxxxxxxxxx1 - # - wf-outputs-xxxxxxxxxxx2 + # - wf-outputs-xxxxxxxxxxxxxxxxxxx1 + # - wf-outputs-xxxxxxxxxxxxxxxxxxx2 - # List of the metric config filenames corresponding to each listed directory + # List of metrics config filenames corresponding to each directory + # Leave commented/empty if plotting immediately after a metrics run metrics_config: - # - metrics_config_1.yaml - # - metrics_config_2.yaml + # - metrics_config_1.yaml + # - metrics_config_2.yaml # If True, plots are shown interactively during execution plot_show: False ``` -### 4. Example Directory Structure -Below is an example of three WaveDiff runs stored under a single parent directory: +**Parameter descriptions:** -**Example Directory Structure** +- `metrics_output_path`: Absolute path to the parent directory containing WaveDiff output folders (e.g., `/home/user/wf-outputs/`). Can be left as `` placeholder if plotting immediately after a metrics run. +- `metrics_dir`: List of output directory names (e.g., `wf-outputs-xxxxxxxxxxxxxxxxxxx1`) whose metrics should be included in plots. **Leave empty or commented out if plotting immediately after a metrics run** — WaveDiff will automatically locate the current run's outputs. +- `metrics_config`: List of `metrics_config.yaml` filenames corresponding to each directory in `metrics_dir`. Each entry should match the config file in `/config/`. Must be row-aligned with `metrics_dir`. **Leave empty or commented out if plotting immediately after a metrics run.** +- `plot_show`: If `True`, displays plots interactively during execution. If `False`, plots are saved to disk without display. + +### 4. Example Directory Structure Below is an example of three WaveDiff runs stored under a single parent directory: -```arduino +``` wf-outputs/ -├── wf-outputs-202305271829 +├── wf-outputs-xxxxxxxxxxxxxxxxxxx1 │ ├── config │ │ ├── data_config.yaml │ │ └── metrics_config_200.yaml │ ├── metrics │ │ └── metrics-poly-coherent_euclid_200stars.npy -├── wf-outputs-202305271845 +├── wf-outputs-xxxxxxxxxxxxxxxxxxx2 │ ├── config │ │ ├── data_config.yaml │ │ └── metrics_config_500.yaml │ ├── metrics │ │ └── metrics-poly-coherent_euclid_500stars.npy -├── wf-outputs-202305271918 +├── wf-outputs-xxxxxxxxxxxxxxxxxxx3 │ ├── config │ │ ├── data_config.yaml │ │ └── metrics_config_1000.yaml @@ -431,12 +492,12 @@ To jointly plot metrics from the three runs shown above, the `plotting_config.ya ```yaml plotting_params: - metrics_output_path: $WORK/wf-outputs/ + metrics_output_path: /path/to/wf-outputs/ metrics_dir: - - wf-outputs-202305271829 - - wf-outputs-202305271845 - - wf-outputs-202305271918 + - wf-outputs-xxxxxxxxxxxxxxxxxxx1 + - wf-outputs-xxxxxxxxxxxxxxxxxxx2 + - wf-outputs-xxxxxxxxxxxxxxxxxxx3 metrics_config: - metrics_config_200.yaml @@ -447,18 +508,94 @@ plotting_params: ``` This configuration instructs the plotting pipeline to load the metrics from each listed run and include them together in summary plots. +(inference_config)= +## Inference Configuration + +### 1. Purpose +Configures the WaveDiff inference API for generating polychromatic PSFs from a trained model, given a set of source positions and SEDs. Unlike the CLI tasks, the inference API is designed for external use: users are expected to load their own positions and SEDs programmatically and interact with the API directly. + +### 2. Key Fields + +`inference` +| Field | Required | Description | +|---------------|--------------|---------| +| `batch_size` | Yes | Number of PSFs to process per batch. | +| `cycle` | Yes | Training cycle checkpoint to load (e.g. `2`).
WaveDiff training typically runs two cycles.| + +`inference.configs` + +| Field | Required | Description | +|---------------|--------------|---------| +| `trained_model_path` | Yes | Absolute path to the directory containing the trained
model. | +| `model_subdir` | Yes | Subdirectory name within `trained_model_path`
containing the model weights (e.g. model). | +|`trained_model_config_path` | Yes | Path to the training configuration file used to train the
model, relative to `trained_model_path`. | +| `data_config_path` | No. | Path to a data configuration file supplying prior
information (e.g. a Phase Diversity calibration prior)
relevant to the inference context. This may differ
from the data configuration used during training. Leave
blank if no external prior is required. + +`inference.model_params` + +These fields are optional. Any field left blank inherits its value from the trained model configuration file. Populated fields override the corresponding `model_params` values from the training config. + +| Field | Required | Description | +|---------------|--------------|---------| +| `n_bins_lda` | inherited | Number of wavelength bins used to reconstruct polychromatic PSFs.| +| `output_Q` | inherited | Downsampling rate to match the oversampled model to the telescope's
native sampling. | +| `output_dim` | inherited | Pixel dimension of the output PSF postage stamp. | +| `correct_centroids` | False | If `True`, applies centroid error correction within the PSF model during inference.. | +| `add_ccd_misalignments` | False | If `True`, incorporates CCD misalignment corrections into
the PSF model during inference. Required data is retrieved
from the trained model configuration file. | + +### 3. Example + +```yaml +inference: + batch_size: 16 + cycle: 2 + configs: + trained_model_path: /path/to/trained/model/ + model_subdir: model + trained_model_config_path: config/training_config.yaml + data_config_path: + model_params: + n_bins_lda: 8 + output_Q: 1 + output_dim: 64 + correct_centroids: False + add_ccd_misalignments: True +``` + +### 4. Notes + +- `trained_model_config_path` is relative to `trained_model_path`, not to the working directory. +- All `model_params` fields are optional; omitting them inherits values from the training configuration. - Only populate fields where you explicitly want to override the trained model's parameters. +- `data_config_path` is intended for cases where inference is performed in a different data context than training, for example using an updated or alternative prior. Leave blank if the trained model's own configuration is sufficient. +- `correct_centroids` and `add_ccd_misalignments` are independent model behaviour flags that modify PSF model computation during inference. Both retrieve their required data from the trained model configuration file — no additional configuration is required to enable them. + + (master_config_file)= ## Master Configuration ### 1. Purpose -The `configs.yaml` file is the _master controller_ for WaveDiff. -It defines **which pipeline tasks** should be executed (training, metrics evaluation, plotting) and in which order. +The `configs.yaml` file is the master controller for WaveDiff CLI tasks. It defines **which pipeline tasks** should be executed (`training`, `metrics`, `plotting`) and in which order. Each task entry points to a dedicated YAML configuration file, allowing WaveDiff to run multiple jobs sequentially from a single entry point. Each task points to a dedicated YAML configuration file—allowing WaveDiff to run multiple jobs sequentially using a single entry point. -### 2. Example: Multiple Training Runs +### 2. General Notes + +`configs.yaml` may contain any combination of the three CLI task types: + +- `training` +- `metrics` +- `plotting` + +-Tasks always execute **in the order they appear** in the file. +- The current release runs all jobs sequentially on a single GPU. +- Parallel multi-GPU execution is planned for a future version. +- For questions or feedback, please open a [GitHub issue](). + +### 3. Example: Multiple Training Runs + To launch a sequence of training runs (models 1…n), list each task and its corresponding configuration file: +`configs.yaml` ```yaml --- training_conf_1: training_config_1.yaml @@ -466,10 +603,11 @@ To launch a sequence of training runs (models 1…n), list each task and its cor ... training_conf_n: training_config_n.yaml ``` -Outputs will be organized as: + +WaveDiff will execute each training task sequentially and organize outputs as: ``` -wf-outputs-20231119151932213823/ +wf-outputs-xxxxxxxxxxxxxxxxxxx1/ ├── checkpoint/ │ ├── checkpoint_callback_poly-coherent_euclid_200stars_1_cycle1.* │ ├── ... @@ -488,29 +626,34 @@ wf-outputs-20231119151932213823/ └── psf_model_poly-coherent_euclid_200stars_n_cycle1.* ``` -### 3 Example: Training + Metrics + Plotting -To evaluate metrics and generate plots for each trained model, include the corresponding configuration files: +### 4. Example: Training + Metrics + Plotting +To evaluate metrics and generate plots after each training run, include metrics and plotting tasks in +`configs.yaml`: + +``` +training_conf_1: training_config_1.yaml +metrics_conf_1: metrics_config_1.yaml +plotting_conf_1: plotting_config_1.yaml +training_conf_2: training_config_2.yaml +metrics_conf_2: metrics_config_2.yaml +plotting_conf_2: plotting_config_2.yaml +... +``` + +Required configuration files: ``` config/ ├── configs.yaml ├── data_config.yaml -├── metrics_config.yaml -├── plotting_config.yaml ├── training_config_1.yaml -├── ... -└── training_config_n.yaml +├── metrics_config_1.yaml +├── plotting_config_1.yaml +├── training_config_2.yaml +├── metrics_config_2.yaml +├── plotting_config_2.yaml +└── ... ``` -Note: current WaveDiff versions generate one plot per metric per model. Creating combined plots requires a separate run [Plot Configuration](section-plotting-config). A future update will support automatic combined plots. - -### 4 General Notes +**Note:** Current WaveDiff versions generate one plot per metric per model. Creating combined comparison plots across multiple runs requires a separate plotting-only run (see [Plot Configuration](section-plotting-config)). Automatic combined plots may be supported in a future release. -- `configs.yaml` may contain **any combination** of the three task types: - - `training` - - `metrics` - - `plotting` -- Tasks always execute **in the order they appear** in the file. -- The current release runs all jobs on a single GPU, sequentially. -- Parallel multi-GPU execution is planned for a future version. -- For questions or feedback, please open a [GitHub issue](https://github.com/CosmoStat/wf-psf/issues/new). diff --git a/docs/source/dependencies.md b/docs/source/dependencies.md index b1bef1da..0dea4934 100644 --- a/docs/source/dependencies.md +++ b/docs/source/dependencies.md @@ -10,7 +10,6 @@ Third-party software packages required by WaveDiff are installed automatically ( | [scipy](https://scipy.org) | {cite:t}`SciPy-NMeth:20` | | [keras](https://keras.io) | {cite:t}`chollet:2015keras`| | [tensorflow](https://www.tensorflow.org) | {cite:t}`tensorflow:15` | -| [tensorflow-addons](https://www.tensorflow.org/addons) |{cite:t}`tensorflow:15` | | [tensorflow-estimator](https://www.tensorflow.org/api_docs/python/tf/estimator) |{cite:t}`tensorflow:15` | | [zernike](https://github.com/jacopoantonello/zernike) | {cite:t}`Antonello:15` | | [opencv-python](https://docs.opencv.org/4.x/index.html) | {cite:t}`opencv_library:08` | @@ -19,4 +18,26 @@ Third-party software packages required by WaveDiff are installed automatically ( | [astropy](https://www.astropy.org) | {cite:t}`astropy:13,astropy:18`,
{cite:t}`astropy:22` | | [matplotlib](https://matplotlib.org) | {cite:t}`Hunter:07` | | [pandas](https://pandas.pydata.org) | {cite:t}`mckinney:2010pandas` | -| [seaborn](https://seaborn.pydata.org) | {cite:t}`Waskom:21` | \ No newline at end of file +| [seaborn](https://seaborn.pydata.org) | {cite:t}`Waskom:21` | + +## Optional Dependencies + +Some features in WaveDiff rely on optional third-party packages that are **not required for standard training and evaluation workflows**. + +### TensorFlow Addons (Optional) + +| Package Name | Purpose | +|--------------|---------| +| [tensorflow-addons](https://www.tensorflow.org/addons) | Optional optimizers (e.g. RectifiedAdam) | + +Starting with WaveDiff **v3.1.0**, `tensorflow-addons` is no longer a required dependency, as TensorFlow Addons reached end-of-life in May 2024. + +- By default, WaveDiff uses standard Keras/TensorFlow optimizers (e.g. `Adam`) +- TensorFlow Addons is only imported **at runtime** if explicitly requested in the configuration +- If a TensorFlow Addons optimizer is selected and the package is not installed, WaveDiff will raise a clear runtime error + +To use TensorFlow Addons optimizers, install manually: + +```bash +pip install tensorflow-addons +``` \ No newline at end of file diff --git a/src/wf_psf/data/data_handler.py b/src/wf_psf/data/data_handler.py index 052fe730..60051645 100644 --- a/src/wf_psf/data/data_handler.py +++ b/src/wf_psf/data/data_handler.py @@ -1,11 +1,14 @@ """Data Handler Module. -Provides tools for loading, preprocessing, and managing data used in both training and inference workflows. +Provides tools for loading, preprocessing, and managing data used in both +training and inference workflows. Includes: + - The `DataHandler` class for managing datasets and associated metadata - Utility functions for loading structured data products -- Preprocessing routines for spectral energy distributions (SEDs), including format conversion (e.g., to TensorFlow) and transformations +- Preprocessing routines for spectral energy distributions (SEDs), including + format conversion (e.g., to TensorFlow) and transformations This module serves as a central interface between raw data and modeling components. @@ -138,6 +141,7 @@ def __init__( @property def tf_positions(self): + """Get positions as TensorFlow tensor.""" return ensure_tensor(self.dataset["positions"]) def load_dataset(self): @@ -325,8 +329,10 @@ def get_data_array( Expected to have methods compatible with the specified run_type. run_type : {"training", "simulation", "metrics", "inference"} Execution context that determines how data is retrieved: + - "training", "simulation", "metrics": Uses extract_star_data function - "inference": Retrieves data directly from dataset using key lookup + key : str, optional Primary key for data lookup. Used directly for inference run_type. If None, falls back to train_key value. Default is None. @@ -338,6 +344,7 @@ def get_data_array( train_key value. Default is None. allow_missing : bool, default True Control behavior when data is missing or keys are not found: + - True: Return None instead of raising exceptions - False: Raise appropriate exceptions (KeyError, ValueError) @@ -358,6 +365,7 @@ def get_data_array( Notes ----- Key resolution follows this priority order: + 1. train_key = train_key or key 2. test_key = test_key or resolved_train_key 3. key = key or resolved_train_key (for inference fallback) @@ -372,7 +380,7 @@ def get_data_array( >>> # Inference with fallback handling >>> inference_data = get_data_array(data, "inference", key="positions", - ... allow_missing=True) + ... allow_missing=True) >>> if inference_data is None: ... print("No inference data available") diff --git a/src/wf_psf/data/data_zernike_utils.py b/src/wf_psf/data/data_zernike_utils.py index 0fad9c8e..97837558 100644 --- a/src/wf_psf/data/data_zernike_utils.py +++ b/src/wf_psf/data/data_zernike_utils.py @@ -1,9 +1,11 @@ """Utilities for Zernike Data Handling. This module provides utility functions for working with Zernike coefficients, including: + - Prior generation - Data loading - Conversions between physical displacements (e.g., defocus, centroid shifts) and modal Zernike coefficients +- Conversions between physical displacements (e.g., defocus, centroid shifts) and modal Zernike coefficients Useful in contexts where Zernike representations are used to model optical aberrations or link physical misalignments to wavefront modes. @@ -26,6 +28,21 @@ @dataclass class ZernikeInputs: + """Zernike-related inputs for PSF modeling, including priors and datasets for corrections. + + All fields are optional to allow flexibility across different run types (training, simulation, inference) and configurations. + + Parameters + ---------- + zernike_prior : Optional[np.ndarray] + The true Zernike prior, if provided (e.g., from PDC). Can be None if not used or not available. + centroid_dataset : Optional[Union[dict, "RecursiveNamespace"]] + Dataset used for computing centroid corrections. Should contain both training and test sets if + used. Can be None if centroid correction is not enabled or no dataset is available. + misalignment_positions : Optional[np.ndarray] + Positions used for computing CCD misalignment corrections. Should be available in inference mode if misalignment correction is enabled. Can be None if not used or not available. + """ + zernike_prior: Optional[np.ndarray] # true prior, if provided (e.g. from PDC) centroid_dataset: Optional[ Union[dict, "RecursiveNamespace"] @@ -34,11 +51,17 @@ class ZernikeInputs: class ZernikeInputsFactory: + """Factory class to build ZernikeInputs based on run type and dataset configuration. + + This class abstracts the logic of extracting the relevant Zernike-related inputs from the dataset based on the specified run type (training, simulation, inference) and model parameters. It handles the conditional logic for which inputs are needed and how to extract them, providing a clean interface for constructing the ZernikeInputs dataclass instance. + + """ + @staticmethod def build( data, run_type: str, model_params, prior: Optional[np.ndarray] = None ) -> ZernikeInputs: - """Builds a ZernikeInputs dataclass instance based on run type and data. + """Build a ZernikeInputs dataclass instance based on run type and data. Parameters ---------- @@ -206,9 +229,9 @@ def assemble_zernike_contributions( positions=None, batch_size=16, ): - """ - Assemble the total Zernike contribution map by combining the prior, - centroid correction, and CCD misalignment correction. + """Assemble Zernike contributions from prior, centroid correction, and CCD misalignment. + + This function checks the model parameters to determine which contributions to include, computes each contribution as needed, and combines them into a single Zernike contribution tensor. It handles the logic for when certain contributions are not used or not available, ensuring that the final output is correctly shaped and contains the appropriate information based on the configuration. Parameters ---------- diff --git a/src/wf_psf/inference/__init__.py b/src/wf_psf/inference/__init__.py index e69de29b..5df682d0 100644 --- a/src/wf_psf/inference/__init__.py +++ b/src/wf_psf/inference/__init__.py @@ -0,0 +1,2 @@ +# src/wf_psf/inference/__init__.py +"""Inference package for PSF generation.""" diff --git a/src/wf_psf/instrument/__init__.py b/src/wf_psf/instrument/__init__.py index e69de29b..d72618f3 100644 --- a/src/wf_psf/instrument/__init__.py +++ b/src/wf_psf/instrument/__init__.py @@ -0,0 +1 @@ +"""Wavefront-based PSF Instrument package.""" From 6da8a73ad0539f1c61f4050492f470871465b88e Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 13 Feb 2026 14:54:58 +0100 Subject: [PATCH 133/135] Uncomment tensorflow dependency --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 7ace733d..0e22269c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ description = 'A software framework to perform Differentiable wavefront-based PS dependencies = [ "numpy>=1.18,<1.24", "scipy", -# "tensorflow==2.11.0", + "tensorflow==2.11.0", "tensorflow-estimator", "zernike", "opencv-python", From c8f5c1ef702d5bf2e845c06c6cad0fe58232ae5a Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 13 Feb 2026 15:12:18 +0100 Subject: [PATCH 134/135] Update cd manual workflow to use editable install and adjust sys.path in conf.py --- .github/workflows/cd_manual.yml | 3 ++- docs/source/conf.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/cd_manual.yml b/.github/workflows/cd_manual.yml index 4064cf67..cb955fbe 100644 --- a/.github/workflows/cd_manual.yml +++ b/.github/workflows/cd_manual.yml @@ -24,7 +24,8 @@ jobs: - name: Install dependencies run: | - python -m pip install ".[docs]" + python -m pip install --upgrade pip setuptools wheel + pip install -e ".[docs] - name: Build API documentation run: | diff --git a/docs/source/conf.py b/docs/source/conf.py index d54e4c8b..ff6ae3fe 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -9,7 +9,8 @@ import os import sys -sys.path.insert(0, os.path.abspath("../../src")) +repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +sys.path.insert(0, os.path.join(repo_root, "src")) # -- Project information ----------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information From 60f26ae63b74c105586b15a472d4099379a29ae9 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Fri, 13 Feb 2026 15:45:46 +0100 Subject: [PATCH 135/135] Correct pip install typo and fix path for uploading artifact --- .github/workflows/cd_manual.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/cd_manual.yml b/.github/workflows/cd_manual.yml index cb955fbe..9c5568a4 100644 --- a/.github/workflows/cd_manual.yml +++ b/.github/workflows/cd_manual.yml @@ -25,7 +25,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip setuptools wheel - pip install -e ".[docs] + pip install -e ".[docs]" - name: Build API documentation run: | @@ -36,4 +36,4 @@ jobs: uses: actions/upload-artifact@v4 with: name: api-docs - publish_dir: docs/build/html + path: docs/build/html