Skip to content

Commit

Permalink
Extend extrapolation methods to allow for arbitrary arrays.
Browse files Browse the repository at this point in the history
  • Loading branch information
michakraus committed Aug 6, 2024
1 parent 7da27cb commit 3ffce0d
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 45 deletions.
20 changes: 10 additions & 10 deletions src/extrapolation/aitken_neville.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,18 @@ where
* `ti`: interpolation nodes
* `xi`: interpolation values
"""
function aitken_neville!(x::AbstractVector, t::TT, ti::AbstractVector{TT}, xi::AbstractMatrix) where {TT}
@assert length(ti) == size(xi,2)
@assert length(x) == size(xi,1)
function aitken_neville!(x::AbstractArray, t::TT, ti::AbstractVector{TT}, xi::AbstractVector) where {TT}
@assert length(ti) == length(xi)

for _xi in xi
@assert axes(x) == axes(_xi)
end

for j in eachindex(ti)
for i in 1:(length(ti)-j)
for k in axes(x,1)
xi[k,i] = xi[k,i+1] + (xi[k,i] - xi[k,i+1]) * (ti[i+j] - t) / (ti[i+j] - ti[i])
end
for i in eachindex(ti)[begin:end-j]
@. xi[i] = xi[i+1] + (xi[i] - xi[i+1]) * (ti[i+j] - t) / (ti[i+j] - ti[i])
end
end
for k in eachindex(x)
x[k] = xi[k,1]
end

copyto!(x, xi[1])
end
18 changes: 6 additions & 12 deletions src/extrapolation/euler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,30 +34,24 @@ struct EulerExtrapolation <: Extrapolation
end


function extrapolate!(t₀::TT, x₀::AbstractVector{DT},
t₁::TT, x₁::AbstractVector{DT},
function extrapolate!(t₀::TT, x₀::AbstractArray{DT},
t₁::TT, x₁::AbstractArray{DT},
problem::AbstractProblemODE,
extrap::EulerExtrapolation) where {DT,TT}

@assert axes(x₀) == axes(x₁)

local F = collect(1:(extrap.s+1))
local σ = (t₁ - t₀) ./ F
local pts = repeat(x₀, outer = [1, extrap.s+1])
local pts = [copy(x₀) for _ in F]

local xᵢ = zero(x₀)
local vᵢ = zero(x₀)

for i in F
for _ in 1:(F[i]-1)
tᵢ = t₀ + σ[i]
for k in axes(pts,1)
xᵢ[k] = pts[k,i]
end
initialguess(problem).v(vᵢ, tᵢ, xᵢ, parameters(problem))
for k in axes(pts,1)
pts[k,i] += σ[i] * vᵢ[k]
end
tᵢ = t₀ + σ[i]
initialguess(problem).v(vᵢ, tᵢ, pts[i], parameters(problem))
pts[i] .+= σ[i] * vᵢ
end
end

Expand Down
36 changes: 13 additions & 23 deletions src/extrapolation/midpoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ end


function extrapolate!(
t₀::TT, x₀::AbstractVector{DT},
t₁::TT, x₁::AbstractVector{DT},
t₀::TT, x₀::AbstractArray{DT},
t₁::TT, x₁::AbstractArray{DT},
problem::Union{AbstractProblemODE, SODEProblem},
extrap::MidpointExtrapolation) where {DT,TT}

Expand All @@ -102,7 +102,7 @@ function extrapolate!(
local F = [2i*one(TT) for i in 1:extrap.s+1]
local σ = (t₁ - t₀) ./ F
local σ² = σ.^2
local pts = zeros(DT, axes(x₀)..., extrap.s+1)
local pts = [zero(x₀) for _ in 1:extrap.s+1]

local xᵢ₁ = zero(x₀)
local xᵢ₂ = zero(x₀)
Expand All @@ -112,7 +112,7 @@ function extrapolate!(

initialguess(problem).v(v₀, t₀, x₀, parameters(problem))

for i in 1:extrap.s+1
for i in eachindex(pts)
tᵢ = t₀ + σ[i]
xᵢ₁ .= x₀
xᵢ₂ .= x₀ .+ σ[i] .* v₀
Expand All @@ -122,9 +122,7 @@ function extrapolate!(
xᵢ₁ .= xᵢ₂
xᵢ₂ .= xᵢₜ
end
for k in axes(pts,1)
pts[k,i] += xᵢ₂[k]
end
pts[i] .+= xᵢ₂
end

aitken_neville!(x₁, zero(TT), σ², pts)
Expand All @@ -150,8 +148,8 @@ function extrapolate!(t₀::TT, q₀::AbstractVector{DT}, p₀::AbstractVector{D
local σ = (t₁ - t₀) ./ F
local σ2 = σ.^2

local qts = zeros(DT, axes(q₀)..., extrap.s+1)
local pts = zeros(DT, axes(p₀)..., extrap.s+1)
local qts = [zero(q₀) for _ in 1:extrap.s+1]
local pts = [zero(p₀) for _ in 1:extrap.s+1]

local qᵢ₁= zero(q₀)
local qᵢ₂= zero(q₀)
Expand Down Expand Up @@ -186,12 +184,8 @@ function extrapolate!(t₀::TT, q₀::AbstractVector{DT}, p₀::AbstractVector{D
pᵢ₁ .= pᵢ₂
pᵢ₂ .= pᵢₜ
end
for k in axes(qts,1)
qts[k,i] += qᵢ₂[k]
end
for k in axes(pts,1)
pts[k,i] += pᵢ₂[k]
end
qts[i] .+= qᵢ₂
pts[i] .+= pᵢ₂
end

aitken_neville!(q₁, zero(TT), σ2, qts)
Expand Down Expand Up @@ -219,8 +213,8 @@ function extrapolate!(
local σ = (t₁ - t₀) ./ F
local σ2 = σ.^2

local qts = zeros(DT, axes(q₀)..., extrap.s+1)
local pts = zeros(DT, axes(p₀)..., extrap.s+1)
local qts = [zero(q₀) for _ in 1:extrap.s+1]
local pts = [zero(p₀) for _ in 1:extrap.s+1]

local qᵢ₁= zero(q₀)
local qᵢ₂= zero(q₀)
Expand Down Expand Up @@ -255,12 +249,8 @@ function extrapolate!(
pᵢ₁ .= pᵢ₂
pᵢ₂ .= pᵢₜ
end
for k in axes(qts,1)
qts[k,i] += qᵢ₂[k]
end
for k in axes(pts,1)
pts[k,i] += pᵢ₂[k]
end
qts[i] .+= qᵢ₂
pts[i] .+= pᵢ₂
end

aitken_neville!(q₁, zero(TT), σ2, qts)
Expand Down

0 comments on commit 3ffce0d

Please sign in to comment.