Skip to content

NODE Fails with "Detected differentiation of a custom_jvp function with respect to a closed-over value" #41

@adam-hartshorne

Description

@adam-hartshorne

The simple MVE NODE shown below produces the following error,

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
File "/media/adam/shared_drive/PycharmProjects/deer_test/deer_ode_solver.py", line 161, in
main()
File "/media/adam/shared_drive/PycharmProjects/deer_test/deer_ode_solver.py", line 142, in main
loss, model, opt_state = make_step(_ts, yi, model, opt_state)
File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/equinox/_jit.py", line 239, in call
return self._call(False, args, kwargs)
File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/equinox/_module.py", line 1093, in call
return self.func(self.self, *args, **kwargs)
File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/equinox/_jit.py", line 212, in _call
out = self._cached(dynamic_donate, dynamic_nodonate, static)
File "/media/adam/shared_drive/PycharmProjects/deer_test/deer_ode_solver.py", line 128, in make_step
loss, grads = grad_loss(model, ti, yi)
File "/media/adam/shared_drive/PycharmProjects/deer_test/deer_ode_solver.py", line 123, in grad_loss
y_pred = jax.vmap(model, in_axes=(None, 0))(ti, yi[:, 0])
File "/media/adam/shared_drive/PycharmProjects/deer_test/deer_ode_solver.py", line 53, in call
res = solve_ivp(self.func, y0, tpts[..., None], None, tpts, method=solve_ivp.DEER())
File "/home/adam/Downloads/deer-mfk/deer/fsolve_ivp.py", line 80, in solve_ivp
return method.compute(func, y0, xinp, params, tpts)
File "/home/adam/Downloads/deer-mfk/deer/fsolve_ivp.py", line 127, in compute
result = deer_iteration(

jax._src.interpreters.ad.CustomJVPException: Detected differentiation of a custom_jvp function with respect to a closed-over value. That isn't supported because the custom JVP rule only specifies how to differentiate the custom_jvp function with respect to explicit input parameters. Try passing the closed-over value into the custom_jvp function as an argument, and adapting the custom_jvp rule.

import time
import diffrax
import equinox as eqx
import jax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jr
import jax.test_util
import matplotlib.pyplot as plt
import optax
from deer import solve_ivp

# enable jax x64 for this test
jax.config.update("jax_enable_x64", True)

dtype = jnp.float64
npts = 10

class Func(eqx.Module):
    mlp: eqx.nn.MLP

    def __init__(self, data_size, width_size, depth, *, key, **kwargs):
        super().__init__(**kwargs)
        self.mlp = eqx.nn.MLP(
            in_size=data_size+1,
            out_size=data_size,
            width_size=width_size,
            depth=depth,
            activation=jnn.softplus,
            key=key,
        )

    def __call__(self, y, t, args=None):
        # concatenate the t and the y
        y = jnp.concatenate([y, jnp.full((1,), t)], axis=-1)
        return self.mlp(y)

class NeuralODE(eqx.Module):
    func: Func
    def __init__(self, data_size, width_size, depth, *, key, **kwargs):
        super().__init__(**kwargs)
        self.func = Func(data_size, width_size, depth, key=key)

    def __call__(self, ts, y0):
        tpts = jnp.linspace(0, 1.0, npts, dtype=dtype)  # (ntpts,)
        res = solve_ivp(self.func, y0, tpts[..., None], None, tpts, method=solve_ivp.DEER())
        return res.value

def _get_data(ts, *, key):
    y0 = jr.uniform(key, (2,), minval=-0.6, maxval=1)

    def f(t, y, args):
        x = y / (1 + y)
        return jnp.stack([x[1], -x[0]], axis=-1)

    solver = diffrax.Tsit5()
    dt0 = 0.1
    saveat = diffrax.SaveAt(ts=ts)
    sol = diffrax.diffeqsolve(
        diffrax.ODETerm(f), solver, ts[0], ts[-1], dt0, y0, saveat=saveat
    )
    ys = sol.ys
    return ys


def get_data(dataset_size, *, key):
    ts = jnp.linspace(0, 10, 100)
    key = jr.split(key, dataset_size)
    ys = jax.vmap(lambda key: _get_data(ts, key=key))(key)
    return ts, ys

def dataloader(arrays, batch_size, *, key):
    dataset_size = arrays[0].shape[0]
    assert all(array.shape[0] == dataset_size for array in arrays)
    indices = jnp.arange(dataset_size)
    while True:
        perm = jr.permutation(key, indices)
        (key,) = jr.split(key, 1)
        start = 0
        end = batch_size
        while end < dataset_size:
            batch_perm = perm[start:end]
            yield tuple(array[batch_perm] for array in arrays)
            start = end
            end = start + batch_size


def main(
    dataset_size=256,
    batch_size=32,
    lr_strategy=(3e-3, 3e-3),
    steps_strategy=(500, 500),
    length_strategy=(0.1, 1),
    width_size=64,
    depth=2,
    seed=5678,
    plot=True,
    print_every=100,
):
    key = jr.PRNGKey(seed)
    data_key, model_key, loader_key = jr.split(key, 3)

    ts, ys = get_data(dataset_size, key=data_key)
    _, length_size, data_size = ys.shape

    model = NeuralODE(data_size, width_size, depth, key=model_key)


    @eqx.filter_value_and_grad
    def grad_loss(model, ti, yi):
        y_pred = jax.vmap(model, in_axes=(None, 0))(ti, yi[:, 0])
        return jnp.mean((yi - y_pred) ** 2)

    @eqx.filter_jit
    def make_step(ti, yi, model, opt_state):
        loss, grads = grad_loss(model, ti, yi)
        updates, opt_state = optim.update(grads, opt_state)
        model = eqx.apply_updates(model, updates)
        return loss, model, opt_state

    for lr, steps, length in zip(lr_strategy, steps_strategy, length_strategy):
        optim = optax.adabelief(lr)
        opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
        _ts = ts[: int(length_size * length)]
        _ys = ys[:, : int(length_size * length)]
        for step, (yi,) in zip(
            range(steps), dataloader((_ys,), batch_size, key=loader_key)
        ):
            start = time.time()
            loss, model, opt_state = make_step(_ts, yi, model, opt_state)
            end = time.time()
            if (step % print_every) == 0 or step == steps - 1:
                print(f"Step: {step}, Loss: {loss}, Computation time: {end - start}")

    if plot:
        plt.plot(ts, ys[0, :, 0], c="dodgerblue", label="Real")
        plt.plot(ts, ys[0, :, 1], c="dodgerblue")
        model_y = model(ts, ys[0, 0])
        plt.plot(ts, model_y[:, 0], c="crimson", label="Model")
        plt.plot(ts, model_y[:, 1], c="crimson")
        plt.legend()
        plt.tight_layout()
        plt.savefig("neural_ode.png")
        plt.show()

    return ts, ys, model

if __name__ == "__main__":
    main()

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions