-
Notifications
You must be signed in to change notification settings - Fork 9
Description
When loading a trained model for inference, the TFPhysicalPolychromatic class still requires access to the dataset used for training, including observed positions. This causes unnecessary coupling and can break inference pipelines.
Code in question:
https://github.com/CosmoStat/wf-psf/blob/f15c0f2068a48f24ae1962cd1791de31989f7c28/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py#L77C7-L77C35
Problem
Currently, instantiating a trained PSF model for inference fails unless the original data object (including observed positions) is provided. This happens even when config flags such as correct_centroids: False are set.
Specific issues:
self.obs_posis computed and passed into a TFPhysicalLayer even during inference.self.zks_prioris computed viaget_zernike_prior(...), which always queries the full training/test dataset—including centroid correction and CCD misalignment contributions—even if those options are disabled.- Even with
correct_centroids: False, the call toget_obs_positions()is still executed.
Reason
Originally, the class was tightly coupled to training-time behavior, as described in this Slack comment by Tobias:
"The original idea of the Zernike prior was that you need the prior at the positions you’re going to compute the PSF. That motivated the development of the physical layer. Then we extended that layer to handle the centroids, CCD misalignments, etc. [...] The issue is that the position is used for indexing as it allows to gather the correct priors in the Zernike contribution list."
This design works for training, but breaks modularity and hinders clean inference from pre-trained weights.
Desired behaviour
Inference mode should:
- Not require loading the full training dataset.
- Be decoupled from training-time artifacts like centroid correction and full dataset access.
- Support inference at arbitrary positions, using:
- A precomputed Zernike prior (if used), matched to inference positions.
- Optional CCD misalignment correction (computed on the fly using inference positions).
- SED and position inputs at prediction time.
Training mode should:
- Retain full control over dataset access and augmentation.
- Aggregate Zernike contributions (including centroids and CCD misalignment).
Proposed solution
Refactor the class to distinguish clearly between training and inference use cases.
- Support lazy loading for positions and physical layer loading, and possibly even the prior but during model creation.
- Avoid calling dataset-related methods during inference unless explicitly requested.
- Introduce a factory method for inference setup
Metadata
Metadata
Assignees
Labels
Type
Projects
Status