diff --git a/deer/fsolve_ivp.py b/deer/fsolve_ivp.py index 860ab49..7b66973 100644 --- a/deer/fsolve_ivp.py +++ b/deer/fsolve_ivp.py @@ -1,9 +1,10 @@ from abc import abstractmethod from typing import Any, Callable, List, Optional, Tuple +import jax import jax.numpy as jnp from deer.deer_iter import deer_iteration from deer.maths import matmul_recursive -from deer.utils import get_method_meta, check_method, Result +from deer.utils import get_method_meta, check_method, Result, while_loop_scan __all__ = ["solve_ivp"] @@ -173,3 +174,237 @@ def solve_ivp_inv_lin(self, gmat: List[jnp.ndarray], rhs: jnp.ndarray, # compute the recursive matrix multiplication yt = matmul_recursive(gtbar, htbar, y0) # (nt, ny) return yt + + +class GeneralODE(SolveIVPMethod): + """ + Compute the solution of initial value problem with the ODE methods. + """ + + def compute(self, func: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray], + y0: jnp.ndarray, xinp: jnp.ndarray, params: Any, tpts: jnp.ndarray): + + # Precompute dx and dt + dx = jnp.diff(xinp) + dt = jnp.diff(tpts) + + # Define the main function to cover all time intervals + def scan_step(carry, inputs): + yi, success = carry + xi, dx_i, dt_i = inputs + def success_fn(carry): + yi, _ = carry + yn, success = self.ode_step(func, yi, xi, dt_i, dx_i, params) + return yn, success + + def failure_fn(carry): + yi, _ = carry + return yi, False + + res = jax.lax.cond(success, success_fn, failure_fn, carry) + return res, res + + # Initial carry + initial_carry = (y0, True) + + # Pack parameters for scan + scan_inputs = (xinp[:-1], dx, dt) + + # Use jax.lax.scan to iterate over time points + _, (y, success) = jax.lax.scan(scan_step, initial_carry, scan_inputs) + + y = jnp.concatenate([y0[None, :], y], axis=0) + success = jnp.concatenate((jnp.full_like(success[:1], True, dtype=jnp.bool), success), axis=0) + + return Result(y, success[:, None]) + + +# Note that `self.ode_step` should be adjusted to properly work with jax if it isn't already. + + + @abstractmethod + def ode_step( + self, + func: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray], + yi: jnp.ndarray, + xi: jnp.ndarray, + dt: jnp.ndarray, + dx: jnp.ndarray, + params: Any + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + """ + Solve a single step of the ODE using the specific method. + + Parameters + ---------- + func : Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray] + The function defining the differential equation dy/dt = f(y, x, params) + yi : jnp.ndarray + The state of the system at the beginning of the time step. + xi : jnp.ndarray + The input at the beginning of the time step. + dt : jnp.ndarray + The size of the time step. + dx : jnp.ndarray + The change in input over the time step. + params : Any + Additional parameters for the differential equation. + + Returns + ------- + jnp.ndarray + The state of the system at the end of the time step. + """ + pass + +class ForwardEuler(GeneralODE): + """ + Compute the solution of initial value problem with the Forward Euler method. + """ + def ode_step( + self, + func: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray], + yi: jnp.ndarray, + xi: jnp.ndarray, + dt: jnp.ndarray, + dx: jnp.ndarray, + params: Any + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + k = func(yi, xi, params) + yi_new = yi + dt * k + return yi_new, True + + +class RK3(GeneralODE): + """ + Compute the solution of initial value problem with the Runge-Kutta 3rd order method. + """ + + def ode_step( + self, + func: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray], + yi: jnp.ndarray, + xi: jnp.ndarray, + dt: jnp.ndarray, + dx: jnp.ndarray, + params: Any + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + k1 = func(yi, xi, params) + k2 = func(yi + dt * k1, xi + dx, params) + k3 = func(yi + 0.25 * dt * (k1 + k2), xi + 0.5 * dx, params) + + yi = yi + (dt / 6.0) * (k1 + k2 + 4 * k3) + return yi, True + + +class RK4(GeneralODE): + """ + Compute the solution of initial value problem with the Runge-Kutta 4th order method. + """ + + def ode_step( + self, + func: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray], + yi: jnp.ndarray, + xi: jnp.ndarray, + dt: jnp.ndarray, + dx: jnp.ndarray, + params: Any + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + k1 = func(yi, xi, params) + k2 = func(yi + 0.5 * dt * k1, xi + 0.5 * dx, params) + k3 = func(yi + 0.5 * dt * k2, xi + 0.5 * dx, params) + k4 = func(yi + dt * k3, xi + dx, params) + + yi = yi + (dt / 6.0) * (k1 + 2*k2 + 2*k3 + k4) + return yi, True + + +class GeneralBackwardODE(GeneralODE): + """ + Compute the solution of initial value problem with Backward ODE methods. + + Arguments + --------- + tol: float + The tolerance for the fixed-point iteration. + max_iter: int + The maximum number of iterations for the fixed-point iteration. + """ + def __init__(self, tol: float = 1e-6, max_iter: int = 100): + self.tol = tol + self.max_iter = max_iter + + def ode_step( + self, + func: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray], + yi: jnp.ndarray, + xi: jnp.ndarray, + dt: jnp.ndarray, + dx: jnp.ndarray, + params: Any + ) -> jnp.ndarray: + def body_fn(state): + y_new, _ = state + y_new_next = self.backward_iteration_step(func, y_new, yi, xi, dt, dx, params) + return y_new_next, jnp.linalg.norm(y_new_next - y_new) + + def cond_fn(state): + _, diff_norm = state + return diff_norm >= self.tol + + # Initial guess for fixed-point iteration (use forward Euler as an initial guess) + y_new = yi + dt * func(yi, xi, params) + + # Run the fixed-point iteration loop using lax.while_loop + state, _ = while_loop_scan(cond_fn, body_fn, (y_new, jnp.inf), self.max_iter) + y_new, diff_norm = state + return y_new, diff_norm < self.tol + + # @abstractmethod + def backward_iteration_step(self, + func: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray], + y_new: jnp.ndarray, + yi: jnp.ndarray, + xi: jnp.ndarray, + dt: jnp.ndarray, + dx: jnp.ndarray, + params: Any) -> jnp.ndarray: + """ + The body function for the fixed-point iteration. Return the next value of yi. + + Arguments: + func: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray] + The function defining the differential equation dy/dt = f(y, x, params). + y_new: jnp.ndarray + The current value of yi. + yi: jnp.ndarray + The initial value of yi. + xi: jnp.ndarray + The input signal at the current time step. + dt: jnp.ndarray + The time step size. + dx: jnp.ndarray + The change in input signal. + params: Any + The parameters of the function. + """ + pass + + +class BackwardEuler(GeneralBackwardODE): + """ + Compute the solution of initial value problem with the Backward Euler method. + """ + + def backward_iteration_step(self, func, y_new, yi, xi, dt, dx, params) -> jnp.ndarray: + return yi + dt * func(y_new, xi, params) + + +class TrapezoidalMethod(GeneralBackwardODE): + """ + Compute the solution of initial value problem with the Trapezoidal method. + """ + + def backward_iteration_step(self, func, y_new, yi, xi, dt, dx, params) -> jnp.ndarray: + return yi + (dt / 2.0) * (func(yi, xi, params) + func(y_new, xi + dx, params)) diff --git a/deer/tests/test_deer.py b/deer/tests/test_deer.py index 4875503..939a576 100644 --- a/deer/tests/test_deer.py +++ b/deer/tests/test_deer.py @@ -413,5 +413,38 @@ def dae_pendulum(vrdot: jnp.ndarray, vr: jnp.ndarray, t: jnp.ndarray, params) -> f4 = x ** 2 + y ** 2 - 1 # index-3 return jnp.concatenate([f0, f1, f2, f3, f4]) + +@pytest.mark.parametrize("method, atol", [ + (solve_ivp.RK3(), 1e-6), + (solve_ivp.RK4(), 1e-6), + (solve_ivp.ForwardEuler(), 1e-2), # Forward Euler has a larger error + (solve_ivp.BackwardEuler(), 1e-2), # Backward Euler has a larger error + (solve_ivp.TrapezoidalMethod(), 1e-6) +]) +def test_ODEs(method, atol: float): + def sample_func(y, x, params): + return -params * y * x + + # Initialize parameters + y0 = jnp.array([1.0, 0.0, -1.0]) + params = -2 + tpts = jnp.linspace(0, 0.5, 1000) + xinp = jnp.sin(tpts * 2 * jnp.pi) + 1e-3 + + # Compute solution using the parameterized method + yt_method = solve_ivp(sample_func, y0, xinp, params, tpts, method=method) + + # Compute reference solution using DEER method + yt_deer = solve_ivp(sample_func, y0, xinp, params, tpts, method=solve_ivp.DEER()) + + # Compare results + assert jnp.allclose(yt_deer.value, yt_method.value, atol=atol) + + # Check the gradients + def solve_ivp_wrapper(y0): + return solve_ivp(sample_func, y0, xinp, params, tpts, method=method).value + + jax.test_util.check_grads(solve_ivp_wrapper, (y0,), order=1, modes=['fwd', 'rev']) + if __name__ == "__main__": test_solve_idae()