-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathiamtrask_recurrentNN_tutorial
More file actions
140 lines (105 loc) · 4.07 KB
/
iamtrask_recurrentNN_tutorial
File metadata and controls
140 lines (105 loc) · 4.07 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
#link to tutorial: https://iamtrask.github.io/2015/11/15/anyone-can-code-lstm/
import copy, numpy as np
np.random.seed(0)
# compute sigmoid nonlinearity
def sigmoid(x):
output = 1/(1+np.exp(-x))
return output
# convert output of sigmoid function to its derivative
def sigmoid_output_to_derivative(output):
return output*(1-output)
def leaky_relu(x):
if x >= 0:
return x
else:
return x / 100.0
def relu_output_to_derivative(output):
if output >= 0:
return 1
else:
return 0.01
def int2trinary(intgr):
remainder = intgr
trinary = []
for i in range(trinary_dim,-1,-1):
trinary.append(remainder/pow(3,i))
remainder = remainder%pow(3,i)
return trinary
# training dataset generation
trinary_dim = 5
largest_number = pow(3,trinary_dim)
# input variables
alpha = 0.1
input_dim = 2
hidden_dim = 10
output_dim = 1
# initialize neural network weights
synapse_0 = 2*np.random.random((input_dim,hidden_dim)) - 1
synapse_1 = 2*np.random.random((hidden_dim,output_dim)) - 1
synapse_h = 2*np.random.random((hidden_dim,hidden_dim)) - 1
synapse_0_update = np.zeros_like(synapse_0)
synapse_1_update = np.zeros_like(synapse_1)
synapse_h_update = np.zeros_like(synapse_h)
# training logic
for j in range(70000):
# generate a simple addition problem (a + b = c)
a_int = np.random.randint(largest_number/2) # int version
a = int2trinary(a_int) # binary encoding
b_int = np.random.randint(largest_number/2) # int version
b = int2trinary(b_int) # binary encoding
# true answer
c_int = a_int + b_int
c = int2trinary(c_int)
# where we'll store our best guess (binary encoded)
d = np.zeros_like(c)
overallError = 0
layer_2_deltas = list()
layer_1_values = list()
layer_1_values.append(np.zeros(hidden_dim))
# moving along the positions in the binary encoding
for position in range(trinary_dim):
# generate input and output
X = np.array([[a[trinary_dim - position - 1],b[trinary_dim - position - 1]]])
y = np.array([[c[trinary_dim - position - 1]]]).T
# hidden layer (input ~+ prev_hidden)
layer_1 = sigmoid(np.dot(X,synapse_0) + np.dot(layer_1_values[-1],synapse_h))
# output layer (new binary representation)
layer_2 = leaky_relu(np.dot(layer_1,synapse_1))
# did we miss?... if so, by how much?
layer_2_error = y - layer_2
layer_2_deltas.append((layer_2_error)*relu_output_to_derivative(layer_2))
overallError += np.abs(layer_2_error[0])
# decode estimate so we can print it out
d[trinary_dim - position - 1] = np.round(layer_2[0][0])
# store hidden layer so we can use it in the next timestep
layer_1_values.append(copy.deepcopy(layer_1))
future_layer_1_delta = np.zeros(hidden_dim)
for position in range(trinary_dim):
X = np.array([[a[position],b[position]]])
layer_1 = layer_1_values[-position-1]
prev_layer_1 = layer_1_values[-position-2]
# error at output layer
layer_2_delta = layer_2_deltas[-position-1]
# error at hidden layer
layer_1_delta = (future_layer_1_delta.dot(synapse_h.T) + layer_2_delta.dot(synapse_1.T)) * sigmoid_output_to_derivative(layer_1)
# let's update all our weights so we can try again
synapse_1_update += np.atleast_2d(layer_1).T.dot(layer_2_delta)
synapse_h_update += np.atleast_2d(prev_layer_1).T.dot(layer_1_delta)
synapse_0_update += X.T.dot(layer_1_delta)
future_layer_1_delta = layer_1_delta
synapse_0 += synapse_0_update * alpha
synapse_1 += synapse_1_update * alpha
synapse_h += synapse_h_update * alpha
synapse_0_update *= 0
synapse_1_update *= 0
synapse_h_update *= 0
# print out progress
if(j % 10000 == 0):
print "Error:" + str(overallError)
print "Pred:" + str(d)
print "True:" + str(c)
out = 0
for index,x in enumerate(reversed(d)):
out += x*pow(3,index)
print str(a_int) + " + " + str(b_int) + " = " + str(out)
print "------------"