Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
129 commits
Select commit Hold shift + click to select a range
0884737
Correct doc string, format errors, unused imports, type hints, etc
Nov 18, 2025
532e985
Update documentation for v3.0.0: replace modules.rst with api.rst, cl…
Nov 18, 2025
be61b0f
Update API reference and clean up package/module docstrings
Nov 21, 2025
afa2923
Remove auto-generated _autosummary from repo; will be built in CD
Nov 21, 2025
6ecb9ba
Add arduino formatting to directory structure example
Nov 24, 2025
11bad16
Revise WaveDiff how-to guide in doc (part 1)
Nov 25, 2025
59dbccd
Fix myst error with invalid code format
Nov 25, 2025
12b8969
Correct syntax in docstring and generalise exception message
May 13, 2025
eaa0e8e
Add inference and test_inference packages
May 13, 2025
1aebacb
Refactor: Encapsulate logic in psf_models package with subpackages: m…
May 13, 2025
d66d2c9
Remove unused module with duplicate zernike_generator function
May 13, 2025
b9cc937
Correct syntax in docstrings and logger messages
May 13, 2025
f72e0e5
Refactor file structure; update import statements in tests; remove un…
May 14, 2025
a86352c
Update package name in import statement
May 14, 2025
2b26228
Reorder imports; Refactor MetricsConfigHandler class attributes, meth…
May 14, 2025
b8375ba
Move psf_model weights loader to psf_model_loader.py module
May 14, 2025
f3876fb
Add psf_model_loader module
May 14, 2025
ffb49e3
Remove weights_path arg from evaluate_model method; Update logger.inf…
jeipollack May 15, 2025
1001277
Update variable name and logger statement
jeipollack May 15, 2025
8e659d0
Add import logging and create logger object
jeipollack May 15, 2025
6d9c406
Create psf_inference.py module
May 15, 2025
47e9c22
Remove arg from evaluate_model unit test
May 15, 2025
e3ad71a
Refactor: reorganise modules, relocate utility functions, rename modu…
May 16, 2025
0cfb8df
Update import statements to new module names
May 16, 2025
2946d47
Update DataHandler class docstring to include option for inference da…
May 16, 2025
0657e24
Refactor data_handler with new utility functions to validate and proc…
May 18, 2025
e28fe00
Update unit tests associated to changes in data_handler.py
May 18, 2025
29e3abf
Change exception handling in DataConfigHandler; modify args to DataHa…
May 18, 2025
d1c8673
Add data and psf_model_imports into inference and sketch out methods
May 19, 2025
0672d33
add base psf inference
tobias-liaudat May 19, 2025
9fbf8b1
add common call interface through PSF models
tobias-liaudat May 19, 2025
b88d562
add handling of inference params
tobias-liaudat May 19, 2025
aa9bad3
automatic formatting
tobias-liaudat May 19, 2025
98e4806
add first completed class draft
tobias-liaudat May 19, 2025
e3ec2d2
add inference config file
tobias-liaudat May 19, 2025
634c81c
remove unused code
tobias-liaudat May 19, 2025
9075fa6
update params
tobias-liaudat May 19, 2025
1cc14cc
update params
tobias-liaudat May 19, 2025
72c80b6
update inference
tobias-liaudat May 19, 2025
566668b
reduce arguments and add compute psfs when appropiate
tobias-liaudat May 19, 2025
068d78e
add config handler class
tobias-liaudat May 19, 2025
ec44530
set up inference config handler and simplify PSFInferenc init
tobias-liaudat May 19, 2025
6b073bd
remove unused imports
tobias-liaudat May 19, 2025
be08c3b
update inference
tobias-liaudat May 19, 2025
ec87d79
Add single-space lines to improve readability; Remove duplicated stat…
May 23, 2025
b5b31f0
Add additional PSFInference class attributes; update set_source_param…
May 26, 2025
4c8db2b
Add checks to convert to np.ndarray and expand dimensions if needed
jeipollack May 27, 2025
fbdd32a
Update pyproject.toml with numpy dependency limits - sdc-uk
jeipollack Jun 5, 2025
94a424e
Correct name of psf_inference_test module to follow repo naming conve…
Jun 6, 2025
44cf375
Correct config subkey names for defining trained_model_path and train…
Jun 8, 2025
cfd3baf
Refactor psf_inference adding PSFInferenceEngine to separate concerns…
Jun 8, 2025
d1a4fb5
Add unit tests for psf_inference
Jun 8, 2025
5ea8bcb
Bugfix: Ensure updated training_config.training.model_params are pass…
Jun 12, 2025
c82836b
test(simPSF): add unit test to verify updated model_params are passed
Jun 12, 2025
c99b1e8
Bug: replace self.data_conf with self.data_config
jeipollack Jun 12, 2025
cf0e8e1
Change logger.warnings to ValueErrors for missing fields in datasets …
Aug 5, 2025
814d047
Update unit tests with changes to data_handler.py
Aug 5, 2025
dfae6ce
Refactor: reorganise modules, relocate utility functions, rename modu…
May 16, 2025
c49ebc7
Refactor data_handler with new utility functions to validate and proc…
May 18, 2025
4903928
Update unit tests associated to changes in data_handler.py
May 18, 2025
152b0d8
automatic formatting
tobias-liaudat May 19, 2025
de0892c
Refactor: add ZernikeInputs dataclass, ZernikeInputsFactory, helper m…
Jun 18, 2025
8db8de7
Update docstring describing data_conf types permitted
Jun 18, 2025
9f7f19e
Move imports to method to avoid circular imports
Jun 21, 2025
3c84c1f
Remove batch_size arg from ZernikeInputsFactory ; raise ValueError to…
Jun 21, 2025
362ed76
Add and set run_type attribute to DataConfigHandler object in Trainin…
Jun 21, 2025
c8f060e
Add and set run_type attribute ; Replace var name end with end_sample…
Jun 21, 2025
8458a9a
Refactor TFPhysicalPolychromaticField to lazy load property objects a…
Jun 21, 2025
61fb8dc
Update/Add unit tests to test refactoring changes to psf_model_physic…
Jun 21, 2025
60dfe56
Replace arg: data in compute_ccd_misalignment with positions
Jun 21, 2025
29f9fc5
Correct object attributes for DataConfigHandler in ZernikeInputsFactory
jeipollack Jun 21, 2025
a943dbe
Add missing return for tf_physical_layer property
jeipollack Jun 21, 2025
e15444f
Add tf_utils.py module to tf_modules subpackage
jeipollack Jun 22, 2025
00b07f7
Use ensure_tensor method from tf_utils.py to check/convert to tensorf…
jeipollack Jun 22, 2025
3446280
Refactor: Add eager-mode helpers and avoid lazy-loading obscurations …
jeipollack Jun 22, 2025
278479f
Replace deprecated get_obs_positions with get_np_obs_positions and ap…
jeipollack Jun 22, 2025
0eacd98
Remove tf.convert_to_tensor from all Zernike list contributors
jeipollack Jun 23, 2025
98fde06
Add and set self.data_conf.run_type value to 'metrics' in MetricsConf…
jeipollack Jun 23, 2025
5c3b585
Eagerly precompute Zernike components; add support for 'metrics' run_…
jeipollack Jun 23, 2025
e3aea26
Correct value error: train in dataset_type with training
jeipollack Jun 25, 2025
ee0f8e0
fix: pass random seed to TFNonParametricPolynomialVariationsOPD const…
jeipollack Jun 25, 2025
776c019
Refactor to suppress TensorFlow debug msgs: replace lambda in call me…
jeipollack Jun 26, 2025
ab8856e
Match old behaviour with conditional and float64 accumulation
jeipollack Jun 26, 2025
ea85477
Add helper to stack x/y field coordinates into (N, 2) positions array
jeipollack Jul 8, 2025
7b24c45
Add helper method to prepare dataset for inference & handle empty/Non…
jeipollack Jul 10, 2025
aeeafa1
Update data_handler_test replacing "get_obs_positions" (deprecation) …
Jul 22, 2025
66b6e6b
Remove deprecated code from rebase
Aug 6, 2025
2997009
Remove duplicated checks on arg existance
Aug 6, 2025
f4adcd4
Improve Zernike prior handling in assemble_zernike_contributions
Aug 8, 2025
e8f6075
Fix bug where Tensor zernike_prior was not appended after eager conve…
Aug 8, 2025
f8ba288
Update unit tests with latest changes to fixtures and data_zernike_ut…
Aug 8, 2025
a97cc66
Set mock Zernike priors to None in test_data_utils.py helper module
Aug 8, 2025
88e49bc
Remove -1.0 multiplicative factor applied to Zernike tip and tilt values
Aug 8, 2025
a1d055b
Move TFPhysicalPolychromaticField.pad_zernikes to helper method pad_t…
Aug 8, 2025
b721f2f
Correct bug in test_load_inference_model
Aug 8, 2025
98d92a0
Revert sign change applied to compute_centroid_correction
Aug 18, 2025
2e644d7
Refactor _prepare_positions_and_seds to enforce shape consistency and…
Aug 19, 2025
dd99fe2
Fix tensor handling in ZernikeInputsFactory
jeipollack Aug 21, 2025
818dc73
Reformat and remove unused import
Aug 21, 2025
f3d4f16
Correct zernike_prior extraction when dataset is a dict, reformat file
Aug 22, 2025
829136a
Replace np.array with Tensorflow tensors in unit test and fixtures, r…
Aug 22, 2025
96460a5
Eagerly initialise trainable layers in physical poly model constructo…
jeipollack Aug 27, 2025
06d715b
fix: use expect_partial() when loading model weights for evaluation
jeipollack Aug 27, 2025
5b3dee3
Add memory cleanup after training completion
jeipollack Aug 27, 2025
6a2d24d
refactor: centralise PSF data extraction in data_handler
jeipollack Sep 3, 2025
8376fe1
Add and options to inference_config.yaml (forgot to stage with prev…
jeipollack Sep 4, 2025
77434be
Update PSFInference doc string with new optional attributes
jeipollack Sep 4, 2025
76a449c
Rename _get_inference_data to _get_direct_data
jeipollack Sep 4, 2025
b0874e7
Reformat with black
Sep 5, 2025
0e85c0f
Correct type hint errors
jeipollack Sep 5, 2025
2bdd622
Remove unused import
jeipollack Sep 5, 2025
6134d8e
Replace call to deprecated get_np_obs_positions with get_data_array
jeipollack Sep 5, 2025
e309487
Remove unused imports and reformat
Oct 31, 2025
8832584
Update fixtures and unit tests
Oct 31, 2025
529fcd3
Reformat with black
Oct 31, 2025
6c3e9d0
Remove outdated back module
Dec 8, 2025
af12a6c
Remove unused import left after rebase
Dec 8, 2025
07e88c7
Update doc strings and cache handling
Dec 12, 2025
43f894f
Update psf_test_setup fixture for reusability and add unit tests for …
Dec 12, 2025
a88c455
Remove unneeded mock of logger in evaluate_model
Dec 12, 2025
6f0e21d
Add wf_psf.inference to list of main packages
Dec 12, 2025
cd715c3
Update version to 3.1.0
Dec 12, 2025
5d54518
docs: Correct syntax error and code-block formatting in doc string ex…
Dec 12, 2025
b5da83f
Remove weights_path references missed during rebase
Jan 23, 2026
98b66e9
Add changelog fragment
Jan 23, 2026
33ad4f5
Fix logger formatting for relative RMSE metrics
jeipollack Jan 26, 2026
63e6acc
Update changelog with entry under Bug Fixes
jeipollack Jan 26, 2026
0c74bfe
Correct rebase error in configuration.md
Feb 2, 2026
8fd6502
Add new changelog fragment for deprecated import
Feb 10, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
<!--
A new scriv changelog fragment.

Uncomment the section that is right (remove the HTML comment wrapper).
For top level release notes, leave all the headers commented out.
-->

<!--
### Breaking changes

- A bullet item for the Breaking changes category.

-->

### New features

- Added PSF inference capabilities for generating broadband (polychromatic) PSFs from trained models given star positions and SEDs
- Introduced `PSFInferenceEngine` class to centralize training, simulation, metrics, and inference workflows
- Added `run_type` attribute to `DataHandler` supporting training, simulation, metrics, and inference modes
- Implemented `ZernikeInputsFactory` class for building `ZernikeInputs` instances based on run type
- Added `psf_model_loader.py` module for centralized model weights loading


### Bug fixes

- Fix logger formatting for relative RMSE metrics in `metrics.py` (values were not being displayed)

<!--
### Performance improvements

- A bullet item for the Performance improvements category.

-->

### Internal changes

- Refactored `TFPhysicalPolychromatic` and related modules to separate training vs. inference behavior
- Enhanced `ZernikeInputs` data class with intelligent assembly based on run type and available data
- Implemented hybrid loading pattern with eager loading in constructors and lazy-loading via property decorators
- Centralized PSF data extraction in `data_handler` module
- Improved code organization with new `tf_utils.py` module in `psf_models` sub-package
- Updated configuration handling to support inference workflows via `inference_config.yaml`
- Fixed incorrect argument name in `DataHandler` that prevented proper TensorFlow data type conversion
- Removed deprecated `get_obs_positions` method
- Updated documentation to include inference package
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
<!--
A new scriv changelog fragment.

Uncomment the section that is right (remove the HTML comment wrapper).
For top level release notes, leave all the headers commented out.
-->

<!--
### Breaking changes

- A bullet item for the Breaking changes category.

-->
<!--
### New features

- A bullet item for the New features category.

-->
<!--
### Bug fixes

- A bullet item for the Bug fixes category.

-->
<!--
### Performance improvements

- A bullet item for the Performance improvements category.

-->

### Internal changes

- Remove deprecated/optional import tensorflow-addons statement from tf_layers.py


37 changes: 37 additions & 0 deletions config/inference_conf.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@

inference:
# Inference batch size
batch_size: 16
# Cycle to use for inference. Can be: 1, 2, ...
cycle: 2

# Paths to the configuration files and trained model directory
configs:
# Path to the directory containing the trained model
trained_model_path: /path/to/trained/model/

# Subdirectory name of the trained model, e.g. psf_model
model_subdir: model

# Relative Path to the training configuration file used to train the model
trained_model_config_path: config/training_config.yaml

# Path to the data config file (this could contain prior information)
data_config_path:

# The following parameters will overwrite the `model_params` in the training config file.
model_params:
# Num of wavelength bins to reconstruct polychromatic objects.
n_bins_lda: 8

# Downsampling rate to match the oversampled model to the specified telescope's sampling.
output_Q: 1

# Dimension of the pixel PSF postage stamp
output_dim: 64

# Flag to perform centroid error correction
correct_centroids: False

# Flag to perform CCD misalignment error correction
add_ccd_misalignments: True
1 change: 1 addition & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ This section contains the API reference for the main packages in WaveDiff.
:recursive:

wf_psf.data
wf_psf.inference
wf_psf.metrics
wf_psf.plotting
wf_psf.psf_models
Expand Down
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
else:
copyright = f"{start_year}, CosmoStat"
author = "CosmoStat"
release = "3.0.0"
release = "3.1.0"

# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
Expand Down
12 changes: 7 additions & 5 deletions docs/source/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,13 +137,13 @@ model_params:
reference_shifts: [-1/3, -1/3] # Euclid-like default shifts

# Obscuration / geometry
obscuration_rotation_angle: 0 # Degrees (multiple of 90); counterclockwise rotation.
obscuration_rotation_angle: 0 # Degrees (multiple of 90); counterclockwise rotation.

# CCD misalignments input file path
ccd_misalignments_input_path: /path/to/ccd_misalignments_file.txt

# Boolean to use sample weights based on the noise standard deviation estimation
use_sample_weights: True
use_sample_weights: True

# Sample weight generalised sigmoid function
sample_weights_sigmoid:
Expand Down Expand Up @@ -220,7 +220,6 @@ training_hparams:
n_epochs_non_params: [100, 120]
```


(metrics_config)=
## `metrics_config.yaml` — Metrics Configuration

Expand Down Expand Up @@ -402,7 +401,10 @@ plotting_params:
### 4. Example Directory Structure
Below is an example of three WaveDiff runs stored under a single parent directory:

```
**Example Directory Structure**
Below is an example of three WaveDiff runs stored under a single parent directory:

```arduino
wf-outputs/
├── wf-outputs-202305271829
│ ├── config
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ maintainers = [

description = 'A software framework to perform Differentiable wavefront-based PSF modelling.'
dependencies = [
"numpy>=1.26.4,<2.0",
"numpy>=1.18,<1.24",
"scipy",
"tensorflow==2.11.0",
"tensorflow-estimator",
Expand All @@ -24,7 +24,7 @@ dependencies = [
"seaborn",
]

version = "3.0.0"
version = "3.1.0"

[project.optional-dependencies]
docs = [
Expand Down
6 changes: 3 additions & 3 deletions src/wf_psf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@

# Dynamically import modules to trigger side effects when wf_psf is imported
importlib.import_module("wf_psf.psf_models.psf_models")
importlib.import_module("wf_psf.psf_models.psf_model_semiparametric")
importlib.import_module("wf_psf.psf_models.psf_model_physical_polychromatic")
importlib.import_module("wf_psf.psf_models.tf_psf_field")
importlib.import_module("wf_psf.psf_models.models.psf_model_semiparametric")
importlib.import_module("wf_psf.psf_models.models.psf_model_physical_polychromatic")
importlib.import_module("wf_psf.psf_models.tf_modules.tf_psf_field")
86 changes: 85 additions & 1 deletion src/wf_psf/utils/centroids.py → src/wf_psf/data/centroids.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,80 @@

import numpy as np
import scipy.signal as scisig
from wf_psf.utils.preprocessing import shift_x_y_to_zk1_2_wavediff
from fractions import Fraction
from typing import Optional


def compute_centroid_correction(
model_params, centroid_dataset, batch_size: int = 1
) -> np.ndarray:
"""Compute centroid corrections using Zernike polynomials.

This function calculates the Zernike contributions required to match the centroid
of the WaveDiff PSF model to the observed star centroids, processing in batches.

Parameters
----------
model_params : RecursiveNamespace
An object containing parameters for the PSF model, including pixel sampling and initial centroid window parameters.

centroid_dataset : dict
Dictionary containing star data needed for centroiding:
- "stamps" : np.ndarray
Array of star postage stamps (required).
- "masks" : Optional[np.ndarray]
Array of star masks (optional, can be None).

batch_size : int, optional
The batch size to use when processing the stars. Default is 16.

Returns
-------
zernike_centroid_array : np.ndarray
A 2D NumPy array of shape `(n_stars, 3)`, where `n_stars` is the number of
observed stars. The array contains the computed Zernike (Z1, Z2) contributions,
with zero padding applied to the first column to ensure a consistent shape.
"""
# Retrieve stamps and masks from centroid_dataset
star_postage_stamps = centroid_dataset.get("stamps")
star_masks = centroid_dataset.get("masks") # may be None

if star_postage_stamps is None:
raise ValueError("centroid_dataset must contain 'stamps'")

pix_sampling = model_params.pix_sampling * 1e-6 # Change units from [um] to [m]

reference_shifts = [
float(Fraction(value)) for value in model_params.reference_shifts
]

n_stars = len(star_postage_stamps)
zernike_centroid_array = []

# Batch process the stars
for i in range(0, n_stars, batch_size):
batch_postage_stamps = star_postage_stamps[i : i + batch_size]
batch_masks = star_masks[i : i + batch_size] if star_masks is not None else None

# Compute Zernike 1 and Zernike 2 for the batch
zk1_2_batch = -1.0 * compute_zernike_tip_tilt(
batch_postage_stamps, batch_masks, pix_sampling, reference_shifts
)

# Zero pad array for each batch and append
zernike_centroid_array.append(
np.pad(
zk1_2_batch,
pad_width=[(0, 0), (1, 0)],
mode="constant",
constant_values=0,
)
)

# Combine all batches into a single array
return np.concatenate(zernike_centroid_array, axis=0)


def compute_zernike_tip_tilt(
star_images: np.ndarray,
star_masks: Optional[np.ndarray] = None,
Expand Down Expand Up @@ -58,6 +128,8 @@ def compute_zernike_tip_tilt(
- This function processes all images at once using vectorized operations.
- The Zernike coefficients are computed in the WaveDiff convention.
"""
from wf_psf.data.data_zernike_utils import shift_x_y_to_zk1_2_wavediff

# Vectorize the centroid computation
centroid_estimator = CentroidEstimator(
im=star_images, mask=star_masks, sigma_init=sigma_init, n_iter=n_iter
Expand Down Expand Up @@ -178,6 +250,18 @@ def __init__(
self, im, mask=None, sigma_init=7.5, n_iter=5, auto_run=True, xc=None, yc=None
):
"""Initialize class attributes."""
# Convert to np.ndarray if not already
im = np.asarray(im)
if mask is not None:
mask = np.asarray(mask)

# Check im dimensions convert to batch, if 2D
if im.ndim == 2:
# Single stamp → convert to batch of one
im = np.expand_dims(im, axis=0)
elif im.ndim != 3:
raise ValueError(f"Expected 2D or 3D input, got shape {im.shape}")

self.im = im
self.mask = mask
if self.mask is not None:
Expand Down
Loading
Loading