-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathCMAPSS_TrainLoop.py
More file actions
70 lines (53 loc) · 2.38 KB
/
CMAPSS_TrainLoop.py
File metadata and controls
70 lines (53 loc) · 2.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def Train(train_loader, test_loader, unshuffle_train_loader, finaltest_loader, model, optimizer, loss_func, num_epochs=100):
train_loss_epoch = []
test_loss_epoch = []
train_output = []
test_output = []
finaltest_output = []
print("\tEpoch | \tTrain Loss | \tTest Loss")
for epoch in range(num_epochs):
running_loss_tr = 0
running_loss_te = 0
batch_counter_tr = 0
batch_counter_te = 0
model.train()
for i, (data_tr, label_tr) in enumerate(train_loader):
batch_counter_tr += 1
optimizer.zero_grad()
output_tr = model(data_tr.float())
loss_tr = loss_func(output_tr, label_tr)
loss_tr.backward()
optimizer.step()
running_loss_tr += loss_tr.item()
if epoch == num_epochs-1:
train_output += output_tr.flatten().tolist()
epoch_loss_tr = running_loss_tr / batch_counter_tr
train_loss_epoch.append(epoch_loss_tr)
if epoch == num_epochs-1:
if unshuffle_train_loader is not None:
train_output = []
for i, (data_uns_tr, lable_un_tr) in enumerate(unshuffle_train_loader):
output_uns_tr = model(data_uns_tr.float())
train_output += output_uns_tr.flatten().tolist()
else:
pass
else:
pass
model.eval()
for i, (data_te, label_te) in enumerate(test_loader):
batch_counter_te +=1
output_te = model(data_te.float())
loss_te = loss_func(output_te, label_te)
running_loss_te += loss_te.item()
if epoch == num_epochs-1:
test_output += output_te.flatten().tolist()
else:
pass
epoch_loss_te = running_loss_te / batch_counter_te
test_loss_epoch.append(epoch_loss_te)
print("\n\t{} \t{} \t{}".format(epoch+1, epoch_loss_tr, epoch_loss_te))
model.eval()
for i, (data_fte) in enumerate(finaltest_loader):
output_fte= model(data_fte.float())
finaltest_output += output_fte.flatten().tolist()
return train_loss_epoch, test_loss_epoch, train_output, test_output, finaltest_output