-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtest.py
More file actions
64 lines (47 loc) · 1.78 KB
/
test.py
File metadata and controls
64 lines (47 loc) · 1.78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
from torch import nn
from data import mnist_generator, cifar_generator
from xox import XOXHyperNetwork
from hyper import RandomBasisHyperNetwork
from utils import reset_parameters
from train import train
'''
This file demonstrates training of a single MLP to solve MNIST, then followed
by training of a hypernetwork-controlled MLP on the same task. Accuracy is lower for
the hypernetwork approach (90% instead of 95%), but still reasonable.
Output should look something like this:
Training 4 weights with total 50890 parameters:
[(64, 784), (64,), (10, 64), (10,)]
500 0.108
1000 0.064
1500 0.040
2000 0.030
final accuracy = 0.945
Hypernetwork: 50890 -> 8346
Training 5 weights with total 8346 parameters:
[(784, 9), (64, 9), (64,), (10, 64), (10,)]
500 0.275
1000 0.191
1500 0.156
2000 0.140
final accuracy = 0.906
'''
# making a simple MLP
net = nn.Sequential(
nn.Linear(28*28, 64),
nn.ReLU(),
nn.Linear(64, 10)
)
net = nn.Linear(28*28, 10)
# train the MLP on mnist for 2000 batches (this should achieve around 95% final accuracy)
# train(net, mnist_generator, 2000, title='ordinary')
# reset the weights and biases in the net to random values
reset_parameters(net)
# create a hyper network that produces the weights and biases of the network
hyper = XOXHyperNetwork(net, num_genes=10, fix_gene_matrices=True, fix_o_matrix=True)
# train the network via the hypernetwork (this should achieve around 90% final accuracy)
train(net, mnist_generator, 10000, title='xox', hyper_net=hyper)
# create a hyper network that produces the weights and biases of the network
hyper = RandomBasisHyperNetwork(net, ndims=400)
# train the network via the hypernetwork (this should achieve around 90% final accuracy)
train(net, mnist_generator, 2000, title='random', hyper_net=hyper)