Skip to content

Commit

Permalink
more WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
Pepijn de Vos committed Nov 7, 2023
1 parent 3809ea4 commit a303634
Showing 1 changed file with 16 additions and 9 deletions.
25 changes: 16 additions & 9 deletions src/dense/generic_dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,9 @@ function ode_interpolation(tval::Number, id::I, idxs, deriv::D, p,
differential_vars = trues(length(timeseries[begin]))
if hasproperty(f, :mass_matrix)
mm = f.mass_matrix
if isdiag(mm)
if mm isa UniformScaling
# already correct
elseif isdiag(mm)
differential_vars = Diagonal(mm).diag .!= 0
else
@show typeof(mm)
Expand Down Expand Up @@ -512,7 +514,9 @@ function ode_interpolation!(out, tval::Number, id::I, idxs, deriv::D, p,
diferential_vars = trues(length(timeseries[begin]))
if hasproperty(f, :mass_matrix)
mm = f.mass_matrix
if isdiag(mm)
if mm isa UniformScaling
# already correct
elseif isdiag(mm)
diferential_vars = Diagonal(mm).diag .!= 0
else
# @show typeof(mm)
Expand Down Expand Up @@ -552,7 +556,7 @@ function ode_interpolation!(out, tval::Number, id::I, idxs, deriv::D, p,
_ode_addsteps!(ks[i₊], ts[i₋], timeseries[i₋], timeseries[i₊], dt, f, p,
cache) # update the kcurrent
ode_interpolant!(out, Θ, dt, timeseries[i₋], timeseries[i₊], ks[i₊], cache,
idxs, deriv)
idxs, deriv, differential_vars)
end
end

Expand Down Expand Up @@ -586,20 +590,23 @@ function ode_interpolant(Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{TI}}) w
_ode_interpolant(Θ, dt, y₀, y₁, k, cache, idxs, T)
end

struct AllDifferential end
Base.getindex(::AllDifferential, ::Any) = true

function ode_interpolant(Θ, dt, y₀, y₁, k, cache::OrdinaryDiffEqMutableCache, idxs,
T::Type{Val{TI}}, differential_vars=[]) where {TI}
T::Type{Val{TI}}, differential_vars=AllDifferential()) where {TI}
@show differential_vars
if typeof(idxs) <: Number || typeof(y₀) <: Union{Number, SArray}
# typeof(y₀) can be these if saveidxs gives a single value
_ode_interpolant(Θ, dt, y₀, y₁, k, cache, idxs, T)
_ode_interpolant(Θ, dt, y₀, y₁, k, cache, idxs, T, differential_vars)
elseif typeof(idxs) <: Nothing
if y₁ isa Array{<:Number}
out = similar(y₁, eltype(first(y₁) * oneunit(Θ)))
copyto!(out, y₁)
else
out = oneunit(Θ) .* y₁
end
_ode_interpolant!(out, Θ, dt, y₀, y₁, k, cache, idxs, T)
_ode_interpolant!(out, Θ, dt, y₀, y₁, k, cache, idxs, T, differential_vars)
else
if y₁ isa Array{<:Number}
out = similar(y₁, eltype(first(y₁) * oneunit(Θ)), axes(idxs))
Expand All @@ -609,7 +616,7 @@ function ode_interpolant(Θ, dt, y₀, y₁, k, cache::OrdinaryDiffEqMutableCach
else
out = oneunit(Θ) .* y₁[idxs]
end
_ode_interpolant!(out, Θ, dt, y₀, y₁, k, cache, idxs, T)
_ode_interpolant!(out, Θ, dt, y₀, y₁, k, cache, idxs, T, differential_vars)
end
end

Expand All @@ -620,12 +627,12 @@ end
##################### Hermite Interpolants

# If no dispatch found, assume Hermite
function _ode_interpolant(Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{TI}}) where {TI}
function _ode_interpolant(Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{TI}}, differential_vars=AllDifferential()) where {TI}
hermite_interpolant(Θ, dt, y₀, y₁, k, Val{typeof(cache) <: OrdinaryDiffEqMutableCache},
idxs, T)
end

function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{TI}}) where {TI}
function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{TI}}, differential_vars=AllDifferential()) where {TI}
hermite_interpolant!(out, Θ, dt, y₀, y₁, k, idxs, T)
end

Expand Down

0 comments on commit a303634

Please sign in to comment.