From f664dfcd45d3ca2ad78604ae79cbf09d3bcaffb1 Mon Sep 17 00:00:00 2001 From: Jason Zhu Date: Mon, 17 Jun 2024 10:07:22 +0000 Subject: [PATCH 1/8] add odes --- deer/fsolve_ivp.py | 234 ++++++++++++++++++++++++++++++++++++++++ deer/tests/test_deer.py | 48 ++++++++- 2 files changed, 278 insertions(+), 4 deletions(-) diff --git a/deer/fsolve_ivp.py b/deer/fsolve_ivp.py index 330c479..bdebe64 100644 --- a/deer/fsolve_ivp.py +++ b/deer/fsolve_ivp.py @@ -169,3 +169,237 @@ def solve_ivp_inv_lin(self, gmat: List[jnp.ndarray], rhs: jnp.ndarray, # compute the recursive matrix multiplication yt = matmul_recursive(gtbar, htbar, y0) # (nt, ny) return yt + + +class ForwardEuler(SolveIVPMethod): + """ + Compute the solution of initial value problem with the Forward Euler method. + + Arguments + --------- + step_size: float + The step size to use for the Euler method. If None, it will use the difference (tpts[1] - tpts[0]) divided by the number of steps. + """ + def __init__(self, step_size: Optional[float] = None): + self.step_size = step_size + + def compute(self, func: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray], + y0: jnp.ndarray, xinp: jnp.ndarray, params: Any, tpts: jnp.ndarray): + y = [y0] + for i in range(1, len(tpts)): + yi = y[-1] + ti_prev = tpts[i-1] + ti_next = tpts[i] + xi = xinp[i-1] + interval = ti_next - ti_prev + + # Determine the step size and number of steps within this interval + if self.step_size is None: + steps = 1 + dt = interval + else: + steps = max(1, int(interval / self.step_size)) + dt = interval / steps + + for _ in range(steps): + dydt = func(yi, xi, params) + yi = yi + dt * dydt + + y.append(yi) + + return jnp.stack(y) + + +class RK3(SolveIVPMethod): + """ + Compute the solution of initial value problem with the Runge-Kutta 3rd order method. + + Arguments + --------- + step_size: float + The step size to use for the RK3 method. If None, it will use (tpts[1] - tpts[0]). + """ + def __init__(self, step_size: Optional[float] = None): + self.step_size = step_size + + def compute(self, func: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray], + y0: jnp.ndarray, xinp: jnp.ndarray, params: Any, tpts: jnp.ndarray): + # Initialize the solution list with the initial condition + y = [y0] + + # Iterate over time points to compute the solution at each step + for i in range(1, len(tpts)): + yi = y[-1] + xi = xinp[i-1] + ti = tpts[i-1] + tf = tpts[i] + + # Determine the step size + dt = self.step_size + if dt is None: + dt = tf - ti + + # Number of steps between tpts + num_steps = int((tf - ti) / dt) + dt = (tf - ti) / num_steps # Recalculate dt to evenly divide the interval + + for _ in range(num_steps): + k1 = func(yi, xi, params) + k2 = func(yi + 0.5 * dt * k1, xi + 0.5 * dt, params) + k3 = func(yi - dt * k1 + 2 * dt * k2, xi + dt, params) + + yi = yi + (dt / 6.0) * (k1 + 4 * k2 + k3) + xi = xi + dt + + y.append(yi) + + # Stack the list of solutions into a single jax array + return jnp.stack(y) + + +class RK4(SolveIVPMethod): + """ + Compute the solution of initial value problem with the Runge-Kutta 4th order method. + + Arguments + --------- + step_size: float + The step size to use for the RK4 method. If None, it will use (tpts[1] - tpts[0]). + """ + def __init__(self, step_size: Optional[float] = None): + self.step_size = step_size + + def compute(self, func: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray], + y0: jnp.ndarray, xinp: jnp.ndarray, params: Any, tpts: jnp.ndarray): + y = [y0] + for i in range(1, len(tpts)): + yi = y[-1] + xi = xinp[i-1] + ti = tpts[i-1] + tf = tpts[i] + + dt = self.step_size + if dt is None: + dt = tf - ti + + num_steps = int((tf - ti) / dt) + dt = (tf - ti) / num_steps + + for _ in range(num_steps): + k1 = func(yi, xi, params) + k2 = func(yi + 0.5 * dt * k1, xi + 0.5 * dt, params) + k3 = func(yi + 0.5 * dt * k2, xi + 0.5 * dt, params) + k4 = func(yi + dt * k3, xi + dt, params) + + yi = yi + (dt / 6.0) * (k1 + 2*k2 + 2*k3 + k4) + xi = xi + dt + + y.append(yi) + + return jnp.stack(y) + + +class BackwardEuler(SolveIVPMethod): + """ + Compute the solution of initial value problem with the Backward Euler method. + + Arguments + --------- + step_size: float + The step size to use for the Euler method. If None, it will use (tpts[1] - tpts[0]). + tol: float + The tolerance for the fixed-point iteration. + max_iter: int + The maximum number of iterations for the fixed-point iteration. + """ + def __init__(self, step_size: Optional[float] = None, tol: float = 1e-6, max_iter: int = 100): + self.step_size = step_size + self.tol = tol + self.max_iter = max_iter + + def compute(self, func: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray], + y0: jnp.ndarray, xinp: jnp.ndarray, params: Any, tpts: jnp.ndarray): + y = [y0] + for i in range(1, len(tpts)): + yi = y[-1] + xi = xinp[i-1] + ti = tpts[i-1] + tf = tpts[i] + + dt = self.step_size + if dt is None: + dt = tf - ti + + num_steps = int((tf - ti) / dt) + dt = (tf - ti) / num_steps + + for _ in range(num_steps): + # Fixed-point iteration to solve: y_new = yi + dt * func(y_new, xi, params) + y_new = yi + for _ in range(self.max_iter): + y_new_next = yi + dt * func(y_new, xi, params) + if jnp.linalg.norm(y_new_next - y_new) < self.tol: + y_new = y_new_next + break + y_new = y_new_next + yi = y_new + xi = xi + dt + + y.append(yi) + + return jnp.stack(y) + + + +class TrapezoidalMethod(SolveIVPMethod): + """ + Compute the solution of initial value problem with the Trapezoidal method. + + Arguments + --------- + step_size: float + The step size to use for the Trapezoidal method. If None, it will use (tpts[1] - tpts[0]). + tol: float + The tolerance for the fixed-point iteration. + max_iter: int + The maximum number of iterations for the fixed-point iteration. + """ + def __init__(self, step_size: Optional[float] = None, tol: float = 1e-6, max_iter: int = 100): + self.step_size = step_size + self.tol = tol + self.max_iter = max_iter + + def compute(self, func: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray], + y0: jnp.ndarray, xinp: jnp.ndarray, params: Any, tpts: jnp.ndarray): + y = [y0] + for i in range(1, len(tpts)): + yi = y[-1] + xi = xinp[i-1] + ti = tpts[i-1] + tf = tpts[i] + + dt = self.step_size + if dt is None: + dt = tf - ti + + num_steps = int((tf - ti) / dt) + dt = (tf - ti) / num_steps + + for _ in range(num_steps): + # Initial guess for fixed-point iteration (use forward Euler as an initial guess) + y_new = yi + dt * func(yi, xi, params) + + # Fixed-point iteration to solve: y_new = yi + (dt/2) * (func(yi, xi, params) + func(y_new, xi + dt, params)) + for _ in range(self.max_iter): + y_new_next = yi + (dt / 2.0) * (func(yi, xi, params) + func(y_new, xi + dt, params)) + if jnp.linalg.norm(y_new_next - y_new) < self.tol: + y_new = y_new_next + break + y_new = y_new_next + + yi = y_new + xi = xi + dt + + y.append(yi) + + return jnp.stack(y) \ No newline at end of file diff --git a/deer/tests/test_deer.py b/deer/tests/test_deer.py index c1e4419..2b716cc 100644 --- a/deer/tests/test_deer.py +++ b/deer/tests/test_deer.py @@ -34,17 +34,22 @@ def test_matmul_recursive(): result2 = matmul_recursive(mats, vecs, y0) assert jnp.allclose(result, result2) -@pytest.mark.parametrize("method", [ - solve_ivp.DEER() +@pytest.mark.parametrize("method, npts", [ + (solve_ivp.DEER(), 10000), + (solve_ivp.ForwardEuler(), 3), + (solve_ivp.RK3(), 1), + (solve_ivp.RK4(), 100), + (solve_ivp.BackwardEuler(), 1), + (solve_ivp.TrapezoidalMethod(), 100), ]) -def test_solve_ivp(method): +def test_solve_ivp(method, npts): ny = 4 dtype = jnp.float64 key = jax.random.PRNGKey(0) subkey1, subkey2, subkey3 = jax.random.split(key, 3) A0 = (jax.random.uniform(subkey1, shape=(ny, ny), dtype=dtype) * 2 - 1) / ny ** 0.5 A1 = jax.random.uniform(subkey2, shape=(ny, ny), dtype=dtype) / ny ** 0.5 - npts = 10000 # TODO: investigate why npts=1000 make nans + # npts = 5 # TODO: investigate why npts=1000 make nans tpts = jnp.linspace(0, 1.0, npts, dtype=dtype) # (ntpts,) y0 = jax.random.uniform(subkey3, shape=(ny,), dtype=dtype) @@ -349,5 +354,40 @@ def dae_pendulum(vrdot: jnp.ndarray, vr: jnp.ndarray, t: jnp.ndarray, params) -> f4 = x ** 2 + y ** 2 - 1 # index-3 return jnp.concatenate([f0, f1, f2, f3, f4]) +def test_odes(): + def sample_func(y, x, params): + return -params * y + + # Initialize parameters + step_size = 0.0001 + y0 = jnp.array([1.0, 0.0]) + xinp = jnp.linspace(0, 1, 10) + params = 0.5 + tpts = jnp.linspace(0, 1, 10) + + # Instantiate solvers + rk4_solver = solve_ivp.RK4(step_size=step_size) + trapezoidal_solver = solve_ivp.TrapezoidalMethod(step_size=step_size) + + # Compute solutions + yt_rk4 = rk4_solver.compute(sample_func, y0, xinp, params, tpts) + yt_trapezoidal = trapezoidal_solver.compute(sample_func, y0, xinp, params, tpts) + yt_rk3 = solve_ivp.RK3().compute(sample_func, y0, xinp, params, tpts) + yt_forward_euler = solve_ivp.ForwardEuler().compute(sample_func, y0, xinp, params, tpts) + yt_backward_euler = solve_ivp.BackwardEuler().compute(sample_func, y0, xinp, params, tpts) + + + # Compare results + print(f"RK4 Results: {yt_rk4}") + print(f"Trapezoidal Results: {yt_trapezoidal}") + print(f"RK3 Results: {yt_rk3}") + print(f"Forward Euler Results: {yt_forward_euler}") + print(f"Backward Euler Results: {yt_backward_euler}") + assert jnp.allclose(yt_rk4, yt_trapezoidal, atol=1e-6) + assert jnp.allclose(yt_rk4, yt_rk3, atol=1e-6) + + # assert jnp.allclose(yt_rk4, yt_forward_euler, atol=1e-6) + # assert jnp.allclose(yt_rk4, yt_backward_euler, atol=1e-6) + if __name__ == "__main__": test_solve_idae() From 294cd5507fd94c1ffc11bd0beb05796e06f48a52 Mon Sep 17 00:00:00 2001 From: Jason Zhu Date: Wed, 19 Jun 2024 17:15:32 +0000 Subject: [PATCH 2/8] Rewrite ode and test cases --- deer/fsolve_ivp.py | 341 ++++++++++++++++++++-------------------- deer/tests/test_deer.py | 63 +++----- 2 files changed, 192 insertions(+), 212 deletions(-) diff --git a/deer/fsolve_ivp.py b/deer/fsolve_ivp.py index bdebe64..3aa697b 100644 --- a/deer/fsolve_ivp.py +++ b/deer/fsolve_ivp.py @@ -171,53 +171,14 @@ def solve_ivp_inv_lin(self, gmat: List[jnp.ndarray], rhs: jnp.ndarray, return yt -class ForwardEuler(SolveIVPMethod): +class GeneralODE(SolveIVPMethod): """ - Compute the solution of initial value problem with the Forward Euler method. - - Arguments - --------- - step_size: float - The step size to use for the Euler method. If None, it will use the difference (tpts[1] - tpts[0]) divided by the number of steps. - """ - def __init__(self, step_size: Optional[float] = None): - self.step_size = step_size - - def compute(self, func: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray], - y0: jnp.ndarray, xinp: jnp.ndarray, params: Any, tpts: jnp.ndarray): - y = [y0] - for i in range(1, len(tpts)): - yi = y[-1] - ti_prev = tpts[i-1] - ti_next = tpts[i] - xi = xinp[i-1] - interval = ti_next - ti_prev - - # Determine the step size and number of steps within this interval - if self.step_size is None: - steps = 1 - dt = interval - else: - steps = max(1, int(interval / self.step_size)) - dt = interval / steps - - for _ in range(steps): - dydt = func(yi, xi, params) - yi = yi + dt * dydt - - y.append(yi) - - return jnp.stack(y) - - -class RK3(SolveIVPMethod): - """ - Compute the solution of initial value problem with the Runge-Kutta 3rd order method. + Compute the solution of initial value problem with the ODE methods. Arguments --------- step_size: float - The step size to use for the RK3 method. If None, it will use (tpts[1] - tpts[0]). + The step size for ODE solver. If None, it will use (tpts[i] - tpts[i - 1]). """ def __init__(self, step_size: Optional[float] = None): self.step_size = step_size @@ -231,175 +192,209 @@ def compute(self, func: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray], for i in range(1, len(tpts)): yi = y[-1] xi = xinp[i-1] + xf = xinp[i] ti = tpts[i-1] tf = tpts[i] # Determine the step size - dt = self.step_size - if dt is None: - dt = tf - ti + dt = self.step_size if self.step_size is not None else (tf - ti) - # Number of steps between tpts - num_steps = int((tf - ti) / dt) + # Number of steps between tpts, at least 1 + num_steps = max(int((tf - ti) / dt), 1) dt = (tf - ti) / num_steps # Recalculate dt to evenly divide the interval + dx = (xf - xi) / (tf - ti) * dt for _ in range(num_steps): - k1 = func(yi, xi, params) - k2 = func(yi + 0.5 * dt * k1, xi + 0.5 * dt, params) - k3 = func(yi - dt * k1 + 2 * dt * k2, xi + dt, params) - - yi = yi + (dt / 6.0) * (k1 + 4 * k2 + k3) - xi = xi + dt - + yi, xi = self.ode_step(func, yi, xi, dt, dx, params) y.append(yi) # Stack the list of solutions into a single jax array return jnp.stack(y) + @abstractmethod + def ode_step( + self, + func: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray], + yi: jnp.ndarray, + xi: jnp.ndarray, + dt: jnp.ndarray, + dx: jnp.ndarray, + params: Any + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + """ + Solve a single step of the ODE using the specific method. + + Parameters + ---------- + func : Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray] + The function defining the differential equation dy/dt = f(y, x, params) + yi : jnp.ndarray + The state of the system at the beginning of the time step. + xi : jnp.ndarray + The input at the beginning of the time step. + dt : jnp.ndarray + The size of the time step. + dx : jnp.ndarray + The change in input over the time step. + params : Any + Additional parameters for the differential equation. -class RK4(SolveIVPMethod): - """ - Compute the solution of initial value problem with the Runge-Kutta 4th order method. + Returns + ------- + Tuple[jnp.ndarray, jnp.ndarray] + The state of the system and input at the end of the time step. + """ + pass +class ForwardEuler(GeneralODE): + """ + Compute the solution of initial value problem with the Forward Euler method. + Arguments --------- step_size: float - The step size to use for the RK4 method. If None, it will use (tpts[1] - tpts[0]). + The step size to use for the Euler method. If None, it will use the difference (tpts[1] - tpts[0]) divided by the number of steps. """ - def __init__(self, step_size: Optional[float] = None): - self.step_size = step_size - - def compute(self, func: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray], - y0: jnp.ndarray, xinp: jnp.ndarray, params: Any, tpts: jnp.ndarray): - y = [y0] - for i in range(1, len(tpts)): - yi = y[-1] - xi = xinp[i-1] - ti = tpts[i-1] - tf = tpts[i] - - dt = self.step_size - if dt is None: - dt = tf - ti - - num_steps = int((tf - ti) / dt) - dt = (tf - ti) / num_steps - - for _ in range(num_steps): - k1 = func(yi, xi, params) - k2 = func(yi + 0.5 * dt * k1, xi + 0.5 * dt, params) - k3 = func(yi + 0.5 * dt * k2, xi + 0.5 * dt, params) - k4 = func(yi + dt * k3, xi + dt, params) - - yi = yi + (dt / 6.0) * (k1 + 2*k2 + 2*k3 + k4) - xi = xi + dt - - y.append(yi) - - return jnp.stack(y) - - -class BackwardEuler(SolveIVPMethod): + def ode_step( + self, + func: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray], + yi: jnp.ndarray, + xi: jnp.ndarray, + dt: jnp.ndarray, + dx: jnp.ndarray, + params: Any + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + k = func(yi, xi, params) + yi_new = yi + dt * k + xi_new = xi + dx + return yi_new, xi_new + + +class RK3(GeneralODE): """ - Compute the solution of initial value problem with the Backward Euler method. - + Compute the solution of initial value problem with the Runge-Kutta 3rd order method. + Arguments --------- step_size: float - The step size to use for the Euler method. If None, it will use (tpts[1] - tpts[0]). - tol: float - The tolerance for the fixed-point iteration. - max_iter: int - The maximum number of iterations for the fixed-point iteration. + The step size to use for the RK3 method. If None, it will use (tpts[i] - tpts[i - 1]). """ - def __init__(self, step_size: Optional[float] = None, tol: float = 1e-6, max_iter: int = 100): - self.step_size = step_size - self.tol = tol - self.max_iter = max_iter - - def compute(self, func: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray], - y0: jnp.ndarray, xinp: jnp.ndarray, params: Any, tpts: jnp.ndarray): - y = [y0] - for i in range(1, len(tpts)): - yi = y[-1] - xi = xinp[i-1] - ti = tpts[i-1] - tf = tpts[i] - - dt = self.step_size - if dt is None: - dt = tf - ti - - num_steps = int((tf - ti) / dt) - dt = (tf - ti) / num_steps - - for _ in range(num_steps): - # Fixed-point iteration to solve: y_new = yi + dt * func(y_new, xi, params) - y_new = yi - for _ in range(self.max_iter): - y_new_next = yi + dt * func(y_new, xi, params) - if jnp.linalg.norm(y_new_next - y_new) < self.tol: - y_new = y_new_next - break - y_new = y_new_next - yi = y_new - xi = xi + dt - - y.append(yi) - - return jnp.stack(y) - + def ode_step( + self, + func: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray], + yi: jnp.ndarray, + xi: jnp.ndarray, + dt: jnp.ndarray, + dx: jnp.ndarray, + params: Any + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + k1 = func(yi, xi, params) + k2 = func(yi + dt * k1, xi + dx, params) + k3 = func(yi + 0.25 * dt * (k1 + k2), xi + 0.5 * dx, params) + + yi = yi + (dt / 6.0) * (k1 + k2 + 4 * k3) + xi = xi + dx + return yi, xi + + +class RK4(GeneralODE): + """ + Compute the solution of initial value problem with the Runge-Kutta 4th order method. -class TrapezoidalMethod(SolveIVPMethod): + Arguments + --------- + step_size: float + The step size to use for the RK4 method. If None, it will use (tpts[i] - tpts[i - 1]). """ - Compute the solution of initial value problem with the Trapezoidal method. + + def ode_step( + self, + func: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray], + yi: jnp.ndarray, + xi: jnp.ndarray, + dt: jnp.ndarray, + dx: jnp.ndarray, + params: Any + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + k1 = func(yi, xi, params) + k2 = func(yi + 0.5 * dt * k1, xi + 0.5 * dx, params) + k3 = func(yi + 0.5 * dt * k2, xi + 0.5 * dx, params) + k4 = func(yi + dt * k3, xi + dx, params) + + yi = yi + (dt / 6.0) * (k1 + 2*k2 + 2*k3 + k4) + xi = xi + dx + return yi, xi + + +class GeneralBackwardODE(GeneralODE): + """ + Compute the solution of initial value problem with Backward ODE methods. Arguments --------- step_size: float - The step size to use for the Trapezoidal method. If None, it will use (tpts[1] - tpts[0]). + The step size for ODE solver. If None, it will use (tpts[i] - tpts[i - 1]). tol: float The tolerance for the fixed-point iteration. max_iter: int The maximum number of iterations for the fixed-point iteration. """ def __init__(self, step_size: Optional[float] = None, tol: float = 1e-6, max_iter: int = 100): - self.step_size = step_size + super().__init__(step_size) self.tol = tol self.max_iter = max_iter - def compute(self, func: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray], - y0: jnp.ndarray, xinp: jnp.ndarray, params: Any, tpts: jnp.ndarray): - y = [y0] - for i in range(1, len(tpts)): - yi = y[-1] - xi = xinp[i-1] - ti = tpts[i-1] - tf = tpts[i] - - dt = self.step_size - if dt is None: - dt = tf - ti - - num_steps = int((tf - ti) / dt) - dt = (tf - ti) / num_steps - - for _ in range(num_steps): - # Initial guess for fixed-point iteration (use forward Euler as an initial guess) - y_new = yi + dt * func(yi, xi, params) - - # Fixed-point iteration to solve: y_new = yi + (dt/2) * (func(yi, xi, params) + func(y_new, xi + dt, params)) - for _ in range(self.max_iter): - y_new_next = yi + (dt / 2.0) * (func(yi, xi, params) + func(y_new, xi + dt, params)) - if jnp.linalg.norm(y_new_next - y_new) < self.tol: - y_new = y_new_next - break - y_new = y_new_next - - yi = y_new - xi = xi + dt - - y.append(yi) - return jnp.stack(y) \ No newline at end of file +class BackwardEuler(GeneralBackwardODE): + """ + Compute the solution of initial value problem with the Backward Euler method. + """ + + def ode_step( + self, + func: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray], + yi: jnp.ndarray, + xi: jnp.ndarray, + dt: jnp.ndarray, + dx: jnp.ndarray, + params: Any + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + y_new = yi + for _ in range(self.max_iter): + y_new_next = yi + dt * func(y_new, xi, params) + if jnp.linalg.norm(y_new_next - y_new) < self.tol: + y_new = y_new_next + break + y_new = y_new_next + xi_new = xi + dx + return y_new, xi_new + + +class TrapezoidalMethod(GeneralBackwardODE): + """ + Compute the solution of initial value problem with the Trapezoidal method. + """ + + def ode_step( + self, + func: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray], + yi: jnp.ndarray, + xi: jnp.ndarray, + dt: jnp.ndarray, + dx: jnp.ndarray, + params: Any + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + # Initial guess for fixed-point iteration (use forward Euler as an initial guess) + y_new = yi + dt * func(yi, xi, params) + + # Fixed-point iteration to solve the trapezoidal equation + for _ in range(self.max_iter): + y_new_next = yi + (dt / 2.0) * (func(yi, xi, params) + func(y_new, xi + dx, params)) + if jnp.linalg.norm(y_new_next - y_new) < self.tol: + y_new = y_new_next + break + y_new = y_new_next + + xi_new = xi + dx + return y_new, xi_new \ No newline at end of file diff --git a/deer/tests/test_deer.py b/deer/tests/test_deer.py index 2b716cc..47be4e9 100644 --- a/deer/tests/test_deer.py +++ b/deer/tests/test_deer.py @@ -34,22 +34,15 @@ def test_matmul_recursive(): result2 = matmul_recursive(mats, vecs, y0) assert jnp.allclose(result, result2) -@pytest.mark.parametrize("method, npts", [ - (solve_ivp.DEER(), 10000), - (solve_ivp.ForwardEuler(), 3), - (solve_ivp.RK3(), 1), - (solve_ivp.RK4(), 100), - (solve_ivp.BackwardEuler(), 1), - (solve_ivp.TrapezoidalMethod(), 100), -]) -def test_solve_ivp(method, npts): +@pytest.mark.parametrize("method", [solve_ivp.DEER()]) +def test_solve_ivp(method): ny = 4 dtype = jnp.float64 key = jax.random.PRNGKey(0) subkey1, subkey2, subkey3 = jax.random.split(key, 3) A0 = (jax.random.uniform(subkey1, shape=(ny, ny), dtype=dtype) * 2 - 1) / ny ** 0.5 A1 = jax.random.uniform(subkey2, shape=(ny, ny), dtype=dtype) / ny ** 0.5 - # npts = 5 # TODO: investigate why npts=1000 make nans + npts = 10000 # TODO: investigate why npts=1000 make nans tpts = jnp.linspace(0, 1.0, npts, dtype=dtype) # (ntpts,) y0 = jax.random.uniform(subkey3, shape=(ny,), dtype=dtype) @@ -354,40 +347,32 @@ def dae_pendulum(vrdot: jnp.ndarray, vr: jnp.ndarray, t: jnp.ndarray, params) -> f4 = x ** 2 + y ** 2 - 1 # index-3 return jnp.concatenate([f0, f1, f2, f3, f4]) -def test_odes(): + +@pytest.mark.parametrize("method, atol", [ + (solve_ivp.RK3(step_size=0.001), 1e-6), + (solve_ivp.RK4(step_size=0.001), 1e-6), + (solve_ivp.ForwardEuler(step_size=0.001), 1e-2), # Forward Euler has a larger error + (solve_ivp.BackwardEuler(step_size=0.001), 1e-2), # Backward Euler has a larger error + (solve_ivp.TrapezoidalMethod(step_size=0.001), 1e-6) +]) +def test_ODEs(method, atol: float): def sample_func(y, x, params): - return -params * y + return -params * y * x # Initialize parameters - step_size = 0.0001 - y0 = jnp.array([1.0, 0.0]) - xinp = jnp.linspace(0, 1, 10) - params = 0.5 - tpts = jnp.linspace(0, 1, 10) - - # Instantiate solvers - rk4_solver = solve_ivp.RK4(step_size=step_size) - trapezoidal_solver = solve_ivp.TrapezoidalMethod(step_size=step_size) - - # Compute solutions - yt_rk4 = rk4_solver.compute(sample_func, y0, xinp, params, tpts) - yt_trapezoidal = trapezoidal_solver.compute(sample_func, y0, xinp, params, tpts) - yt_rk3 = solve_ivp.RK3().compute(sample_func, y0, xinp, params, tpts) - yt_forward_euler = solve_ivp.ForwardEuler().compute(sample_func, y0, xinp, params, tpts) - yt_backward_euler = solve_ivp.BackwardEuler().compute(sample_func, y0, xinp, params, tpts) - + y0 = jnp.array([1.0, 0.0, -1.0]) + params = -2 + tpts = jnp.linspace(0, 0.5, 6) + xinp = jnp.sin(tpts * 2 * jnp.pi) + + # Compute reference solution using DEER method + yt_deer = solve_ivp(sample_func, y0, xinp, params, tpts, method=solve_ivp.DEER()) + + # Compute solution using the parameterized method + yt_method = solve_ivp(sample_func, y0, xinp, params, tpts, method=method) # Compare results - print(f"RK4 Results: {yt_rk4}") - print(f"Trapezoidal Results: {yt_trapezoidal}") - print(f"RK3 Results: {yt_rk3}") - print(f"Forward Euler Results: {yt_forward_euler}") - print(f"Backward Euler Results: {yt_backward_euler}") - assert jnp.allclose(yt_rk4, yt_trapezoidal, atol=1e-6) - assert jnp.allclose(yt_rk4, yt_rk3, atol=1e-6) - - # assert jnp.allclose(yt_rk4, yt_forward_euler, atol=1e-6) - # assert jnp.allclose(yt_rk4, yt_backward_euler, atol=1e-6) + assert jnp.allclose(yt_deer, yt_method, atol=atol) if __name__ == "__main__": test_solve_idae() From b9ca65a4954e15c4ce01143076b080f504b5b9fd Mon Sep 17 00:00:00 2001 From: Jason Zhu Date: Thu, 20 Jun 2024 19:21:09 +0000 Subject: [PATCH 3/8] formatting --- deer/tests/test_deer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/deer/tests/test_deer.py b/deer/tests/test_deer.py index 47be4e9..b1cd052 100644 --- a/deer/tests/test_deer.py +++ b/deer/tests/test_deer.py @@ -34,7 +34,9 @@ def test_matmul_recursive(): result2 = matmul_recursive(mats, vecs, y0) assert jnp.allclose(result, result2) -@pytest.mark.parametrize("method", [solve_ivp.DEER()]) +@pytest.mark.parametrize("method", [ + solve_ivp.DEER() +]) def test_solve_ivp(method): ny = 4 dtype = jnp.float64 From 8705210cb81534d16aa5bb52d8a6c617c92f89f9 Mon Sep 17 00:00:00 2001 From: Jason Zhu Date: Thu, 27 Jun 2024 23:57:50 +0000 Subject: [PATCH 4/8] remove step_size, use jax.lax.scan --- deer/fsolve_ivp.py | 104 +++++++++++++++++++++------------------- deer/tests/test_deer.py | 12 ++--- 2 files changed, 60 insertions(+), 56 deletions(-) diff --git a/deer/fsolve_ivp.py b/deer/fsolve_ivp.py index 3aa697b..65b75ce 100644 --- a/deer/fsolve_ivp.py +++ b/deer/fsolve_ivp.py @@ -1,5 +1,6 @@ from abc import abstractmethod from typing import Any, Callable, List, Optional, Tuple +import jax import jax.numpy as jnp from deer.deer_iter import deer_iteration from deer.maths import matmul_recursive @@ -180,36 +181,27 @@ class GeneralODE(SolveIVPMethod): step_size: float The step size for ODE solver. If None, it will use (tpts[i] - tpts[i - 1]). """ - def __init__(self, step_size: Optional[float] = None): - self.step_size = step_size def compute(self, func: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray], y0: jnp.ndarray, xinp: jnp.ndarray, params: Any, tpts: jnp.ndarray): - # Initialize the solution list with the initial condition - y = [y0] - - # Iterate over time points to compute the solution at each step - for i in range(1, len(tpts)): - yi = y[-1] + + # Define the main function to cover all time intervals + def scan_step(yi, i): xi = xinp[i-1] - xf = xinp[i] - ti = tpts[i-1] - tf = tpts[i] + dx = xinp[i] - xinp[i-1] + dt = tpts[i] - tpts[i-1] - # Determine the step size - dt = self.step_size if self.step_size is not None else (tf - ti) + yn = self.ode_step(func, yi, xi, dt, dx, params) + return yn, yn - # Number of steps between tpts, at least 1 - num_steps = max(int((tf - ti) / dt), 1) - dt = (tf - ti) / num_steps # Recalculate dt to evenly divide the interval - dx = (xf - xi) / (tf - ti) * dt + _, y = jax.lax.scan(scan_step, y0, jnp.arange(1, len(tpts))) + + y = jnp.concatenate([y0[None, :], y], axis=0) - for _ in range(num_steps): - yi, xi = self.ode_step(func, yi, xi, dt, dx, params) - y.append(yi) + return y + +# Note that `self.ode_step` should be adjusted to properly work with jax if it isn't already. - # Stack the list of solutions into a single jax array - return jnp.stack(y) @abstractmethod def ode_step( @@ -241,8 +233,8 @@ def ode_step( Returns ------- - Tuple[jnp.ndarray, jnp.ndarray] - The state of the system and input at the end of the time step. + jnp.ndarray + The state of the system at the end of the time step. """ pass @@ -266,8 +258,7 @@ def ode_step( ) -> Tuple[jnp.ndarray, jnp.ndarray]: k = func(yi, xi, params) yi_new = yi + dt * k - xi_new = xi + dx - return yi_new, xi_new + return yi_new class RK3(GeneralODE): @@ -294,8 +285,7 @@ def ode_step( k3 = func(yi + 0.25 * dt * (k1 + k2), xi + 0.5 * dx, params) yi = yi + (dt / 6.0) * (k1 + k2 + 4 * k3) - xi = xi + dx - return yi, xi + return yi class RK4(GeneralODE): @@ -323,8 +313,7 @@ def ode_step( k4 = func(yi + dt * k3, xi + dx, params) yi = yi + (dt / 6.0) * (k1 + 2*k2 + 2*k3 + k4) - xi = xi + dx - return yi, xi + return yi class GeneralBackwardODE(GeneralODE): @@ -340,8 +329,7 @@ class GeneralBackwardODE(GeneralODE): max_iter: int The maximum number of iterations for the fixed-point iteration. """ - def __init__(self, step_size: Optional[float] = None, tol: float = 1e-6, max_iter: int = 100): - super().__init__(step_size) + def __init__(self, tol: float = 1e-6, max_iter: int = 100): self.tol = tol self.max_iter = max_iter @@ -360,15 +348,26 @@ def ode_step( dx: jnp.ndarray, params: Any ) -> Tuple[jnp.ndarray, jnp.ndarray]: - y_new = yi - for _ in range(self.max_iter): + + def cond(val): + y_new, y_new_next, i = val + return jnp.logical_and(i < self.max_iter, jnp.linalg.norm(y_new_next - y_new) >= self.tol) + + def body(val): + y_new, _, i = val y_new_next = yi + dt * func(y_new, xi, params) - if jnp.linalg.norm(y_new_next - y_new) < self.tol: - y_new = y_new_next - break - y_new = y_new_next - xi_new = xi + dx - return y_new, xi_new + return y_new_next, y_new_next, i + 1 + + # Initial values for y_new and y_new_next + y_new = yi + y_new_next = yi + dt * func(y_new, xi, params) + val = (y_new, y_new_next, 0) + + # Use jax.lax.while_loop to iterate while the condition is met + y_new, _, _ = jax.lax.while_loop(cond, body, val) + + return y_new + class TrapezoidalMethod(GeneralBackwardODE): @@ -384,17 +383,22 @@ def ode_step( dt: jnp.ndarray, dx: jnp.ndarray, params: Any - ) -> Tuple[jnp.ndarray, jnp.ndarray]: + ) -> jnp.ndarray: + def body_fn(state): + y_new, _, i = state + print(_, i) + y_new_next = yi + (dt / 2.0) * (func(yi, xi, params) + func(y_new, xi + dx, params)) + return y_new_next, jnp.linalg.norm(y_new_next - y_new), i + 1 + + def cond_fn(state): + _, diff_norm, i = state + print(i, diff_norm >= self.tol) + return jnp.logical_and(i < self.max_iter, diff_norm >= self.tol) + # Initial guess for fixed-point iteration (use forward Euler as an initial guess) y_new = yi + dt * func(yi, xi, params) - # Fixed-point iteration to solve the trapezoidal equation - for _ in range(self.max_iter): - y_new_next = yi + (dt / 2.0) * (func(yi, xi, params) + func(y_new, xi + dx, params)) - if jnp.linalg.norm(y_new_next - y_new) < self.tol: - y_new = y_new_next - break - y_new = y_new_next + # Run the fixed-point iteration loop using lax.while_loop + y_new, _, _ = jax.lax.while_loop(cond_fn, body_fn, (y_new, jnp.inf, 0)) - xi_new = xi + dx - return y_new, xi_new \ No newline at end of file + return y_new \ No newline at end of file diff --git a/deer/tests/test_deer.py b/deer/tests/test_deer.py index b1cd052..eb46bcc 100644 --- a/deer/tests/test_deer.py +++ b/deer/tests/test_deer.py @@ -351,11 +351,11 @@ def dae_pendulum(vrdot: jnp.ndarray, vr: jnp.ndarray, t: jnp.ndarray, params) -> @pytest.mark.parametrize("method, atol", [ - (solve_ivp.RK3(step_size=0.001), 1e-6), - (solve_ivp.RK4(step_size=0.001), 1e-6), - (solve_ivp.ForwardEuler(step_size=0.001), 1e-2), # Forward Euler has a larger error - (solve_ivp.BackwardEuler(step_size=0.001), 1e-2), # Backward Euler has a larger error - (solve_ivp.TrapezoidalMethod(step_size=0.001), 1e-6) + (solve_ivp.RK3(), 1e-6), + (solve_ivp.RK4(), 1e-6), + (solve_ivp.ForwardEuler(), 1e-2), # Forward Euler has a larger error + (solve_ivp.BackwardEuler(), 1e-2), # Backward Euler has a larger error + (solve_ivp.TrapezoidalMethod(), 1e-6) ]) def test_ODEs(method, atol: float): def sample_func(y, x, params): @@ -364,7 +364,7 @@ def sample_func(y, x, params): # Initialize parameters y0 = jnp.array([1.0, 0.0, -1.0]) params = -2 - tpts = jnp.linspace(0, 0.5, 6) + tpts = jnp.linspace(0, 0.5, 1000) xinp = jnp.sin(tpts * 2 * jnp.pi) # Compute reference solution using DEER method From 0f1c6cd269495b27ad58516084d05343a1089db3 Mon Sep 17 00:00:00 2001 From: Jason Zhu Date: Fri, 28 Jun 2024 00:28:13 +0000 Subject: [PATCH 5/8] Change return type of ODE to Result --- deer/fsolve_ivp.py | 7 +------ deer/tests/test_deer.py | 2 +- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/deer/fsolve_ivp.py b/deer/fsolve_ivp.py index 6d251e4..bcdae37 100644 --- a/deer/fsolve_ivp.py +++ b/deer/fsolve_ivp.py @@ -171,11 +171,6 @@ def solve_ivp_inv_lin(self, gmat: List[jnp.ndarray], rhs: jnp.ndarray, class GeneralODE(SolveIVPMethod): """ Compute the solution of initial value problem with the ODE methods. - - Arguments - --------- - step_size: float - The step size for ODE solver. If None, it will use (tpts[i] - tpts[i - 1]). """ def compute(self, func: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray], @@ -194,7 +189,7 @@ def scan_step(yi, i): y = jnp.concatenate([y0[None, :], y], axis=0) - return y + return Result(y, True) # Note that `self.ode_step` should be adjusted to properly work with jax if it isn't already. diff --git a/deer/tests/test_deer.py b/deer/tests/test_deer.py index 943bce2..2afe87b 100644 --- a/deer/tests/test_deer.py +++ b/deer/tests/test_deer.py @@ -438,7 +438,7 @@ def sample_func(y, x, params): yt_method = solve_ivp(sample_func, y0, xinp, params, tpts, method=method) # Compare results - assert jnp.allclose(yt_deer, yt_method, atol=atol) + assert jnp.allclose(yt_deer.value, yt_method.value, atol=atol) if __name__ == "__main__": test_solve_idae() From e7a0639de2fdd0d6c962273f063a937356a61fb1 Mon Sep 17 00:00:00 2001 From: Jason Zhu Date: Thu, 4 Jul 2024 10:22:38 +0000 Subject: [PATCH 6/8] remove step_size in docstrings --- deer/fsolve_ivp.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/deer/fsolve_ivp.py b/deer/fsolve_ivp.py index bcdae37..6ba1de9 100644 --- a/deer/fsolve_ivp.py +++ b/deer/fsolve_ivp.py @@ -232,11 +232,6 @@ def ode_step( class ForwardEuler(GeneralODE): """ Compute the solution of initial value problem with the Forward Euler method. - - Arguments - --------- - step_size: float - The step size to use for the Euler method. If None, it will use the difference (tpts[1] - tpts[0]) divided by the number of steps. """ def ode_step( self, @@ -255,11 +250,6 @@ def ode_step( class RK3(GeneralODE): """ Compute the solution of initial value problem with the Runge-Kutta 3rd order method. - - Arguments - --------- - step_size: float - The step size to use for the RK3 method. If None, it will use (tpts[i] - tpts[i - 1]). """ def ode_step( @@ -282,11 +272,6 @@ def ode_step( class RK4(GeneralODE): """ Compute the solution of initial value problem with the Runge-Kutta 4th order method. - - Arguments - --------- - step_size: float - The step size to use for the RK4 method. If None, it will use (tpts[i] - tpts[i - 1]). """ def ode_step( @@ -313,8 +298,6 @@ class GeneralBackwardODE(GeneralODE): Arguments --------- - step_size: float - The step size for ODE solver. If None, it will use (tpts[i] - tpts[i - 1]). tol: float The tolerance for the fixed-point iteration. max_iter: int From 18645eb4196a02f0b88a0d727fa077aeeaefe10a Mon Sep 17 00:00:00 2001 From: Jason Zhu Date: Fri, 5 Jul 2024 14:56:57 +0000 Subject: [PATCH 7/8] Add grad check, improve the convergence --- deer/fsolve_ivp.py | 148 ++++++++++++++++++++++------------------ deer/tests/test_deer.py | 14 ++-- 2 files changed, 90 insertions(+), 72 deletions(-) diff --git a/deer/fsolve_ivp.py b/deer/fsolve_ivp.py index 1180b30..c5c0896 100644 --- a/deer/fsolve_ivp.py +++ b/deer/fsolve_ivp.py @@ -4,7 +4,7 @@ import jax.numpy as jnp from deer.deer_iter import deer_iteration from deer.maths import matmul_recursive -from deer.utils import get_method_meta, check_method, Result +from deer.utils import get_method_meta, check_method, Result, while_loop_scan __all__ = ["solve_ivp"] @@ -185,19 +185,27 @@ def compute(self, func: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray], y0: jnp.ndarray, xinp: jnp.ndarray, params: Any, tpts: jnp.ndarray): # Define the main function to cover all time intervals - def scan_step(yi, i): - xi = xinp[i-1] - dx = xinp[i] - xinp[i-1] - dt = tpts[i] - tpts[i-1] - - yn = self.ode_step(func, yi, xi, dt, dx, params) - return yn, yn - - _, y = jax.lax.scan(scan_step, y0, jnp.arange(1, len(tpts))) - + def scan_step(carry, i): + _, success = carry + def success_fn(carry, i): + yi, _ = carry + xi = xinp[i-1] + dx = xinp[i] - xinp[i-1] + dt = tpts[i] - tpts[i-1] + yn, success = self.ode_step(func, yi, xi, dt, dx, params) + return yn, success + + def failure_fn(carry, i): + yi, _ = carry + return yi, False + + res = jax.lax.cond(success, success_fn, failure_fn, carry, i) + return res, res + + _, (y, success) = jax.lax.scan(scan_step, (y0, True), jnp.arange(1, len(tpts))) y = jnp.concatenate([y0[None, :], y], axis=0) - - return Result(y, True) + success = jnp.concatenate((jnp.full_like(success[:1], True, dtype=jnp.bool), success), axis=0) + return Result(y, success[:, None]) # Note that `self.ode_step` should be adjusted to properly work with jax if it isn't already. @@ -252,7 +260,7 @@ def ode_step( ) -> Tuple[jnp.ndarray, jnp.ndarray]: k = func(yi, xi, params) yi_new = yi + dt * k - return yi_new + return yi_new, True class RK3(GeneralODE): @@ -274,7 +282,7 @@ def ode_step( k3 = func(yi + 0.25 * dt * (k1 + k2), xi + 0.5 * dx, params) yi = yi + (dt / 6.0) * (k1 + k2 + 4 * k3) - return yi + return yi, True class RK4(GeneralODE): @@ -297,7 +305,7 @@ def ode_step( k4 = func(yi + dt * k3, xi + dx, params) yi = yi + (dt / 6.0) * (k1 + 2*k2 + 2*k3 + k4) - return yi + return yi, True class GeneralBackwardODE(GeneralODE): @@ -315,12 +323,6 @@ def __init__(self, tol: float = 1e-6, max_iter: int = 100): self.tol = tol self.max_iter = max_iter - -class BackwardEuler(GeneralBackwardODE): - """ - Compute the solution of initial value problem with the Backward Euler method. - """ - def ode_step( self, func: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray], @@ -329,58 +331,68 @@ def ode_step( dt: jnp.ndarray, dx: jnp.ndarray, params: Any - ) -> Tuple[jnp.ndarray, jnp.ndarray]: - - def cond(val): - y_new, y_new_next, i = val - return jnp.logical_and(i < self.max_iter, jnp.linalg.norm(y_new_next - y_new) >= self.tol) - - def body(val): - y_new, _, i = val - y_new_next = yi + dt * func(y_new, xi, params) - return y_new_next, y_new_next, i + 1 - - # Initial values for y_new and y_new_next - y_new = yi - y_new_next = yi + dt * func(y_new, xi, params) - val = (y_new, y_new_next, 0) - - # Use jax.lax.while_loop to iterate while the condition is met - y_new, _, _ = jax.lax.while_loop(cond, body, val) + ) -> jnp.ndarray: + def body_fn(state): + y_new, _ = state + y_new_next = self.backward_iteration_step(func, y_new, yi, xi, dt, dx, params) + return y_new_next, jnp.linalg.norm(y_new_next - y_new) + + def cond_fn(state): + _, diff_norm = state + return diff_norm >= self.tol + + # Initial guess for fixed-point iteration (use forward Euler as an initial guess) + y_new = yi + dt * func(yi, xi, params) + + # Run the fixed-point iteration loop using lax.while_loop + state, _ = while_loop_scan(cond_fn, body_fn, (y_new, jnp.inf), self.max_iter) + y_new, diff_norm = state + return y_new, diff_norm < self.tol + + # @abstractmethod + def backward_iteration_step(self, + func: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray], + y_new: jnp.ndarray, + yi: jnp.ndarray, + xi: jnp.ndarray, + dt: jnp.ndarray, + dx: jnp.ndarray, + params: Any) -> jnp.ndarray: + """ + The body function for the fixed-point iteration. Return the next value of yi. - return y_new + Arguments: + func: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray] + The function defining the differential equation dy/dt = f(y, x, params). + y_new: jnp.ndarray + The current value of yi. + yi: jnp.ndarray + The initial value of yi. + xi: jnp.ndarray + The input signal at the current time step. + dt: jnp.ndarray + The time step size. + dx: jnp.ndarray + The change in input signal. + params: Any + The parameters of the function. + """ + pass +class BackwardEuler(GeneralBackwardODE): + """ + Compute the solution of initial value problem with the Backward Euler method. + """ + + def backward_iteration_step(self, func, y_new, yi, xi, dt, dx, params) -> jnp.ndarray: + return yi + dt * func(y_new, xi, params) + class TrapezoidalMethod(GeneralBackwardODE): """ Compute the solution of initial value problem with the Trapezoidal method. """ - def ode_step( - self, - func: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray], - yi: jnp.ndarray, - xi: jnp.ndarray, - dt: jnp.ndarray, - dx: jnp.ndarray, - params: Any - ) -> jnp.ndarray: - def body_fn(state): - y_new, _, i = state - print(_, i) - y_new_next = yi + (dt / 2.0) * (func(yi, xi, params) + func(y_new, xi + dx, params)) - return y_new_next, jnp.linalg.norm(y_new_next - y_new), i + 1 - - def cond_fn(state): - _, diff_norm, i = state - print(i, diff_norm >= self.tol) - return jnp.logical_and(i < self.max_iter, diff_norm >= self.tol) - - # Initial guess for fixed-point iteration (use forward Euler as an initial guess) - y_new = yi + dt * func(yi, xi, params) - - # Run the fixed-point iteration loop using lax.while_loop - y_new, _, _ = jax.lax.while_loop(cond_fn, body_fn, (y_new, jnp.inf, 0)) - - return y_new \ No newline at end of file + def backward_iteration_step(self, func, y_new, yi, xi, dt, dx, params) -> jnp.ndarray: + return yi + (dt / 2.0) * (func(yi, xi, params) + func(y_new, xi + dx, params)) diff --git a/deer/tests/test_deer.py b/deer/tests/test_deer.py index 2afe87b..939a576 100644 --- a/deer/tests/test_deer.py +++ b/deer/tests/test_deer.py @@ -429,16 +429,22 @@ def sample_func(y, x, params): y0 = jnp.array([1.0, 0.0, -1.0]) params = -2 tpts = jnp.linspace(0, 0.5, 1000) - xinp = jnp.sin(tpts * 2 * jnp.pi) - - # Compute reference solution using DEER method - yt_deer = solve_ivp(sample_func, y0, xinp, params, tpts, method=solve_ivp.DEER()) + xinp = jnp.sin(tpts * 2 * jnp.pi) + 1e-3 # Compute solution using the parameterized method yt_method = solve_ivp(sample_func, y0, xinp, params, tpts, method=method) + # Compute reference solution using DEER method + yt_deer = solve_ivp(sample_func, y0, xinp, params, tpts, method=solve_ivp.DEER()) + # Compare results assert jnp.allclose(yt_deer.value, yt_method.value, atol=atol) + # Check the gradients + def solve_ivp_wrapper(y0): + return solve_ivp(sample_func, y0, xinp, params, tpts, method=method).value + + jax.test_util.check_grads(solve_ivp_wrapper, (y0,), order=1, modes=['fwd', 'rev']) + if __name__ == "__main__": test_solve_idae() From da37873b367a24900a6a24876e7559b26c44e76c Mon Sep 17 00:00:00 2001 From: Jason Zhu Date: Fri, 12 Jul 2024 12:21:52 +0000 Subject: [PATCH 8/8] make function purer --- deer/fsolve_ivp.py | 34 +++++++++++++++++++++++----------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/deer/fsolve_ivp.py b/deer/fsolve_ivp.py index c5c0896..7b66973 100644 --- a/deer/fsolve_ivp.py +++ b/deer/fsolve_ivp.py @@ -181,32 +181,44 @@ class GeneralODE(SolveIVPMethod): Compute the solution of initial value problem with the ODE methods. """ - def compute(self, func: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray], + def compute(self, func: Callable[[jnp.ndarray, jnp.ndarray, Any], jnp.ndarray], y0: jnp.ndarray, xinp: jnp.ndarray, params: Any, tpts: jnp.ndarray): + # Precompute dx and dt + dx = jnp.diff(xinp) + dt = jnp.diff(tpts) + # Define the main function to cover all time intervals - def scan_step(carry, i): - _, success = carry - def success_fn(carry, i): + def scan_step(carry, inputs): + yi, success = carry + xi, dx_i, dt_i = inputs + def success_fn(carry): yi, _ = carry - xi = xinp[i-1] - dx = xinp[i] - xinp[i-1] - dt = tpts[i] - tpts[i-1] - yn, success = self.ode_step(func, yi, xi, dt, dx, params) + yn, success = self.ode_step(func, yi, xi, dt_i, dx_i, params) return yn, success - def failure_fn(carry, i): + def failure_fn(carry): yi, _ = carry return yi, False - res = jax.lax.cond(success, success_fn, failure_fn, carry, i) + res = jax.lax.cond(success, success_fn, failure_fn, carry) return res, res - _, (y, success) = jax.lax.scan(scan_step, (y0, True), jnp.arange(1, len(tpts))) + # Initial carry + initial_carry = (y0, True) + + # Pack parameters for scan + scan_inputs = (xinp[:-1], dx, dt) + + # Use jax.lax.scan to iterate over time points + _, (y, success) = jax.lax.scan(scan_step, initial_carry, scan_inputs) + y = jnp.concatenate([y0[None, :], y], axis=0) success = jnp.concatenate((jnp.full_like(success[:1], True, dtype=jnp.bool), success), axis=0) + return Result(y, success[:, None]) + # Note that `self.ode_step` should be adjusted to properly work with jax if it isn't already.