-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinit.py
More file actions
57 lines (37 loc) · 1.83 KB
/
init.py
File metadata and controls
57 lines (37 loc) · 1.83 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import math
import numpy as np
from autograd import Tensor
def rand(*shape, low=0.0, high=1.0, dtype="float32", requires_grad=False):
array = np.random.rand(*shape) * (high - low) + low
return Tensor(array, dtype=dtype, requires_grad=requires_grad)
def randn(*shape, mean=0.0, std=1.0, dtype="float32", requires_grad=False):
array = np.random.randn(*shape) * std + mean
return Tensor(array, dtype=dtype, requires_grad=requires_grad)
def constant(*shape, c=1.0, dtype="float32", requires_grad=False):
array = np.ones(*shape, dtype=dtype) * c
return Tensor(array, dtype=dtype, requires_grad=requires_grad)
def zeros(*shape, dtype="float32", requires_grad=False):
return constant(*shape, c=0.0, dtype=dtype, requires_grad=requires_grad)
def ones(*shape, dtype="float32", requires_grad=False):
return constant(*shape, c=1.0, dtype=dtype, requires_grad=requires_grad)
def one_hot(num, idx, dtype="float32", requires_grad=False):
array = np.eye(num)[idx]
return Tensor(array, dtype=dtype, requires_grad=requires_grad)
def kaiming_uniform(fan_in, fan_out, mode='fan_in', **kwargs):
if mode == 'fan_in':
bound = math.sqrt(6 / fan_in)
else:
bound = math.sqrt(6 / fan_out)
return rand(fan_in, fan_out, low=-bound, high=bound, **kwargs)
def kaiming_normal(fan_in, fan_out, mode='fan_in', **kwargs):
if mode == 'fan_in':
std = math.sqrt(2 / fan_in)
else:
std = math.sqrt(2 / fan_out)
return randn(fan_in, fan_out, mean=0, std=std, **kwargs)
def xavier_uniform(fan_in, fan_out, gain, **kwargs):
a = gain * math.sqrt(6 / (fan_in + fan_out))
return rand(fan_in, fan_out, low=-a, high=a, **kwargs)
def xavier_normal(fan_in, fan_out, gain, **kwargs):
std = gain * math.sqrt(2 / (fan_in + fan_out))
return randn(fan_in, fan_out, mean=0, std=std, **kwargs)