-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathmodels.py
More file actions
90 lines (76 loc) · 2.76 KB
/
models.py
File metadata and controls
90 lines (76 loc) · 2.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch
import torch.nn as nn
import torch.nn.functional as F
# Fully connected network, size: input_size, hidden_size,... , output_size
class MLP(nn.Module):
def __init__(self, size, act='sigmoid'):
super(type(self), self).__init__()
self.num_layers = len(size) - 1
lower_modules = []
for i in range(self.num_layers - 1):
lower_modules.append(nn.Linear(size[i], size[i+1]))
if act == 'relu':
lower_modules.append(nn.ReLU())
elif act == 'sigmoid':
lower_modules.append(nn.Sigmoid())
else:
raise ValueError("%s activation layer hasn't been implemented in this code" %act)
self.layer_1 = nn.Sequential(*lower_modules)
self.layer_2 = nn.Linear(size[-2], size[-1])
def forward(self, x):
o = self.layer_1(x)
o = self.layer_2(o)
return o
class SplitMLP(nn.Module):
def __init__(self, size, act='sigmoid'):
super(type(self), self).__init__()
self.num_layers = len(size) - 1
lower_modules = []
for i in range(self.num_layers - 1):
lower_modules.append(nn.Linear(size[i], size[i+1]))
if act == 'relu':
lower_modules.append(nn.ReLU())
elif act == 'sigmoid':
lower_modules.append(nn.Sigmoid())
else:
raise ValueError("%s activation layer hasn't been implemented in this code" %act)
self.layer_1 = nn.Sequential(*lower_modules)
self.layer_2 = nn.Linear(size[-2], size[-1])
def forward(self, x, label_set):
o = self.layer_1(x)
o = self.layer_2(o)
o = o[:, label_set]
return o
class CifarNet(nn.Module):
def __init__(self, in_channels, out_channels):
super(type(self), self).__init__()
self.conv_block = nn.Sequential(
nn.Conv2d(in_channels, 32, 3, padding=1),
nn.ReLU(),
nn.Conv2d(32, 32, 3),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Dropout(p=0.25),
nn.Conv2d(32, 64, 3, padding=1),
nn.ReLU(),
nn.Conv2d(64, 64, 3),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Dropout(p=0.25)
)
self.linear_block = nn.Sequential(
nn.Linear(64*6*6, 512),
nn.ReLU(),
nn.Dropout(p=0.5)
)
self.out_block = nn.Linear(512, out_channels)
def weight_init(self):
nn.init.constant_(self.out_block.weight, 0)
nn.init.constant_(self.out_block.bias, 0)
def forward(self, x, label_set):
o = self.conv_block(x)
o = torch.flatten(o, 1)
o = self.linear_block(o)
o = self.out_block(o)
o = o[:, label_set]
return o