-
Notifications
You must be signed in to change notification settings - Fork 3
Description
The simple MVE NODE shown below produces the following error,
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/media/adam/shared_drive/PycharmProjects/deer_test/deer_ode_solver.py", line 161, in
main()
File "/media/adam/shared_drive/PycharmProjects/deer_test/deer_ode_solver.py", line 142, in main
loss, model, opt_state = make_step(_ts, yi, model, opt_state)
File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/equinox/_jit.py", line 239, in call
return self._call(False, args, kwargs)
File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/equinox/_module.py", line 1093, in call
return self.func(self.self, *args, **kwargs)
File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/equinox/_jit.py", line 212, in _call
out = self._cached(dynamic_donate, dynamic_nodonate, static)
File "/media/adam/shared_drive/PycharmProjects/deer_test/deer_ode_solver.py", line 128, in make_step
loss, grads = grad_loss(model, ti, yi)
File "/media/adam/shared_drive/PycharmProjects/deer_test/deer_ode_solver.py", line 123, in grad_loss
y_pred = jax.vmap(model, in_axes=(None, 0))(ti, yi[:, 0])
File "/media/adam/shared_drive/PycharmProjects/deer_test/deer_ode_solver.py", line 53, in call
res = solve_ivp(self.func, y0, tpts[..., None], None, tpts, method=solve_ivp.DEER())
File "/home/adam/Downloads/deer-mfk/deer/fsolve_ivp.py", line 80, in solve_ivp
return method.compute(func, y0, xinp, params, tpts)
File "/home/adam/Downloads/deer-mfk/deer/fsolve_ivp.py", line 127, in compute
result = deer_iteration(jax._src.interpreters.ad.CustomJVPException: Detected differentiation of a custom_jvp function with respect to a closed-over value. That isn't supported because the custom JVP rule only specifies how to differentiate the custom_jvp function with respect to explicit input parameters. Try passing the closed-over value into the custom_jvp function as an argument, and adapting the custom_jvp rule.
import time
import diffrax
import equinox as eqx
import jax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jr
import jax.test_util
import matplotlib.pyplot as plt
import optax
from deer import solve_ivp
# enable jax x64 for this test
jax.config.update("jax_enable_x64", True)
dtype = jnp.float64
npts = 10
class Func(eqx.Module):
mlp: eqx.nn.MLP
def __init__(self, data_size, width_size, depth, *, key, **kwargs):
super().__init__(**kwargs)
self.mlp = eqx.nn.MLP(
in_size=data_size+1,
out_size=data_size,
width_size=width_size,
depth=depth,
activation=jnn.softplus,
key=key,
)
def __call__(self, y, t, args=None):
# concatenate the t and the y
y = jnp.concatenate([y, jnp.full((1,), t)], axis=-1)
return self.mlp(y)
class NeuralODE(eqx.Module):
func: Func
def __init__(self, data_size, width_size, depth, *, key, **kwargs):
super().__init__(**kwargs)
self.func = Func(data_size, width_size, depth, key=key)
def __call__(self, ts, y0):
tpts = jnp.linspace(0, 1.0, npts, dtype=dtype) # (ntpts,)
res = solve_ivp(self.func, y0, tpts[..., None], None, tpts, method=solve_ivp.DEER())
return res.value
def _get_data(ts, *, key):
y0 = jr.uniform(key, (2,), minval=-0.6, maxval=1)
def f(t, y, args):
x = y / (1 + y)
return jnp.stack([x[1], -x[0]], axis=-1)
solver = diffrax.Tsit5()
dt0 = 0.1
saveat = diffrax.SaveAt(ts=ts)
sol = diffrax.diffeqsolve(
diffrax.ODETerm(f), solver, ts[0], ts[-1], dt0, y0, saveat=saveat
)
ys = sol.ys
return ys
def get_data(dataset_size, *, key):
ts = jnp.linspace(0, 10, 100)
key = jr.split(key, dataset_size)
ys = jax.vmap(lambda key: _get_data(ts, key=key))(key)
return ts, ys
def dataloader(arrays, batch_size, *, key):
dataset_size = arrays[0].shape[0]
assert all(array.shape[0] == dataset_size for array in arrays)
indices = jnp.arange(dataset_size)
while True:
perm = jr.permutation(key, indices)
(key,) = jr.split(key, 1)
start = 0
end = batch_size
while end < dataset_size:
batch_perm = perm[start:end]
yield tuple(array[batch_perm] for array in arrays)
start = end
end = start + batch_size
def main(
dataset_size=256,
batch_size=32,
lr_strategy=(3e-3, 3e-3),
steps_strategy=(500, 500),
length_strategy=(0.1, 1),
width_size=64,
depth=2,
seed=5678,
plot=True,
print_every=100,
):
key = jr.PRNGKey(seed)
data_key, model_key, loader_key = jr.split(key, 3)
ts, ys = get_data(dataset_size, key=data_key)
_, length_size, data_size = ys.shape
model = NeuralODE(data_size, width_size, depth, key=model_key)
@eqx.filter_value_and_grad
def grad_loss(model, ti, yi):
y_pred = jax.vmap(model, in_axes=(None, 0))(ti, yi[:, 0])
return jnp.mean((yi - y_pred) ** 2)
@eqx.filter_jit
def make_step(ti, yi, model, opt_state):
loss, grads = grad_loss(model, ti, yi)
updates, opt_state = optim.update(grads, opt_state)
model = eqx.apply_updates(model, updates)
return loss, model, opt_state
for lr, steps, length in zip(lr_strategy, steps_strategy, length_strategy):
optim = optax.adabelief(lr)
opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
_ts = ts[: int(length_size * length)]
_ys = ys[:, : int(length_size * length)]
for step, (yi,) in zip(
range(steps), dataloader((_ys,), batch_size, key=loader_key)
):
start = time.time()
loss, model, opt_state = make_step(_ts, yi, model, opt_state)
end = time.time()
if (step % print_every) == 0 or step == steps - 1:
print(f"Step: {step}, Loss: {loss}, Computation time: {end - start}")
if plot:
plt.plot(ts, ys[0, :, 0], c="dodgerblue", label="Real")
plt.plot(ts, ys[0, :, 1], c="dodgerblue")
model_y = model(ts, ys[0, 0])
plt.plot(ts, model_y[:, 0], c="crimson", label="Model")
plt.plot(ts, model_y[:, 1], c="crimson")
plt.legend()
plt.tight_layout()
plt.savefig("neural_ode.png")
plt.show()
return ts, ys, model
if __name__ == "__main__":
main()