-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathMLRU.py
More file actions
279 lines (187 loc) · 7.59 KB
/
MLRU.py
File metadata and controls
279 lines (187 loc) · 7.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
# %%
import numpy as np
import jax
import jax.numpy as jnp
from jax import jit, grad, vmap
import pickle as pkl
from alive_progress import alive_bar
parallel_scan = jax.lax.associative_scan
# %% [markdown]
# ## Make a single LRU
# %%
def binary_operator_diag(element_i, element_j):
a_i, bu_i = element_i
a_j, bu_j = element_j
return a_j * a_i, a_j * bu_i + bu_j
def init_lru_parameters(N, H, r_min = 0.0, r_max = 1, max_phase = 6.28):
# N: state dimension, H: model dimension
# Initialization of Lambda is complex valued distributed uniformly on ring
# between r_min and r_max, with phase in [0, max_phase].
u1 = np.random.uniform(size = (N,))
u2 = np.random.uniform(size = (N,))
nu_log = np.log(-0.5*np.log(u1*(r_max**2-r_min**2) + r_min**2))
theta_log = np.log(max_phase*u2)
# Glorot initialized Input/Output projection matrices
B_re = np.random.normal(size=(N,H))/np.sqrt(2*H)
B_im = np.random.normal(size=(N,H))/np.sqrt(2*H)
C_re = np.random.normal(size=(H,N))/np.sqrt(N)
C_im = np.random.normal(size=(H,N))/np.sqrt(N)
D = np.random.normal(size=(H,))
# Normalization
diag_lambda = np.exp(-np.exp(nu_log) + 1j*np.exp(theta_log))
gamma_log = np.log(np.sqrt(1-np.abs(diag_lambda)**2))
return nu_log, theta_log, B_re, B_im, C_re, C_im, D, gamma_log
def forward_LRU(lru_parameters, input_sequence):
# Unpack the LRU parameters
nu_log, theta_log, B_re, B_im, C_re, C_im, D, gamma_log = lru_parameters
# Initialize the hidden state
Lambda = jnp.exp(-jnp.exp(nu_log) + 1j*jnp.exp(theta_log))
B_norm = (B_re + 1j*B_im) * jnp.expand_dims(jnp.exp(gamma_log), axis=-1)
C = C_re + 1j*C_im
Lambda_elements = jnp.repeat(Lambda[None, ...], input_sequence.shape[0], axis=0)
Bu_elements = jax.vmap(lambda u: B_norm @ u)(input_sequence)
elements = (Lambda_elements, Bu_elements)
_, inner_states = parallel_scan(binary_operator_diag, elements) # all x_k
y = jax.vmap(lambda x, u: (C @ x).real + D * u)(inner_states, input_sequence)
return y
def loss_fn(lru_parameters, input_sequence, target_sequence):
y = forward_LRU(lru_parameters, input_sequence)
return jnp.mean((y - target_sequence)**2)
def update(lru_parameters, input_sequence, target_sequence):
return grad(loss_fn)(lru_parameters, input_sequence, target_sequence)
# %%
# Create the MLP encoder
# %% [markdown]
# # Make the MLP model
# %%
def init_mlp_parameters(layers):
# Initialize the MLP parameters
parameters = []
for i in range(len(layers)-1):
W = np.random.normal(size=(layers[i], layers[i+1]))/np.sqrt(layers[i])
b = np.zeros((layers[i+1],))
parameters.append((W, b))
return parameters
@jit
def forward_mlp(mlp_parameters, input, activation_function = jnp.tanh):
# Forward pass of the MLP
x = input
for W, b in mlp_parameters:
x = x @ W + b
x = activation_function(x)
return x
def forward_mlp_linear_layer(mlp_parameters, input, activation_function = jnp.tanh):
x = input
# Only apply the MLP up to the second last layer
for W, b in mlp_parameters[:-1]:
x = x @ W + b
x = activation_function(x)
# Apply the last layer without activation function
W, b = mlp_parameters[-1]
x = x @ W + b
# Use the softmax function on the last layer
x = jax.nn.softmax(x)
return x
# %%
def max_pooling(sequence_to_pool):
return jnp.max(sequence_to_pool, axis=0)
def mean_pooling(sequence_to_pool):
return jnp.mean(sequence_to_pool, axis=0)
def sum_pooling(sequence_to_pool):
return jnp.sum(sequence_to_pool, axis=0)
# %% [markdown]
# # Create the complete model
# %%
# Create the model
Linear_encoder_parameter = init_mlp_parameters([1,3,5])
seconday_parameters = init_mlp_parameters([5,5,5])
LRU = init_lru_parameters(5, 5)
Linear_decoder_parameter = init_mlp_parameters([5,3,2])
model_parameters = [Linear_encoder_parameter, LRU, seconday_parameters, Linear_decoder_parameter]
def model_forward(input_sequence, parameters):
Linear_encoder_parameter, LRU, seconday_parameters, Linear_decoder_parameter = parameters
x = forward_mlp(Linear_encoder_parameter, input_sequence)
x = forward_LRU(LRU, x)
x = forward_mlp(seconday_parameters, x)
x = max_pooling(x)
x = forward_mlp_linear_layer(Linear_decoder_parameter, x)
return x
def one_hot(x, k, dtype=jnp.float32):
"""Create a one-hot encoding of x of size k."""
return jnp.array(x[:, None] == jnp.arange(k), dtype)
def model_loss(input_sequence, target_sequence, parameters):
y = model_forward(input_sequence, parameters)
# Binary cross entropy loss
return -jnp.mean(target_sequence * jnp.log(y) + (1-target_sequence) * jnp.log(1-y))
@jit
def model_grad(input_sequence, target_sequence, parameters):
return grad(model_loss, argnums=2)(input_sequence, target_sequence, parameters)
@jit
def parameter_update(parameters, gradients, learning_rate = 0.01):
new_parameters = []
im = []
for parameter, gradient in zip(parameters[0], gradients[0]):
im.append((parameter[0] - learning_rate * gradient[0], parameter[1] - learning_rate * gradient[1]))
new_parameters.append(im)
im = []
for parameter, gradient in zip(parameters[1], gradients[1]):
im.append(parameter - learning_rate * gradient)
new_parameters.append(im)
im = []
for parameter, gradient in zip(parameters[2], gradients[2]):
im.append((parameter[0] - learning_rate * gradient[0], parameter[1] - learning_rate * gradient[1]))
new_parameters.append(im)
im = []
for parameter, gradient in zip(parameters[3], gradients[3]):
im.append((parameter[0] - learning_rate * gradient[0], parameter[1] - learning_rate * gradient[1]))
new_parameters.append(im)
return new_parameters
# Test batch model forward
batch_model_forward = vmap(model_forward, in_axes=(0, None))
@jit
def accuracy(input_sequences, target_sequences, parameters):
y = batch_model_forward(input_sequences, parameters)
return jnp.mean(jnp.argmax(y, axis=1) == jnp.argmax(target_sequences, axis=1))
# %%
# %%
# Prepare the dataset
waveforms = pkl.load(open("waveforms.pkl", "rb"))
np.concat = np.concatenate
sequences = np.concat((waveforms["waveform1"], waveforms["waveform2"]))
labels = np.concat((np.zeros(500),np.ones(500)))
# Shuffle the dataset
perm = np.random.permutation(1000)
sequences = sequences[perm]
labels = labels[perm]
labels = one_hot(labels, 2)
# %%
print(sequences[:2].reshape((2, len(sequences[0]), 1)).shape)
try:
print(batch_model_forward(sequences[:2].reshape((2, len(sequences[0]), 1)), model_parameters).shape)
except:
print("Model forward failed")
# %%
epochs = 10
batchsize = 1
train_acc = []
test_acc = []
train_sequences = sequences[:800].reshape((800, len(sequences[0]), 1))
train_labels = labels[:800]
test_sequences = sequences[800:].reshape((200, len(sequences[0]), 1))
test_labels = labels[800:]
for i in range(epochs):
print(f"Epoch {i}")
with alive_bar(800) as bar:
for x, y in zip(train_sequences, train_labels):
gradients = model_grad(x, jnp.array([y]), model_parameters)
model_parameters = parameter_update(model_parameters, gradients)
bar()
train_acc.append(accuracy(train_sequences, train_labels, model_parameters))
test_acc.append(accuracy(test_sequences, test_labels, model_parameters))
print(f"Train Accuracy: {train_acc[i]}")
print(f"Test Accuracy: {test_acc[i]}")
# %%
x = sequences[0].reshape((len(sequences[0]), 1))
y = jnp.array(labels[0])
value = model_grad(x, y, model_parameters)
model_forward(x, model_parameters)