From a3036340f7346d2ccbe001ff474c1f23f6f57cd1 Mon Sep 17 00:00:00 2001 From: Pepijn de Vos Date: Tue, 7 Nov 2023 14:58:55 +0100 Subject: [PATCH] more WIP --- src/dense/generic_dense.jl | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/src/dense/generic_dense.jl b/src/dense/generic_dense.jl index c8e1c11b20..451279e227 100644 --- a/src/dense/generic_dense.jl +++ b/src/dense/generic_dense.jl @@ -451,7 +451,9 @@ function ode_interpolation(tval::Number, id::I, idxs, deriv::D, p, differential_vars = trues(length(timeseries[begin])) if hasproperty(f, :mass_matrix) mm = f.mass_matrix - if isdiag(mm) + if mm isa UniformScaling + # already correct + elseif isdiag(mm) differential_vars = Diagonal(mm).diag .!= 0 else @show typeof(mm) @@ -512,7 +514,9 @@ function ode_interpolation!(out, tval::Number, id::I, idxs, deriv::D, p, diferential_vars = trues(length(timeseries[begin])) if hasproperty(f, :mass_matrix) mm = f.mass_matrix - if isdiag(mm) + if mm isa UniformScaling + # already correct + elseif isdiag(mm) diferential_vars = Diagonal(mm).diag .!= 0 else # @show typeof(mm) @@ -552,7 +556,7 @@ function ode_interpolation!(out, tval::Number, id::I, idxs, deriv::D, p, _ode_addsteps!(ks[i₊], ts[i₋], timeseries[i₋], timeseries[i₊], dt, f, p, cache) # update the kcurrent ode_interpolant!(out, Θ, dt, timeseries[i₋], timeseries[i₊], ks[i₊], cache, - idxs, deriv) + idxs, deriv, differential_vars) end end @@ -586,12 +590,15 @@ function ode_interpolant(Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{TI}}) w _ode_interpolant(Θ, dt, y₀, y₁, k, cache, idxs, T) end +struct AllDifferential end +Base.getindex(::AllDifferential, ::Any) = true + function ode_interpolant(Θ, dt, y₀, y₁, k, cache::OrdinaryDiffEqMutableCache, idxs, - T::Type{Val{TI}}, differential_vars=[]) where {TI} + T::Type{Val{TI}}, differential_vars=AllDifferential()) where {TI} @show differential_vars if typeof(idxs) <: Number || typeof(y₀) <: Union{Number, SArray} # typeof(y₀) can be these if saveidxs gives a single value - _ode_interpolant(Θ, dt, y₀, y₁, k, cache, idxs, T) + _ode_interpolant(Θ, dt, y₀, y₁, k, cache, idxs, T, differential_vars) elseif typeof(idxs) <: Nothing if y₁ isa Array{<:Number} out = similar(y₁, eltype(first(y₁) * oneunit(Θ))) @@ -599,7 +606,7 @@ function ode_interpolant(Θ, dt, y₀, y₁, k, cache::OrdinaryDiffEqMutableCach else out = oneunit(Θ) .* y₁ end - _ode_interpolant!(out, Θ, dt, y₀, y₁, k, cache, idxs, T) + _ode_interpolant!(out, Θ, dt, y₀, y₁, k, cache, idxs, T, differential_vars) else if y₁ isa Array{<:Number} out = similar(y₁, eltype(first(y₁) * oneunit(Θ)), axes(idxs)) @@ -609,7 +616,7 @@ function ode_interpolant(Θ, dt, y₀, y₁, k, cache::OrdinaryDiffEqMutableCach else out = oneunit(Θ) .* y₁[idxs] end - _ode_interpolant!(out, Θ, dt, y₀, y₁, k, cache, idxs, T) + _ode_interpolant!(out, Θ, dt, y₀, y₁, k, cache, idxs, T, differential_vars) end end @@ -620,12 +627,12 @@ end ##################### Hermite Interpolants # If no dispatch found, assume Hermite -function _ode_interpolant(Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{TI}}) where {TI} +function _ode_interpolant(Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{TI}}, differential_vars=AllDifferential()) where {TI} hermite_interpolant(Θ, dt, y₀, y₁, k, Val{typeof(cache) <: OrdinaryDiffEqMutableCache}, idxs, T) end -function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{TI}}) where {TI} +function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{TI}}, differential_vars=AllDifferential()) where {TI} hermite_interpolant!(out, Θ, dt, y₀, y₁, k, idxs, T) end