forked from singhgautam/slate
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
104 lines (65 loc) · 2.74 KB
/
utils.py
File metadata and controls
104 lines (65 loc) · 2.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
def gumbel_max(logits, dim=-1):
eps = torch.finfo(logits.dtype).tiny
gumbels = -(torch.empty_like(logits).exponential_() + eps).log()
gumbels = logits + gumbels
return gumbels.argmax(dim)
def gumbel_softmax(logits, tau=1., hard=False, dim=-1):
eps = torch.finfo(logits.dtype).tiny
gumbels = -(torch.empty_like(logits).exponential_() + eps).log()
gumbels = (logits + gumbels) / tau
y_soft = F.softmax(gumbels, dim)
if hard:
index = y_soft.argmax(dim, keepdim=True)
y_hard = torch.zeros_like(logits).scatter_(dim, index, 1.)
return y_hard - y_soft.detach() + y_soft
else:
return y_soft
def log_prob_gaussian(value, mean, std):
var = std ** 2
if isinstance(var, float):
return -0.5 * (((value - mean) ** 2) / var + math.log(var) + math.log(2 * math.pi))
else:
return -0.5 * (((value - mean) ** 2) / var + var.log() + math.log(2 * math.pi))
def conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0,
dilation=1, groups=1, bias=True, padding_mode='zeros',
weight_init='xavier'):
m = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding,
dilation, groups, bias, padding_mode)
if weight_init == 'kaiming':
nn.init.kaiming_uniform_(m.weight, nonlinearity='relu')
else:
nn.init.xavier_uniform_(m.weight)
if bias:
nn.init.zeros_(m.bias)
return m
class Conv2dBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
super().__init__()
self.m = conv2d(in_channels, out_channels, kernel_size, stride, padding,
bias=False, weight_init='kaiming')
self.weight = nn.Parameter(torch.ones(out_channels))
self.bias = nn.Parameter(torch.zeros(out_channels))
def forward(self, x):
x = self.m(x)
return F.relu(F.group_norm(x, 1, self.weight, self.bias))
def linear(in_features, out_features, bias=True, weight_init='xavier', gain=1.):
m = nn.Linear(in_features, out_features, bias)
if weight_init == 'kaiming':
nn.init.kaiming_uniform_(m.weight, nonlinearity='relu')
else:
nn.init.xavier_uniform_(m.weight, gain)
if bias:
nn.init.zeros_(m.bias)
return m
def gru_cell(input_size, hidden_size, bias=True):
m = nn.GRUCell(input_size, hidden_size, bias)
nn.init.xavier_uniform_(m.weight_ih)
nn.init.orthogonal_(m.weight_hh)
if bias:
nn.init.zeros_(m.bias_ih)
nn.init.zeros_(m.bias_hh)
return m