forked from probml/pyprobml
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathautodiff_demo.py
More file actions
281 lines (221 loc) · 8.94 KB
/
autodiff_demo.py
File metadata and controls
281 lines (221 loc) · 8.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
# Desmonstrate automatic differentiaiton on binary logistic regression
# using JAX, Torch and TF
import numpy as np
from scipy.misc import logsumexp
np.set_printoptions(precision=3)
USE_JAX = True
USE_TORCH = True
USE_TF = True
# We make some wrappers around random number generation
# so it works even if we switch from numpy to JAX
import numpy as onp # original numpy
def set_seed(seed):
onp.random.seed(seed)
def randn(args):
return onp.random.randn(*args)
def randperm(args):
return onp.random.permutation(args)
if USE_TORCH:
import torch
import torchvision
print("torch version {}".format(torch.__version__))
if torch.cuda.is_available():
print(torch.cuda.get_device_name(0))
print("current device {}".format(torch.cuda.current_device()))
else:
print("Torch cannot find GPU")
def set_seed(seed):
onp.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
#torch.backends.cudnn.benchmark = True
if USE_JAX:
import jax
import jax.numpy as np
import numpy as onp
from jax.scipy.special import logsumexp
from jax import grad, hessian, jacfwd, jacrev, jit, vmap
from jax.experimental import optimizers
print("jax version {}".format(jax.__version__))
from jax.lib import xla_bridge
print("jax backend {}".format(xla_bridge.get_backend().platform))
import os
os.environ["XLA_FLAGS"]="--xla_gpu_cuda_data_dir=/home/murphyk/miniconda3/lib"
if USE_TF:
import tensorflow as tf
from tensorflow import keras
print("tf version {}".format(tf.__version__))
if tf.test.is_gpu_available():
print(tf.test.gpu_device_name())
else:
print("TF cannot find GPU")
### Dataset
import sklearn.datasets
from sklearn.model_selection import train_test_split
iris = sklearn.datasets.load_iris()
X = iris["data"]
y = (iris["target"] == 2).astype(onp.int) # 1 if Iris-Virginica, else 0'
N, D = X.shape # 150, 4
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.33, random_state=42)
from sklearn.linear_model import LogisticRegression
# We set C to a large number to turn off regularization.
# We don't fit the bias term to simplify the comparison below.
log_reg = LogisticRegression(solver="lbfgs", C=1e5, fit_intercept=False)
log_reg.fit(X_train, y_train)
w_mle_sklearn = np.ravel(log_reg.coef_)
set_seed(0)
w = w_mle_sklearn
## Compute gradient of loss "by hand" using numpy
def BCE_with_logits(logits, targets):
N = logits.shape[0]
logits = logits.reshape(N,1)
logits_plus = np.hstack([np.zeros((N,1)), logits]) # e^0=1
logits_minus = np.hstack([np.zeros((N,1)), -logits])
logp1 = -logsumexp(logits_minus, axis=1)
logp0 = -logsumexp(logits_plus, axis=1)
logprobs = logp1 * targets + logp0 * (1-targets)
return -np.sum(logprobs)/N
if True:
# Compute using numpy
def sigmoid(x): return 0.5 * (np.tanh(x / 2.) + 1)
def predict_logit(weights, inputs):
return np.dot(inputs, weights) # Already vectorized
def predict_prob(weights, inputs):
return sigmoid(predict_logit(weights, inputs))
def NLL(weights, batch):
X, y = batch
logits = predict_logit(weights, X)
return BCE_with_logits(logits, y)
def NLL_grad(weights, batch):
X, y = batch
N = X.shape[0]
mu = predict_prob(weights, X)
g = np.sum(np.dot(np.diag(mu - y), X), axis=0)/N
return g
y_pred = predict_prob(w, X_test)
loss = NLL(w, (X_test, y_test))
grad_np = NLL_grad(w, (X_test, y_test))
print("params {}".format(w))
#print("pred {}".format(y_pred))
print("loss {}".format(loss))
print("grad {}".format(grad_np))
if USE_JAX:
print("Starting JAX demo")
grad_jax = grad(NLL)(w, (X_test, y_test))
print("grad {}".format(grad_jax))
assert np.allclose(grad_np, grad_jax)
print("Starting STAX demo")
# Stax version
from jax.experimental import stax
def const_init(params):
def init(rng_key, shape):
return params
return init
#net_init, net_apply = stax.serial(stax.Dense(1), stax.elementwise(sigmoid))
dense_layer = stax.Dense(1, W_init=const_init(np.reshape(w, (D,1))),
b_init=const_init(np.array([0.0])))
net_init, net_apply = stax.serial(dense_layer)
rng = jax.random.PRNGKey(0)
in_shape = (-1,D)
out_shape, net_params = net_init(rng, in_shape)
def NLL_model(net_params, net_apply, batch):
X, y = batch
logits = net_apply(net_params, X)
return BCE_with_logits(logits, y)
y_pred2 = net_apply(net_params, X_test)
loss2 = NLL_model(net_params, net_apply, (X_test, y_test))
grad_jax2 = grad(NLL_model)(net_params, net_apply, (X_test, y_test))
grad_jax3 = grad_jax2[0][0] # layer 0, block 0 (weights not bias)
grad_jax4 = grad_jax3[:,0] # column vector
assert np.allclose(grad_np, grad_jax4)
print("params {}".format(net_params))
#print("pred {}".format(y_pred2))
print("loss {}".format(loss2))
print("grad {}".format(grad_jax2))
if USE_TORCH:
import torch
print("Starting torch demo")
w_torch = torch.Tensor(np.reshape(w, [D, 1])).to(device)
w_torch.requires_grad_()
x_test_tensor = torch.Tensor(X_test).to(device)
y_test_tensor = torch.Tensor(y_test).to(device)
y_pred = torch.sigmoid(torch.matmul(x_test_tensor, w_torch))[:,0]
criterion = torch.nn.BCELoss(reduction='mean')
loss_torch = criterion(y_pred, y_test_tensor)
loss_torch.backward()
grad_torch = w_torch.grad[:,0].numpy()
assert np.allclose(grad_np, grad_torch)
print("params {}".format(w_torch))
#print("pred {}".format(y_pred))
print("loss {}".format(loss_torch))
print("grad {}".format(grad_torch))
if USE_TORCH:
print("Starting torch demo: Model version")
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear = torch.nn.Linear(D, 1, bias=False)
def forward(self, x):
y_pred = torch.sigmoid(self.linear(x))
return y_pred
model = Model()
# Manually set parameters to desired values
print(model.state_dict())
from collections import OrderedDict
w1 = torch.Tensor(np.reshape(w, [1, D])).to(device) # row vector
new_state_dict = OrderedDict({'linear.weight': w1})
model.load_state_dict(new_state_dict, strict=False)
#print(model.state_dict())
model.to(device) # make sure new params are on same device as data
criterion = torch.nn.BCELoss(reduction='mean')
y_pred2 = model(x_test_tensor)[:,0]
loss_torch2 = criterion(y_pred2, y_test_tensor)
loss_torch2.backward()
params_torch2 = list(model.parameters())
grad_torch2 = params_torch2[0].grad[0].numpy()
assert np.allclose(grad_np, grad_torch2)
print("params {}".format(w1))
#print("pred {}".format(y_pred))
print("loss {}".format(loss_torch))
print("grad {}".format(grad_torch2))
if USE_TF:
print("Starting TF demo")
w_tf = tf.Variable(np.reshape(w, (D,1)))
x_test_tf = tf.convert_to_tensor(X_test, dtype=np.float64)
y_test_tf = tf.convert_to_tensor(np.reshape(y_test, (-1,1)), dtype=np.float64)
with tf.GradientTape() as tape:
logits = tf.linalg.matmul(x_test_tf, w_tf)
y_pred = tf.math.sigmoid(logits)
loss_batch = tf.nn.sigmoid_cross_entropy_with_logits(y_test_tf, logits)
loss_tf = tf.reduce_mean(loss_batch, axis=0)
grad_tf = tape.gradient(loss_tf, [w_tf])
grad_tf = grad_tf[0][:,0].numpy()
assert np.allclose(grad_np, grad_tf)
print("params {}".format(w_tf))
#print("pred {}".format(y_pred))
print("loss {}".format(loss_tf))
print("grad {}".format(grad_tf))
if USE_TF:
print("Starting TF demo: keras version")
model = tf.keras.models.Sequential([
tf.keras.layers.Dense(1, input_shape=(D,), activation=None, use_bias=False)
])
#model.compile(optimizer='sgd', loss=tf.nn.sigmoid_cross_entropy_with_logits)
model.build()
w_tf2 = tf.convert_to_tensor(np.reshape(w, (D,1)))
model.set_weights([w_tf2])
y_test_tf2 = tf.convert_to_tensor(np.reshape(y_test, (-1,1)), dtype=np.float32)
with tf.GradientTape() as tape:
logits_temp = model.predict(x_test_tf) # forwards pass only
logits2 = model(x_test_tf, training=True) # OO version enables backprop
loss_batch2 = tf.nn.sigmoid_cross_entropy_with_logits(y_test_tf2, logits2)
loss_tf2 = tf.reduce_mean(loss_batch2, axis=0)
grad_tf2 = tape.gradient(loss_tf2, model.trainable_variables)
grad_tf2 = grad_tf2[0][:,0].numpy()
assert np.allclose(grad_np, grad_tf2)
print("params {}".format(w_tf2))
print("loss {}".format(loss_tf2))
print("grad {}".format(grad_tf2))