diff --git a/Project.toml b/Project.toml index 9079151..79d53b9 100644 --- a/Project.toml +++ b/Project.toml @@ -25,8 +25,11 @@ julia = "1.6" [extras] ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" +ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" +NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" ODEProblemLibrary = "fdc4e326-1af4-4b90-96e7-779fcce2daa5" +SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["ExplicitImports", "Test", "ODEProblemLibrary"] +test = ["ExplicitImports", "ModelingToolkit", "NonlinearSolve", "SymbolicIndexingInterface", "Test", "ODEProblemLibrary"] diff --git a/src/ODEInterfaceDiffEq.jl b/src/ODEInterfaceDiffEq.jl index 66f5ed2..f6b5005 100644 --- a/src/ODEInterfaceDiffEq.jl +++ b/src/ODEInterfaceDiffEq.jl @@ -32,6 +32,7 @@ module ODEInterfaceDiffEq include("integrator_types.jl") include("integrator_utils.jl") include("solve.jl") + include("initialize.jl") export ODEInterfaceAlgorithm, dopri5, dop853, odex, seulex, radau, radau5, rodas, ddeabm, ddebdf diff --git a/src/algorithms.jl b/src/algorithms.jl index 5484fe1..3acdb2e 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -226,17 +226,21 @@ end SciMLBase.alg_order(alg::ddebdf) = 5 seulex(; jac_lower = nothing, jac_upper = nothing) = seulex(jac_lower, jac_upper) -function radau(; jac_lower = nothing, jac_upper = nothing, - M1 = nothing, M2 = nothing, - DIMOFIND1VAR = nothing, DIMOFIND2VAR = nothing, DIMOFIND3VAR = nothing, - massmatrix = nothing) - radau(jac_lower, jac_upper, M1, M2, DIMOFIND1VAR, DIMOFIND2VAR, DIMOFIND3VAR, massmatrix) +function radau(; + jac_lower = nothing, jac_upper = nothing, + M1 = nothing, M2 = nothing, + DIMOFIND1VAR = nothing, DIMOFIND2VAR = nothing, DIMOFIND3VAR = nothing, + massmatrix = nothing + ) + return radau(jac_lower, jac_upper, M1, M2, DIMOFIND1VAR, DIMOFIND2VAR, DIMOFIND3VAR, massmatrix) end -function radau5(; jac_lower = nothing, jac_upper = nothing, - M1 = nothing, M2 = nothing, - DIMOFIND1VAR = nothing, DIMOFIND2VAR = nothing, DIMOFIND3VAR = nothing, - massmatrix = nothing) - radau5(jac_lower, jac_upper, M1, M2, DIMOFIND1VAR, DIMOFIND2VAR, DIMOFIND3VAR, massmatrix) +function radau5(; + jac_lower = nothing, jac_upper = nothing, + M1 = nothing, M2 = nothing, + DIMOFIND1VAR = nothing, DIMOFIND2VAR = nothing, DIMOFIND3VAR = nothing, + massmatrix = nothing + ) + return radau5(jac_lower, jac_upper, M1, M2, DIMOFIND1VAR, DIMOFIND2VAR, DIMOFIND3VAR, massmatrix) end rodas(; jac_lower = nothing, jac_upper = nothing) = rodas(jac_lower, jac_upper) ddebdf(; jac_lower = nothing, jac_upper = nothing) = ddebdf(jac_lower, jac_upper) diff --git a/src/initialize.jl b/src/initialize.jl new file mode 100644 index 0000000..f72a24d --- /dev/null +++ b/src/initialize.jl @@ -0,0 +1,164 @@ +# DAE Initialization support for ODEInterfaceDiffEq +# Following the pattern from Sundials.jl: +# https://github.com/SciML/Sundials.jl/blob/master/src/common_interface/initialize.jl + +import SciMLBase: OverrideInit, NoInit, CheckInit, has_initialization_data +import DiffEqBase: DefaultInit + +# Re-export initialization algorithms (including DefaultInit from DiffEqBase) +export OverrideInit, NoInit, CheckInit, DefaultInit + +# DefaultInit: OverrideInit → CheckInit pattern (matching Sundials v5) +# First run OverrideInit to compute consistent initial conditions, +# then run CheckInit to verify the algebraic constraints are satisfied. +function DiffEqBase.initialize_dae!( + integrator::ODEInterfaceIntegrator, + initializealg::DefaultInit + ) + prob = integrator.sol.prob + + # First: OverrideInit to compute consistent initial conditions + if has_initialization_data(prob.f) + DiffEqBase.initialize_dae!(integrator, OverrideInit()) + + # Check if OverrideInit failed + if integrator.sol.retcode == ReturnCode.InitialFailure + return nothing + end + end + + # Then: CheckInit to verify algebraic constraints are satisfied + DiffEqBase.initialize_dae!(integrator, CheckInit()) + + return nothing +end + +# NoInit: Do nothing, assume initial conditions are correct +function DiffEqBase.initialize_dae!( + integrator::ODEInterfaceIntegrator, + initializealg::NoInit + ) + # No-op: initial conditions are assumed to be correct + return nothing +end + +# CheckInit: Verify that initial conditions satisfy the algebraic constraints +function DiffEqBase.initialize_dae!( + integrator::ODEInterfaceIntegrator, + initializealg::CheckInit + ) + prob = integrator.sol.prob + f = prob.f + M = f.mass_matrix + + # If no mass matrix or identity, no algebraic constraints to check + M == I && return nothing + + u0 = integrator.u + p = integrator.p + t = integrator.t + + # Find algebraic equations (rows of M that are all zeros) + algebraic_eqs = [all(iszero, M[i, :]) for i in axes(M, 1)] + + # If no algebraic equations, nothing to check + !any(algebraic_eqs) && return nothing + + # Evaluate the RHS + tmp = similar(u0) + f(tmp, u0, p, t) + + # Check residuals of algebraic equations only + abstol = integrator.sol.prob isa DiffEqBase.AbstractODEProblem ? + get(Dict(integrator.opts.callback.discrete_callbacks), :abstol, 1.0e-8) : 1.0e-8 + + # Try to get abstol from the solve options, fallback to default + # Note: ODEInterface doesn't expose tolerances through the integrator in the same way + abstol = 1.0e-8 # Default tolerance + + max_residual = maximum(abs.(tmp[algebraic_eqs])) + + if max_residual > abstol + error( + """ + DAE initialization failed with CheckInit: Initial conditions do not satisfy the algebraic constraints. + + The maximum residual in algebraic equations is $(max_residual), which exceeds the tolerance $(abstol). + + To resolve this issue, you have several options: + 1. Fix your initial conditions to satisfy the algebraic constraints (M * du/dt = f(u, p, t), where algebraic rows have M[i,:] = 0) + 2. If using ModelingToolkit, use: initializealg = OverrideInit() + 3. Use initializealg = NoInit() to skip initialization checks (use with caution) + + Example: + solve(prob, radau5(); initializealg = OverrideInit()) + """ + ) + end + + return nothing +end + +# OverrideInit: Use SciMLBase's initialization system (e.g., from ModelingToolkit) +function DiffEqBase.initialize_dae!( + integrator::ODEInterfaceIntegrator, + initializealg::OverrideInit + ) + prob = integrator.sol.prob + f = prob.f + + # If no initialization data, nothing to do + if !has_initialization_data(f) + return nothing + end + + # Get initial values using SciMLBase's initialization system + # Note: ODEInterface integrators don't have a built-in nonlinear solver, + # so we rely on the user providing one via nlsolve_alg in OverrideInit, + # or the initialization being trivial (no solve needed) + isinplace = Val(SciMLBase.isinplace(prob)) + + # Default tolerances + abstol = 1.0e-8 + reltol = 1.0e-8 + + u0, p, success = SciMLBase.get_initial_values( + prob, integrator, f, initializealg, isinplace; + abstol = abstol, reltol = reltol + ) + + if !success + integrator.sol = SciMLBase.solution_new_retcode( + integrator.sol, + ReturnCode.InitialFailure + ) + return nothing + end + + # Update integrator state + if SciMLBase.isinplace(prob) + integrator.u .= u0 + if length(integrator.sol.u) >= 1 && !isempty(integrator.sol.u) + integrator.sol.u[1] .= u0 + end + else + # For out-of-place problems, we need to handle this differently + # since integrator.u might be immutable + integrator.u .= u0 + if length(integrator.sol.u) >= 1 && !isempty(integrator.sol.u) + integrator.sol.u[1] = u0 + end + end + + # Update parameters if they changed (in-place via SciMLStructures) + if p !== integrator.p + SS = SciMLBase.SciMLStructures + old_vals, _, _ = SS.canonicalize(SS.Tunable(), integrator.p) + new_vals, _, _ = SS.canonicalize(SS.Tunable(), p) + copyto!(old_vals, new_vals) + end + + integrator.u_modified = true + + return nothing +end diff --git a/src/solve.jl b/src/solve.jl index a332e8a..80a721f 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -10,6 +10,7 @@ function DiffEqBase.__solve( prob.tspan[1] in saveat, timeseries_errors = true, dense_errors = false, callback = nothing, alias_u0 = false, + initializealg = DiffEqBase.DefaultInit(), kwargs... ) where {uType, tuptType, isinplace, AlgType <: ODEInterfaceAlgorithm} @@ -107,6 +108,14 @@ function DiffEqBase.__solve( ) initialize_callbacks!(integrator) + # DAE initialization - check/compute consistent initial conditions + DiffEqBase.initialize_dae!(integrator, initializealg) + + # Check if initialization failed + if integrator.sol.retcode == ReturnCode.InitialFailure + return integrator.sol + end + outputfcn = OutputFunction(integrator) o[:OUTPUTFCN] = outputfcn if !(callbacks_internal.continuous_callbacks isa Tuple{}) || !isempty(saveat) @@ -141,8 +150,8 @@ function DiffEqBase.__solve( end if !haskey(dict, :MASSMATRIX) && prob.f.mass_matrix != I - if prob.f.mass_matrix isa Matrix && isstiff - dict[:MASSMATRIX] = prob.f.mass_matrix + if prob.f.mass_matrix isa AbstractMatrix && isstiff + dict[:MASSMATRIX] = Matrix(prob.f.mass_matrix) elseif !isstiff error("This solver does not support mass matrices") else diff --git a/test/initialization_tests.jl b/test/initialization_tests.jl new file mode 100644 index 0000000..2dd713e --- /dev/null +++ b/test/initialization_tests.jl @@ -0,0 +1,99 @@ +using ODEInterfaceDiffEq, DiffEqBase, SciMLBase +using Test +using LinearAlgebra + +@testset "DAE Initialization" begin + # Simple Robertson chemical kinetics model as a DAE (with mass matrix) + # dy1/dt = -0.04*y1 + 1e4*y2*y3 + # dy2/dt = 0.04*y1 - 1e4*y2*y3 - 3e7*y2^2 + # 0 = y1 + y2 + y3 - 1 (conservation constraint) + + function robertson!(du, u, p, t) + y1, y2, y3 = u + du[1] = -0.04 * y1 + 1.0e4 * y2 * y3 + du[2] = 0.04 * y1 - 1.0e4 * y2 * y3 - 3.0e7 * y2^2 + du[3] = y1 + y2 + y3 - 1.0 # algebraic equation + return nothing + end + + # Mass matrix: [1 0 0; 0 1 0; 0 0 0] + M = [1.0 0 0; 0 1 0; 0 0 0] + + @testset "NoInit - skip initialization" begin + # With correct initial conditions + u0_correct = [1.0, 0.0, 0.0] # satisfies y1 + y2 + y3 = 1 + f = ODEFunction(robertson!; mass_matrix = M) + prob = ODEProblem(f, u0_correct, (0.0, 1.0e-3)) + + sol = solve(prob, radau5(); initializealg = NoInit()) + @test sol.retcode == ReturnCode.Success + + # With incorrect initial conditions - NoInit should still work (no check) + u0_wrong = [1.0, 1.0, 1.0] # does NOT satisfy y1 + y2 + y3 = 1 + prob_wrong = ODEProblem(f, u0_wrong, (0.0, 1.0e-3)) + + # NoInit should not throw, even with wrong ICs + sol_wrong = solve(prob_wrong, radau5(); initializealg = NoInit()) + # The solver might fail, but not during initialization + @test true # Just checking NoInit doesn't throw + end + + @testset "CheckInit - verify constraints" begin + u0_correct = [1.0, 0.0, 0.0] # satisfies y1 + y2 + y3 = 1 + f = ODEFunction(robertson!; mass_matrix = M) + prob = ODEProblem(f, u0_correct, (0.0, 1.0e-3)) + + sol = solve(prob, radau5(); initializealg = CheckInit()) + @test sol.retcode == ReturnCode.Success + + # With incorrect initial conditions + u0_wrong = [1.0, 1.0, 1.0] # does NOT satisfy y1 + y2 + y3 = 1 + prob_wrong = ODEProblem(f, u0_wrong, (0.0, 1.0e-3)) + + @test_throws ErrorException solve(prob_wrong, radau5(); initializealg = CheckInit()) + end + + @testset "DefaultInit - dispatch based on problem" begin + # Without initialization_data, should use CheckInit + u0_correct = [1.0, 0.0, 0.0] + f = ODEFunction(robertson!; mass_matrix = M) + prob = ODEProblem(f, u0_correct, (0.0, 1.0e-3)) + + sol = solve(prob, radau5(); initializealg = DefaultInit()) + @test sol.retcode == ReturnCode.Success + end + + @testset "No mass matrix - identity case" begin + # Simple ODE without mass matrix (not a DAE) + function simple_ode!(du, u, p, t) + du[1] = -u[1] + return nothing + end + + u0 = [1.0] + f = ODEFunction(simple_ode!) + prob = ODEProblem(f, u0, (0.0, 1.0)) + + # All initialization algorithms should work + sol_noinit = solve(prob, radau5(); initializealg = NoInit()) + @test sol_noinit.retcode == ReturnCode.Success + + sol_check = solve(prob, radau5(); initializealg = CheckInit()) + @test sol_check.retcode == ReturnCode.Success + + sol_default = solve(prob, radau5(); initializealg = DefaultInit()) + @test sol_default.retcode == ReturnCode.Success + end + + @testset "Multiple solvers with initialization" begin + u0_correct = [1.0, 0.0, 0.0] + f = ODEFunction(robertson!; mass_matrix = M) + prob = ODEProblem(f, u0_correct, (0.0, 1.0e-3)) + + # Test with different implicit solvers that support mass matrices + for alg in [radau5(), radau(), rodas(), seulex(), ddebdf()] + sol = solve(prob, alg; initializealg = CheckInit()) + @test sol.retcode == ReturnCode.Success + end + end +end diff --git a/test/mtk_initialization_tests.jl b/test/mtk_initialization_tests.jl new file mode 100644 index 0000000..3ec7aa3 --- /dev/null +++ b/test/mtk_initialization_tests.jl @@ -0,0 +1,110 @@ +# ModelingToolkit initialization tests for ODEInterfaceDiffEq +# Based on Sundials.jl/test/common_interface/initialization.jl + +using ODEInterfaceDiffEq, DiffEqBase, SciMLBase, Test +using ModelingToolkit +using NonlinearSolve +using SymbolicIndexingInterface +using ModelingToolkit: t_nounits as t, D_nounits as D + +@testset "MTK ODE Initialization" begin + # ODE with missing parameters that need to be determined via initialization + # System: dx/dt = p*y + q*t, dy/dt = 5x + q + # With initialization equations to find p and q from initial state constraints + @variables x(t) [guess = 1.0] y(t) [guess = 1.0] + @parameters p = missing [guess = 1.0] q = missing [guess = 1.0] + + @mtkcompile sys = System( + [D(x) ~ p * y + q * t, D(y) ~ 5x + q], + t; + initialization_eqs = [p^2 + q^2 ~ 3, x^3 + y^3 ~ 5] + ) + + @testset "IIP: $iip" for iip in [true, false] + prob = ODEProblem{iip}(sys, [x => 1.0, p => 1.0], (0.0, 0.1)) + + # Test with implicit solvers that support OverrideInit + @testset "$alg" for alg in [radau5, radau, rodas, seulex] + sol = solve(prob, alg()) + @test SciMLBase.successful_retcode(sol) + @test sol[x, 1] ≈ 1.0 + @test sol[y, 1] ≈ cbrt(4) + @test sol.ps[p] ≈ 1.0 + @test sol.ps[q] ≈ sqrt(2) + end + end +end + +@testset "MTK DAE Initialization (Mass Matrix)" begin + # DAE system using mass matrix formulation + # System: dx/dt = p*y + q*t, x^3 + y^3 = 5 (algebraic constraint) + # With initialization equations to find the missing parameter q + @variables x(t) [guess = 1.0] y(t) [guess = 1.0] + @parameters p = missing [guess = 1.0] q = missing [guess = 1.0] + + @mtkcompile sys = System( + [D(x) ~ p * y + q * t, x^3 + y^3 ~ 5], + t; + initialization_eqs = [p^2 + q^2 ~ 3] + ) + + @testset "OverrideInit" begin + # Provide D(x) guess and p value - initialization should determine x, y, q + prob = ODEProblem(sys, [D(x) => cbrt(4), p => 1.0], (0.0, 0.1)) + + @testset "$alg" for alg in [radau5, radau] + sol = solve(prob, alg()) + @test SciMLBase.successful_retcode(sol) + @test sol[x, 1] ≈ 1.0 + @test sol[y, 1] ≈ cbrt(4) + @test sol.ps[p] ≈ 1.0 + @test sol.ps[q] ≈ sqrt(2) + end + end + + @testset "CheckInit" begin + prob = ODEProblem(sys, [D(x) => cbrt(4), p => 1.0], (0.0, 0.1)) + + # CheckInit should fail because the MTK problem has initialization_data + # and the user-provided values need processing via OverrideInit + @test_throws Any solve(prob, radau5(); initializealg = SciMLBase.CheckInit()) + end +end + +@testset "MTK Pendulum DAE" begin + # Classic pendulum as index-3 DAE (reduced to index-1 via ModelingToolkit) + # This tests a more complex DAE initialization scenario + @parameters g = 9.81 L = 1.0 + @variables begin + x(t), [guess = 1.0] + y(t), [guess = 0.0] + vx(t), [guess = 0.0] + vy(t), [guess = 0.0] + λ(t), [guess = 0.0] # Lagrange multiplier + end + + eqs = [ + D(x) ~ vx + D(y) ~ vy + D(vx) ~ -2λ * x + D(vy) ~ -2λ * y - g + x^2 + y^2 ~ L^2 # algebraic constraint + ] + + @mtkcompile pend = System(eqs, t) + + # Initial condition: pendulum at angle θ₀ = π/6 from vertical + θ₀ = π / 6 + L_val = 1.0 + x0 = L_val * sin(θ₀) + y0 = -L_val * cos(θ₀) + + prob = ODEProblem(pend, [x => x0, y => y0, vx => 0.0, vy => 0.0], (0.0, 0.5)) + + @testset "$alg" for alg in [radau5, radau] + sol = solve(prob, alg()) + @test SciMLBase.successful_retcode(sol) + # Check that the constraint is satisfied throughout + @test all(abs.((sol[x] .^ 2 .+ sol[y] .^ 2) .- L_val^2) .< 1.0e-4) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 5d27976..630a8d4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -19,3 +19,9 @@ end @time @testset "Callback Tests" begin include("callbacks.jl") end +@time @testset "Initialization Tests" begin + include("initialization_tests.jl") +end +@time @testset "MTK Initialization Tests" begin + include("mtk_initialization_tests.jl") +end