-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathModel.py
More file actions
78 lines (67 loc) · 2.16 KB
/
Model.py
File metadata and controls
78 lines (67 loc) · 2.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import matplotlib.image as image
import numpy as np
import os
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import torchvision.datasets as dset
import torch.nn as nn
import torch.optim as optim
from torch import nn
import math
class VGG(nn.Module):
def __init__(self, features, num_classes=100, init_weights=True):
super(VGG, self).__init__()
self.features = features
self.classifier = nn.Sequential(
nn.Linear(8192, 8192),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(8192, 8192),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(8192, num_classes)
)
if init_weights:
self._initialize_weights()
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0, 0.01)
m.bias.data.zero_()
def make_layers(cfg):
layers = []
in_channels = 3
for v in cfg:
if v == 'M':
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
layers += [conv2d, nn.ReLU(True)]
in_channels = v
return nn.Sequential(*layers)
cfg = {
'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512]
}
def vgg19():
model = VGG(make_layers(cfg['E']))
return model
def vgg19_bn():
model = VGG(make_layers(cfg['E'], batch_norm=True))
return model