Skip to content

Commit

Permalink
Fix definitions to always pass and make missing more robust
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Dec 10, 2023
1 parent 4580a96 commit 9ab430e
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions src/dense/generic_dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ function evaluate_interpolant(f, Θ, dt, timeseries, i₋, i₊, cache, idxs,
end
end

struct DifferentialVarsUndefined end
function get_differential_vars(f, idxs, timeseries)
differential_vars = nothing
if hasproperty(f, :mass_matrix)
Expand All @@ -322,7 +323,7 @@ function get_differential_vars(f, idxs, timeseries)
elseif isdiag(mm) && all(x -> size(x) == size(timeseries[begin]), timeseries)
differential_vars = reshape(diag(mm) .!= 0, size(timeseries[begin]))
else
return missing # interpret missing downstream as not implemented
return DifferentialVarsUndefined()
end
end
if idxs === nothing
Expand Down Expand Up @@ -413,10 +414,10 @@ function ode_interpolation!(vals, tvals, id::I, idxs, deriv::D, p,
if cache isa (FunctionMapCache) || cache isa FunctionMapConstantCache
if eltype(vals) <: AbstractArray
ode_interpolant!(vals[j], Θ, dt, timeseries[i₋], timeseries[i₊], 0, cache,
idxs, deriv)
idxs, deriv, differential_vars)
else
vals[j] = ode_interpolant(Θ, dt, timeseries[i₋], timeseries[i₊], 0, cache,
idxs, deriv)
idxs, deriv, differential_vars)
end
elseif !id.dense
if eltype(vals) <: AbstractArray
Expand All @@ -435,10 +436,10 @@ function ode_interpolation!(vals, tvals, id::I, idxs, deriv::D, p,
cache_i₊) # update the kcurrent
if eltype(vals) <: AbstractArray
ode_interpolant!(vals[j], Θ, dt, timeseries[i₋], timeseries[i₊], ks[i₊],
cache_i₊, idxs, deriv)
cache_i₊, idxs, deriv, differential_vars)
else
vals[j] = ode_interpolant(Θ, dt, timeseries[i₋], timeseries[i₊], ks[i₊],
cache_i₊, idxs, deriv)
cache_i₊, idxs, deriv, differential_vars)
end
else
_ode_addsteps!(ks[i₊], ts[i₋], timeseries[i₋], timeseries[i₊], dt, f, p,
Expand Down Expand Up @@ -494,7 +495,7 @@ function ode_interpolation(tval::Number, id::I, idxs, deriv::D, p,
_ode_addsteps!(ks[i₊], ts[i₋], timeseries[i₋], timeseries[i₊], dt, f, p,
cache.caches[id.alg_choice[i₊]]) # update the kcurrent
val = ode_interpolant(Θ, dt, timeseries[i₋], timeseries[i₊], ks[i₊],
cache.caches[id.alg_choice[i₊]], idxs, deriv)
cache.caches[id.alg_choice[i₊]], idxs, deriv, differential_vars)
else
_ode_addsteps!(ks[i₊], ts[i₋], timeseries[i₋], timeseries[i₊], dt, f, p,
cache) # update the kcurrent
Expand Down Expand Up @@ -635,7 +636,7 @@ end

# If no dispatch found, assume Hermite
function _ode_interpolant(Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{TI}}, differential_vars) where {TI}
differential_vars === missing && throw(HermiteInterpolationNonDiagonalError())
differential_vars isa DifferentialVarsUndefined && throw(HermiteInterpolationNonDiagonalError())

differential_vars = if differential_vars === nothing
if y₀ isa Number
Expand All @@ -651,7 +652,7 @@ function _ode_interpolant(Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{TI}},
end

function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{TI}}, differential_vars) where {TI}
differential_vars === missing && throw(HermiteInterpolationNonDiagonalError())
differential_vars isa DifferentialVarsUndefined && throw(HermiteInterpolationNonDiagonalError())

differential_vars = if differential_vars === nothing
if y₀ isa Number
Expand Down

0 comments on commit 9ab430e

Please sign in to comment.