-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathpose_discriminator.py
More file actions
38 lines (27 loc) · 1.04 KB
/
pose_discriminator.py
File metadata and controls
38 lines (27 loc) · 1.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch.nn as nn
class Pos2dDiscriminator(nn.Module):
def __init__(self, num_joints=16):
super(Pos2dDiscriminator, self).__init__()
# Pose path
self.pose_layer_1 = nn.Linear(num_joints * 2, 100)
self.pose_layer_2 = nn.Linear(100, 100)
self.pose_layer_3 = nn.Linear(100, 100)
self.pose_layer_4 = nn.Linear(100, 100)
self.layer_last = nn.Linear(100, 100)
self.layer_pred = nn.Linear(100, 1)
self.relu = nn.LeakyReLU()
def init_weights_real(m):
if isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight)
def init_weights(self):
self.apply(self.init_weights_real)
def forward(self, x):
# Pose path
x = x.contiguous().view(x.size(0), -1)
d1 = self.relu(self.pose_layer_1(x))
d2 = self.relu(self.pose_layer_2(d1))
d3 = self.relu(self.pose_layer_3(d2) + d1)
d4 = self.pose_layer_4(d3)
d_last = self.relu(self.layer_last(d4))
d_out = self.layer_pred(d_last)
return d_out