Skip to content

Investigate where the most memory is consumed for DEER iterations (jitted) with large number of dimensions #29

@mfkasim1

Description

@mfkasim1

An example problem that can be used for large number of dimensions (e.g., by varying the multiplier term below):

import jax
import jax.numpy as jnp
import deer


jax.config.update("jax_enable_x64", True)

def test_solve_idae():
    method = deer.solve_idae.BwdEulerDEER
    multiplier = 5  # can be varied to see how the number of dimension affects the memory usage or speed

    gval = 10.0
    theta = jnp.ones(multiplier) * jnp.pi / 2
    x0 = jnp.sin(theta)
    y0 = -jnp.cos(theta)
    u0 = jnp.zeros(multiplier)
    v0 = jnp.zeros(multiplier)
    T0 = -gval * y0
    vr0 = jnp.concatenate([x0, y0, u0, v0, T0], axis=-1)

    g = jnp.array(gval)
    params = g
    npts = 10000
    tpts = jnp.linspace(0, 2.0, npts)  # (ntpts,)
    res = deer.solve_idae(dae_pendulum, vr0, tpts[..., None], params, tpts, method=method)  # (ntpts, ny)
    vrt = res.value
    return vrt

def dae_pendulum(vrdot: jnp.ndarray, vr: jnp.ndarray, t: jnp.ndarray, params) -> jnp.ndarray:
    # pendulum problem:
    # x', y' = u, v
    # u' = -lambda * x
    # v' = -lambda * y - g
    # 0 = x ** 2 + y ** 2 - 1
    # vrdot, vr: both (5 * multiplier,)
    # x: (1,) is time
    # params: a tuple (g,) where g has shape of ()
    # returns: (5 * multiplier,)
    g = params
    x, y, u, v, T = jnp.split(vr, 5)  # each: (multiplier,)
    xdot, ydot, udot, vdot, Tdot = jnp.split(vrdot, 5)
    f0 = xdot - u
    f1 = ydot - v
    f2 = udot + T * x
    f3 = vdot + T * y + g
    f4 = x ** 2 + y ** 2 - 1  # index-3
    return jnp.concatenate([f0, f1, f2, f3, f4])

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions