Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use Differentiation Interface #2482

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lib/OrdinaryDiffEqDifferentiation/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ version = "1.1.0"
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ import OrdinaryDiffEqCore: get_chunksize, resize_J_W!, resize_nlsolver!, alg_aut

using FastBroadcast: @..

import DifferentiationInterface as DI

@static if isdefined(DiffEqBase, :OrdinaryDiffEqTag)
import DiffEqBase: OrdinaryDiffEqTag
else
Expand Down
3 changes: 2 additions & 1 deletion lib/OrdinaryDiffEqDifferentiation/src/alg_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ function alg_autodiff(alg)
if autodiff == Val(false)
return AutoFiniteDiff()
elseif autodiff == Val(true)
return AutoForwardDiff()
tag = ForwardDiff.Tag(OrdinaryDiffEqTag(), Float64) # FIXME
return AutoForwardDiff{1, typeof(tag)}(tag)
else
return _unwrap_val(autodiff)
end
Expand Down
9 changes: 6 additions & 3 deletions lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ function calc_tderivative!(integrator, cache, dtd1, repeat_step)
else
tf.uprev = uprev
tf.p = p
derivative!(dT, tf, t, du2, integrator, cache.grad_config)
alg = unwrap_alg(integrator, true)
# DI.derivative(f!, y, dy_dt, prep, backend, t) for y(t)
DI.derivative!(tf, linsolve_tmp, dT, cache.grad_config, alg_autodiff(alg), t)
end
end

Expand All @@ -57,7 +59,8 @@ function calc_tderivative(integrator, cache)
tf = cache.tf
tf.u = uprev
tf.p = p
dT = derivative(tf, t, integrator)
alg = unwrap_alg(integrator, true)
dT = DI.derivative(tf, alg_autodiff(alg), t)
end
dT
end
Expand Down Expand Up @@ -117,7 +120,7 @@ Update the Jacobian object `J`.

If `integrator.f` has a custom Jacobian update function, then it will be called. Otherwise,
either automatic or finite differencing will be used depending on the `cache`.
If `next_step`, then it will evaluate the Jacobian at the next step.
If `next_step`, theOrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl:n it will evaluate the Jacobian at the next step.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like a mistake! 🤔

