diff --git a/python/tarts/NeuralActiveOpticsSys.py b/python/tarts/NeuralActiveOpticsSys.py index c52bbea..0250629 100644 --- a/python/tarts/NeuralActiveOpticsSys.py +++ b/python/tarts/NeuralActiveOpticsSys.py @@ -141,7 +141,7 @@ def __init__( # Always use checkpoint loading - the pretrained parameter doesn't matter # when loading from checkpoint self.wavenet_model = WaveNetSystem.load_from_checkpoint( - wavenet_path, map_location=str(self.device_val) + wavenet_path, map_location=str(self.device_val), strict=False ).to(self.device_val) if alignet_path is None: @@ -150,7 +150,7 @@ def __init__( # Always use checkpoint loading - the pretrained parameter doesn't matter # when loading from checkpoint self.alignnet_model = AlignNetSystem.load_from_checkpoint( - alignet_path, map_location=str(self.device_val) + alignet_path, map_location=str(self.device_val), strict=False ).to(self.device_val) self.max_seq_length = params["max_seq_len"] @@ -169,9 +169,9 @@ def __init__( max_seq_length=self.max_seq_length, ).to(self.device_val) else: - self.aggregatornet_model = AggregatorNet.load_from_checkpoint(aggregatornet_path).to( - self.device_val - ) + self.aggregatornet_model = AggregatorNet.load_from_checkpoint( + aggregatornet_path, strict=False + ).to(self.device_val) if final_layer is not None: layers = [ diff --git a/python/tarts/aggregatornet.py b/python/tarts/aggregatornet.py index 3881a56..1ff15ca 100644 --- a/python/tarts/aggregatornet.py +++ b/python/tarts/aggregatornet.py @@ -15,7 +15,6 @@ # Local/application imports from .utils import convert_zernikes_deploy -from .utils import zernikes_to_dof_torch, dof_to_zernikes_torch class AggregatorNet(pl.LightningModule): @@ -91,6 +90,12 @@ def __init__( """ super().__init__() self.save_hyperparameters() # Save model hyperparameters + + # Input projection layer: (num_zernikes + 3) -> d_model + # The +3 accounts for field_x, field_y, and snr features + input_dim = num_zernikes + 3 + self.input_proj = nn.Linear(input_dim, d_model) + encoder_layer = nn.TransformerEncoderLayer( d_model=d_model, nhead=nhead, @@ -110,7 +115,7 @@ def forward(self, x: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: x : tuple of (torch.Tensor, torch.Tensor) A tuple where: - x[0] (torch.Tensor): The input sequence tensor of shape - (batch_size, seq_length, d_model). + (batch_size, seq_length, num_zernikes + 3). - x[1] (torch.Tensor): The mean tensor used for output adjustment. Returns @@ -120,14 +125,18 @@ def forward(self, x: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: Notes ----- - - The transformer encoder processes the first element of the tuple. + - Input features are first projected from (num_zernikes + 3) to d_model dimensions. + - The transformer encoder processes the projected features. - The last token's output is extracted and passed through a linear layer. - The mean correction (second element) is added to the final output. """ x_input, mean = x - x_tensor = self.transformer_encoder(x_input) + # Project input features to d_model dimensions + x_projected = self.input_proj(x_input) + # Pass through transformer + x_tensor = self.transformer_encoder(x_projected) x_tensor = x_tensor[:, -1, :] # Take the last token's output x_tensor = self.fc(x_tensor) # Predict the next token x_tensor += mean @@ -161,41 +170,11 @@ def training_step(self, batch: tuple, batch_idx: int): - The training loss is logged for monitoring. """ - if not self.zk_dof_zk: - x, y = batch # y is the target token - x_input, x_mean, filter_name, chipid = x - logits = self.forward((x_input, x_mean)) - loss = self.loss_fn(logits, y) - self.log("train_loss", loss, prog_bar=True) - else: - x, y = batch # y is the target token - x_input, x_mean, filter_name, chipid = x - logits = self.forward((x_input, x_mean)) - new_logits = torch.zeros_like(logits) - for i in range(len(filter_name)): - filter_name_i = filter_name[i] - sensor_names = chipid[i] - print("old logits", logits[0, :]) - x_dof = zernikes_to_dof_torch( - filter_name=filter_name_i, - measured_zk=logits[i][None, :], - sensor_names=[sensor_names], - rotation_angle=0.0, - device=self.device, - verbose=False, - ) - new_logits[i, :] = dof_to_zernikes_torch( - filter_name=filter_name_i, - x_dof=x_dof, - sensor_names=[sensor_names], - rotation_angle=0.0, - device=self.device, - verbose=False, - ) - print("new_logits", new_logits[0, :]) - - loss = self.loss_fn(new_logits, y) - self.log("train_loss", loss, prog_bar=True) + x, y = batch # y is the target token + x_input, x_mean, filter_name, chipid = x + logits = self.forward((x_input, x_mean)) + loss = self.loss_fn(logits, y) + self.log("train_loss", loss, prog_bar=True) return loss def validation_step(self, batch, batch_idx): @@ -226,40 +205,12 @@ def validation_step(self, batch, batch_idx): - The validation loss is logged for monitoring. """ - if not self.zk_dof_zk: - x, y = batch # y is the target token - x_input, x_mean, filter_name, chipid = x - logits = self.forward((x_input, x_mean)) - loss = self.loss_fn(logits, y) - self.log("val_loss", loss, prog_bar=True) - self.log("val_mRSSE", loss, prog_bar=True) # mRSSE is the same as loss for this model - else: - x, y = batch # y is the target token - x_input, x_mean, filter_name, chipid = x - logits = self.forward((x_input, x_mean)) - new_logits = torch.zeros_like(logits) - for i in range(len(filter_name)): - filter_name_i = filter_name[i] - sensor_names = chipid[i] - x_dof = zernikes_to_dof_torch( - filter_name=filter_name_i, - measured_zk=logits[i][None, :], - sensor_names=[sensor_names], - rotation_angle=0.0, - device=self.device, - verbose=False, - ) - new_logits[i, :] = dof_to_zernikes_torch( - filter_name=filter_name_i, - x_dof=x_dof, - sensor_names=[sensor_names], - rotation_angle=0.0, - device=self.device, - verbose=False, - ) - loss = self.loss_fn(new_logits, y) - self.log("val_loss", loss, prog_bar=True) - self.log("val_mRSSE", loss, prog_bar=True) # mRSSE is the same as loss for this model + x, y = batch # y is the target token + x_input, x_mean, filter_name, chipid = x + logits = self.forward((x_input, x_mean)) + loss = self.loss_fn(logits, y) + self.log("val_loss", loss, prog_bar=True) + self.log("val_mRSSE", loss, prog_bar=True) # mRSSE is the same as loss for this model return loss def loss_fn(self, x, y): diff --git a/python/tarts/aggregatornet_coral.py b/python/tarts/aggregatornet_coral.py new file mode 100644 index 0000000..539abd2 --- /dev/null +++ b/python/tarts/aggregatornet_coral.py @@ -0,0 +1,502 @@ +"""Aggregator Network with DARE-GRAM domain adaptation for coral data. + +This module implements a transformer-based aggregator network with DARE-GRAM loss +for unsupervised domain adaptation. The method aligns inverse Gram matrices between +source (simulation) and target (coral/real) domains without requiring target labels. +""" + +# Standard library imports +import logging +from typing import Any, Tuple + +# Third-party imports +import pytorch_lightning as pl +import torch +import torch.nn as nn +import torch.nn.functional as F_loss +from torch.optim.lr_scheduler import ReduceLROnPlateau + +# Local/application imports +from .utils import convert_zernikes_deploy + +logger = logging.getLogger(__name__) + + +class AggregatorNet_Coral(pl.LightningModule): + """Aggregator Network with DARE-GRAM domain adaptation. + + Implements a transformer encoder network to aggregate the + values of multiple donuts and performs single point estimations, + with DARE-GRAM loss for domain adaptation to real data. + + Attributes + ---------- + input_proj : nn.Linear + Projects input features to d_model dimensions + transformer_encoder : nn.TransformerEncoder + Transformer encoder architecture + fc : nn.Linear + Fully connected output layer + transformer_features : torch.Tensor + Cached transformer features for domain adaptation + + Methods + ------- + forward() + Forward propagation through the model + training_step() + Single train step with domain adaptation + validation_step() + Single validation step + loss_fn() + Regression loss function (mRSSE) + dare_gram_loss() + Domain adaptation loss between source and target features + configure_optimizers() + Setup optimizers with learning rate scheduling + """ + + def __init__( + self, + d_model: int, + nhead: int, + num_layers: int, + dim_feedforward: int, + max_seq_length: int, + lr: float = 0.002507905395321983, + num_zernikes: int = 17, + tradeoff_angle: float = 0.05, + tradeoff_scale: float = 0.001, + threshold: float = 0.9, + dare_gram_weight: float = 1.0, + ): + """Initialize the AggregatorNet_Coral model. + + Parameters + ---------- + d_model : int + The number of expected features in the input (embedding dimension). + nhead : int + The number of attention heads in the multi-head attention mechanism. + num_layers : int + The number of transformer encoder layers. + dim_feedforward : int + The dimension of the feedforward network model inside the transformer encoder. + max_seq_length : int + The maximum sequence length of input data. + lr : float, optional + The learning rate for model training (default is 0.002507905395321983). + num_zernikes : int, optional + The number of Zernike polynomial coefficients to predict (default is 17). + tradeoff_angle : float, optional, default=0.05 + Weight for the DARE-GRAM angle alignment loss. + tradeoff_scale : float, optional, default=0.001 + Weight for the DARE-GRAM scale alignment loss. + threshold : float, optional, default=0.9 + Cumulative variance threshold for low-rank approximation. + dare_gram_weight : float, optional, default=1.0 + Overall weight to scale the DARE-GRAM loss relative to regression loss. + Higher values prioritize domain adaptation over regression accuracy. + + Notes + ----- + - The transformer encoder consists of `num_layers` encoder layers. + - The model outputs a linear transformation of size `num_zernikes`. + - Domain adaptation is performed by aligning features from the transformer. + """ + super().__init__() + self.save_hyperparameters() # Save model hyperparameters + + # Input projection layer: (num_zernikes + 3) -> d_model + # The +3 accounts for field_x, field_y, and snr features + input_dim = num_zernikes + 3 + self.input_proj = nn.Linear(input_dim, d_model) + + encoder_layer = nn.TransformerEncoderLayer( + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + batch_first=True, + ) + self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) + # final layer to transform to the shape of number of zernikes + self.fc = nn.Linear(d_model, num_zernikes) + + # Cache for transformer features (for domain adaptation) + self.transformer_features = None + self.val_mRSSE: torch.Tensor | None = None + + def forward(self, x: tuple[torch.Tensor, torch.Tensor], cache_features: bool = False) -> torch.Tensor: + """Forward pass of the AggregatorNet_Coral model. + + Parameters + ---------- + x : tuple of (torch.Tensor, torch.Tensor) + A tuple where: + - x[0] (torch.Tensor): The input sequence tensor of shape + (batch_size, seq_length, num_zernikes + 3). + - x[1] (torch.Tensor): The mean tensor used for output adjustment. + cache_features : bool, optional, default=False + Whether to cache transformer features for domain adaptation. + + Returns + ------- + torch.Tensor + The transformed output tensor of shape (batch_size, num_zernikes). + + Notes + ----- + - Input features are first projected from (num_zernikes + 3) to d_model dimensions. + - The transformer encoder processes the projected features. + - The last token's output is extracted and passed through a linear layer. + - The mean correction (second element) is added to the final output. + - If cache_features=True, transformer features are stored in self.transformer_features. + """ + x_input, mean = x + # Project input features to d_model dimensions + x_projected = self.input_proj(x_input) + # Pass through transformer + x_tensor = self.transformer_encoder(x_projected) + + # Cache features for domain adaptation if requested + if cache_features: + # Use the last token's features before FC layer + # Convert to float32 for SVD operations in DARE-GRAM loss + self.transformer_features = x_tensor[:, -1, :].detach().float() + + x_tensor = x_tensor[:, -1, :] # Take the last token's output + x_tensor = self.fc(x_tensor) # Predict the next token + x_tensor += mean + return x_tensor + + def dare_gram_loss(self, features_source: torch.Tensor, features_target: torch.Tensor) -> torch.Tensor: + """Compute DARE-GRAM loss between source and target features. + + Parameters + ---------- + features_source: torch.Tensor + Source domain features of shape (batch_size, n_features). + features_target: torch.Tensor + Target domain features of shape (batch_size, n_features). + + Returns + ------- + torch.Tensor + DARE-GRAM alignment loss. + """ + # Convert to float32 for SVD operations (not supported in half precision) + features_source = features_source.float() + features_target = features_target.float() + + batch_size, n_features = features_source.shape + + # Check for NaN or Inf values + if torch.any(torch.isnan(features_source)) or torch.any(torch.isnan(features_target)): + return torch.tensor(0.0, device=self.device) + if torch.any(torch.isinf(features_source)) or torch.any(torch.isinf(features_target)): + return torch.tensor(0.0, device=self.device) + + # Add bias term (ones column) to features + A = torch.cat((torch.ones(batch_size, 1, dtype=torch.float32).to(self.device), features_source), 1) + B = torch.cat((torch.ones(batch_size, 1, dtype=torch.float32).to(self.device), features_target), 1) + + # Compute covariance matrices + cov_A = A.t() @ A + cov_B = B.t() @ B + + # SVD to get eigenvalues + _, L_A, _ = torch.linalg.svd(cov_A) + _, L_B, _ = torch.linalg.svd(cov_B) + + # Normalize eigenvalues to get cumulative variance + # Temporarily disable deterministic algorithms for cumsum (not supported on CUDA) + is_deterministic = torch.are_deterministic_algorithms_enabled() + try: + if is_deterministic: + torch.use_deterministic_algorithms(False, warn_only=False) + eigen_A = torch.cumsum(L_A, dim=0) / L_A.sum() + eigen_B = torch.cumsum(L_B, dim=0) / L_B.sum() + finally: + if is_deterministic: + torch.use_deterministic_algorithms(True) + + # Determine rank k based on threshold + T = self.hparams.threshold + + # Find index where cumulative variance reaches threshold + if eigen_A[1] > T: + T_A = eigen_A[1] + else: + T_A = T + + index_A = torch.argwhere(eigen_A <= T_A) + if len(index_A) > 0: + index_A_val = int(index_A[-1][0].item()) + else: + index_A_val = 1 + + if eigen_B[1] > T: + T_B = eigen_B[1] + else: + T_B = T + + index_B = torch.argwhere(eigen_B <= T_B) + if len(index_B) > 0: + index_B_val = int(index_B[-1][0].item()) + else: + index_B_val = 1 + + k = max(index_A_val, index_B_val) + + # Ensure k is within valid range (avoid numerical issues) + n_eigen = min(len(L_A), len(L_B)) + k = min(k, n_eigen - 1) # Ensure k < n_eigen + + # Add safety check for numerical stability + if L_A[0] < 1e-10 or L_B[0] < 1e-10: + # Near-singular matrix, return small loss + return torch.tensor(0.0, device=self.device) + + # Compute pseudo-inverse with low-rank regularization + rtol_A = max((L_A[k] / L_A[0]).item(), 1e-6) + rtol_B = max((L_B[k] / L_B[0]).item(), 1e-6) + A_pinv = torch.linalg.pinv(cov_A, rtol=rtol_A) + B_pinv = torch.linalg.pinv(cov_B, rtol=rtol_B) + + # Compute cosine similarity for angle alignment + cos_sim = nn.CosineSimilarity(dim=0, eps=1e-6) + cos_distance = torch.dist( + torch.ones(n_features + 1, dtype=torch.float32).to(self.device), cos_sim(A_pinv, B_pinv), p=1 + ) / (n_features + 1) + + # Compute scale alignment loss + scale_loss = torch.dist(L_A[:k], L_B[:k], p=1) / k + + # Clamp losses to prevent extreme values + cos_distance = torch.clamp(cos_distance, min=0.0, max=10.0) + scale_loss = torch.clamp(scale_loss, min=0.0, max=100.0) + + # Combined DARE-GRAM loss + dare_gram_loss = self.hparams.tradeoff_angle * cos_distance + self.hparams.tradeoff_scale * scale_loss + + # Final clamp to prevent explosion + dare_gram_loss = torch.clamp(dare_gram_loss, min=0.0, max=100.0) + + return dare_gram_loss + + def exp_rise_flipped(self, loss, a=6.0): + """Exponentially rises from 0 at loss=0.10 to 1 at loss=0.09. + + Then flattens at 0 above 0.10 and 1 below 0.09. + """ + loss = torch.as_tensor(loss, dtype=torch.float32, device=self.device) + + # CHANGE HERE: + x1 = torch.tensor(0.09, dtype=torch.float32, device=loss.device) # Peak (1.0) + x2 = torch.tensor(0.10, dtype=torch.float32, device=loss.device) # Start (0.0) + + a = torch.tensor(a, dtype=torch.float32, device=loss.device) + + f = torch.zeros_like(loss, dtype=torch.float32) + + # Region 1: loss <= 0.09 (Before flip = 0) + f[loss <= x1] = 0.0 + + # Region 2: 0.09 < loss < 0.10 + mask = (loss > x1) & (loss < x2) + t = (loss[mask] - x1) / (x2 - x1) + f[mask] = (1 - torch.exp(-a * t)) / (1 - torch.exp(-a)) + + # Region 3: loss >= 0.10 (Before flip = 1) + f[loss >= x2] = 1.0 + + # FLIP: + # 0.10 (which was 1) -> becomes 0 + # 0.09 (which was 0) -> becomes 1 + f = -f + 1 + return f + + def calc_losses( + self, + batch: tuple, + batch_idx: int, + use_coral: bool = False, + add_dare_gram_to_loss: bool = True, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Calculate losses with optional DARE-GRAM domain adaptation. + + Parameters + ---------- + batch: tuple + Batch of training data. Format depends on coral mode: + - Without coral: (x, y) where x=(x_input, x_mean, filter_name, chipid) + - With coral: (x, y) where x=(x_input, x_mean, filter_name, chipid, + coral_x_total, coral_x_mean, coral_filter, coral_chipid) + batch_idx: int + Batch index. + use_coral: bool, default=False + Whether to compute DARE-GRAM loss with coral/target data. + add_dare_gram_to_loss: bool, default=True + Whether to add DARE-GRAM loss to the total loss. If False, DARE-GRAM is computed + for logging but not included in the total loss. + + Returns + ------- + tuple + (total_loss, mRSSE, dare_gram_loss) + dare_gram_loss is 0 if use_coral=False or coral data not available. + """ + x, y = batch # y is the target token + + # Check if coral data is present (8 elements in x vs 4 elements) + has_coral = len(x) == 8 + + if has_coral: + x_input, x_mean, filter_name, chipid, coral_x_total, coral_x_mean, coral_filter, coral_chipid = x + else: + x_input, x_mean, filter_name, chipid = x + + # Forward pass on source data + logits = self.forward((x_input, x_mean), cache_features=True) + source_features = self.transformer_features + + # Calculate regression loss + regression_loss = self.loss_fn(logits, y) + + # Extract mRSSE for monitoring + logits_converted = convert_zernikes_deploy(logits) + y_converted = convert_zernikes_deploy(y) + sse = F_loss.mse_loss(logits_converted, y_converted, reduction="none").sum(dim=-1) + mRSSE = torch.sqrt(sse).mean() + + # DARE-GRAM loss if coral data is available + dare_gram_loss = torch.tensor(0.0, device=self.device) + if use_coral and has_coral: + try: + # Forward pass on target/coral data + # Use eval mode to prevent BN updates (keep statistics source-domain only) + was_training = self.training + try: + self.eval() + _ = self.forward((coral_x_total, coral_x_mean), cache_features=True) + target_features = self.transformer_features + finally: + if was_training: + self.train() + + # Compute DARE-GRAM loss + dare_gram_loss = self.dare_gram_loss(source_features, target_features) + except (RuntimeError, ValueError, IndexError) as e: + logger.warning(f"DARE-GRAM loss computation failed: {e}") + dare_gram_loss = torch.tensor(0.0, device=self.device) + + # Add DARE-GRAM to loss only if requested + if add_dare_gram_to_loss: + scale_loss = self.exp_rise_flipped(self.val_mRSSE if self.val_mRSSE is not None else mRSSE) + total_loss = regression_loss + self.hparams.dare_gram_weight * scale_loss * dare_gram_loss + else: + total_loss = regression_loss + + return total_loss, mRSSE, dare_gram_loss + + def training_step(self, batch: tuple, batch_idx: int): + """Perform a single training step with domain adaptation. + + Parameters + ---------- + batch : tuple + A tuple containing input data and targets. + batch_idx : int + The index of the batch in the current epoch. + + Returns + ------- + torch.Tensor + The computed loss for the batch. + + Notes + ----- + - The model processes the batch using calc_losses with coral data. + - DARE-GRAM loss is added to regression loss if coral data is available. + - Training loss, mRSSE, and DARE-GRAM loss are logged for monitoring. + """ + loss, mRSSE, dare_gram_loss = self.calc_losses(batch, batch_idx, use_coral=True) + self.log("train_loss", loss, prog_bar=True, sync_dist=True) + self.log("train_mRSSE", mRSSE, sync_dist=True) + self.log("train_dare_gram_loss", dare_gram_loss, sync_dist=True) + return loss + + def validation_step(self, batch, batch_idx): + """Perform a single validation step. + + Parameters + ---------- + batch : tuple + A tuple containing input data and targets. + batch_idx : int + The index of the batch in the current epoch. + + Returns + ------- + torch.Tensor + The computed validation loss for the batch. + + Notes + ----- + - DARE-GRAM is computed for logging but not added to validation loss. + - This allows monitoring domain adaptation without affecting validation metrics. + - Validation loss, mRSSE, and DARE-GRAM loss are logged for monitoring. + """ + # Compute DARE-GRAM for logging but don't add it to validation loss + loss, mRSSE, dare_gram_loss = self.calc_losses( + batch, batch_idx, use_coral=True, add_dare_gram_to_loss=False + ) + self.log("val_loss", loss, prog_bar=True, sync_dist=True) + self.log("val_mRSSE", mRSSE, prog_bar=True, sync_dist=True) + self.log("val_dare_gram_loss", dare_gram_loss, sync_dist=True) + self.val_mRSSE = mRSSE.clone().detach() # Store a copy to avoid tensor reference issues + return loss + + def loss_fn(self, x, y): + """Compute the loss using the Root Sum of Squared Errors (mRSSE). + + Parameters + ---------- + x : torch.Tensor + The predicted tensor of shape (batch_size, num_zernikes). + y : torch.Tensor + The target tensor of shape (batch_size, num_zernikes). + + Returns + ------- + torch.Tensor + The computed mean Root Sum of Squared Errors (mRSSE). + + Notes + ----- + - The loss is calculated as the mean of the square root of + the sum of squared errors. + - Zernikes are converted to deployment format before computing loss. + - Mean squared error (MSE) is computed first, followed by + summation along the last dimension. + - The final value is the mean of the root sum of squared + errors across the batch. + """ + x = convert_zernikes_deploy(x) + y = convert_zernikes_deploy(y) + sse = F_loss.mse_loss(x, y, reduction="none").sum(dim=-1) + mRSSe = torch.sqrt(sse).mean() + return mRSSe + + def configure_optimizers(self) -> Any: + """Configure the optimizer with learning rate scheduling.""" + optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr) + + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": ReduceLROnPlateau(optimizer), + "monitor": "val_loss", + "frequency": 1, + }, + } diff --git a/python/tarts/dataloader.py b/python/tarts/dataloader.py index 6364c0b..320f69d 100644 --- a/python/tarts/dataloader.py +++ b/python/tarts/dataloader.py @@ -5,7 +5,8 @@ import logging import os import pickle -from typing import Any, Dict, List, Optional, Tuple +import json +from typing import Any, Dict, List, Optional, Tuple, Union # Third-party imports import numpy as np @@ -671,6 +672,10 @@ class zernikeDataset(Dataset): return_true : bool, optional, default=False Whether to return the true Zernike coefficients (`True`) or the estimated coefficients (`False`). + coral_mode : bool, optional, default=False + Whether to enable coral mode for sampling real data alongside simulations. + coral_filepath : str, optional, default='.../LSST_FULL_FRAME/aggregator_real/' + Path to the directory containing real/coral aggregator data files. Attributes ---------- @@ -707,6 +712,8 @@ def __init__( data_dir="/media/peterma/mnt2/peterma/research/LSST_FULL_FRAME/aggregator/", alpha=1e-3, return_true=False, + coral_mode=False, + coral_filepath="/media/peterma/mnt2/peterma/research/LSST_FULL_FRAME/aggregator_real/", ): """Initialize the zernikeDataset. @@ -722,8 +729,15 @@ def __init__( Parameter used for adjusting Zernike coefficients during processing. return_true : bool, optional, default=False Whether to return the true Zernike coefficients or estimated coefficients. + coral_mode : bool, optional, default=False + Whether to enable coral mode for real data sampling. + coral_filepath : str, optional + Path to the real/coral aggregator dataset directory. """ self.max_seq_length = seq_length + self.coral_mode = coral_mode + self.coral_filepath = coral_filepath + # Loop through all files and subdirectories if train: self.image_dir = data_dir + "/train" @@ -744,13 +758,153 @@ def __init__( self.num_samples = len(self.filename) self.alpha = alpha self.return_true = return_true - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Load coral files if coral_mode is enabled + if self.coral_mode: + self.coral_files = [] + coral_data_path = coral_filepath + if train: + coral_data_path += "/train" + else: + coral_data_path += "/val" + + for root, _, files in os.walk(coral_data_path): + for file in files: + file_path = os.path.join(root, file) + self.coral_files.append(file_path) + + logger.info(f"Loaded {len(self.coral_files)} coral aggregator files from {coral_data_path}") + + # Always use CPU in dataset - PyTorch Lightning handles GPU transfer + # This avoids CUDA reinitialization issues with num_workers > 0 + self.device = torch.device("cpu") def __len__(self) -> int: """Return the number of samples in the dataset.""" return self.num_samples - def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, str, str]: + def sample_coral(self) -> Dict[str, Any]: + """Sample a coral (real) aggregator data sample randomly. + + Returns + ------- + dict + Dictionary containing: + - coral_x_total: Input features tensor (seq_length, features) + - coral_mean: Mean Zernike coefficients + - coral_filter: Filter name + - coral_raftbay: Raft bay sensor name + """ + # Check if coral files are available + if not self.coral_files or len(self.coral_files) == 0: + raise RuntimeError("No coral aggregator files available for sampling.") + + # Randomly sample from coral files with retry for corrupted files + max_retries = 10 + corrupted_files = [] + + for attempt in range(max_retries): + try: + # Re-check availability in case files were deleted + if not self.coral_files: + raise RuntimeError("All coral aggregator files have been removed due to corruption.") + + idx = np.random.randint(0, len(self.coral_files)) + coral_file = self.coral_files[idx] + + # Skip already identified corrupted files + if coral_file in corrupted_files: + continue + + # Load the coral data from npz + npz_data = np.load(coral_file, allow_pickle=True) + + # Reconstruct dictionary - split stacked arrays back into lists + loaded_data = { + "estimated_zk": [torch.from_numpy(arr) for arr in npz_data["estimated_zk"]], + "zk_mean": torch.from_numpy(npz_data["zk_mean"]), + "field_x": [torch.from_numpy(arr) for arr in npz_data["field_x"]], + "field_y": [torch.from_numpy(arr) for arr in npz_data["field_y"]], + "snr": npz_data["snr"].tolist(), + "header": json.loads(str(npz_data["header_json"])), + } + break # Successfully loaded, exit retry loop + except (pickle.UnpicklingError, EOFError, IOError, OSError) as e: + # File is corrupted, truncated, or missing - delete it + logger.warning(f"Corrupted coral file detected: {coral_file}. Error: {e}. Deleting...") + try: + if os.path.exists(coral_file): + os.remove(coral_file) + logger.info(f"Deleted corrupted file: {coral_file}") + except (OSError, PermissionError) as delete_error: + logger.warning(f"Failed to delete {coral_file}: {delete_error}") + + # Remove from list to avoid trying again + if coral_file in self.coral_files: + self.coral_files.remove(coral_file) + corrupted_files.append(coral_file) + + if attempt == max_retries - 1: + # Last attempt failed, raise the error + raise RuntimeError( + f"Failed to load coral file after {max_retries} attempts. " + f"Last error: {e}. All coral files may be corrupted." + ) + # Try another random file + continue + + # Process coral data similar to __getitem__ + x = torch.stack(loaded_data["estimated_zk"]).to(self.device) / 1000 + mean = loaded_data["zk_mean"].to(self.device) + + # Track the field x/y in degrees + field_x = torch.stack(loaded_data["field_x"]) + field_y = torch.stack(loaded_data["field_y"]) + + # Load the SNR values + normalize + snr = ( + torch.tensor(loaded_data["snr"]).to(self.device)[..., None] + / torch.tensor(loaded_data["snr"]).max() + ) + + # Combine the field x/y + position = torch.concatenate([field_x, field_y], dim=-1).to(self.device) + + # Combine all into one array as an embedding + x = x.squeeze(1) # [seq_length, features] + position = position.squeeze(1) # [seq_length, 2] + x_total = torch.cat([x, position, snr], dim=1) + + # Control padding the sequence + idx_tensor = torch.randperm(x_total.size(0)) + x_total = x_total[idx_tensor] + + if x_total.shape[0] > self.max_seq_length: + x_total = x_total[: self.max_seq_length, :] + else: + padding = torch.zeros((self.max_seq_length - x_total.shape[0], x_total.shape[1])).to(self.device) + x_total = torch.cat([x_total, padding], dim=0).to(self.device).float() + + # Extract filter and raftbay info + filter_name = loaded_data["header"].get("FILTER", "unknown") + if isinstance(filter_name, str): + filter_name = filter_name.split("_")[0] + + raftbay = loaded_data["header"].get("RAFTBAY", "UNKNOWN") + "_SW0" + + coral_output = { + "coral_x_total": x_total, + "coral_mean": mean[None, ...], + "coral_filter": filter_name, + "coral_raftbay": raftbay, + } + + return coral_output + + def __getitem__(self, idx: int) -> Union[ + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, str, str], + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, str, str, torch.Tensor, torch.Tensor, str, str], + ]: """Retrieve and process a single sample from the dataset at the specified index. This method loads a data sample from the file at the given index, @@ -790,11 +944,23 @@ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tenso otherwise, it is padded with zeros to match the specified sequence length. """ - # Load dictionary from file + # Load dictionary from npz file try: - with open(self.filename[idx], "rb") as file: - loaded_data = pickle.load(file) - except (pickle.UnpicklingError, IOError, OSError) as e: + npz_data = np.load(self.filename[idx], allow_pickle=True) + + # Reconstruct dictionary - split stacked arrays back into lists + loaded_data = { + "estimated_zk": [torch.from_numpy(arr) for arr in npz_data["estimated_zk"]], + "zk_mean": torch.from_numpy(npz_data["zk_mean"]), + "field_x": [torch.from_numpy(arr) for arr in npz_data["field_x"]], + "field_y": [torch.from_numpy(arr) for arr in npz_data["field_y"]], + "snr": npz_data["snr"].tolist(), + "header": json.loads(str(npz_data["header_json"])), + } + # Add conditional fields if they exist + if "zk_true" in npz_data: + loaded_data["zk_true"] = torch.from_numpy(npz_data["zk_true"]) + except (IOError, OSError, RuntimeError, KeyError) as e: logger.error( f"Error loading file {self.filename[idx] if idx < len(self.filename) else 'unknown'}: {e}" ) @@ -828,14 +994,47 @@ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tenso padding = torch.zeros((self.max_seq_length - x_total.shape[0], x_total.shape[1])).to(self.device) x_total = torch.cat([x_total, padding], dim=0).to(self.device).float() y = loaded_data["zk_true"] + + # Prepare output dictionary + output = { + "x_total": x_total, + "mean": mean[None, ...], + "y": y, + "filter": loaded_data["header"]["FILTER"].split("_")[0], + "raftbay": loaded_data["header"]["RAFTBAY"] + "_SW0", + } + + # Sample coral data if coral_mode is enabled + if self.coral_mode: + try: + coral_output = self.sample_coral() + output.update(coral_output) + except RuntimeError as e: + logger.warning(f"Failed to sample coral data: {e}") + # Continue without coral data + # return the stack of embedings, mean zernike estimate and the true zernike in PSF - return ( - x_total, - mean[None, ...], - y, - loaded_data["header"]["FILTER"].split("_")[0], - loaded_data["header"]["RAFTBAY"] + "_SW0", - ) + # Return as tuple for backward compatibility + if self.coral_mode and "coral_x_total" in output: + return ( + output["x_total"], + output["mean"], + output["y"], + output["filter"], + output["raftbay"], + output["coral_x_total"], + output["coral_mean"], + output["coral_filter"], + output["coral_raftbay"], + ) + else: + return ( + output["x_total"], + output["mean"], + output["y"], + output["filter"], + output["raftbay"], + ) # Collate function for padding sequences @@ -852,18 +1051,29 @@ def zk_collate_fn(batch): the sample, shaped as `(1, features)`. - y (torch.Tensor) : True Zernike coefficients (target values) for the sample, shaped as `(1, features)`. + - filter (str) : Filter name. + - raftbay (str) : Raft bay sensor name. + And optionally (if coral_mode=True): + - coral_x_total (torch.Tensor) : Coral input features. + - coral_mean (torch.Tensor) : Coral mean Zernike coefficients. + - coral_filter (str) : Coral filter name. + - coral_raftbay (str) : Coral raft bay sensor name. Returns ------- tuple A tuple containing: - - (x_total, x_mean_total) : + - (x_total, x_mean_total, filter_total, chipid_total, + [coral_x_total, coral_x_mean_total, coral_filter_total, coral_chipid_total]) : - x_total (torch.Tensor) : A tensor of input features for the entire batch, shaped as `(batch_size, seq_length, features)`. - x_mean_total (torch.Tensor) : A tensor of mean Zernike coefficients for the entire batch, shaped as `(batch_size, features)`. + - filter_total (list) : List of filter names. + - chipid_total (list) : List of raft bay sensor names. + - [coral_*_total] : Optional coral data if available. - y_total (torch.Tensor) : - y_total (torch.Tensor) : A tensor of true Zernike coefficients (targets) for the entire batch, @@ -873,18 +1083,67 @@ def zk_collate_fn(batch): ----- - The resulting tensors (`x_total`, `x_mean_total`, and `y_total`) are returned in a format suitable for training a model. + - If coral_mode is enabled, coral data tensors are also included. """ - x_batch, x_mean_batch, y_batch, filter_batch, chipid_batch = zip(*batch) + # Check if batch contains coral data (9 elements) or not (5 elements) + has_coral = len(batch[0]) == 9 + + if has_coral: + ( + x_batch, + x_mean_batch, + y_batch, + filter_batch, + chipid_batch, + coral_x_batch, + coral_mean_batch, + coral_filter_batch, + coral_chipid_batch, + ) = zip(*batch) + else: + x_batch, x_mean_batch, y_batch, filter_batch, chipid_batch = zip(*batch) + x_total = torch.zeros((len(x_batch), x_batch[0].shape[0], x_batch[0].shape[1])) y_total = torch.zeros((len(y_batch), y_batch[0].shape[1])) x_mean_total = torch.zeros((len(x_mean_batch), x_mean_batch[0].shape[-1])) + # match the parallel arrays together to get the values filter_total = [] chipid_total = [] + for i, (x, x_mean, y, f, s) in enumerate(zip(x_batch, x_mean_batch, y_batch, filter_batch, chipid_batch)): x_total[i, :, :] = x y_total[i, :] = y[0, :] x_mean_total[i, :] = x_mean[0, 0, :] # <-- fix here filter_total.append(f) chipid_total.append(s) - return (x_total, x_mean_total, filter_total, chipid_total), y_total + + # Process coral data if available + if has_coral: + coral_x_total = torch.zeros( + (len(coral_x_batch), coral_x_batch[0].shape[0], coral_x_batch[0].shape[1]) + ) + coral_x_mean_total = torch.zeros((len(coral_mean_batch), coral_mean_batch[0].shape[-1])) + coral_filter_total = [] + coral_chipid_total = [] + + for i, (cx, cm, cf, cs) in enumerate( + zip(coral_x_batch, coral_mean_batch, coral_filter_batch, coral_chipid_batch) + ): + coral_x_total[i, :, :] = cx + coral_x_mean_total[i, :] = cm[0, 0, :] + coral_filter_total.append(cf) + coral_chipid_total.append(cs) + + return ( + x_total, + x_mean_total, + filter_total, + chipid_total, + coral_x_total, + coral_x_mean_total, + coral_filter_total, + coral_chipid_total, + ), y_total + else: + return (x_total, x_mean_total, filter_total, chipid_total), y_total diff --git a/python/tarts/dataset_params.yaml b/python/tarts/dataset_params.yaml index 88d52ee..be35512 100644 --- a/python/tarts/dataset_params.yaml +++ b/python/tarts/dataset_params.yaml @@ -1,4 +1,5 @@ version: 0 +seed: 42 # Random seed for reproducibility across all training scripts adjustment_WaveNet: 10 adjustment_AlignNet: 120 refinements: 1 @@ -25,13 +26,16 @@ aggregator_val_filepath: '/media/peterma/mnt2/peterma/research/LSST_FULL_FRAME/a fullframe_train_filepath: '/media/peterma/mnt2/peterma/research/LSST_FULL_FRAME/full_plane_ml_opsim_0925/train/' fullframe_val_filepath: '/media/peterma/mnt2/peterma/research/LSST_FULL_FRAME/full_plane_ml_opsim_0925/val/' fullframe_test_filepath: '/media/peterma/mnt2/peterma/research/LSST_FULL_FRAME/full_plane_ml_opsim_0925/test/' +aggregator_real_filepath: "/media/peterma/mnt2/peterma/research/LSST_FULL_FRAME/aggregator_real" +aggregator_real_train_filepath: "/media/peterma/mnt2/peterma/research/LSST_FULL_FRAME/aggregator_real/train" +aggregator_real_val_filepath: "/media/peterma/mnt2/peterma/research/LSST_FULL_FRAME/aggregator_real/val" # AggregatorNet model parameters aggregator_model: - d_model: 28 - nhead: 2 + d_model: 128 + nhead: 4 num_layers: 6 - dim_feedforward: 128 + dim_feedforward: 512 noll_zk: - 4 diff --git a/python/tarts/lightning_wavenet.py b/python/tarts/lightning_wavenet.py index f7a5729..70b2997 100644 --- a/python/tarts/lightning_wavenet.py +++ b/python/tarts/lightning_wavenet.py @@ -1,6 +1,7 @@ """Wrapping everything for WaveNet in Pytorch Lightning.""" # Standard library imports +import logging from typing import Any, Dict, Tuple # Third-party imports @@ -14,6 +15,7 @@ from .constants import ( BAND_MEAN, BAND_STD, + BAND_VALUES_TENSOR, CAMERA_TYPE, DEFAULT_INPUT_SHAPE, DEG_TO_RAD, @@ -21,11 +23,14 @@ FIELD_STD, INTRA_MEAN, INTRA_STD, + ZERNIKE_SCALE_FACTOR, ) from .dataloader import Donuts, Donuts_Fullframe -from .utils import convert_zernikes +from .utils import convert_zernikes_deploy from .wavenet import WaveNet +logger = logging.getLogger(__name__) + class DonutLoader(pl.LightningDataModule): """Pytorch Lightning wrapper for the simulated Donuts DataSet.""" @@ -212,7 +217,7 @@ def predict_step(self, batch: Dict[str, Any], batch_idx: int) -> Tuple[torch.Ten fy = batch["field_y"] intra = batch["intrafocal"] band = batch["band"] - zk_true = batch["zernikes"].cuda() + zk_true = batch["zernikes"].to(self.device_val) # dof_true = batch["dof"] # noqa: F841 # predict zernikes @@ -230,22 +235,42 @@ def calc_losses(self, batch: Dict[str, Any], batch_idx: int) -> Tuple[torch.Tens The mRSSE provides an estimate of the PSF degradation. """ - # predict zernikes - zk_pred, zk_true = self.predict_step(batch, batch_idx) - - # convert to FWHM contributions - zk_pred = convert_zernikes(zk_pred) - zk_true = convert_zernikes(zk_true) - - # pull out the weights from the final linear layer - *_, A, _ = self.wavenet.predictor.parameters() - - # calculate loss - sse = F.mse_loss(zk_pred, zk_true, reduction="none").sum(dim=-1) - loss = sse.mean() + self.hparams.alpha * A.square().sum() - mRSSE = torch.sqrt(sse).mean() - - return loss, mRSSE + try: + # predict zernikes + zk_pred, zk_true = self.predict_step(batch, batch_idx) + + # Check for NaN or Inf values in predictions/truth + if torch.any(torch.isnan(zk_pred)) or torch.any(torch.isnan(zk_true)): + logger.warning("NaN detected in predictions or truth values, returning zero loss") + return torch.tensor(0.0, device=self.device_val), torch.tensor(0.0, device=self.device_val) + if torch.any(torch.isinf(zk_pred)) or torch.any(torch.isinf(zk_true)): + logger.warning("Inf detected in predictions or truth values, returning zero loss") + return torch.tensor(0.0, device=self.device_val), torch.tensor(0.0, device=self.device_val) + + # convert to FWHM contributions + zk_pred = convert_zernikes_deploy(zk_pred) + zk_true = convert_zernikes_deploy(zk_true) + + # pull out the weights from the final linear layer + *_, A, _ = self.wavenet.predictor.parameters() + + # calculate loss + sse = F.mse_loss(zk_pred, zk_true, reduction="none").sum(dim=-1) + loss = sse.mean() + self.hparams.alpha * A.square().sum() + mRSSE = torch.sqrt(sse).mean() + + # Check for NaN or Inf in computed losses + if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)): + logger.warning("NaN/Inf detected in computed loss, returning zero loss") + return torch.tensor(0.0, device=self.device_val), torch.tensor(0.0, device=self.device_val) + if torch.any(torch.isnan(mRSSE)) or torch.any(torch.isinf(mRSSE)): + logger.warning("NaN/Inf detected in computed mRSSE, returning zero mRSSE") + mRSSE = torch.tensor(0.0, device=self.device_val) + + return loss, mRSSE + except (RuntimeError, ValueError, IndexError) as e: + logger.warning(f"Error in calc_losses: {e}, returning zero loss") + return torch.tensor(0.0, device=self.device_val), torch.tensor(0.0, device=self.device_val) def calc_losses_pure( self, batch: Dict[str, Any], batch_idx: int @@ -259,22 +284,55 @@ def calc_losses_pure( The mRSSE provides an estimate of the PSF degradation. """ - # predict zernikes - zk_pred, zk_true = self.predict_step(batch, batch_idx) - - # convert to FWHM contributions - zk_pred = convert_zernikes(zk_pred) - zk_true = convert_zernikes(zk_true) - - # pull out the weights from the final linear layer - *_, A, _ = self.wavenet.predictor.parameters() - - # calculate loss - sse = F.mse_loss(zk_pred, zk_true, reduction="none").sum(dim=-1) - loss = sse.mean() + self.hparams.alpha * A.square().sum() - mRSSE = torch.sqrt(sse).mean() - - return loss, mRSSE, zk_pred, zk_true + try: + # predict zernikes + zk_pred, zk_true = self.predict_step(batch, batch_idx) + + # Check for NaN or Inf values in predictions/truth + if torch.any(torch.isnan(zk_pred)) or torch.any(torch.isnan(zk_true)): + logger.warning("NaN detected in predictions or truth values, returning zero loss") + zero_tensor = torch.tensor(0.0, device=self.device_val) + return zero_tensor, zero_tensor, zk_pred, zk_true + if torch.any(torch.isinf(zk_pred)) or torch.any(torch.isinf(zk_true)): + logger.warning("Inf detected in predictions or truth values, returning zero loss") + zero_tensor = torch.tensor(0.0, device=self.device_val) + return zero_tensor, zero_tensor, zk_pred, zk_true + + # convert to FWHM contributions + zk_pred = convert_zernikes_deploy(zk_pred) + zk_true = convert_zernikes_deploy(zk_true) + + # pull out the weights from the final linear layer + *_, A, _ = self.wavenet.predictor.parameters() + + # calculate loss + sse = F.mse_loss(zk_pred, zk_true, reduction="none").sum(dim=-1) + loss = sse.mean() + self.hparams.alpha * A.square().sum() + mRSSE = torch.sqrt(sse).mean() + + # Check for NaN or Inf in computed losses + if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)): + logger.warning("NaN/Inf detected in computed loss, returning zero loss") + zero_tensor = torch.tensor(0.0, device=self.device_val) + return zero_tensor, zero_tensor, zk_pred, zk_true + if torch.any(torch.isnan(mRSSE)) or torch.any(torch.isinf(mRSSE)): + logger.warning("NaN/Inf detected in computed mRSSE, returning zero mRSSE") + mRSSE = torch.tensor(0.0, device=self.device_val) + + return loss, mRSSE, zk_pred, zk_true + except (RuntimeError, ValueError, IndexError) as e: + logger.warning(f"Error in calc_losses_pure: {e}, returning zero loss") + zero_tensor = torch.tensor(0.0, device=self.device_val) + # Return zero loss but preserve predictions for debugging + try: + zk_pred, zk_true = self.predict_step(batch, batch_idx) + return zero_tensor, zero_tensor, zk_pred, zk_true + except Exception: + # If even predict_step fails, create dummy tensors + dummy_shape = (batch.get("zernikes", torch.tensor([[]])).shape[0],) + dummy_pred = torch.zeros(dummy_shape, device=self.device_val) + dummy_true = torch.zeros(dummy_shape, device=self.device_val) + return zero_tensor, zero_tensor, dummy_pred, dummy_true def training_step(self, batch: Dict[str, Any], batch_idx: int) -> torch.Tensor: """Execute training step on a batch.""" @@ -320,12 +378,7 @@ def get_band_values(self, bands: torch.Tensor) -> torch.Tensor: Returns: torch.Tensor: A tensor of shape (batch_size, 1) with band values. """ - # Create a tensor with band values - band_values = torch.tensor([[0.3671], [0.4827], [0.6223], [0.7546], [0.8691], [0.9712]]).to( - self.device_val - ) - - return band_values[bands] + return BAND_VALUES_TENSOR.to(self.device_val)[bands] def rescale_image(self, data): """Rescale image data to the range [0, 1]. @@ -399,6 +452,6 @@ def forward( zk_pred = self.wavenet(img, fx, fy, focalFlag, band) # convert to nanometers - zk_pred *= 1_000 + zk_pred *= ZERNIKE_SCALE_FACTOR return zk_pred diff --git a/python/tarts/lightning_wavenet_coral.py b/python/tarts/lightning_wavenet_coral.py index ceea0af..472ec20 100644 --- a/python/tarts/lightning_wavenet_coral.py +++ b/python/tarts/lightning_wavenet_coral.py @@ -51,6 +51,7 @@ def __init__( alpha: float = 0, lr: float = 1e-3, lr_schedule: bool = False, + weight_decay: float = 1e-4, device: str = "cuda", pretrained: bool = False, tradeoff_angle: float = 0.05, @@ -73,9 +74,11 @@ def __init__( alpha: float, default=0 Weight for the L2 penalty. lr: float, default=1e-3 - The initial learning rate for Adam. + The initial learning rate for AdamW. lr_schedule: bool, default=False Whether to use the ReduceLROnPlateau learning rate scheduler. + weight_decay: float, default=1e-4 + The weight decay (L2 penalty) coefficient for the AdamW optimizer. device: str, default='cuda' The device to use for computation ('cuda' or 'cpu'). pretrained: bool, default=False @@ -373,7 +376,9 @@ def validation_step(self, batch: Dict[str, Any], batch_idx: int) -> torch.Tensor def configure_optimizers(self) -> Any: """Configure the optimizer.""" - optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.lr, weight_decay=1e-4) + optimizer = torch.optim.AdamW( + self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay + ) if self.hparams.lr_schedule: return {