Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Linear interpolation for algebraic variables #2048

Merged
merged 40 commits into from
Dec 11, 2023

Conversation

pepijndevos
Copy link
Contributor

Except for those solvers with specialized stiffness aware* interpolation, hermite interpolation is used which uses the derivative. For algebraic variables this ends up incorrectly using the residual. This PR aims to make it fall back to linear interpolation for algebraic variables.

*what does that actually mean?

At the moment I just pushed whatever I was doing with @oscardssmith last time

@pepijndevos
Copy link
Contributor Author

I made a very first attempt to get this to work, and am now at a stage again where I could use some help from someone more familiar with the code. There are several questions

  • Which paths into the code need to be handled to cover all uses
  • Which types of states and indexes need to be handled
  • Something about QR factorization and cacheing the computation

I'm currently testing with the rober example from https://docs.sciml.ai/DiffEqDocs/stable/tutorials/dae_example/#Mass-Matrix-Differential-Algebraic-Equations-(DAEs)

To elaborate on the first point, there are two methods of ode_interpolation that currently don't pass differential_vars anywhere, and I'm not sure when they are used. Are there other code paths that lead here that we need to handle?

The second point is a bit annoying. It seems like u can be an array or a scalar, and that idxs can be an array, nothing, or a single value, but it's not clear to me what assumptions I can make. Like, are they always numbers? Are there combinations that can't happen? It seems almost impossible to write a generic method that handles all cases, so I've started to implement special cases, but am uncertain of the total set of special cases I need to handle.

@pepijndevos pepijndevos marked this pull request as ready for review November 13, 2023 17:50
@pepijndevos
Copy link
Contributor Author

I did the implementation as best as I can, from here on it's a tone of clean up and review and handling edge cases I suppose.

I'm currently throwing an error in the base case to make sure we handle all the cases that we should, and I see that is getting hit by a NoIndexArray whatever that may be https://github.com/SciML/OrdinaryDiffEq.jl/actions/runs/6853617794/job/18634794102?pr=2048#step:6:690

@pepijndevos
Copy link
Contributor Author

pepijndevos commented Nov 14, 2023

Summary of the current failures and my ongoing understanding of them

@@ -324,6 +340,7 @@ function ode_interpolation(tvals, id::I, idxs, deriv::D, p,
@unpack ts, timeseries, ks, f, cache = id
@inbounds tdir = sign(ts[end] - ts[1])
idx = sortperm(tvals, rev = tdir < 0)
differential_vars = get_differential_vars(f, size(timeseries[begin]))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should maybe be passed somewhere, but @oscardssmith had some argument why not all cache types need it that I forgot so idk if deep down in the evaluate_interpolant machinery there is somewhere this needs to go

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we do want to cache this eventually, but at first it's not necessary.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure for now I'm talking more about covering all the places. This variable is currently unused.

end

function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{TI}}) where {TI}
function partial_hermite_interpolant!(out, Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{TI}}, differential_vars) where {TI}
throw("how did we get here")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once I'm confident I've handled all the cases I guess either completely remove this method, or remove the error and leave it as a fallback if it turns out there is some insane edge case somewhere in the ecosystem.

@pepijndevos
Copy link
Contributor Author

Memory profile:
image

It seems like pretty much everything is getting inlined except idxs[.!sel] which allocates a new BitVector. Not sure what I can do about that.

Would be funny if the broadcasted version is faster because it doesn't allocate. Math is fast, memory is slow.

@pepijndevos
Copy link
Contributor Author

It is officially faster to broadcast both interpolations than to do the indexing, go figure.

julia> @benchmark sol(0.1)
BenchmarkTools.Trial: 10000 samples with 10 evaluations.
 Range (min … max):  1.112 μs …   9.666 μs  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     1.205 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   1.275 μs ± 276.650 ns  ┊ GC (mean ± σ):  0.00% ± 0.00%

    ▅█▅▃                                                       
  ▂▆████▇▅▃▂▃▄▅▅▄▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂
  1.11 μs         Histogram: frequency by time        2.09 μs <

 Memory estimate: 1.14 KiB, allocs estimate: 32.

