Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .JuliaFormatter.toml

This file was deleted.

14 changes: 10 additions & 4 deletions .github/workflows/FormatCheck.yml
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
name: "Format Check"
name: format-check

on:
push:
branches:
- 'master'
- 'main'
- 'release-'
tags: '*'
pull_request:

jobs:
format-check:
name: "Format Check"
uses: "SciML/.github/.github/workflows/format-check.yml@v1"
runic:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: fredrikekre/runic-action@v1
with:
version: '1'
36 changes: 19 additions & 17 deletions src/ODEInterfaceDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,31 @@ __precompile__()

module ODEInterfaceDiffEq

using Reexport
@reexport using DiffEqBase
using Reexport
@reexport using DiffEqBase

using ODEInterface, Compat, DataStructures, FunctionWrappers
using LinearAlgebra
using ODEInterface, Compat, DataStructures, FunctionWrappers
using LinearAlgebra

import DiffEqBase: solve
import DiffEqBase: solve

const warnkeywords = (:save_idxs, :d_discontinuities, :unstable_check, :tstops,
:calck, :progress, :dense, :save_start)
const warnkeywords = (
:save_idxs, :d_discontinuities, :unstable_check, :tstops,
:calck, :progress, :dense, :save_start,
)

function __init__()
global warnlist = Set(warnkeywords)
end
function __init__()
return global warnlist = Set(warnkeywords)
end

const KW = Dict{Symbol, Any}
const KW = Dict{Symbol, Any}

include("algorithms.jl")
include("integrator_types.jl")
include("integrator_utils.jl")
include("solve.jl")
include("algorithms.jl")
include("integrator_types.jl")
include("integrator_utils.jl")
include("solve.jl")

export ODEInterfaceAlgorithm, dopri5, dop853, odex, seulex, radau, radau5, rodas,
ddeabm, ddebdf
export ODEInterfaceAlgorithm, dopri5, dop853, odex, seulex, radau, radau5, rodas,
ddeabm, ddebdf

end # module
16 changes: 10 additions & 6 deletions src/integrator_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ mutable struct DEOptions{SType, CType}
callback::CType
end

