Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/OrdinaryDiffEqBDF/src/OrdinaryDiffEqBDF.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import OrdinaryDiffEqCore: alg_order, calculate_residuals!,
DAEAlgorithm, _unwrap_val, DummyController,
get_fsalfirstlast, generic_solver_docstring, _ad_chunksize_int, _ad_fdtype, _fixup_ad,
_ode_interpolant, _ode_interpolant!, has_stiff_interpolation,
_ode_addsteps!, DerivativeOrderNotPossibleError
_ode_addsteps!, DerivativeOrderNotPossibleError, set_discontinuity
using OrdinaryDiffEqSDIRK: ImplicitEulerConstantCache, ImplicitEulerCache

using TruncatedStacktraces: @truncate_stacktrace
Expand Down
11 changes: 10 additions & 1 deletion lib/OrdinaryDiffEqBDF/src/controllers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,15 @@ end
function bdf_step_reject_controller!(integrator, cache, EEst1)
k = cache.order
h = integrator.dt
integrator.cache.consfailcnt += 1
integrator.cache.nconsteps = 0

disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache)
if disco_dt != -1
integrator.dt = disco_dt
return integrator.dt
end

cache.consfailcnt += 1
cache.nconsteps = 0
if cache.consfailcnt > 1
Expand Down Expand Up @@ -495,4 +504,4 @@ function step_accept_controller!(
cache.qwait -= 1 # countdown
end
return integrator.dt / q
end
end
1 change: 1 addition & 0 deletions lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ include("cache_utils.jl")
include("initialize_dae.jl")

include("perform_step/composite_perform_step.jl")
include("disco.jl")

include("dense/generic_dense.jl")

Expand Down
74 changes: 74 additions & 0 deletions lib/OrdinaryDiffEqCore/src/disco.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
function set_discontinuity(u, uprev, integrator, cache)
breakpointθ = find_discontinuity(u, uprev, integrator, cache)
dt = integrator.dt
t = integrator.t
if !isnan(breakpointθ) && 1e-6 < breakpointθ < 1.0
#println("Discontinuity detected at t = ", t + breakpointθ * dt)
return breakpointθ * dt
end
return -1
end

function find_discontinuity(u, uprev, integrator, cache)
cb = integrator.opts.callback
cb === nothing && return -1
isempty(cb.continuous_callbacks) && return -1
p = integrator.p
t = integrator.t
dt = integrator.dt
save_idxs = integrator.opts.save_idxs
k = integrator.k
cache = integrator.cache
differential_vars = integrator.differential_vars
θlo = zero(dt)
θhi = one(dt)
bracket = [θlo, θhi]
breakpointθ = -one(dt)
idx = 1
for i in cb.continuous_callbacks
if (!(i.is_discontinuity))
continue
end
disco_prob = integrator.disco_probs[idx]
disco_zero = disco_prob.f.f
disco_zero.dt = dt
disco_zero.uprev = uprev
disco_zero.u = u
disco_zero.k = k
disco_zero.cache = cache
disco_zero.differential_vars = differential_vars
disco_zero.idxs = save_idxs
disco_zero.tprev = t
disco_zero.f = integrator.f
disco_zero.p = p
if (i isa VectorContinuousCallback)
len_cb = i.len
out_prev = similar(u)
out_curr = similar(u)
i.condition(out_prev, uprev, t, integrator)
i.condition(out_curr, u, t + dt, integrator)
for j in 1:len_cb
if (out_prev[j] * out_curr[j] < zero(out_prev[j]))
disco_zero.ind = j
sol = solve(disco_prob; bracket = bracket)
tmp = sol[]
if (!isnan(tmp) && (breakpointθ == -1 || tmp < breakpointθ))
breakpointθ = tmp
end
end
end
else
out_prev = i.condition(uprev, t, integrator)
out_curr = i.condition(u, t + dt, integrator)
if (out_prev * out_curr < zero(out_prev))
sol = solve(disco_prob; bracket = bracket)
tmp = sol[]
if (!isnan(tmp) && (breakpointθ == -1 || tmp < breakpointθ))
breakpointθ = tmp
end
end
end
idx += 1
end
breakpointθ
end
27 changes: 26 additions & 1 deletion lib/OrdinaryDiffEqCore/src/integrators/controllers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,11 @@ end
end

@inline function step_reject_controller!(integrator, alg)
disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache)
if disco_dt != -1
integrator.dt = disco_dt
return integrator.dt
end
step_reject_controller!(integrator, integrator.controller_cache, alg)
cache = integrator.cache
if hasfield(typeof(cache), :nlsolve)
Expand Down Expand Up @@ -320,7 +325,12 @@ function step_accept_controller!(integrator, cache::IControllerCache, alg, q)
end

function step_reject_controller!(integrator, cache::IControllerCache, alg)
return integrator.dt = cache.dtreject
disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache)
if disco_dt != -1
integrator.dt = disco_dt
return integrator.dt
end
return integrator.dt = cache.dtreject # TODO this does not look right.
end

