-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
86 lines (64 loc) · 2.47 KB
/
train.py
File metadata and controls
86 lines (64 loc) · 2.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from matplotlib import pyplot as plt
import numpy as np
from tqdm import tqdm
from ResidualMLP import ResidualMLP
from data import MNISTDataset, DataLoader
import autograd as ag
import optim
import nn
def train_one_epoch(dataloader, model, loss_fn, optimizer):
model.train()
total_loss, num, acc = 0.0, 0, 0
for x, y in tqdm(dataloader):
logits = model(x)
loss = loss_fn(logits, y)
acc += np.sum(logits.numpy().argmax(axis=1) == y.numpy())
total_loss += loss * x.shape[0]
num += x.shape[0]
optimizer.reset_grad()
loss.backward()
optimizer.step()
return (total_loss / num).numpy().item(), (acc / num).item()
def test(dataloader, model, loss_fn):
model.eval()
total_loss, num, acc = 0.0, 0, 0
for x, y in tqdm(dataloader):
logits = model(x)
loss = loss_fn(logits, y)
acc += np.sum(logits.numpy().argmax(axis=1) == y.numpy())
total_loss += loss * x.shape[0]
num += x.shape[0]
return (total_loss / num).numpy().item(), (acc / num).item()
if __name__ == '__main__':
batch_size = 100
lr = 1e-3
epochs = 50
num_classes = 10
model = ResidualMLP(dim=784, hidden_dim=128, num_classes=num_classes)
train_set = MNISTDataset("MNIST/train-images-idx3-ubyte.gz", "MNIST/train-labels-idx1-ubyte.gz")
test_set = MNISTDataset("MNIST/t10k-images-idx3-ubyte.gz", "MNIST/t10k-labels-idx1-ubyte.gz")
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_set, batch_size=batch_size)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9 ,weight_decay=1e-4)
scheduler = optim.StepDecay(optimizer, 5, 0.9)
x = list(range(1,epochs+1))
losses, accs = [], []
for i in range(epochs):
train_loss, train_acc = train_one_epoch(train_loader, model, loss_fn, optimizer)
losses.append(train_loss)
accs.append(train_acc)
scheduler.step()
print('epoch {}: train_loss = {:.4f} | train_acc = {:.4f}'.format(i+1, train_loss, train_acc))
plt.figure(1)
plt.plot(x, losses)
plt.xlabel('epochs')
plt.ylabel('train loss')
plt.savefig('loss.jpg')
plt.figure(2)
plt.plot(x, accs)
plt.xlabel('epochs')
plt.ylabel('train acc')
plt.savefig('acc.jpg')
test_loss, test_acc = test(test_loader, model, loss_fn)
print('Test: test_loss = {:.4f} | test_acc = {:.4f}'.format(test_loss, test_acc))