Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,11 @@ julia = "1.6"

[extras]
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
ODEProblemLibrary = "fdc4e326-1af4-4b90-96e7-779fcce2daa5"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["ExplicitImports", "Test", "ODEProblemLibrary"]
test = ["ExplicitImports", "ModelingToolkit", "NonlinearSolve", "SymbolicIndexingInterface", "Test", "ODEProblemLibrary"]
1 change: 1 addition & 0 deletions src/ODEInterfaceDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ module ODEInterfaceDiffEq
include("integrator_types.jl")
include("integrator_utils.jl")
include("solve.jl")
include("initialize.jl")

export ODEInterfaceAlgorithm, dopri5, dop853, odex, seulex, radau, radau5, rodas,
ddeabm, ddebdf
Expand Down
24 changes: 14 additions & 10 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -226,17 +226,21 @@ end
SciMLBase.alg_order(alg::ddebdf) = 5

seulex(; jac_lower = nothing, jac_upper = nothing) = seulex(jac_lower, jac_upper)
function radau(; jac_lower = nothing, jac_upper = nothing,
M1 = nothing, M2 = nothing,
DIMOFIND1VAR = nothing, DIMOFIND2VAR = nothing, DIMOFIND3VAR = nothing,
massmatrix = nothing)
radau(jac_lower, jac_upper, M1, M2, DIMOFIND1VAR, DIMOFIND2VAR, DIMOFIND3VAR, massmatrix)
function radau(;
jac_lower = nothing, jac_upper = nothing,
M1 = nothing, M2 = nothing,
DIMOFIND1VAR = nothing, DIMOFIND2VAR = nothing, DIMOFIND3VAR = nothing,
massmatrix = nothing
)
return radau(jac_lower, jac_upper, M1, M2, DIMOFIND1VAR, DIMOFIND2VAR, DIMOFIND3VAR, massmatrix)
end
function radau5(; jac_lower = nothing, jac_upper = nothing,
M1 = nothing, M2 = nothing,
DIMOFIND1VAR = nothing, DIMOFIND2VAR = nothing, DIMOFIND3VAR = nothing,
massmatrix = nothing)
radau5(jac_lower, jac_upper, M1, M2, DIMOFIND1VAR, DIMOFIND2VAR, DIMOFIND3VAR, massmatrix)
function radau5(;
jac_lower = nothing, jac_upper = nothing,
M1 = nothing, M2 = nothing,
DIMOFIND1VAR = nothing, DIMOFIND2VAR = nothing, DIMOFIND3VAR = nothing,
massmatrix = nothing
)
return radau5(jac_lower, jac_upper, M1, M2, DIMOFIND1VAR, DIMOFIND2VAR, DIMOFIND3VAR, massmatrix)
end
rodas(; jac_lower = nothing, jac_upper = nothing) = rodas(jac_lower, jac_upper)
ddebdf(; jac_lower = nothing, jac_upper = nothing) = ddebdf(jac_lower, jac_upper)
164 changes: 164 additions & 0 deletions src/initialize.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# DAE Initialization support for ODEInterfaceDiffEq
# Following the pattern from Sundials.jl:
# https://github.com/SciML/Sundials.jl/blob/master/src/common_interface/initialize.jl

import SciMLBase: OverrideInit, NoInit, CheckInit, has_initialization_data
import DiffEqBase: DefaultInit

# Re-export initialization algorithms (including DefaultInit from DiffEqBase)
export OverrideInit, NoInit, CheckInit, DefaultInit

# DefaultInit: OverrideInit → CheckInit pattern (matching Sundials v5)
# First run OverrideInit to compute consistent initial conditions,
# then run CheckInit to verify the algebraic constraints are satisfied.
function DiffEqBase.initialize_dae!(
integrator::ODEInterfaceIntegrator,
initializealg::DefaultInit
)
prob = integrator.sol.prob

# First: OverrideInit to compute consistent initial conditions
if has_initialization_data(prob.f)
DiffEqBase.initialize_dae!(integrator, OverrideInit())

# Check if OverrideInit failed
if integrator.sol.retcode == ReturnCode.InitialFailure
return nothing
end
end

# Then: CheckInit to verify algebraic constraints are satisfied
DiffEqBase.initialize_dae!(integrator, CheckInit())

return nothing
end

# NoInit: Do nothing, assume initial conditions are correct
function DiffEqBase.initialize_dae!(
integrator::ODEInterfaceIntegrator,
initializealg::NoInit
)
# No-op: initial conditions are assumed to be correct
return nothing
end

# CheckInit: Verify that initial conditions satisfy the algebraic constraints
function DiffEqBase.initialize_dae!(
integrator::ODEInterfaceIntegrator,
initializealg::CheckInit
)
prob = integrator.sol.prob
f = prob.f
M = f.mass_matrix

