-
Notifications
You must be signed in to change notification settings - Fork 3
Open
Description
An example problem that can be used for large number of dimensions (e.g., by varying the multiplier term below):
import jax
import jax.numpy as jnp
import deer
jax.config.update("jax_enable_x64", True)
def test_solve_idae():
method = deer.solve_idae.BwdEulerDEER
multiplier = 5 # can be varied to see how the number of dimension affects the memory usage or speed
gval = 10.0
theta = jnp.ones(multiplier) * jnp.pi / 2
x0 = jnp.sin(theta)
y0 = -jnp.cos(theta)
u0 = jnp.zeros(multiplier)
v0 = jnp.zeros(multiplier)
T0 = -gval * y0
vr0 = jnp.concatenate([x0, y0, u0, v0, T0], axis=-1)
g = jnp.array(gval)
params = g
npts = 10000
tpts = jnp.linspace(0, 2.0, npts) # (ntpts,)
res = deer.solve_idae(dae_pendulum, vr0, tpts[..., None], params, tpts, method=method) # (ntpts, ny)
vrt = res.value
return vrt
def dae_pendulum(vrdot: jnp.ndarray, vr: jnp.ndarray, t: jnp.ndarray, params) -> jnp.ndarray:
# pendulum problem:
# x', y' = u, v
# u' = -lambda * x
# v' = -lambda * y - g
# 0 = x ** 2 + y ** 2 - 1
# vrdot, vr: both (5 * multiplier,)
# x: (1,) is time
# params: a tuple (g,) where g has shape of ()
# returns: (5 * multiplier,)
g = params
x, y, u, v, T = jnp.split(vr, 5) # each: (multiplier,)
xdot, ydot, udot, vdot, Tdot = jnp.split(vrdot, 5)
f0 = xdot - u
f1 = ydot - v
f2 = udot + T * x
f3 = vdot + T * y + g
f4 = x ** 2 + y ** 2 - 1 # index-3
return jnp.concatenate([f0, f1, f2, f3, f4])Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels