Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions python/tarts/NeuralActiveOpticsSys.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def __init__(
# Always use checkpoint loading - the pretrained parameter doesn't matter
# when loading from checkpoint
self.wavenet_model = WaveNetSystem.load_from_checkpoint(
wavenet_path, map_location=str(self.device_val)
wavenet_path, map_location=str(self.device_val), strict=False
).to(self.device_val)

if alignet_path is None:
Expand All @@ -150,7 +150,7 @@ def __init__(
# Always use checkpoint loading - the pretrained parameter doesn't matter
# when loading from checkpoint
self.alignnet_model = AlignNetSystem.load_from_checkpoint(
alignet_path, map_location=str(self.device_val)
alignet_path, map_location=str(self.device_val), strict=False
).to(self.device_val)

self.max_seq_length = params["max_seq_len"]
Expand All @@ -169,9 +169,9 @@ def __init__(
max_seq_length=self.max_seq_length,
).to(self.device_val)
else:
self.aggregatornet_model = AggregatorNet.load_from_checkpoint(aggregatornet_path).to(
self.device_val
)
self.aggregatornet_model = AggregatorNet.load_from_checkpoint(
aggregatornet_path, strict=False
).to(self.device_val)

if final_layer is not None:
layers = [
Expand Down
97 changes: 24 additions & 73 deletions python/tarts/aggregatornet.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

# Local/application imports
from .utils import convert_zernikes_deploy
from .utils import zernikes_to_dof_torch, dof_to_zernikes_torch


class AggregatorNet(pl.LightningModule):
Expand Down Expand Up @@ -91,6 +90,12 @@ def __init__(
"""
super().__init__()
self.save_hyperparameters() # Save model hyperparameters

# Input projection layer: (num_zernikes + 3) -> d_model
# The +3 accounts for field_x, field_y, and snr features
input_dim = num_zernikes + 3
self.input_proj = nn.Linear(input_dim, d_model)

encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=nhead,
Expand All @@ -110,7 +115,7 @@ def forward(self, x: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
x : tuple of (torch.Tensor, torch.Tensor)
A tuple where:
- x[0] (torch.Tensor): The input sequence tensor of shape
(batch_size, seq_length, d_model).
(batch_size, seq_length, num_zernikes + 3).
- x[1] (torch.Tensor): The mean tensor used for output adjustment.

Returns
Expand All @@ -120,14 +125,18 @@ def forward(self, x: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:

Notes
-----
- The transformer encoder processes the first element of the tuple.
- Input features are first projected from (num_zernikes + 3) to d_model dimensions.
- The transformer encoder processes the projected features.
- The last token's output is extracted and passed through a linear
layer.
- The mean correction (second element) is added to the final output.

"""
x_input, mean = x
x_tensor = self.transformer_encoder(x_input)
# Project input features to d_model dimensions
x_projected = self.input_proj(x_input)
# Pass through transformer
x_tensor = self.transformer_encoder(x_projected)
x_tensor = x_tensor[:, -1, :] # Take the last token's output
x_tensor = self.fc(x_tensor) # Predict the next token
x_tensor += mean
Expand Down Expand Up @@ -161,41 +170,11 @@ def training_step(self, batch: tuple, batch_idx: int):
- The training loss is logged for monitoring.

"""
if not self.zk_dof_zk:
x, y = batch # y is the target token
x_input, x_mean, filter_name, chipid = x
logits = self.forward((x_input, x_mean))
loss = self.loss_fn(logits, y)
self.log("train_loss", loss, prog_bar=True)
else:
x, y = batch # y is the target token
x_input, x_mean, filter_name, chipid = x
logits = self.forward((x_input, x_mean))
new_logits = torch.zeros_like(logits)
for i in range(len(filter_name)):
filter_name_i = filter_name[i]
sensor_names = chipid[i]
print("old logits", logits[0, :])
x_dof = zernikes_to_dof_torch(
filter_name=filter_name_i,
measured_zk=logits[i][None, :],
sensor_names=[sensor_names],
rotation_angle=0.0,
device=self.device,
verbose=False,
)
new_logits[i, :] = dof_to_zernikes_torch(
filter_name=filter_name_i,
x_dof=x_dof,
sensor_names=[sensor_names],
rotation_angle=0.0,
device=self.device,
verbose=False,
)
print("new_logits", new_logits[0, :])

loss = self.loss_fn(new_logits, y)
self.log("train_loss", loss, prog_bar=True)
x, y = batch # y is the target token
x_input, x_mean, filter_name, chipid = x
logits = self.forward((x_input, x_mean))
loss = self.loss_fn(logits, y)
self.log("train_loss", loss, prog_bar=True)
return loss

def validation_step(self, batch, batch_idx):
Expand Down Expand Up @@ -226,40 +205,12 @@ def validation_step(self, batch, batch_idx):
- The validation loss is logged for monitoring.

"""
if not self.zk_dof_zk:
x, y = batch # y is the target token
x_input, x_mean, filter_name, chipid = x
logits = self.forward((x_input, x_mean))
loss = self.loss_fn(logits, y)
self.log("val_loss", loss, prog_bar=True)
self.log("val_mRSSE", loss, prog_bar=True) # mRSSE is the same as loss for this model
else:
x, y = batch # y is the target token
x_input, x_mean, filter_name, chipid = x
logits = self.forward((x_input, x_mean))
new_logits = torch.zeros_like(logits)
for i in range(len(filter_name)):
filter_name_i = filter_name[i]
sensor_names = chipid[i]
x_dof = zernikes_to_dof_torch(
filter_name=filter_name_i,
measured_zk=logits[i][None, :],
sensor_names=[sensor_names],
rotation_angle=0.0,
device=self.device,
verbose=False,
)
new_logits[i, :] = dof_to_zernikes_torch(
filter_name=filter_name_i,
x_dof=x_dof,
sensor_names=[sensor_names],
rotation_angle=0.0,
device=self.device,
verbose=False,
)
loss = self.loss_fn(new_logits, y)
self.log("val_loss", loss, prog_bar=True)
self.log("val_mRSSE", loss, prog_bar=True) # mRSSE is the same as loss for this model
x, y = batch # y is the target token
x_input, x_mean, filter_name, chipid = x
logits = self.forward((x_input, x_mean))
loss = self.loss_fn(logits, y)
self.log("val_loss", loss, prog_bar=True)
self.log("val_mRSSE", loss, prog_bar=True) # mRSSE is the same as loss for this model
return loss

def loss_fn(self, x, y):
Expand Down
Loading