Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions src/deepdiagnostics/plots/lossPlot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import matplotlib.pyplot as plt
from typing import Dict, List

from deepdiagnostics.plots.plot import Display

class LossPlot(Display):
def __init__(self, model=None, data=None):
...

def _data_setup(self):
raise NotImplementedError

def plot_name(self):
return "loss_curves.png"

def plot(
self,
training_history: Dict[str, List[float]],
epochs: int = None,
best_val_loss: float = None
):
try:
training_loss_data = training_history["train_loss"]
validation_loss_data = training_history["val_loss"]
except Exception as e:
raise KeyError(f"Key {e} not found in supplied training history")

if epochs is None:
if len(training_loss_data) != len(validation_loss_data):
raise ValueError("Inconsistent training history data supplied [length of train and validation losses not equal]")
epochs = len(training_loss_data)
print(f"Number of epochs determined: {epochs}")
else:
if epochs != len(training_loss_data) or epochs != len(validation_loss_data):
raise ValueError("Epochs supplied inconsistent with training history data [epochs and loss history not equal]")

if best_val_loss is None:
best_val_loss = min(validation_loss_data)
print(f"Best validation loss found: {best_val_loss}")

epochs_trained = [ x for x in range(1, epochs+1) ]
plt.plot(epochs_trained, training_loss_data, label='Training Loss')
plt.plot(epochs_trained, validation_loss_data, label='Validation Loss')
plt.axhline(y=best_val_loss, color='m', linestyle='--', label="Best Val. Loss")

plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training and Validation Loss Over Epochs')

plt.legend()
plt.grid()
plt.show()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two things:

  • Please use the super _finish to either save or show the plot
  • Please return the figure, axes object so the user can edit them outside of the plot method if they want


def __call__(self, **kwargs):
raise NotImplementedError("Plotting loss is not supported in pipeline mode")
Loading