julia> @benchmark sol(0.1)
BenchmarkTools.Trial: 10000 samples with 49 evaluations.
 Range (min … max):  889.204 ns …   3.239 μs  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):       1.030 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):     1.041 μs ± 118.043 ns  ┊ GC (mean ± σ):  0.00% ± 0.00%

              ▂▁▁▁▅█▅▃▅▆▄▂ ▂                                     
  ▁▁▁▁▁▂▃▃▃▄▅███████████████▇▇▄▄▄▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▃
  889 ns           Histogram: frequency by time         1.33 μs <

 Memory estimate: 640 bytes, allocs estimate: 16.

@oscardssmith
Copy link
Contributor

You should probably have a fast path for no algebraic variables though. Lots of people only solve ODEs.

@pepijndevos
Copy link
Contributor Author

Ok so basically a complete rewrite to avoid allocations and use broadcasting and optimize for ODE.

I think if we default differential_vars::Nothing we can dispatch on that and do the optimal thing. Is it sufficient to have nothing in case of no mass matrix or uniform scaling, or is it common to have some other type of diagonal matrix with all ones that we need to test explicitly?

For broadcasting we need to change differential_vars to be output/index shaped.

And I fear that for completely avoiding allocations in the in-place version I need to write a fully integrated and broadcasted interpolation.

For the out of place version you can just do

    h = hermite_interpolant(Θ, dt, y₀, y₁, k, Val{cache isa OrdinaryDiffEqMutableCache}, idxs, T)
    l = linear_interpolant(Θ, dt, y₀, y₁, idxs, T)
    @.. broadcast=false h*differential_vars + l*!differential_vars

but for the in-place version you can't have both write to out without indexing or allocating:

    hermite_interpolant!(out, Θ, dt, y₀, y₁, k, idxs, T)
    l = linear_interpolant(Θ, dt, y₀, y₁, idxs, T)
    @.. broadcast=false out=out*differential_vars + l*!differential_vars

So I'm thinking to change

    return (1 - Θ) * y₀[idxs] + Θ * y₁[idxs] +
           Θ *- 1) *
           ((1 - 2Θ) * (y₁[idxs] - y₀[idxs]) +- 1) * dt * k[1][idxs] +
            Θ * dt * k[2][idxs])

to

    return (1 - Θ) * y₀[idxs] + Θ * y₁[idxs] +
           isdiff * Θ *- 1) *
           ((1 - 2Θ) * (y₁[idxs] - y₀[idxs]) +- 1) * dt * k[1][idxs] +
            Θ * dt * k[2][idxs])

That adds like one multiplication. Maybe that's insignificant enough that we don't need a complete fast path?

Would it make sense to pass it as a type-level boolean so in the fast path you can compile away the multiplication, or would that cause more trouble than it's worth when it can't be statically inferred?

I'd really like to avoid having two complete copies of hermite interpolation for a single multiplication.

@pepijndevos
Copy link
Contributor Author

That wasn't so bad for the primal. For derivatives it's a bit more tricky but if it passes the tests I'll do the rest.

Comment on lines 630 to 632
hermite_interpolant!(out, Θ, dt, y₀, y₁, k, idxs, T)
else
hermite_interpolant!(out, Θ, dt, y₀, y₁, k, idxs, T, differential_vars)
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just use dispatch?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ambiguous method error

if idxs === nothing || differential_vars === nothing
return differential_vars
else
return differential_vars[idxs]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not a view?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea I suppose

@pepijndevos
Copy link
Contributor Author

I'm honestly a bit puzzled by some of these CI errors. Like the foop one.

@oscardssmith
Copy link
Contributor

The foop one isn't your fault.

@pepijndevos
Copy link
Contributor Author

Is this my fault? https://github.com/SciML/OrdinaryDiffEq.jl/actions/runs/6940604309/job/18879835886?pr=2048#step:6:637

