diff --git a/src/dense/generic_dense.jl b/src/dense/generic_dense.jl index f457828358..c8e1c11b20 100644 --- a/src/dense/generic_dense.jl +++ b/src/dense/generic_dense.jl @@ -87,13 +87,14 @@ end end @inline function ode_interpolant(Θ, integrator::DiffEqBase.DEIntegrator, idxs, deriv) + @show integrator.f.mass_matrix DiffEqBase.addsteps!(integrator) - if !(typeof(integrator.cache) <: CompositeCache) - val = ode_interpolant(Θ, integrator.dt, integrator.uprev, integrator.u, - integrator.k, integrator.cache, idxs, deriv) - else + if integrator.cache isa CompositeCache val = composite_ode_interpolant(Θ, integrator, integrator.cache.caches, integrator.cache.current, idxs, deriv) + else + val = ode_interpolant(Θ, integrator.dt, integrator.uprev, integrator.u, + integrator.k, integrator.cache, idxs, deriv) end val end @@ -119,6 +120,7 @@ end end @inline function ode_interpolant!(val, Θ, integrator::DiffEqBase.DEIntegrator, idxs, deriv) + @show "!" integrator.f.mass_matrix DiffEqBase.addsteps!(integrator) if !(typeof(integrator.cache) <: CompositeCache) ode_interpolant!(val, Θ, integrator.dt, integrator.uprev, integrator.u, @@ -446,6 +448,19 @@ function ode_interpolation(tval::Number, id::I, idxs, deriv::D, p, @unpack ts, timeseries, ks, f, cache = id @inbounds tdir = sign(ts[end] - ts[1]) + differential_vars = trues(length(timeseries[begin])) + if hasproperty(f, :mass_matrix) + mm = f.mass_matrix + if isdiag(mm) + differential_vars = Diagonal(mm).diag .!= 0 + else + @show typeof(mm) + error("QR factorizations is annoying") + end + end + @show differential_vars + + if continuity === :left # we have i₋ = i₊ = 1 if tval = ts[1], i₊ = i₋ + 1 = lastindex(ts) if tval > ts[end], # and otherwise i₋ and i₊ satisfy ts[i₋] < tval ≤ ts[i₊] @@ -476,7 +491,7 @@ function ode_interpolation(tval::Number, id::I, idxs, deriv::D, p, _ode_addsteps!(ks[i₊], ts[i₋], timeseries[i₋], timeseries[i₊], dt, f, p, cache) # update the kcurrent val = ode_interpolant(Θ, dt, timeseries[i₋], timeseries[i₊], ks[i₊], cache, - idxs, deriv) + idxs, deriv, differential_vars) end end @@ -494,6 +509,19 @@ function ode_interpolation!(out, tval::Number, id::I, idxs, deriv::D, p, @unpack ts, timeseries, ks, f, cache = id @inbounds tdir = sign(ts[end] - ts[1]) + diferential_vars = trues(length(timeseries[begin])) + if hasproperty(f, :mass_matrix) + mm = f.mass_matrix + if isdiag(mm) + diferential_vars = Diagonal(mm).diag .!= 0 + else + # @show typeof(mm) + error("QR factorizations is annoying") + end + end + @show "!" diferential_vars + + if continuity === :left # we have i₋ = i₊ = 1 if tval = ts[1], i₊ = i₋ + 1 = lastindex(ts) if tval > ts[end], # and otherwise i₋ and i₊ satisfy ts[i₋] < tval ≤ ts[i₊] @@ -559,7 +587,8 @@ function ode_interpolant(Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{TI}}) w end function ode_interpolant(Θ, dt, y₀, y₁, k, cache::OrdinaryDiffEqMutableCache, idxs, - T::Type{Val{TI}}) where {TI} + T::Type{Val{TI}}, differential_vars=[]) where {TI} + @show differential_vars if typeof(idxs) <: Number || typeof(y₀) <: Union{Number, SArray} # typeof(y₀) can be these if saveidxs gives a single value _ode_interpolant(Θ, dt, y₀, y₁, k, cache, idxs, T)