reinit_controller!(integrator::SciMLBase.DEIntegrator, cache::IControllerCache) = nothing
Expand Down Expand Up @@ -465,6 +475,11 @@ end
function step_reject_controller!(integrator, cache::PIControllerCache, alg)
(; controller, q11) = cache
(; qmin, gamma) = controller
disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache)
if disco_dt != -1
integrator.dt = disco_dt
return integrator.dt
end
return integrator.dt /= min(inv(qmin), q11 / gamma)
end

Expand Down Expand Up @@ -683,6 +698,11 @@ function step_accept_controller!(integrator, cache::PIDControllerCache, alg, dt_
end

function step_reject_controller!(integrator, cache::PIDControllerCache, alg)
disco_dt = set_discontinuity(integrator.u, integrator.uprev, integrator, integrator.cache)
if disco_dt != -1
integrator.dt = disco_dt
return integrator.dt
end
return integrator.dt *= cache.dt_factor
end

Expand Down Expand Up @@ -868,6 +888,11 @@ end
function step_reject_controller!(integrator, cache::PredictiveControllerCache, alg)
(; dt, success_iter) = integrator
(; qold) = cache
if (integrator.disco_dt_set)
println("using fixed dt from discontinuity handling")
integrator.disco_dt_set = false
return integrator.dt
end
return integrator.dt = success_iter == 0 ? 0.1 * dt : dt / qold
end

Expand Down
3 changes: 3 additions & 0 deletions lib/OrdinaryDiffEqCore/src/integrators/type.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ mutable struct ODEIntegrator{
dtcache::tType
dtchangeable::Bool
dtpropose::tType
disco_dt_set::Bool
tdir::tdirType
eigen_est::eigenType
controller_cache::CC
Expand Down Expand Up @@ -175,6 +176,8 @@ mutable struct ODEIntegrator{
fsalfirst::FSALType
fsallast::FSALType
rng::RNGType
#disco_prob::IntervalNonlinearProblem
disco_probs::Vector{IntervalNonlinearProblem} #should we change this?
W::WType
P::PType
sqdt::SqdtType
Expand Down
55 changes: 52 additions & 3 deletions lib/OrdinaryDiffEqCore/src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,36 @@ determine_controller_datatype(u, internalnorm, ts::Tuple{<:Number, <:Number}) =
determine_controller_datatype(u::AbstractVector{<:Number}, internalnorm, ts::Tuple{<:Integer, <:Integer}) = promote_type(typeof(DiffEqBase.value(internalnorm(u, ts[1]))), typeof(DiffEqBase.value(internalnorm(u, ts[2]))), eltype(float.(DiffEqBase.value(ts))))
determine_controller_datatype(u, internalnorm, ts::Tuple{<:Integer, <:Integer}) = promote_type(typeof(float(DiffEqBase.value(ts[1]))), typeof(float(DiffEqBase.value(ts[2])))) # This seems to be an assumption implicitly taken somewhere

mutable struct zero_func_struct{uType, tType, kType, CacheType, idxsType, varsType, callbackType, outType, FunctionType, tType2, ParameterType}
#integrator_ref::IntegratorType
u₁::uType
callback::callbackType
dt::tType
uprev::uType
u::uType
k::kType
cache::CacheType
idxs::idxsType
differential_vars::varsType
ind::Int
out::outType
f::FunctionType
tprev::tType2
p::ParameterType
end

function (z::zero_func_struct)(θ, p)
_ode_addsteps!(z.k, z.tprev, z.uprev, z.u, z.dt, z.f, z.p, z.cache, false, true, false)
ode_interpolant!(z.u₁, θ, z.dt, z.uprev, z.u, z.k, z.cache, z.idxs, Val{0}, z.differential_vars)
return zero_condition(z.callback, z.out, z.u₁, z.dt + θ * z.dt, z, z.ind)
end

@inline zero_condition(cb::ContinuousCallback, out::Nothing, u, t, z, ind) = cb.condition(u, t, z)
@inline function zero_condition(cb::VectorContinuousCallback, out, u, t, z, ind)
cb.condition(out, u, t, z)
return out[ind]
end

Base.@constprop :aggressive function SciMLBase.__init(
prob::Union{
SciMLBase.AbstractODEProblem,
Expand Down Expand Up @@ -57,6 +87,7 @@ Base.@constprop :aggressive function _ode_init(
save_everystep = isempty(saveat),
save_on = true,
save_discretes = true,
disco_dt_set = false,
save_start = save_everystep || isempty(saveat) ||
saveat isa Number || prob.tspan[1] in saveat,
save_end = nothing,
Expand Down Expand Up @@ -99,6 +130,7 @@ Base.@constprop :aggressive function _ode_init(
alias = ODEAliasSpecifier(),
initializealg = DefaultInit(),
rng = nothing,
disco_probs = nothing,
# SDE/RODE fields: accepted here so that SDE packages can delegate to
# _ode_init and construct an ODEIntegrator with noise populated.
save_noise = false,
Expand Down Expand Up @@ -626,6 +658,23 @@ Base.@constprop :aggressive function _ode_init(

_rng = rng === nothing ? Random.default_rng() : rng

num_probs = 0
for i in callbacks_internal.continuous_callbacks
if i.is_discontinuity
num_probs += 1
end
end
disco_probs = Vector{IntervalNonlinearProblem}(undef, num_probs)
idx = 1
for i in callbacks_internal.continuous_callbacks
if i.is_discontinuity
u₁ = similar(u)
out = i isa VectorContinuousCallback ? similar(u) : nothing
zero_func = zero_func_struct(u₁, i, _dt, uprev, u, k, cache, save_idxs, differential_vars, 1, out, f, tprev, p)
disco_probs[idx] = IntervalNonlinearProblem(zero_func, [zero(tType), one(tType)], p)
idx += 1
end
end
# Seed the initial EEst on the controller cache (was previously
# `integrator.EEst = oneunit(EEstT)`).
set_EEst!(controller_cache, EEst)
Expand All @@ -640,12 +689,12 @@ Base.@constprop :aggressive function _ode_init(
typeof(initializealg), typeof(differential_vars),
typeof(controller_cache), typeof(_rng),
typeof(W), typeof(P), typeof(sqdt),
typeof(noise), typeof(c), typeof(rate_constants),
typeof(noise), typeof(c), typeof(rate_constants)
}(
sol, u, du, k, t, tType(_dt), f, p,
uprev, uprev2, duprev, tprev,
_alg, dtcache, dtchangeable,
dtpropose, tdir, eigen_est,
dtpropose, disco_dt_set, tdir, eigen_est,
controller_cache,
success_iter,
iter, saveiter, saveiter_dense, cache,
Expand All @@ -659,7 +708,7 @@ Base.@constprop :aggressive function _ode_init(
isout, reeval_fsal,
derivative_discontinuity, reinitialize, isdae,
opts, stats, initializealg, differential_vars,
fsalfirst, fsallast, _rng,
fsalfirst, fsallast, _rng, disco_probs,
W, P, sqdt,
noise, c, rate_constants
)
Expand Down
83 changes: 83 additions & 0 deletions lib/OrdinaryDiffEqCore/test/disco_benchmarks.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
using DiffEqDevTools, Test, LinearAlgebra
using OrdinaryDiffEqTsit5, OrdinaryDiffEqRosenbrock, OrdinaryDiffEqLowOrderRK
using OrdinaryDiffEqRadau, OrdinaryDiffEqBS3
using Logging
global_logger(ConsoleLogger(stderr, Logging.Error))
using BenchmarkTools


#tests against Hairer's RADAR problems
h(p, t) = 0.5

# state-dependent delay: τ(t) = y(t)
function delay(p, t, u)
return u[1]
end

# DDE: y'(t) = y(y(t))
function f(du, u, h, p, t)
τ = u[1]
du[1] = h(p, τ)
end

# initial condition at t = 0 (must match tspan start)
u0 = [1.0]
tspan = (1.0, 5.5)

prob = DDEProblem(f, h, delay, u0, tspan)

sol = solve(prob, MethodOfSteps(Tsit5()))


# https://dieci.math.gatech.edu/preps/DieciLopez-Fili4.pdf
# vector fields
function f1!(du, u, p, t)
x1, x2 = u
du[1] = x2
du[2] = -x1 + 1/(1.2 - x2)
end

function f2!(du, u, p, t)
x1, x2 = u
du[1] = x2
du[2] = -x1 - 1/(0.8 + x2)
end

# switching surface Σ: x2 = 0.2
condition(u, p, t) = u[2] - 0.2

# mode indicator (which vector field is active)
mode = Ref(1)

function f!(du, u, p, t)
if mode[] == 1
f1!(du, u, p, t)
else
f2!(du, u, p, t)
end
end

# switch dynamics when crossing Σ
function affect!(integrator)
mode[] = 2 # toggle 1 ↔ 2
end

cb = ContinuousCallback(condition, affect!, is_discontinuity = true;)
cb2 = ContinuousCallback(condition, affect!, is_discontinuity = false;)

u0 = [-0.4, -0.5]
tspan = (0.0, 10.0)

prob = ODEProblem(f!, u0, tspan)

sol_disco_tsit5 = solve(prob, Tsit5(), callback=cb)
sol_no_disco_tsit5 = solve(prob, Tsit5(), callback=cb2)

sol_disco_radau = solve(prob, RadauIIA5(), callback=cb)
sol_no_disco_radau = solve(prob, RadauIIA5(), callback=cb2)

sol_disco_rosenbrock = solve(prob, Rodas5P(), callback=cb)
sol_no_disco_rosenbrock = solve(prob, Rodas5P(), callback=cb2)

sol_disco_bs3 = solve(prob, BS3(), callback=cb)
sol_no_disco_bs3 = solve(prob, BS3(), callback=cb2)
Loading
Loading