Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
146 commits
Select commit Hold shift + click to select a range
d7cbe1a
Correct syntax in docstring and generalise exception message
May 13, 2025
18ee533
Add inference and test_inference packages
May 13, 2025
3358e7b
Refactor: Encapsulate logic in psf_models package with subpackages: m…
May 13, 2025
d00acd4
Remove unused module with duplicate zernike_generator function
May 13, 2025
22edf80
Correct syntax in docstrings and logger messages
May 13, 2025
0959398
Refactor file structure; update import statements in tests; remove un…
May 14, 2025
16d4f94
Update package name in import statement
May 14, 2025
a0fe57b
Reorder imports; Refactor MetricsConfigHandler class attributes, meth…
May 14, 2025
7e34077
Move psf_model weights loader to psf_model_loader.py module
May 14, 2025
d6c8dfd
Add psf_model_loader module
May 14, 2025
fa4f890
Remove weights_path arg from evaluate_model method; Update logger.inf…
jeipollack May 15, 2025
0eb4800
Update variable name and logger statement
jeipollack May 15, 2025
079b990
Add import logging and create logger object
jeipollack May 15, 2025
2aa8db7
Create psf_inference.py module
May 15, 2025
b33f171
Remove arg from evaluate_model unit test
May 15, 2025
cc57511
Refactor: reorganise modules, relocate utility functions, rename modu…
May 16, 2025
478dfb2
Update import statements to new module names
May 16, 2025
d79e89d
Update DataHandler class docstring to include option for inference da…
May 16, 2025
e9e066e
Refactor data_handler with new utility functions to validate and proc…
May 18, 2025
cfddcdf
Update unit tests associated to changes in data_handler.py
May 18, 2025
9479146
Change exception handling in DataConfigHandler; modify args to DataHa…
May 18, 2025
abe53ed
Add data and psf_model_imports into inference and sketch out methods
May 19, 2025
3d5ffb0
add base psf inference
tobias-liaudat May 19, 2025
2e9cbd2
add common call interface through PSF models
tobias-liaudat May 19, 2025
b6e3f44
add handling of inference params
tobias-liaudat May 19, 2025
0414a50
automatic formatting
tobias-liaudat May 19, 2025
34cd121
add first completed class draft
tobias-liaudat May 19, 2025
bc79560
add inference config file
tobias-liaudat May 19, 2025
95e592c
remove unused code
tobias-liaudat May 19, 2025
aa9f46c
update params
tobias-liaudat May 19, 2025
5a945e0
update params
tobias-liaudat May 19, 2025
95c1201
update inference
tobias-liaudat May 19, 2025
f9d7c53
reduce arguments and add compute psfs when appropiate
tobias-liaudat May 19, 2025
6c84c0b
add config handler class
tobias-liaudat May 19, 2025
52a6676
set up inference config handler and simplify PSFInferenc init
tobias-liaudat May 19, 2025
0f7eea6
remove unused imports
tobias-liaudat May 19, 2025
0a0fc2d
update inference
tobias-liaudat May 19, 2025
b1fdf29
Add single-space lines to improve readability; Remove duplicated stat…
May 23, 2025
8aedf77
Add additional PSFInference class attributes; update set_source_param…
May 26, 2025
0968e0f
Add checks to convert to np.ndarray and expand dimensions if needed
jeipollack May 27, 2025
dc464fb
Update pyproject.toml with numpy dependency limits - sdc-uk
jeipollack Jun 5, 2025
4d3bdbe
Revert "Update pyproject.toml with numpy dependency limits - sdc-uk"
jeipollack Jun 5, 2025
75cdaac
Correct name of psf_inference_test module to follow repo naming conve…
Jun 6, 2025
1171b7b
Correct config subkey names for defining trained_model_path and train…
Jun 8, 2025
c750904
Refactor psf_inference adding PSFInferenceEngine to separate concerns…
Jun 8, 2025
aeeefd2
Add unit tests for psf_inference
Jun 8, 2025
e791771
Bugfix: Ensure updated training_config.training.model_params are pass…
Jun 12, 2025
447f554
test(simPSF): add unit test to verify updated model_params are passed
Jun 12, 2025
89da2e0
Bug: replace self.data_conf with self.data_config
jeipollack Jun 12, 2025
2cb7515
Change logger.warnings to ValueErrors for missing fields in datasets …
Aug 5, 2025
355f030
Update unit tests with changes to data_handler.py
Aug 5, 2025
be7ddb6
Refactor: reorganise modules, relocate utility functions, rename modu…
May 16, 2025
869900f
Refactor data_handler with new utility functions to validate and proc…
May 18, 2025
149a02c
Update unit tests associated to changes in data_handler.py
May 18, 2025
f0c7abe
automatic formatting
tobias-liaudat May 19, 2025
44460e4
Refactor: add ZernikeInputs dataclass, ZernikeInputsFactory, helper m…
Jun 18, 2025
23a86cc
Update docstring describing data_conf types permitted
Jun 18, 2025
b2fc928
Move imports to method to avoid circular imports
Jun 21, 2025
34da05d
Remove batch_size arg from ZernikeInputsFactory ; raise ValueError to…
Jun 21, 2025
c683aa2
Add and set run_type attribute to DataConfigHandler object in Trainin…
Jun 21, 2025
df9a4fe
Add and set run_type attribute ; Replace var name end with end_sample…
Jun 21, 2025
5d12b55
Refactor TFPhysicalPolychromaticField to lazy load property objects a…
Jun 21, 2025
8d6a726
Update/Add unit tests to test refactoring changes to psf_model_physic…
Jun 21, 2025
8eb4454
Replace arg: data in compute_ccd_misalignment with positions
Jun 21, 2025
5fa9aaf
Correct object attributes for DataConfigHandler in ZernikeInputsFactory
jeipollack Jun 21, 2025
e0361b7
Add missing return for tf_physical_layer property
jeipollack Jun 21, 2025
9ec8bc0
Add tf_utils.py module to tf_modules subpackage
jeipollack Jun 22, 2025
fdc8c6c
Use ensure_tensor method from tf_utils.py to check/convert to tensorf…
jeipollack Jun 22, 2025
d37b36d
Refactor: Add eager-mode helpers and avoid lazy-loading obscurations …
jeipollack Jun 22, 2025
1076215
Replace deprecated get_obs_positions with get_np_obs_positions and ap…
jeipollack Jun 22, 2025
b586cf6
Remove tf.convert_to_tensor from all Zernike list contributors
jeipollack Jun 23, 2025
78b3db1
Add and set self.data_conf.run_type value to 'metrics' in MetricsConf…
jeipollack Jun 23, 2025
733760d
Eagerly precompute Zernike components; add support for 'metrics' run_…
jeipollack Jun 23, 2025
c515666
Correct value error: train in dataset_type with training
jeipollack Jun 25, 2025
245d412
fix: pass random seed to TFNonParametricPolynomialVariationsOPD const…
jeipollack Jun 25, 2025
b91c006
Refactor to suppress TensorFlow debug msgs: replace lambda in call me…
jeipollack Jun 26, 2025
89dddc5
Match old behaviour with conditional and float64 accumulation
jeipollack Jun 26, 2025
7c38895
Add helper to stack x/y field coordinates into (N, 2) positions array
jeipollack Jul 8, 2025
19c98a0
Add helper method to prepare dataset for inference & handle empty/Non…
jeipollack Jul 10, 2025
87dc863
Update data_handler_test replacing "get_obs_positions" (deprecation) …
Jul 22, 2025
f8bc8bd
Remove deprecated code from rebase
Aug 6, 2025
0a748a2
Remove duplicated checks on arg existance
Aug 6, 2025
105bf1a
Improve Zernike prior handling in assemble_zernike_contributions
Aug 8, 2025
3ee97b4
Fix bug where Tensor zernike_prior was not appended after eager conve…
Aug 8, 2025
b94c703
Update unit tests with latest changes to fixtures and data_zernike_ut…
Aug 8, 2025
2576208
Set mock Zernike priors to None in test_data_utils.py helper module
Aug 8, 2025
d0309d1
Remove -1.0 multiplicative factor applied to Zernike tip and tilt values
Aug 8, 2025
90ba68e
Update unit tests with changes to compute_centroid_correction
Aug 8, 2025
3bd18d2
Move TFPhysicalPolychromaticField.pad_zernikes to helper method pad_t…
Aug 8, 2025
7ca574a
Correct bug in test_load_inference_model
Aug 8, 2025
47223a9
Revert sign change applied to compute_centroid_correction
Aug 18, 2025
e10edc7
Refactor _prepare_positions_and_seds to enforce shape consistency and…
Aug 19, 2025
339bfab
Fix tensor handling in ZernikeInputsFactory
jeipollack Aug 21, 2025
adb88ea
Reformat and remove unused import
Aug 21, 2025
b147be4
Correct zernike_prior extraction when dataset is a dict, reformat file
Aug 22, 2025
8747300
Replace np.array with Tensorflow tensors in unit test and fixtures, r…
Aug 22, 2025
633266a
Eagerly initialise trainable layers in physical poly model constructo…
jeipollack Aug 27, 2025
c72cb84
fix: use expect_partial() when loading model weights for evaluation
jeipollack Aug 27, 2025
fea4e3e
Add memory cleanup after training completion
jeipollack Aug 27, 2025
1233590
refactor: centralise PSF data extraction in data_handler
jeipollack Sep 3, 2025
c3a522a
Add and options to inference_config.yaml (forgot to stage with prev…
jeipollack Sep 4, 2025
0d1aa6c
Update PSFInference doc string with new optional attributes
jeipollack Sep 4, 2025
2a44500
Rename _get_inference_data to _get_direct_data
jeipollack Sep 4, 2025
1fc9cdd
Reformat with black
Sep 5, 2025
4e9add1
Correct type hint errors
jeipollack Sep 5, 2025
7c07a7e
Remove unused import
jeipollack Sep 5, 2025
71a50ea
Replace call to deprecated get_np_obs_positions with get_data_array
jeipollack Sep 5, 2025
0fd8b9e
Remove unused imports and reformat
Oct 31, 2025
660cee2
Update fixtures and unit tests
Oct 31, 2025
76c17f3
Reformat with black
Oct 31, 2025
d11f0bb
refactor the noise std dev calculation
tobias-liaudat Jun 27, 2025
45e2044
improve refactoring
tobias-liaudat Jun 27, 2025
1ac4241
add chi2 metric function
tobias-liaudat Jun 27, 2025
43213b0
added chi2 metric to list and automatic black formating
tobias-liaudat Jun 27, 2025
c54928e
improve naming
tobias-liaudat Jun 27, 2025
56b5e72
improve docsting
tobias-liaudat Jun 27, 2025
9fd138d
add chi2 evaluation function
tobias-liaudat Jun 27, 2025
6303290
add evalutation flag
tobias-liaudat Jun 27, 2025
22bf695
fix parameter name bug
tobias-liaudat Jun 27, 2025
4a0580d
fix data type problem
tobias-liaudat Jun 27, 2025
5dbdecb
Refactor: reorganise modules, relocate utility functions, rename modu…
May 16, 2025
545b96c
Refactor data_handler with new utility functions to validate and proc…
May 18, 2025
09da166
Update unit tests associated to changes in data_handler.py
May 18, 2025
a1f215c
Refactor TFPhysicalPolychromaticField to lazy load property objects a…
Jun 21, 2025
3aa325c
Use ensure_tensor method from tf_utils.py to check/convert to tensorf…
jeipollack Jun 22, 2025
8383deb
add median noise calculation and todo comment
tobias-liaudat Jul 7, 2025
eaacdb5
update docstring
tobias-liaudat Jul 7, 2025
f17add0
update tabs
tobias-liaudat Jul 7, 2025
ba4cfa4
add info about the reference simulated datasets
tobias-liaudat Jul 7, 2025
15c22fe
add noisy stars and masked noisy stars for the test dataset
tobias-liaudat Jul 7, 2025
9c2da96
add noshift generation for the test dataset
tobias-liaudat Jul 7, 2025
8fe11d7
update to new paths
tobias-liaudat Jul 7, 2025
624b101
update the intrapixel shift range
tobias-liaudat Jul 7, 2025
91c4c84
fix small bug for unitary masks
tobias-liaudat Jul 8, 2025
91748cd
fix chi2 bug
tobias-liaudat Jul 8, 2025
1e1308e
add chi2 per image calculation
tobias-liaudat Jul 8, 2025
b01c4d7
add per image results
tobias-liaudat Jul 8, 2025
3b94c6f
add more output statistics for the chi2 metric
tobias-liaudat Jul 8, 2025
aa52b50
update default metric config to include chi2 option
tobias-liaudat Jul 8, 2025
d5d08ac
merge feature/159-psf-output-from-trained-model with real data metrics
tobias-liaudat Sep 19, 2025
59e72ac
compute the metrics on `noisy_stars` if `stars` is not in the data. R…
tobias-liaudat Sep 24, 2025
6e119c4
fix small bug in nested conditionals
tobias-liaudat Sep 24, 2025
76373fc
Update evaluate_model unit test with patch for chi2 metric
Nov 26, 2025
81a94fc
Update train_util tests with mock data including binary mask
Nov 26, 2025
b9cbd6b
Remove duplicated function and unnecessary variable re-assignment
Nov 26, 2025
1e500b9
Remove functions left over from previous rebase
Nov 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions config/inference_conf.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@

inference:
# Inference batch size
batch_size: 16
# Cycle to use for inference. Can be: 1, 2, ...
cycle: 2

# Paths to the configuration files and trained model directory
configs:
# Path to the directory containing the trained model
trained_model_path: /path/to/trained/model/

# Subdirectory name of the trained model, e.g. psf_model
model_subdir: model

# Relative Path to the training configuration file used to train the model
trained_model_config_path: config/training_config.yaml

# Path to the data config file (this could contain prior information)
data_config_path:

# The following parameters will overwrite the `model_params` in the training config file.
model_params:
# Num of wavelength bins to reconstruct polychromatic objects.
n_bins_lda: 8

# Downsampling rate to match the oversampled model to the specified telescope's sampling.
output_Q: 1

# Dimension of the pixel PSF postage stamp
output_dim: 64

# Flag to perform centroid error correction
correct_centroids: False

# Flag to perform CCD misalignment error correction
add_ccd_misalignments: True
3 changes: 3 additions & 0 deletions config/metrics_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ metrics:
# Name of the Trained Model Config file stored in config sub-directory in the trained_model_path parent directory
trained_model_config: <enter name of trained model config file>

# Evaluate the chi2 metric.
eval_chi2_metric: True

# Evaluate the monchromatic RMSE metric.
eval_mono_metric: True

Expand Down
12 changes: 12 additions & 0 deletions data/generation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,16 @@ To run the script:
python data_generation_script.py -c data_generation_params_v.0.1.0.yml
```

> ⚠️ Warning
>
> There are some differences with the original dataset used for the WaveDiff paper (Liaudat et al. 2023) even if we use it as a reference:
> - The assingment of SEDs for each star will not match that one of the original dataset. Although the same templates are used.
> - The assigned noise level (SNR) for each star will not match the original dataset. Although, the same distribution will be used.
>
> Nevertheless, the `C_poly` will match and the `positions` will match. Therefore, results from the new datasets will not be exactly the same as in the original datasets.


### Dataset description

**Dataset 0.x.x:**

Expand Down Expand Up @@ -56,3 +66,5 @@ python data_generation_script.py -c data_generation_params_v.0.1.0.yml
- v3.2.1/2/3/4 with dummy (unitary) masks
- v3.3.1/2/3/4 with realistic masks



2 changes: 1 addition & 1 deletion data/generation/data_generation_params_v.0.1.0.yml
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ dataset_features:
# Centroid shift options
add_intrapixel_shifts: False
# In pixels (should be abs(limits)<0.5)
intrapixel_shift_range: [-0.5, 0.5]
intrapixel_shift_range: [-0.3, 0.3]

# CCD misalignment options
add_ccd_misalignments: False
Expand Down
2 changes: 1 addition & 1 deletion data/generation/data_generation_params_v.0.2.0.yml
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ dataset_features:
# Centroid shift options
add_intrapixel_shifts: False
# In pixels (should be abs(limits)<0.5)
intrapixel_shift_range: [-0.5, 0.5]
intrapixel_shift_range: [-0.3, 0.3]

# CCD misalignment options
add_ccd_misalignments: False
Expand Down
2 changes: 1 addition & 1 deletion data/generation/data_generation_params_v.0.3.0.yml
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ dataset_features:
# Centroid shift options
add_intrapixel_shifts: False
# In pixels (should be abs(limits)<0.5)
intrapixel_shift_range: [-0.5, 0.5]
intrapixel_shift_range: [-0.3, 0.3]

# CCD misalignment options
add_ccd_misalignments: False
Expand Down
2 changes: 1 addition & 1 deletion data/generation/data_generation_params_v.1.1.0.yml
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ dataset_features:
# Centroid shift options
add_intrapixel_shifts: True
# In pixels (should be abs(limits)<0.5)
intrapixel_shift_range: [-0.5, 0.5]
intrapixel_shift_range: [-0.3, 0.3]

# CCD misalignment options
add_ccd_misalignments: True
Expand Down
2 changes: 1 addition & 1 deletion data/generation/data_generation_params_v.1.2.0.yml
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ dataset_features:
# Centroid shift options
add_intrapixel_shifts: True
# In pixels (should be abs(limits)<0.5)
intrapixel_shift_range: [-0.5, 0.5]
intrapixel_shift_range: [-0.3, 0.3]

# CCD misalignment options
add_ccd_misalignments: True
Expand Down
2 changes: 1 addition & 1 deletion data/generation/data_generation_params_v.1.3.0.yml
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ dataset_features:
# Centroid shift options
add_intrapixel_shifts: True
# In pixels (should be abs(limits)<0.5)
intrapixel_shift_range: [-0.5, 0.5]
intrapixel_shift_range: [-0.3, 0.3]

# CCD misalignment options
add_ccd_misalignments: True
Expand Down
2 changes: 1 addition & 1 deletion data/generation/data_generation_params_v.2.1.0.yml
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ dataset_features:
# Centroid shift options
add_intrapixel_shifts: True
# In pixels (should be abs(limits)<0.5)
intrapixel_shift_range: [-0.5, 0.5]
intrapixel_shift_range: [-0.3, 0.3]

# CCD misalignment options
add_ccd_misalignments: True
Expand Down
2 changes: 1 addition & 1 deletion data/generation/data_generation_params_v.2.2.0.yml
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ dataset_features:
# Centroid shift options
add_intrapixel_shifts: True
# In pixels (should be abs(limits)<0.5)
intrapixel_shift_range: [-0.5, 0.5]
intrapixel_shift_range: [-0.3, 0.3]

# CCD misalignment options
add_ccd_misalignments: True
Expand Down
2 changes: 1 addition & 1 deletion data/generation/data_generation_params_v.2.3.0.yml
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ dataset_features:
# Centroid shift options
add_intrapixel_shifts: True
# In pixels (should be abs(limits)<0.5)
intrapixel_shift_range: [-0.5, 0.5]
intrapixel_shift_range: [-0.3, 0.3]

# CCD misalignment options
add_ccd_misalignments: True
Expand Down
2 changes: 1 addition & 1 deletion data/generation/data_generation_params_v.3.1.0.yml
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ dataset_features:
# Centroid shift options
add_intrapixel_shifts: True
# In pixels (should be abs(limits)<0.5)
intrapixel_shift_range: [-0.5, 0.5]
intrapixel_shift_range: [-0.3, 0.3]

# CCD misalignment options
add_ccd_misalignments: True
Expand Down
2 changes: 1 addition & 1 deletion data/generation/data_generation_params_v.3.2.0.yml
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ dataset_features:
# Centroid shift options
add_intrapixel_shifts: True
# In pixels (should be abs(limits)<0.5)
intrapixel_shift_range: [-0.5, 0.5]
intrapixel_shift_range: [-0.3, 0.3]

# CCD misalignment options
add_ccd_misalignments: True
Expand Down
2 changes: 1 addition & 1 deletion data/generation/data_generation_params_v.3.3.0.yml
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ dataset_features:
# Centroid shift options
add_intrapixel_shifts: True
# In pixels (should be abs(limits)<0.5)
intrapixel_shift_range: [-0.5, 0.5]
intrapixel_shift_range: [-0.3, 0.3]

# CCD misalignment options
add_ccd_misalignments: True
Expand Down
60 changes: 58 additions & 2 deletions data/generation/data_generation_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
)
from wf_psf.utils.read_config import read_conf, RecursiveNamespace
from wf_psf.sims.psf_simulator import PSFSimulator
from wf_psf.utils.preprocessing import shift_x_y_to_zk1_2_wavediff
from wf_psf.data.data_zernike_utils import shift_x_y_to_zk1_2_wavediff
from wf_psf.sims.spatial_varying_psf import SpatialVaryingPSF, ZernikeHelper
from wf_psf.utils.ccd_misalignments import CCDMisalignmentCalculator
from wf_psf.instrument.ccd_misalignments import CCDMisalignmentCalculator


# Pre-defined colormap
Expand Down Expand Up @@ -561,6 +561,7 @@ def main(args):

# ------------ #
# Centroid shifts
no_shift_test_zks = None

if add_intrapixel_shifts:

Expand Down Expand Up @@ -599,6 +600,7 @@ def main(args):

# Add the centroid shifts to the Zernike coefficients
train_zks += train_delta_centroid_shifts
no_shift_test_zks = np.copy(test_zks)
test_zks += test_delta_centroid_shifts

# ------------ #
Expand Down Expand Up @@ -752,12 +754,24 @@ def main(args):
sim_PSF_toolkit.generate_poly_PSF(test_SED_list[it], n_bins=n_bins)
)

# Generate test polychromatic PSFs without shifts
if no_shift_test_zks is not None:
test_poly_psf_noshift_list = []
print("Generate test PSFs at observation resolution without shifts")
for it in tqdm(range(no_shift_test_zks.shape[0])):
sim_PSF_toolkit.set_z_coeffs(no_shift_test_zks[it, :])
test_poly_psf_noshift_list.append(
sim_PSF_toolkit.generate_poly_PSF(test_SED_list[it], n_bins=n_bins)
)

# Generate numpy arrays from the lists
train_poly_psf_np = np.array(train_poly_psf_list)
train_SED_np = np.array(train_SED_list)

test_poly_psf_np = np.array(test_poly_psf_list)
test_SED_np = np.array(test_SED_list)
if no_shift_test_zks is not None:
test_poly_psf_noshift_np = np.array(test_poly_psf_noshift_list)

# Generate the noisy train stars
# Copy the training stars
Expand All @@ -776,11 +790,28 @@ def main(args):
axis=0,
)

# Also add noise to the test stars
noisy_test_poly_psf_np = np.copy(test_poly_psf_np)
# Generate a dataset with a SNR varying randomly within the desired range
rand_SNR = (
np.random.rand(noisy_test_poly_psf_np.shape[0]) * (SNR_range[1] - SNR_range[0])
) + SNR_range[0]
# Add Gaussian noise to the observations
noisy_test_poly_psf_np = np.stack(
[
add_noise(_im, desired_SNR=_SNR)
for _im, _SNR in zip(noisy_test_poly_psf_np, rand_SNR)
],
axis=0,
)

# ------------ #
# Generate masks

if add_masks:

masked_noisy_test_poly_psf_np = np.copy(noisy_test_poly_psf_np)

if mask_type == "random":
# Generate random train masks
train_masks = generate_n_mask(
Expand All @@ -797,6 +828,12 @@ def main(args):
noisy_train_poly_psf_np.dtype
)

# Apply the random masks to the test stars
masked_noisy_test_poly_psf_np = (
masked_noisy_test_poly_psf_np
* test_masks.astype(noisy_test_poly_psf_np.dtype)
)

# Turn masks to SHE convention. 1 (True) means to mask and 0 (False) means to keep
train_masks = ~train_masks
test_masks = ~test_masks
Expand Down Expand Up @@ -1093,8 +1130,21 @@ def main(args):
SR_sim_PSF_toolkit.generate_poly_PSF(test_SED_list[it_j], n_bins=n_bins)
)

# Generate the test super resolved (SR) polychromatic PSFs without shifts
if no_shift_test_zks is not None:
SR_test_poly_psf_noshift_list = []

print("Generate testing SR PSFs no shifts")
for it_j in tqdm(range(n_test_stars)):
SR_sim_PSF_toolkit.set_z_coeffs(no_shift_test_zks[it_j, :])
SR_test_poly_psf_noshift_list.append(
SR_sim_PSF_toolkit.generate_poly_PSF(test_SED_list[it_j], n_bins=n_bins)
)

# Generate numpy arrays from the lists
SR_test_poly_psf_np = np.array(SR_test_poly_psf_list)
if no_shift_test_zks is not None:
SR_test_poly_psf_noshift_np = np.array(SR_test_poly_psf_noshift_list)

# ------------ #
# Save test datasets
Expand Down Expand Up @@ -1131,17 +1181,23 @@ def main(args):
test_psf_dataset = {
"stars": test_poly_psf_np,
"SR_stars": SR_test_poly_psf_np,
"noisy_stars": noisy_test_poly_psf_np,
"positions": test_positions,
"SEDs": test_SED_np,
"zernike_GT": test_zks,
}

if add_masks:
test_psf_dataset["masks"] = test_masks
test_psf_dataset["masked_noisy_stars"] = masked_noisy_test_poly_psf_np

if add_ccd_misalignments:
test_psf_dataset["zernike_ccd_misalignments"] = test_delta_Z3_arr

if no_shift_test_zks is not None:
test_psf_dataset["stars_noshift"] = test_poly_psf_noshift_np
test_psf_dataset["SR_stars_noshift"] = SR_test_poly_psf_noshift_np

if add_intrapixel_shifts:
test_psf_dataset["zernike_centroid_shifts"] = test_delta_centroid_shifts
test_psf_dataset["pix_centroid_shifts"] = np.stack(
Expand Down
8 changes: 4 additions & 4 deletions src/wf_psf/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import importlib

# Dynamically import modules to trigger side effects when wf_psf is imported
importlib.import_module('wf_psf.psf_models.psf_models')
importlib.import_module('wf_psf.psf_models.psf_model_semiparametric')
importlib.import_module('wf_psf.psf_models.psf_model_physical_polychromatic')
importlib.import_module('wf_psf.psf_models.tf_psf_field')
importlib.import_module("wf_psf.psf_models.psf_models")
importlib.import_module("wf_psf.psf_models.models.psf_model_semiparametric")
importlib.import_module("wf_psf.psf_models.models.psf_model_physical_polychromatic")
importlib.import_module("wf_psf.psf_models.tf_modules.tf_psf_field")
Loading