diff --git a/mart/models/modular.py b/mart/models/modular.py index 605b237c..67cb9d5e 100644 --- a/mart/models/modular.py +++ b/mart/models/modular.py @@ -160,7 +160,16 @@ def training_step(self, batch, batch_idx): # We need to manually log loss on the progress bar in newer PL. self.log("loss", loss, prog_bar=True) - return loss + if self.output_loss_key != "loss": + if "loss" in output: + raise ValueError( + 'The key "loss" is preserved in Pytorch Lightning as the training loss. Please change the module name if it does not output the training loss.' + ) + else: + output["loss"] = output[self.output_loss_key] + + # Look out if we will have memory leak by returning the whole dictionary with attached tensors. + return output def on_train_epoch_end(self): if self.training_metrics is not None: @@ -186,7 +195,7 @@ def validation_step(self, batch, batch_idx): self.validation_metrics(output[self.output_preds_key], output[self.output_target_key]) - return None + return output def on_validation_epoch_end(self): metrics = self.validation_metrics.compute() @@ -210,7 +219,7 @@ def test_step(self, batch, batch_idx): self.test_metrics(output[self.output_preds_key], output[self.output_target_key]) - return None + return output def on_test_epoch_end(self): metrics = self.test_metrics.compute()