mutable struct ODEInterfaceIntegrator{F, algType, uType, uPrevType, oType, SType, solType,
P, CallbackCacheType} <:
DiffEqBase.AbstractODEIntegrator{algType, true, uType, Float64}
mutable struct ODEInterfaceIntegrator{
F, algType, uType, uPrevType, oType, SType, solType,
P, CallbackCacheType,
} <:
DiffEqBase.AbstractODEIntegrator{algType, true, uType, Float64}
f::F
u::uType
uprev::uPrevType
Expand All @@ -27,9 +29,11 @@ mutable struct ODEInterfaceIntegrator{F, algType, uType, uPrevType, oType, SType
last_event_error::Float64
end

@inline function (integrator::ODEInterfaceIntegrator)(t, deriv::Type{Val{N}} = Val{0};
idxs = nothing) where {N}
@assert N==0 "ODEInterface does not support dense derivative"
@inline function (integrator::ODEInterfaceIntegrator)(
t, deriv::Type{Val{N}} = Val{0};
idxs = nothing
) where {N}
@assert N == 0 "ODEInterface does not support dense derivative"
sol = integrator.eval_sol_fcn(t)
return idxs == nothing ? sol : sol[idxs]
end
52 changes: 32 additions & 20 deletions src/integrator_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,38 +9,46 @@ function handle_callbacks!(integrator, eval_sol_fcn)
saved_in_cb = false
if !(continuous_callbacks isa Tuple{})
time, upcrossing,
event_occurred,
event_idx,
idx,
counter = DiffEqBase.find_first_continuous_callback(integrator,
continuous_callbacks...)
event_occurred,
event_idx,
idx,
counter = DiffEqBase.find_first_continuous_callback(
integrator,
continuous_callbacks...
)
if event_occurred
integrator.event_last_time = idx
integrator.vector_event_last_time = event_idx
continuous_modified,
saved_in_cb = DiffEqBase.apply_callback!(integrator,
saved_in_cb = DiffEqBase.apply_callback!(
integrator,
continuous_callbacks[idx],
time, upcrossing,
event_idx)
event_idx
)
else
integrator.event_last_time = 0
integrator.vector_event_last_time = 1
end
end
if !(discrete_callbacks isa Tuple{})
discrete_modified,
saved_in_cb = DiffEqBase.apply_discrete_callback!(integrator,
discrete_callbacks...)
saved_in_cb = DiffEqBase.apply_discrete_callback!(
integrator,
discrete_callbacks...
)
end
if !saved_in_cb
savevalues!(integrator)
end

integrator.u_modified = continuous_modified || discrete_modified
return integrator.u_modified = continuous_modified || discrete_modified
end

function DiffEqBase.savevalues!(integrator::ODEInterfaceIntegrator,
force_save = false)::Tuple{Bool, Bool}
function DiffEqBase.savevalues!(
integrator::ODEInterfaceIntegrator,
force_save = false
)::Tuple{Bool, Bool}
saved, savedexactly = false, false
!integrator.opts.save_on && return saved, savedexactly
uType = eltype(integrator.sol.u)
Expand All @@ -52,7 +60,7 @@ function DiffEqBase.savevalues!(integrator::ODEInterfaceIntegrator,
end

while !isempty(integrator.opts.saveat) &&
integrator.tdir * first(integrator.opts.saveat) < integrator.tdir * integrator.t
integrator.tdir * first(integrator.opts.saveat) < integrator.tdir * integrator.t
saved = true
curt = pop!(integrator.opts.saveat)
tmp = integrator(curt)::Vector{Float64}
Expand All @@ -63,17 +71,19 @@ function DiffEqBase.savevalues!(integrator::ODEInterfaceIntegrator,
return saved, savedexactly
end

function DiffEqBase.change_t_via_interpolation!(integrator::ODEInterfaceIntegrator, t,
function DiffEqBase.change_t_via_interpolation!(
integrator::ODEInterfaceIntegrator, t,
modify_save_endpoint::Type{Val{T}} = Val{false},
reinitialize_alg = nothing) where {T}
reinitialize_alg = nothing
) where {T}
integrator.t = t
tmp = integrator(integrator.t)::Vector{Float64}
if eltype(integrator.sol.u) <: Vector
integrator.u .= tmp
else
integrator.u .= reshape(tmp, integrator.sizeu)
end
nothing
return nothing
end
DiffEqBase.get_tmp_cache(i::ODEInterfaceIntegrator, args...) = nothing

Expand All @@ -86,7 +96,7 @@ DiffEqBase.get_tmp_cache(i::ODEInterfaceIntegrator, args...) = nothing
end

@inline function DiffEqBase.u_modified!(integrator::ODEInterfaceIntegrator, bool::Bool)
integrator.u_modified = bool
return integrator.u_modified = bool
end

function initialize_callbacks!(integrator, initialize_save = true)
Expand All @@ -100,14 +110,16 @@ function initialize_callbacks!(integrator, initialize_save = true)
# if the user modifies u, we need to fix current values
if u_modified
if initialize_save &&
(any((c) -> c.save_positions[2], callbacks.discrete_callbacks) ||
any((c) -> c.save_positions[2], callbacks.continuous_callbacks))
(
any((c) -> c.save_positions[2], callbacks.discrete_callbacks) ||
any((c) -> c.save_positions[2], callbacks.continuous_callbacks)
)
savevalues!(integrator, true)
end
end

# reset this as it is now handled so the integrators should proceed as normal
integrator.u_modified = false
return integrator.u_modified = false
end

DiffEqBase.set_proposed_dt!(integrator::ODEInterfaceIntegrator, dt) = nothing
Loading
Loading