Skip to content

Commit

Permalink
oscar wip
Browse files Browse the repository at this point in the history
  • Loading branch information
Pepijn de Vos committed Nov 2, 2023
1 parent 880221b commit 3809ea4
Showing 1 changed file with 35 additions and 6 deletions.
41 changes: 35 additions & 6 deletions src/dense/generic_dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,14 @@ end
end

@inline function ode_interpolant(Θ, integrator::DiffEqBase.DEIntegrator, idxs, deriv)
@show integrator.f.mass_matrix
DiffEqBase.addsteps!(integrator)
if !(typeof(integrator.cache) <: CompositeCache)
val = ode_interpolant(Θ, integrator.dt, integrator.uprev, integrator.u,
integrator.k, integrator.cache, idxs, deriv)
else
if integrator.cache isa CompositeCache
val = composite_ode_interpolant(Θ, integrator, integrator.cache.caches,
integrator.cache.current, idxs, deriv)
else
val = ode_interpolant(Θ, integrator.dt, integrator.uprev, integrator.u,
integrator.k, integrator.cache, idxs, deriv)
end
val
end
Expand All @@ -119,6 +120,7 @@ end
end

@inline function ode_interpolant!(val, Θ, integrator::DiffEqBase.DEIntegrator, idxs, deriv)
@show "!" integrator.f.mass_matrix
DiffEqBase.addsteps!(integrator)
if !(typeof(integrator.cache) <: CompositeCache)
ode_interpolant!(val, Θ, integrator.dt, integrator.uprev, integrator.u,
Expand Down Expand Up @@ -446,6 +448,19 @@ function ode_interpolation(tval::Number, id::I, idxs, deriv::D, p,
@unpack ts, timeseries, ks, f, cache = id
@inbounds tdir = sign(ts[end] - ts[1])

differential_vars = trues(length(timeseries[begin]))
if hasproperty(f, :mass_matrix)
mm = f.mass_matrix
if isdiag(mm)
differential_vars = Diagonal(mm).diag .!= 0
else
@show typeof(mm)
error("QR factorizations is annoying")
end
end
@show differential_vars


if continuity === :left
# we have i₋ = i₊ = 1 if tval = ts[1], i₊ = i₋ + 1 = lastindex(ts) if tval > ts[end],
# and otherwise i₋ and i₊ satisfy ts[i₋] < tval ≤ ts[i₊]
Expand Down Expand Up @@ -476,7 +491,7 @@ function ode_interpolation(tval::Number, id::I, idxs, deriv::D, p,
_ode_addsteps!(ks[i₊], ts[i₋], timeseries[i₋], timeseries[i₊], dt, f, p,
cache) # update the kcurrent
val = ode_interpolant(Θ, dt, timeseries[i₋], timeseries[i₊], ks[i₊], cache,
idxs, deriv)
idxs, deriv, differential_vars)
end
end

Expand All @@ -494,6 +509,19 @@ function ode_interpolation!(out, tval::Number, id::I, idxs, deriv::D, p,
@unpack ts, timeseries, ks, f, cache = id
@inbounds tdir = sign(ts[end] - ts[1])

diferential_vars = trues(length(timeseries[begin]))
if hasproperty(f, :mass_matrix)
mm = f.mass_matrix
if isdiag(mm)
diferential_vars = Diagonal(mm).diag .!= 0
else
# @show typeof(mm)
error("QR factorizations is annoying")
end
end
@show "!" diferential_vars


if continuity === :left
# we have i₋ = i₊ = 1 if tval = ts[1], i₊ = i₋ + 1 = lastindex(ts) if tval > ts[end],
# and otherwise i₋ and i₊ satisfy ts[i₋] < tval ≤ ts[i₊]
Expand Down Expand Up @@ -559,7 +587,8 @@ function ode_interpolant(Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{TI}}) w
end

function ode_interpolant(Θ, dt, y₀, y₁, k, cache::OrdinaryDiffEqMutableCache, idxs,
T::Type{Val{TI}}) where {TI}
T::Type{Val{TI}}, differential_vars=[]) where {TI}
@show differential_vars
if typeof(idxs) <: Number || typeof(y₀) <: Union{Number, SArray}
# typeof(y₀) can be these if saveidxs gives a single value
_ode_interpolant(Θ, dt, y₀, y₁, k, cache, idxs, T)
Expand Down

0 comments on commit 3809ea4

Please sign in to comment.