-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplotting.py
More file actions
48 lines (41 loc) · 1.29 KB
/
plotting.py
File metadata and controls
48 lines (41 loc) · 1.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import os
# Load Model
network_path = 'models/fused_epoch6'
if torch.cuda.is_available():
checkpoint = torch.load(network_path)
else:
checkpoint = torch.load(network_path, map_location=torch.device('cpu'))
# Load Saved Data
num_epochs = checkpoint['epoch']
train_total_loss_list = checkpoint['train_total_loss_list']
epoch_total_loss_list = checkpoint['epoch_total_loss_list']
test_loss_list = checkpoint['test_loss_list']
train_counter = checkpoint['train_counter']
accuracy_list = checkpoint['accuracy_list']
epoch_list = np.arange(num_epochs+1)
# Training Loss
fig = plt.figure()
plt.plot(epoch_list, epoch_total_loss_list, color='blue')
plt.legend(['FuseNet Train Loss'], loc='upper right')
plt.xlabel('Epochs')
plt.ylabel('Total Loss')
# Test Loss
fig = plt.figure()
plt.plot(epoch_list, test_loss_list, color='red')
plt.legend(['Validation Loss'], loc='upper right')
plt.xlabel('Epochs')
plt.ylabel('Total Loss')
# Accuracy
fig = plt.figure()
plt.plot(epoch_list, accuracy_list, color='red')
plt.legend(['FuseNet Validation Accuracy'], loc='lower right')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.show()