# If no mass matrix or identity, no algebraic constraints to check
M == I && return nothing

u0 = integrator.u
p = integrator.p
t = integrator.t

# Find algebraic equations (rows of M that are all zeros)
algebraic_eqs = [all(iszero, M[i, :]) for i in axes(M, 1)]

# If no algebraic equations, nothing to check
!any(algebraic_eqs) && return nothing

# Evaluate the RHS
tmp = similar(u0)
f(tmp, u0, p, t)

# Check residuals of algebraic equations only
abstol = integrator.sol.prob isa DiffEqBase.AbstractODEProblem ?
get(Dict(integrator.opts.callback.discrete_callbacks), :abstol, 1.0e-8) : 1.0e-8

# Try to get abstol from the solve options, fallback to default
# Note: ODEInterface doesn't expose tolerances through the integrator in the same way
abstol = 1.0e-8 # Default tolerance

max_residual = maximum(abs.(tmp[algebraic_eqs]))

if max_residual > abstol
error(
"""
DAE initialization failed with CheckInit: Initial conditions do not satisfy the algebraic constraints.

The maximum residual in algebraic equations is $(max_residual), which exceeds the tolerance $(abstol).

To resolve this issue, you have several options:
1. Fix your initial conditions to satisfy the algebraic constraints (M * du/dt = f(u, p, t), where algebraic rows have M[i,:] = 0)
2. If using ModelingToolkit, use: initializealg = OverrideInit()
3. Use initializealg = NoInit() to skip initialization checks (use with caution)

Example:
solve(prob, radau5(); initializealg = OverrideInit())
"""
)
end

return nothing
end

# OverrideInit: Use SciMLBase's initialization system (e.g., from ModelingToolkit)
function DiffEqBase.initialize_dae!(
integrator::ODEInterfaceIntegrator,
initializealg::OverrideInit
)
prob = integrator.sol.prob
f = prob.f

# If no initialization data, nothing to do
if !has_initialization_data(f)
return nothing
end

# Get initial values using SciMLBase's initialization system
# Note: ODEInterface integrators don't have a built-in nonlinear solver,
# so we rely on the user providing one via nlsolve_alg in OverrideInit,
# or the initialization being trivial (no solve needed)
isinplace = Val(SciMLBase.isinplace(prob))

# Default tolerances
abstol = 1.0e-8
reltol = 1.0e-8

u0, p, success = SciMLBase.get_initial_values(
prob, integrator, f, initializealg, isinplace;
abstol = abstol, reltol = reltol
)

if !success
integrator.sol = SciMLBase.solution_new_retcode(
integrator.sol,
ReturnCode.InitialFailure
)
return nothing
end

# Update integrator state
if SciMLBase.isinplace(prob)
integrator.u .= u0
if length(integrator.sol.u) >= 1 && !isempty(integrator.sol.u)
integrator.sol.u[1] .= u0
end
else
# For out-of-place problems, we need to handle this differently
# since integrator.u might be immutable
integrator.u .= u0
if length(integrator.sol.u) >= 1 && !isempty(integrator.sol.u)
integrator.sol.u[1] = u0
end
end

# Update parameters if they changed (in-place via SciMLStructures)
if p !== integrator.p
SS = SciMLBase.SciMLStructures
old_vals, _, _ = SS.canonicalize(SS.Tunable(), integrator.p)
new_vals, _, _ = SS.canonicalize(SS.Tunable(), p)
copyto!(old_vals, new_vals)
end

integrator.u_modified = true

return nothing
end
13 changes: 11 additions & 2 deletions src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ function DiffEqBase.__solve(
prob.tspan[1] in saveat,
timeseries_errors = true, dense_errors = false,
callback = nothing, alias_u0 = false,
initializealg = DiffEqBase.DefaultInit(),
kwargs...
) where
{uType, tuptType, isinplace, AlgType <: ODEInterfaceAlgorithm}
Expand Down Expand Up @@ -107,6 +108,14 @@ function DiffEqBase.__solve(
)
initialize_callbacks!(integrator)

# DAE initialization - check/compute consistent initial conditions
DiffEqBase.initialize_dae!(integrator, initializealg)

# Check if initialization failed
if integrator.sol.retcode == ReturnCode.InitialFailure
return integrator.sol
end