It seems to be a regression test for #2055 and I don't really see the connection or the bug.

@ChrisRackauckas
Copy link
Member

I just merged a PR that was all green except for the format check (which the formatter is still having issues, I'm going to put a bounty on that) #2069, so if interpolation things are failing I'd venture to guess there was a merge issue.

@pepijndevos
Copy link
Contributor Author

I'm hoping UndefVarError: linearizing_save_finalize not defined isn't my fault?

Meanwhile I've started adding all the derivative implementations.

Comment on lines -789 to -794
@muladd function hermite_interpolant!(out::Array, Θ, dt, y₀, y₁, k, idxs, T::Type{Val{2}}) # Default interpolant is Hermite
@views @.. broadcast=false out=(-4 * dt * k[1][idxs] - 2 * dt * k[2][idxs] -
6 * y₀[idxs] +
Θ * (6 * dt * k[1][idxs] + 6 * dt * k[2][idxs] +
12 * y₀[idxs] - 12 * y₁[idxs]) + 6 * y₁[idxs]) /
(dt * dt)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this one had two complete copies of the interpolation??? I just kept the last one which was effectively being used I guess.

@topolarity
Copy link

topolarity commented Nov 28, 2023

Seems to throw an error for sol(0.0, Val{1}, idxs=3):

julia> sol(t[1], Val{1}, idxs=3)
ERROR: BoundsError: attempt to access 0-dimensional view(::BitVector, 3) with eltype Bool at index [3]
Stacktrace:
  [1] throw_boundserror(A::SubArray{Bool, 0, BitVector, Tuple{Int64}, true}, I::Tuple{Int64})
    @ Base ./essentials.jl:14
  [2] checkbounds
    @ ./abstractarray.jl:699 [inlined]
  [3] getindex
    @ ./subarray.jl:308 [inlined]
  [4] hermite_interpolant(Θ::Float64, dt::Float64, y₀::Vector{…}, y₁::Vector{…}, k::Vector{…}, cache::Type, idxs::Int64, T::Type{…}, dv::SubArray{…})
    @ OrdinaryDiffEq ~/repos/OrdinaryDiffEq.jl/src/dense/generic_dense.jl:729
  [5] _ode_interpolant
    @ ~/repos/OrdinaryDiffEq.jl/src/dense/generic_dense.jl:624 [inlined]
  [6] ode_interpolant
    @ ~/repos/OrdinaryDiffEq.jl/src/dense/generic_dense.jl:591 [inlined]
  [7] ode_interpolation(tval::Float64, id::OrdinaryDiffEq.InterpolationData{…}, idxs::Int64, deriv::Type{…}, p::Vector{…}, continuity::Symbol)
    @ OrdinaryDiffEq ~/repos/OrdinaryDiffEq.jl/src/dense/generic_dense.jl:502
  [8] InterpolationData
    @ OrdinaryDiffEq ~/repos/OrdinaryDiffEq.jl/src/interp_func.jl:168 [inlined]
  [9] AbstractODESolution
    @ SciMLBase ~/.julia/packages/SciMLBase/XEPyX/src/solutions/ode_solutions.jl:158 [inlined]
 [10] #_#439
    @ SciMLBase ~/.julia/packages/SciMLBase/XEPyX/src/solutions/ode_solutions.jl:139 [inlined]
 [11] top-level scope
    @ REPL[84]:1
Some type information was truncated. Use `show(err)` to see complete types.

@pepijndevos
Copy link
Contributor Author

For what problem/solver? Anyway, I'll add a bunch of tests to gain some confidence that all of those dozens of methods I changed actually work.

@pepijndevos
Copy link
Contributor Author

I'm pretty sure I've seen matrix-shaped idxs and outputs, but when I just pass a matrix to sol, it errors. Is there some other code path that results in a matrix?

@pepijndevos
Copy link
Contributor Author

It seems like master fails the same checks as this branch.

If the tests I've just added pass on CI and didn't break anything else, this should be good for another round of reviews and hopefully merge.

@topolarity
Copy link

This is definitely a step in the right direction, but it's worth mentioning that the derivatives will still need some work after this PR:
image

In general, the derivative of an interpolant doesn't have anything to do with the derivative of the interpolated function, unless that property was guaranteed by construction (e.g. the first-derivative of a hermite interpolation), so I think we need to use a different trick here.

@@ -282,7 +282,7 @@ end

@muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k,
cache::Union{Rosenbrock5ConstantCache, Rosenbrock5Cache},
idxs, T::Type{Val{1}})
idxs, T::Type{Val{1}}, dv=nothing)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why would dv show up here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we pass it as nothing to all interpolants

