Skip to content

Refactor TFPhysicalPolychromatic to cleanly separate training vs inference behaviour #165

@jeipollack

Description

@jeipollack

When loading a trained model for inference, the TFPhysicalPolychromatic class still requires access to the dataset used for training, including observed positions. This causes unnecessary coupling and can break inference pipelines.

Code in question:
https://github.com/CosmoStat/wf-psf/blob/f15c0f2068a48f24ae1962cd1791de31989f7c28/src/wf_psf/psf_models/models/psf_model_physical_polychromatic.py#L77C7-L77C35

Problem

Currently, instantiating a trained PSF model for inference fails unless the original data object (including observed positions) is provided. This happens even when config flags such as correct_centroids: False are set.

Specific issues:

  • self.obs_pos is computed and passed into a TFPhysicalLayer even during inference.
  • self.zks_prior is computed via get_zernike_prior(...), which always queries the full training/test dataset—including centroid correction and CCD misalignment contributions—even if those options are disabled.
  • Even with correct_centroids: False, the call to get_obs_positions() is still executed.

Reason

Originally, the class was tightly coupled to training-time behavior, as described in this Slack comment by Tobias:

"The original idea of the Zernike prior was that you need the prior at the positions you’re going to compute the PSF. That motivated the development of the physical layer. Then we extended that layer to handle the centroids, CCD misalignments, etc. [...] The issue is that the position is used for indexing as it allows to gather the correct priors in the Zernike contribution list."

This design works for training, but breaks modularity and hinders clean inference from pre-trained weights.

Desired behaviour

Inference mode should:

  • Not require loading the full training dataset.
  • Be decoupled from training-time artifacts like centroid correction and full dataset access.
  • Support inference at arbitrary positions, using:
    • A precomputed Zernike prior (if used), matched to inference positions.
    • Optional CCD misalignment correction (computed on the fly using inference positions).
    • SED and position inputs at prediction time.

Training mode should:

  • Retain full control over dataset access and augmentation.
  • Aggregate Zernike contributions (including centroids and CCD misalignment).

Proposed solution

Refactor the class to distinguish clearly between training and inference use cases.

  • Support lazy loading for positions and physical layer loading, and possibly even the prior but during model creation.
  • Avoid calling dataset-related methods during inference unless explicitly requested.
  • Introduce a factory method for inference setup

Metadata

Metadata

Labels

Type

No type

Projects

Status

In progress

Relationships

None yet

Development

No branches or pull requests

Issue actions