outputfcn = OutputFunction(integrator)
o[:OUTPUTFCN] = outputfcn
if !(callbacks_internal.continuous_callbacks isa Tuple{}) || !isempty(saveat)
Expand Down Expand Up @@ -141,8 +150,8 @@ function DiffEqBase.__solve(
end

if !haskey(dict, :MASSMATRIX) && prob.f.mass_matrix != I
if prob.f.mass_matrix isa Matrix && isstiff
dict[:MASSMATRIX] = prob.f.mass_matrix
if prob.f.mass_matrix isa AbstractMatrix && isstiff
dict[:MASSMATRIX] = Matrix(prob.f.mass_matrix)
elseif !isstiff
error("This solver does not support mass matrices")
else
Expand Down
99 changes: 99 additions & 0 deletions test/initialization_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
using ODEInterfaceDiffEq, DiffEqBase, SciMLBase
using Test
using LinearAlgebra

@testset "DAE Initialization" begin
# Simple Robertson chemical kinetics model as a DAE (with mass matrix)
# dy1/dt = -0.04*y1 + 1e4*y2*y3
# dy2/dt = 0.04*y1 - 1e4*y2*y3 - 3e7*y2^2
# 0 = y1 + y2 + y3 - 1 (conservation constraint)

function robertson!(du, u, p, t)
y1, y2, y3 = u
du[1] = -0.04 * y1 + 1.0e4 * y2 * y3
du[2] = 0.04 * y1 - 1.0e4 * y2 * y3 - 3.0e7 * y2^2
du[3] = y1 + y2 + y3 - 1.0 # algebraic equation
return nothing
end

# Mass matrix: [1 0 0; 0 1 0; 0 0 0]
M = [1.0 0 0; 0 1 0; 0 0 0]

@testset "NoInit - skip initialization" begin
# With correct initial conditions
u0_correct = [1.0, 0.0, 0.0] # satisfies y1 + y2 + y3 = 1
f = ODEFunction(robertson!; mass_matrix = M)
prob = ODEProblem(f, u0_correct, (0.0, 1.0e-3))

sol = solve(prob, radau5(); initializealg = NoInit())
@test sol.retcode == ReturnCode.Success

# With incorrect initial conditions - NoInit should still work (no check)
u0_wrong = [1.0, 1.0, 1.0] # does NOT satisfy y1 + y2 + y3 = 1
prob_wrong = ODEProblem(f, u0_wrong, (0.0, 1.0e-3))

# NoInit should not throw, even with wrong ICs
sol_wrong = solve(prob_wrong, radau5(); initializealg = NoInit())
# The solver might fail, but not during initialization
@test true # Just checking NoInit doesn't throw
end

@testset "CheckInit - verify constraints" begin
u0_correct = [1.0, 0.0, 0.0] # satisfies y1 + y2 + y3 = 1
f = ODEFunction(robertson!; mass_matrix = M)
prob = ODEProblem(f, u0_correct, (0.0, 1.0e-3))

sol = solve(prob, radau5(); initializealg = CheckInit())
@test sol.retcode == ReturnCode.Success

# With incorrect initial conditions
u0_wrong = [1.0, 1.0, 1.0] # does NOT satisfy y1 + y2 + y3 = 1
prob_wrong = ODEProblem(f, u0_wrong, (0.0, 1.0e-3))

@test_throws ErrorException solve(prob_wrong, radau5(); initializealg = CheckInit())
end

@testset "DefaultInit - dispatch based on problem" begin
# Without initialization_data, should use CheckInit
u0_correct = [1.0, 0.0, 0.0]
f = ODEFunction(robertson!; mass_matrix = M)
prob = ODEProblem(f, u0_correct, (0.0, 1.0e-3))

sol = solve(prob, radau5(); initializealg = DefaultInit())
@test sol.retcode == ReturnCode.Success
end

@testset "No mass matrix - identity case" begin
# Simple ODE without mass matrix (not a DAE)
function simple_ode!(du, u, p, t)
du[1] = -u[1]
return nothing
end

u0 = [1.0]
f = ODEFunction(simple_ode!)
prob = ODEProblem(f, u0, (0.0, 1.0))

# All initialization algorithms should work
sol_noinit = solve(prob, radau5(); initializealg = NoInit())
@test sol_noinit.retcode == ReturnCode.Success

sol_check = solve(prob, radau5(); initializealg = CheckInit())
@test sol_check.retcode == ReturnCode.Success

sol_default = solve(prob, radau5(); initializealg = DefaultInit())
@test sol_default.retcode == ReturnCode.Success
end

@testset "Multiple solvers with initialization" begin
u0_correct = [1.0, 0.0, 0.0]
f = ODEFunction(robertson!; mass_matrix = M)
prob = ODEProblem(f, u0_correct, (0.0, 1.0e-3))

# Test with different implicit solvers that support mass matrices
for alg in [radau5(), radau(), rodas(), seulex(), ddebdf()]
sol = solve(prob, alg; initializealg = CheckInit())
@test sol.retcode == ReturnCode.Success
end
end
end
Loading
Loading