@ChrisRackauckas
Copy link
Member

using StaticArrays, LinearAlgebra
mm = SMatrix{2,2}(rand(2,2))
u = SA[1.0,2.0]
reshape(diag(mm)  .!= 0, size(u)) isa SVector # true

@ChrisRackauckas ChrisRackauckas merged commit 7d500e5 into SciML:master Dec 11, 2023
42 of 56 checks passed
@ChrisRackauckas
Copy link
Member

I'll need to do DelayDiffEq.jl because it reaches into internals that are changed on this.

@pepijndevos
Copy link
Contributor Author

Wow you've been busy, thanks!

@topolarity
Copy link

In what way? I think the key is really just that the free interpolation cannot give you any more insight and you have to compute more (like a higher order Hermite)

A higher-order analytic derivative would be lovely, but this is effectively a finite-differencing scheme anyway, so choosing an improved scheme would already provide a significant improvement. For example, the purple line here is just a linear interpolation of the "staircase" this PR provides for sol(t, Val{1}) and pushes down the error of the derivative by more than an order of magnitude:
image (6)

Similarly, I don't see why we'd give up for sol(t, Val{2}) and report an incorrect zero derivative, instead of continuing to do a finite-differencing scheme of some kind.

@topolarity
Copy link

BTW by finite-differencing, I mean that the coefficients of the polynomial interpolants and their derivatives for algebraic variables are (usually linear) functions of u(tᵢ). They don't analytically measure u'(t) anywhere, so they can at best perform like a decent finite-differencing scheme rather than a proper analytic derivative

@ChrisRackauckas
Copy link
Member

That's not a local scheme though, as you're using more than just the current step information? That can only be done in the post solution interpolation.

@topolarity
Copy link

That's not a local scheme though, as you're using more than just the current step information? That can only be done in the post solution interpolation.

That just means you have to restrict to one-sided finite differencing techniques

I was kind of expecting sol(t, Val{n}) to just "do the right thing" in terms of using derivatives if they were included in the system solution, exploiting the solver's interpolation, etc.

But if there are scenarios where it uses a poor finite-differencing scheme or returns zero results when it shouldn't then maybe we just need a way to know when sol(t, Val{n}) is the right thing to use and when it isn't (esp. for post-solution interpolation)

@ChrisRackauckas
Copy link
Member

That just means you have to restrict to one-sided finite differencing techniques

But you only have u_n and u_{n+1} and the k's, which means if you finite difference the u's in a step then it must be a linear interpolation.

@devmotion
Copy link
Member

I'll need to do DelayDiffEq.jl because it reaches into internals that are changed on this.

Someone apparently already ran into this problem 😅 SciML/DelayDiffEq.jl#274

@ChrisRackauckas
Copy link
Member

Yeah it was a fundamental break and just needs a downstream fix. I'll put that in ASAP. Just been doing grant writing all week and was using this as a sidepiece while avoiding the real work.

@jebej
Copy link

jebej commented Dec 18, 2023

I suspect this PR caused a regression: #2086

@pepijndevos
Copy link
Contributor Author

What makes you say so?

If that's the case the thing to look for is a broadcast in Hermite interpolant that returns an array rather than a scalar.

@jebej
Copy link

jebej commented Dec 18, 2023

There are only two changes between 6.60 and 6.61, and the other doesn't appear related, but I might be wrong. v6.60.0...v6.61.0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants