Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
237 changes: 236 additions & 1 deletion deer/fsolve_ivp.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from abc import abstractmethod
from typing import Any, Callable, List, Optional, Tuple
import jax
import jax.numpy as jnp
from deer.deer_iter import deer_iteration
from deer.maths import matmul_recursive
from deer.utils import get_method_meta, check_method, Result
from deer.utils import get_method_meta, check_method, Result, while_loop_scan


__all__ = ["solve_ivp"]
Expand Down Expand Up @@ -173,3 +174,237 @@ def solve_ivp_inv_lin(self, gmat: List[jnp.ndarray], rhs: jnp.ndarray,
# compute the recursive matrix multiplication
yt = matmul_recursive(gtbar, htbar, y0) # (nt, ny)
return yt


class GeneralODE(SolveIVPMethod):
"""
Compute the solution of initial value problem with the ODE methods.
"""

def compute(self, func: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray],
y0: jnp.ndarray, xinp: jnp.ndarray, params: Any, tpts: jnp.ndarray):

# Precompute dx and dt
dx = jnp.diff(xinp)
dt = jnp.diff(tpts)

# Define the main function to cover all time intervals
def scan_step(carry, inputs):
yi, success = carry
xi, dx_i, dt_i = inputs
def success_fn(carry):
yi, _ = carry
yn, success = self.ode_step(func, yi, xi, dt_i, dx_i, params)
return yn, success

def failure_fn(carry):
yi, _ = carry
return yi, False

res = jax.lax.cond(success, success_fn, failure_fn, carry)
return res, res

# Initial carry
initial_carry = (y0, True)

# Pack parameters for scan
scan_inputs = (xinp[:-1], dx, dt)

# Use jax.lax.scan to iterate over time points
_, (y, success) = jax.lax.scan(scan_step, initial_carry, scan_inputs)

y = jnp.concatenate([y0[None, :], y], axis=0)
success = jnp.concatenate((jnp.full_like(success[:1], True, dtype=jnp.bool), success), axis=0)

return Result(y, success[:, None])


# Note that `self.ode_step` should be adjusted to properly work with jax if it isn't already.


@abstractmethod
def ode_step(
self,
func: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray],
yi: jnp.ndarray,
xi: jnp.ndarray,
dt: jnp.ndarray,
dx: jnp.ndarray,
params: Any
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""
Solve a single step of the ODE using the specific method.

Parameters
----------
func : Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray]
The function defining the differential equation dy/dt = f(y, x, params)
yi : jnp.ndarray
The state of the system at the beginning of the time step.
xi : jnp.ndarray
The input at the beginning of the time step.
dt : jnp.ndarray
The size of the time step.
dx : jnp.ndarray
The change in input over the time step.
params : Any
Additional parameters for the differential equation.

Returns
-------
jnp.ndarray
The state of the system at the end of the time step.
"""
pass

class ForwardEuler(GeneralODE):
"""
Compute the solution of initial value problem with the Forward Euler method.
"""
def ode_step(
self,
func: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray],
yi: jnp.ndarray,
xi: jnp.ndarray,
dt: jnp.ndarray,
dx: jnp.ndarray,
params: Any
) -> Tuple[jnp.ndarray, jnp.ndarray]:
k = func(yi, xi, params)
yi_new = yi + dt * k
return yi_new, True


class RK3(GeneralODE):
"""
Compute the solution of initial value problem with the Runge-Kutta 3rd order method.
"""

def ode_step(
self,
func: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray],
yi: jnp.ndarray,
xi: jnp.ndarray,
dt: jnp.ndarray,
dx: jnp.ndarray,
params: Any
) -> Tuple[jnp.ndarray, jnp.ndarray]:
k1 = func(yi, xi, params)
k2 = func(yi + dt * k1, xi + dx, params)
k3 = func(yi + 0.25 * dt * (k1 + k2), xi + 0.5 * dx, params)

yi = yi + (dt / 6.0) * (k1 + k2 + 4 * k3)
return yi, True


class RK4(GeneralODE):
"""
Compute the solution of initial value problem with the Runge-Kutta 4th order method.
"""

def ode_step(
self,
func: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray],
yi: jnp.ndarray,
xi: jnp.ndarray,
dt: jnp.ndarray,
dx: jnp.ndarray,
params: Any
) -> Tuple[jnp.ndarray, jnp.ndarray]:
k1 = func(yi, xi, params)
k2 = func(yi + 0.5 * dt * k1, xi + 0.5 * dx, params)
k3 = func(yi + 0.5 * dt * k2, xi + 0.5 * dx, params)
k4 = func(yi + dt * k3, xi + dx, params)

yi = yi + (dt / 6.0) * (k1 + 2*k2 + 2*k3 + k4)
return yi, True


class GeneralBackwardODE(GeneralODE):
"""
Compute the solution of initial value problem with Backward ODE methods.

Arguments
---------
tol: float
The tolerance for the fixed-point iteration.
max_iter: int
The maximum number of iterations for the fixed-point iteration.
"""
def __init__(self, tol: float = 1e-6, max_iter: int = 100):
self.tol = tol
self.max_iter = max_iter

def ode_step(
self,
func: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray],
yi: jnp.ndarray,
xi: jnp.ndarray,
dt: jnp.ndarray,
dx: jnp.ndarray,
params: Any
) -> jnp.ndarray:
def body_fn(state):
y_new, _ = state
y_new_next = self.backward_iteration_step(func, y_new, yi, xi, dt, dx, params)
return y_new_next, jnp.linalg.norm(y_new_next - y_new)

def cond_fn(state):
_, diff_norm = state
return diff_norm >= self.tol

# Initial guess for fixed-point iteration (use forward Euler as an initial guess)
y_new = yi + dt * func(yi, xi, params)

# Run the fixed-point iteration loop using lax.while_loop
state, _ = while_loop_scan(cond_fn, body_fn, (y_new, jnp.inf), self.max_iter)
y_new, diff_norm = state
return y_new, diff_norm < self.tol

# @abstractmethod
def backward_iteration_step(self,
func: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray],
y_new: jnp.ndarray,
yi: jnp.ndarray,
xi: jnp.ndarray,
dt: jnp.ndarray,
dx: jnp.ndarray,
params: Any) -> jnp.ndarray:
"""
The body function for the fixed-point iteration. Return the next value of yi.

Arguments:
func: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray]
The function defining the differential equation dy/dt = f(y, x, params).
y_new: jnp.ndarray
The current value of yi.
yi: jnp.ndarray
The initial value of yi.
xi: jnp.ndarray
The input signal at the current time step.
dt: jnp.ndarray
The time step size.
dx: jnp.ndarray
The change in input signal.
params: Any
The parameters of the function.
"""
pass


class BackwardEuler(GeneralBackwardODE):
"""
Compute the solution of initial value problem with the Backward Euler method.
"""

def backward_iteration_step(self, func, y_new, yi, xi, dt, dx, params) -> jnp.ndarray:
return yi + dt * func(y_new, xi, params)


class TrapezoidalMethod(GeneralBackwardODE):
"""
Compute the solution of initial value problem with the Trapezoidal method.
"""

def backward_iteration_step(self, func, y_new, yi, xi, dt, dx, params) -> jnp.ndarray:
return yi + (dt / 2.0) * (func(yi, xi, params) + func(y_new, xi + dx, params))
33 changes: 33 additions & 0 deletions deer/tests/test_deer.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,5 +413,38 @@ def dae_pendulum(vrdot: jnp.ndarray, vr: jnp.ndarray, t: jnp.ndarray, params) ->
f4 = x ** 2 + y ** 2 - 1 # index-3
return jnp.concatenate([f0, f1, f2, f3, f4])


@pytest.mark.parametrize("method, atol", [
(solve_ivp.RK3(), 1e-6),
(solve_ivp.RK4(), 1e-6),
(solve_ivp.ForwardEuler(), 1e-2), # Forward Euler has a larger error
(solve_ivp.BackwardEuler(), 1e-2), # Backward Euler has a larger error
(solve_ivp.TrapezoidalMethod(), 1e-6)
])
def test_ODEs(method, atol: float):
def sample_func(y, x, params):
return -params * y * x

# Initialize parameters
y0 = jnp.array([1.0, 0.0, -1.0])
params = -2
tpts = jnp.linspace(0, 0.5, 1000)
xinp = jnp.sin(tpts * 2 * jnp.pi) + 1e-3

# Compute solution using the parameterized method
yt_method = solve_ivp(sample_func, y0, xinp, params, tpts, method=method)

# Compute reference solution using DEER method
yt_deer = solve_ivp(sample_func, y0, xinp, params, tpts, method=solve_ivp.DEER())

# Compare results
assert jnp.allclose(yt_deer.value, yt_method.value, atol=atol)

# Check the gradients
def solve_ivp_wrapper(y0):
return solve_ivp(sample_func, y0, xinp, params, tpts, method=method).value

jax.test_util.check_grads(solve_ivp_wrapper, (y0,), order=1, modes=['fwd', 'rev'])

if __name__ == "__main__":
test_solve_idae()