Skip to content

Commit

Permalink
Duck typing and consistent return types (Vector{typeof(y0)}) for all …
Browse files Browse the repository at this point in the history
…solvers; see SciML#7. Tests adjusted accordingly.
  • Loading branch information
acroy committed Feb 26, 2014
1 parent cd9ccff commit 2db3a89
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 50 deletions.
100 changes: 54 additions & 46 deletions src/ODE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ export ode23, ode4, ode45, ode4s, ode4ms
# Initialize variables.
# Adapted from Cleve Moler's textbook
# http://www.mathworks.com/moler/ncm/ode23tx.m
function ode23{T}(F::Function, tspan::AbstractVector, y0::AbstractVector{T})
function ode23(F, tspan, y0)

rtol = 1.e-5
atol = 1.e-8
Expand All @@ -56,7 +56,8 @@ function ode23{T}(F::Function, tspan::AbstractVector, y0::AbstractVector{T})
y = y0

tout = t
yout = y.'
yout = Array(typeof(y0),1)
yout[1] = y

tlen = length(t)

Expand Down Expand Up @@ -99,7 +100,7 @@ function ode23{T}(F::Function, tspan::AbstractVector, y0::AbstractVector{T})
t = tnew
y = ynew
tout = [tout; t]
yout = [yout; y.']
push!(yout, y)
s1 = s4 # Reuse final function value to start new step
end

Expand All @@ -116,7 +117,7 @@ function ode23{T}(F::Function, tspan::AbstractVector, y0::AbstractVector{T})

end # while (t != tfinal)

return (tout, yout)
return tout, yout

end # ode23

Expand Down Expand Up @@ -178,7 +179,7 @@ end # ode23
# created : 06 October 1999
# modified: 17 January 2001

function oderkf{T}(F::Function, tspan::AbstractVector, x0::AbstractVector{T}, a, b4, b5)
function oderkf(F, tspan, x0, a, b4, b5)

This comment has been minimized.

Copy link
@tomasaschan

tomasaschan Feb 27, 2014

While we're changing this, maybe it would make sense to s/x0/y0 here as well, to have consistent argument naming across all the methods. Since it's not uncommon to talk about either x = f(t), y = f(x) or y = f(t), (or even y = f(t,x) I think calling the solution y is the least ambiguous - y is almost always used as whatever f is describing, and t is almost always used as the (/an) argument to f, while x can be either.

tol = 1.0e-5

# see p.91 in the Ascher & Petzold reference for more infomation.
Expand All @@ -194,9 +195,10 @@ function oderkf{T}(F::Function, tspan::AbstractVector, x0::AbstractVector{T}, a,
h = (tfinal - t)/100 # initial guess at a step size
x = x0
tout = t # first output time
xout = x.' # first output solution
xout = Array(typeof(x0), 1)
xout[1] = x # first output solution

k = zeros(eltype(x), length(c), length(x))
k = Array(typeof(x0), length(c))

while t < tfinal && h >= hmin
if t + h > tfinal
Expand All @@ -211,20 +213,20 @@ function oderkf{T}(F::Function, tspan::AbstractVector, x0::AbstractVector{T}, a,
# This is part of the Dormand-Prince pair caveat.
# k[:,7] has already been computed, so use it instead of recomputing it
# again as k[:,1] during the next step.
k[1,:] = k[end,:]
k[1] = k[end]
else
k[1,:] = F(t,x) # first stage
k[1] = F(t,x) # first stage
end

for j = 2:length(c)
k[j,:] = F(t + h.*c[j], x + h.*(a[j,1:j-1]*k[1:j-1,:]).')
k[j] = F(t + h.*c[j], x + h.*(a[j,1:j-1]*k[1:j-1])[1])
end

# compute the 4th order estimate
x4 = x + h.*(b4*k).'
x4 = x + h.*(b4*k)[1]

# compute the 5th order estimate
x5 = x + h.*(b5*k).'
x5 = x + h.*(b5*k)[1]

# estimate the local truncation error
gamma1 = x5 - x4
Expand All @@ -238,7 +240,7 @@ function oderkf{T}(F::Function, tspan::AbstractVector, x0::AbstractVector{T}, a,
t = t + h
x = x5 # <-- using the higher order estimate is called 'local extrapolation'
tout = [tout; t]
xout = [xout; x.']
push!(xout, x)
end

# Update the step size
Expand All @@ -252,7 +254,7 @@ function oderkf{T}(F::Function, tspan::AbstractVector, x0::AbstractVector{T}, a,
println("Step size grew too small. t=", t, ", h=", h, ", x=", x)
end

return (tout, xout)
return tout, xout
end

# Both the Dormand-Prince and Fehlberg 4(5) coefficients are from a tableau in
Expand Down Expand Up @@ -316,59 +318,65 @@ const ode45 = ode45_dp
# ODEFUN(T,X) must return a column vector corresponding to f(t,x). Each
# row in the solution array X corresponds to a time returned in the
# column vector T.
function ode4{T}(F::Function, tspan::AbstractVector, x0::AbstractVector{T})
function ode4(F, tspan, x0)
h = diff(tspan)
x = Array(T, length(tspan), length(x0))
x[1,:] = x0
x = Array(typeof(x0), length(tspan))
x[1] = x0

midxdot = Array(T, 4, length(x0))
midxdot = Array(typeof(x0), 4)
for i = 1:length(tspan)-1
# Compute midstep derivatives
midxdot[1,:] = F(tspan[i], x[i,:]')
midxdot[2,:] = F(tspan[i]+h[i]./2, x[i,:]' + midxdot[1,:]'.*h[i]./2)
midxdot[3,:] = F(tspan[i]+h[i]./2, x[i,:]' + midxdot[2,:]'.*h[i]./2)
midxdot[4,:] = F(tspan[i]+h[i], x[i,:]' + midxdot[3,:]'.*h[i])
midxdot[1] = F(tspan[i], x[i])
midxdot[2] = F(tspan[i]+h[i]./2, x[i] + midxdot[1].*h[i]./2)
midxdot[3] = F(tspan[i]+h[i]./2, x[i] + midxdot[2].*h[i]./2)
midxdot[4] = F(tspan[i]+h[i], x[i] + midxdot[3].*h[i])

# Integrate
x[i+1,:] = x[i,:] + 1./6.*h[i].*[1 2 2 1]*midxdot
x[i+1] = x[i] + 1./6.*(h[i].*[1 2 2 1]*midxdot)[1]
end
return (tspan, x)
return tspan, x
end

#ODEROSENBROCK Solve stiff differential equations, Rosenbrock method
# with provided coefficients.
function oderosenbrock{T}(F::Function, G::Function, tspan::AbstractVector, x0::AbstractVector{T}, gamma, a, b, c)
function oderosenbrock(F, G, tspan, x0, gamma, a, b, c)
h = diff(tspan)
x = Array(T, length(tspan), length(x0))
x[1,:] = x0
x = Array(typeof(x0), length(tspan))
x[1] = x0

solstep = 1
while tspan[solstep] < maximum(tspan)
ts = tspan[solstep]
hs = h[solstep]
xs = reshape(x[solstep,:], size(x0))
xs = x[solstep]
dFdx = G(ts, xs)
jac = eye(size(dFdx,1))./gamma./hs-dFdx
# FIXME
if size(dFdx,1) == 1
jac = 1/gamma/hs - dFdx[1]
else
jac = eye(dFdx)./gamma./hs - dFdx
end

g = zeros(size(a,1), length(x0))
g[1,:] = jac \ F(ts + b[1].*hs, xs)
g = Array(typeof(x0), size(a,1))
g[1] = (jac \ F(ts + b[1].*hs, xs))
for i = 2:size(a,1)
g[i,:] = jac \ (F(ts + b[i].*hs, xs + (a[i,1:i-1]*g[1:i-1,:]).') + (c[i,1:i-1]*g[1:i-1,:]).'./hs)
g[i] = (jac \ (F(ts + b[i].*hs, xs + (a[i,1:i-1]*g[1:i-1])[1]) + (c[i,1:i-1]*g[1:i-1])[1]./hs))
end

x[solstep+1,:] = x[solstep,:] + b*g
x[solstep+1] = x[solstep] + (b*g)[1]
solstep += 1
end
return (tspan, x)
return tspan, x
end

function oderosenbrock{T}(F::Function, tspan::AbstractVector, x0::AbstractVector{T}, gamma, a, b, c)
function oderosenbrock(F, tspan, x0, gamma, a, b, c)
# Crude forward finite differences estimator as fallback
function jacobian(F::Function, t::Number, x::AbstractVector)
# FIXME: This doesn't really work if x is anything but a Vector or a scalar
function jacobian(F, t, x)
ftx = F(t, x)
dFdx = zeros(length(x), length(x))
for j = 1:length(x)
dx = zeros(size(x))
lx = max(length(x),1)
dFdx = zeros(eltype(x), lx, lx)
for j = 1:lx
dx = zeros(eltype(x), lx)
# The 100 below is heuristic
dx[j] = (x[j]+(x[j]==0))./100
dFdx[:,j] = (F(t,x+dx)-ftx)./dx[j]
Expand Down Expand Up @@ -412,10 +420,10 @@ ode4s_s(F, G, tspan, x0) = oderosenbrock(F, G, tspan, x0, s4_coefficients...)
const ode4s = ode4s_s

# ODE_MS Fixed-step, fixed-order multi-step numerical method with Adams-Bashforth-Moulton coefficients
function ode_ms{T}(F::Function, tspan::AbstractVector, x0::AbstractVector{T}, order::Integer)
function ode_ms(F, tspan, x0, order::Integer)
h = diff(tspan)
x = zeros(T, length(tspan), length(x0))
x[1,:] = x0
x = Array(typeof(x0), length(tspan))
x[1] = x0

if 1 <= order <= 4
b = [ 1 0 0 0
Expand All @@ -438,10 +446,10 @@ function ode_ms{T}(F::Function, tspan::AbstractVector, x0::AbstractVector{T}, or
for i = 1:length(tspan)-1
# Need to run the first several steps at reduced order
steporder = min(i, order)
xdot[i,:] = F(tspan[i], x[i,:]')
x[i+1,:] = x[i,:] + b[steporder,1:steporder]*xdot[i-(steporder-1):i,:].*h[i]
xdot[i] = F(tspan[i], x[i])
x[i+1] = x[i] + (b[steporder,1:steporder]*xdot[i-(steporder-1):i])[1].*h[i]
end
return (tspan, x)
return tspan, x
end

# Use order 4 by default
Expand Down
9 changes: 5 additions & 4 deletions test/tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,19 @@ for solver in solvers
# dy
# -- = 6 ==> y = 6t
# dt
t,y=solver((t,y)->6, [0:.1:1], [0.])
t,y=solver((t,y)->6, [0:.1:1], 0.)
@test maximum(abs(y-6t)) < tol

# dy
# -- = 2t ==> y = t.^2
# dt
t,y=solver((t,y)->2t, [0:.001:1], [0.])
t,y=solver((t,y)->2t, [0:.001:1], 0.)
@test maximum(abs(y-t.^2)) < tol

# dy
# -- = y ==> y = y0*e.^t
# dt
t,y=solver((t,y)->y, [0:.001:1], [1.])
t,y=solver((t,y)->y, [0:.001:1], 1.)
@test maximum(abs(y-e.^t)) < tol

# dv dw
Expand All @@ -40,7 +40,8 @@ for solver in solvers
#
# y = [v, w]
t,y=solver((t,y)->[-y[2], y[1]], [0:.001:2*pi], [1., 2.])
@test maximum(abs(y-[cos(t)-2*sin(t) 2*cos(t)+sin(t)])) < tol
ys = hcat(y...).' # convert Vector{Vector{Float}} to Matrix{Float}
@test maximum(abs(ys-[cos(t)-2*sin(t) 2*cos(t)+sin(t)])) < tol
end

println("All looks OK")

0 comments on commit 2db3a89

Please sign in to comment.