Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions mart/models/modular.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,16 @@ def training_step(self, batch, batch_idx):
# We need to manually log loss on the progress bar in newer PL.
self.log("loss", loss, prog_bar=True)

return loss
if self.output_loss_key != "loss":
if "loss" in output:
raise ValueError(
'The key "loss" is preserved in Pytorch Lightning as the training loss. Please change the module name if it does not output the training loss.'
)
else:
output["loss"] = output[self.output_loss_key]

# Look out if we will have memory leak by returning the whole dictionary with attached tensors.
return output

def on_train_epoch_end(self):
if self.training_metrics is not None:
Expand All @@ -186,7 +195,7 @@ def validation_step(self, batch, batch_idx):

self.validation_metrics(output[self.output_preds_key], output[self.output_target_key])

return None
return output

def on_validation_epoch_end(self):
metrics = self.validation_metrics.compute()
Expand All @@ -210,7 +219,7 @@ def test_step(self, batch, batch_idx):

self.test_metrics(output[self.output_preds_key], output[self.output_target_key])

return None
return output

def on_test_epoch_end(self):
metrics = self.test_metrics.compute()
Expand Down
Loading