"""
function calc_J!(J, integrator, cache, next_step::Bool = false)
@unpack dt, t, uprev, f, p, alg = integrator
Expand Down
258 changes: 5 additions & 253 deletions lib/OrdinaryDiffEqDifferentiation/src/derivative_wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,239 +75,22 @@ function Base.showerror(io::IO, e::FirstAutodiffJacError)
Base.showerror(io, e.e)
end

function derivative!(df::AbstractArray{<:Number}, f,
x::Union{Number, AbstractArray{<:Number}}, fx::AbstractArray{<:Number},
integrator, grad_config)
alg = unwrap_alg(integrator, true)
tmp = length(x) # We calculate derivative for all elements in gradient
if alg_autodiff(alg) isa AutoForwardDiff
T = if standardtag(alg)
typeof(ForwardDiff.Tag(OrdinaryDiffEqTag(), eltype(df)))
else
typeof(ForwardDiff.Tag(f, eltype(df)))
end

xdual = Dual{T, eltype(df), 1}(convert(eltype(df), x),
ForwardDiff.Partials((one(eltype(df)),)))

if integrator.iter == 1
try
f(grad_config, xdual)
catch e
throw(FirstAutodiffTgradError(e))
end
else
f(grad_config, xdual)
end

df .= first.(ForwardDiff.partials.(grad_config))
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
elseif alg_autodiff(alg) isa AutoFiniteDiff
FiniteDiff.finite_difference_gradient!(df, f, x, grad_config,
dir = diffdir(integrator))
fdtype = alg_difftype(alg)
if fdtype == Val{:forward} || fdtype == Val{:central}
tmp *= 2
if eltype(df) <: Complex
tmp *= 2
end
end
integrator.stats.nf += tmp
else
error("$alg_autodiff not yet supported in derivative! function")
end
nothing
end

function derivative(f, x::Union{Number, AbstractArray{<:Number}},
integrator)
local d
tmp = length(x) # We calculate derivative for all elements in gradient
alg = unwrap_alg(integrator, true)
if alg_autodiff(alg) isa AutoForwardDiff
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
if integrator.iter == 1
try
d = ForwardDiff.derivative(f, x)
catch e
throw(FirstAutodiffTgradError(e))
end
else
d = ForwardDiff.derivative(f, x)
end
elseif alg_autodiff(alg) isa AutoFiniteDiff
d = FiniteDiff.finite_difference_derivative(f, x, alg_difftype(alg),
dir = diffdir(integrator))
if alg_difftype(alg) === Val{:central} || alg_difftype(alg) === Val{:forward}
tmp *= 2
end
integrator.stats.nf += tmp
d
else
error("$alg_autodiff not yet supported in derivative function")
end
end

jacobian_autodiff(f, x, odefun, alg) = (ForwardDiff.derivative(f, x), 1, alg)
function jacobian_autodiff(f, x::AbstractArray, odefun, alg)
jac_prototype = odefun.jac_prototype
sparsity, colorvec = sparsity_colorvec(odefun, x)
maxcolor = maximum(colorvec)
chunk_size = get_chunksize(alg) === Val(0) ? nothing : get_chunksize(alg)
num_of_chunks = chunk_size === nothing ?
Int(ceil(maxcolor / getsize(ForwardDiff.pickchunksize(maxcolor)))) :
Int(ceil(maxcolor / _unwrap_val(chunk_size)))
(
forwarddiff_color_jacobian(f, x, colorvec = colorvec, sparsity = sparsity,
jac_prototype = jac_prototype, chunksize = chunk_size),
num_of_chunks)
end

function _nfcount(N, ::Type{diff_type}) where {diff_type}
if diff_type === Val{:complex}
tmp = N
elseif diff_type === Val{:forward}
tmp = N + 1
else
tmp = 2N
end
tmp
end

function jacobian_finitediff(f, x, ::Type{diff_type}, dir, colorvec, sparsity,
jac_prototype) where {diff_type}
(FiniteDiff.finite_difference_derivative(f, x, diff_type, eltype(x), dir = dir), 2)
end
function jacobian_finitediff(f, x::AbstractArray, ::Type{diff_type}, dir, colorvec,
sparsity, jac_prototype) where {diff_type}
f_in = diff_type === Val{:forward} ? f(x) : similar(x)
ret_eltype = eltype(f_in)
J = FiniteDiff.finite_difference_jacobian(f, x, diff_type, ret_eltype, f_in,
dir = dir, colorvec = colorvec,
sparsity = sparsity,
jac_prototype = jac_prototype)
return J, _nfcount(maximum(colorvec), diff_type)
end
function jacobian(f, x, integrator)
alg = unwrap_alg(integrator, true)
local tmp
if alg_autodiff(alg) isa AutoForwardDiff
if integrator.iter == 1
try
J, tmp = jacobian_autodiff(f, x, integrator.f, alg)
catch e
throw(FirstAutodiffJacError(e))
end
else
J, tmp = jacobian_autodiff(f, x, integrator.f, alg)
end
elseif alg_autodiff(alg) isa AutoFiniteDiff
jac_prototype = integrator.f.jac_prototype
sparsity, colorvec = sparsity_colorvec(integrator.f, x)
dir = diffdir(integrator)
J, tmp = jacobian_finitediff(f, x, alg_difftype(alg), dir, colorvec, sparsity,
jac_prototype)
else
bleh
end
integrator.stats.nf += tmp
J
end

function jacobian_finitediff_forward!(J, f, x, jac_config, forwardcache, integrator)
(FiniteDiff.finite_difference_jacobian!(J, f, x, jac_config, forwardcache,
dir = diffdir(integrator));
maximum(jac_config.colorvec))
end
function jacobian_finitediff!(J, f, x, jac_config, integrator)
(FiniteDiff.finite_difference_jacobian!(J, f, x, jac_config,
dir = diffdir(integrator));
2 * maximum(jac_config.colorvec))
return DI.jacobian(f, alg_autodiff(alg), x)
end

function jacobian!(J::AbstractMatrix{<:Number}, f, x::AbstractArray{<:Number},
fx::AbstractArray{<:Number}, integrator::DiffEqBase.DEIntegrator,
jac_config)
alg = unwrap_alg(integrator, true)
if alg_autodiff(alg) isa AutoForwardDiff
if integrator.iter == 1
try
forwarddiff_color_jacobian!(J, f, x, jac_config)
catch e
throw(FirstAutodiffJacError(e))
end
else
forwarddiff_color_jacobian!(J, f, x, jac_config)
end
OrdinaryDiffEqCore.increment_nf!(integrator.stats, maximum(jac_config.colorvec))
elseif alg_autodiff(alg) isa AutoFiniteDiff
isforward = alg_difftype(alg) === Val{:forward}
if isforward
forwardcache = get_tmp_cache(integrator, alg, unwrap_cache(integrator, true))[2]
f(forwardcache, x)
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
tmp = jacobian_finitediff_forward!(J, f, x, jac_config, forwardcache,
integrator)
else # not forward difference
tmp = jacobian_finitediff!(J, f, x, jac_config, integrator)
end
integrator.stats.nf += tmp
else
error("$alg_autodiff not yet supported in jacobian! function")
end
DI.jacobian!(f, fx, J, jac_config, alg_autodiff(alg), x)
nothing
end

function build_jac_config(alg, f::F1, uf::F2, du1, uprev, u, tmp, du2) where {F1, F2}
haslinsolve = hasfield(typeof(alg), :linsolve)

if !DiffEqBase.has_jac(f) && # No Jacobian if has analytical solution
(!DiffEqBase.has_Wfact_t(f)) &&
((concrete_jac(alg) === nothing && (!haslinsolve || (haslinsolve && # No Jacobian if linsolve doesn't want it
(alg.linsolve === nothing || LinearSolve.needs_concrete_A(alg.linsolve))))) ||
(concrete_jac(alg) !== nothing && concrete_jac(alg))) # Jacobian if explicitly asked for
jac_prototype = f.jac_prototype

if jac_prototype isa SparseMatrixCSC
if f.mass_matrix isa UniformScaling
idxs = diagind(jac_prototype)
@. @view(jac_prototype[idxs]) = 1
else
idxs = findall(!iszero, f.mass_matrix)
@. @view(jac_prototype[idxs]) = @view(f.mass_matrix[idxs])
end
end

sparsity, colorvec = sparsity_colorvec(f, u)

if alg_autodiff(alg) isa AutoForwardDiff
_chunksize = get_chunksize(alg) === Val(0) ? nothing : get_chunksize(alg) # SparseDiffEq uses different convection...
T = if standardtag(alg)
typeof(ForwardDiff.Tag(OrdinaryDiffEqTag(), eltype(u)))
else
typeof(ForwardDiff.Tag(uf, eltype(u)))
end
jac_config = ForwardColorJacCache(uf, uprev, _chunksize; colorvec = colorvec,
sparsity = sparsity, tag = T)
elseif alg_autodiff(alg) isa AutoFiniteDiff
if alg_difftype(alg) !== Val{:complex}
jac_config = FiniteDiff.JacobianCache(tmp, du1, du2, alg_difftype(alg),
colorvec = colorvec,
sparsity = sparsity)
else
jac_config = FiniteDiff.JacobianCache(Complex{eltype(tmp)}.(tmp),
Complex{eltype(du1)}.(du1), nothing,
alg_difftype(alg), eltype(u),
colorvec = colorvec,
sparsity = sparsity)
end
else
error("$alg_autodiff not yet supported in build_jac_config function")
end
else
jac_config = nothing
end
jac_config
# DI.pepare!_jacobian(f!, target, backend, input)
return DI.prepare_jacobian(uf, du1, alg_autodiff(alg), u)
end

function get_chunksize(jac_config::ForwardDiff.JacobianConfig{
Expand Down Expand Up @@ -356,38 +139,7 @@ function resize_grad_config!(grad_config::FiniteDiff.GradientCache, i)
end

function build_grad_config(alg, f::F1, tf::F2, du1, t) where {F1, F2}
if !DiffEqBase.has_tgrad(f)
if alg_autodiff(alg) isa AutoForwardDiff
T = if standardtag(alg)
typeof(ForwardDiff.Tag(OrdinaryDiffEqTag(), eltype(du1)))
else
typeof(ForwardDiff.Tag(f, eltype(du1)))
end

if du1 isa Array
dualt = Dual{T, eltype(du1), 1}(first(du1) * t,
ForwardDiff.Partials((one(eltype(du1)),)))
grad_config = similar(du1, typeof(dualt))
fill!(grad_config, false)
else
grad_config = ArrayInterface.restructure(du1,
Dual{
T,
eltype(du1),
1
}.(du1,
(ForwardDiff.Partials((one(eltype(du1)),)),)) .*
false)
end
elseif alg_autodiff(alg) isa AutoFiniteDiff
grad_config = FiniteDiff.GradientCache(du1, t, alg_difftype(alg))
else
error("$alg_autodiff not yet supported in build_grad_config function")
end
else
grad_config = nothing
end
grad_config
return DI.prepare_derivative(tf, du1, alg_autodiff(alg), t)
end

function sparsity_colorvec(f, x)
Expand Down
7 changes: 7 additions & 0 deletions lib/OrdinaryDiffEqDifferentiation/test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1 +1,8 @@
using OrdinaryDiffEqDifferentiation
using Test


@testset "OrdinaryDiffEqDifferentiation" begin


end
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ PrecompileTools.@compile_workload begin
Float64[]))
end

prob_list = []
for prob in prob_list, solver in solver_list
solve(prob, solver)(5.0)
end
Expand Down
Loading