Skip to content

Commit

Permalink
Merge pull request #860 from AmitRotem/master
Browse files Browse the repository at this point in the history
Passthrough AD in ode with complex numbers
  • Loading branch information
ChrisRackauckas authored Jan 1, 2023
2 parents c5154b8 + b3fa3d1 commit 81465e1
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 9 deletions.
16 changes: 16 additions & 0 deletions src/forwarddiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,18 @@ end
u0
end

@inline function promote_u0(u0::AbstractArray{<:Complex}, p, t0)
if !(real(eltype(u0)) <: ForwardDiff.Dual)
T = anyeltypedual(p)
T === Any && return u0
if T <: ForwardDiff.Dual
Ts = promote_type(T, eltype(u0))
return Ts.(u0)
end
end
u0
end

function promote_tspan(u0::AbstractArray{<:ForwardDiff.Dual}, p,
tspan::Tuple{<:ForwardDiff.Dual, <:ForwardDiff.Dual}, prob, kwargs)
return _promote_tspan(tspan, kwargs)
Expand All @@ -229,6 +241,10 @@ function promote_tspan(u0::AbstractArray{<:ForwardDiff.Dual}, p, tspan, prob, kw
end
end

function promote_tspan(u0::AbstractArray{<:Complex{<:ForwardDiff.Dual}}, p, tspan, prob, kwargs)
return _promote_tspan(real(eltype(u0)).(tspan), kwargs)
end

value(x::Type{ForwardDiff.Dual{T, V, N}}) where {T, V, N} = V
value(x::ForwardDiff.Dual) = value(ForwardDiff.value(x))

Expand Down
74 changes: 74 additions & 0 deletions test/downstream/complex_number_ad.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
using LinearAlgebra, OrdinaryDiffEq, Test
import ForwardDiff

# setup
pd = 3

## ode with complex numbers
H0 = rand(ComplexF64,pd,pd)
A = rand(ComplexF64,pd,pd)
function f!(du,u,p,t)
a,b,c = p
du .= (A*u) * (a*cos(b*t+c))
du.+= H0*u
return nothing
end

## time span
tspan = (0.0, 1.0)

## initial state
u0 = hcat(normalize(rand(ComplexF64,pd)), normalize(rand(pd)))

## ode problem
prob0 = ODEProblem(f!, u0, tspan, rand(3); saveat=range(tspan..., 3), reltol=1e-6, alg=Tsit5())
## final state cost
cost(u) = abs2(tr(first(u)'u[2])) - abs2(tr(first(u)'last(u)))

## real loss function via complex ode
function loss(p)
prob = remake(prob0; p)
sol = solve(prob)
cost(sol.u) + sum(p) / 10
end

## same problem via reals
### realify complex ode problem
function real_f(du,u,p,t)
complex_u = complex.(selectdim(u,3,1), selectdim(u,3,2))
complex_du = copy(complex_u)
prob0.f(complex_du, complex_u, p, t)
selectdim(du,3,1) .= real(complex_du)
selectdim(du,3,2) .= imag(complex_du)
return nothing
end
prob0_real = remake(prob0; f=real_f, u0=cat(real(prob0.u0), imag(prob0.u0); dims=3))
### real loss function via real ode
function loss_via_real(p)
prob = remake(prob0_real; p)
sol = solve(prob)
u = [complex.(selectdim(u,3,1), selectdim(u,3,2)) for u=sol.u]
cost(u) + sum(p) / 10
end

# assert
@assert eltype(last(solve(prob0 ).u)) <: Complex
@assert eltype(last(solve(prob0_real).u)) <: Real
function assert_fun()
p0 = rand(3)
isapprox(loss(p0), loss_via_real(p0); rtol=1e-4)
end
@assert all([assert_fun() for _=1:2^6])

# test ad with ForwardDiff
function test_ad()
p0 = rand(3)
grad_real = ForwardDiff.gradient(loss_via_real, p0)
grad_complex = ForwardDiff.gradient(loss, p0)
any(isnan.(grad_complex)) && @warn "NaN detected in gradient using ode with complex numbers !!"
any(isnan.(grad_real )) && @warn "NaN detected in gradient using realified ode !!"
rel_err = norm(grad_complex-grad_real)/max(norm(grad_complex), norm(grad_real))
isapprox(grad_complex, grad_real; rtol=1e-6) ? true : (@show rel_err; false)
end

@time @test all([test_ad() for _=1:2^6])
38 changes: 29 additions & 9 deletions test/forwarddiff_dual_detection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ t0 = 1.0

@test DiffEqBase.promote_u0(u0, p, t0) isa Float64
@test DiffEqBase.promote_u0(u0, p, t0) == 2.0
@test DiffEqBase.promote_u0(cis(u0), p, t0) isa ComplexF64
@test DiffEqBase.promote_u0(cis(u0), p, t0) == cis(2.0)

struct MyStruct{T, T2} <: Number
x::T
Expand Down Expand Up @@ -39,10 +41,12 @@ p_possibilities = [ForwardDiff.Dual(2.0), (ForwardDiff.Dual(2.0), 2.0),
for p in p_possibilities
@show p
@test DiffEqBase.anyeltypedual(p) <: ForwardDiff.Dual
u0 = 2.0
local u0 = 2.0
@test DiffEqBase.promote_u0(u0, p, t0) isa ForwardDiff.Dual
@test DiffEqBase.promote_u0([cis(u0)], p, t0) isa AbstractArray{<:Complex{<:ForwardDiff.Dual}}
u0 = ForwardDiff.Dual(2.0)
@test DiffEqBase.promote_u0(u0, p, t0) isa ForwardDiff.Dual
@test DiffEqBase.promote_u0([cis(u0)], p, t0) isa AbstractArray{<:Complex{<:ForwardDiff.Dual}}
@inferred DiffEqBase.anyeltypedual(p)
end

Expand All @@ -60,10 +64,12 @@ for p in higher_order_p_possibilities
@test DiffEqBase.anyeltypedual(p) <: ForwardDiff.Dual
@test DiffEqBase.anyeltypedual(p) <:
ForwardDiff.Dual{Nothing, ForwardDiff.Dual{MyStruct, Float64, 0}, 0}
u0 = 2.0
local u0 = 2.0
@test DiffEqBase.promote_u0(u0, p, t0) isa ForwardDiff.Dual
@test DiffEqBase.promote_u0([cis(u0)], p, t0) isa AbstractArray{<:Complex{<:ForwardDiff.Dual}}
u0 = ForwardDiff.Dual(2.0)
@test DiffEqBase.promote_u0(u0, p, t0) isa ForwardDiff.Dual
@test DiffEqBase.promote_u0([cis(u0)], p, t0) isa AbstractArray{<:Complex{<:ForwardDiff.Dual}}
@inferred DiffEqBase.anyeltypedual(p)
end

Expand All @@ -81,10 +87,12 @@ VERSION >= v"1.7" &&
for p in p_possibilities17
@show p
@test DiffEqBase.anyeltypedual(p) <: ForwardDiff.Dual
u0 = 2.0
local u0 = 2.0
@test DiffEqBase.promote_u0(u0, p, t0) isa ForwardDiff.Dual
@test DiffEqBase.promote_u0([cis(u0)], p, t0) isa AbstractArray{<:Complex{<:ForwardDiff.Dual}}
u0 = ForwardDiff.Dual(2.0)
@test DiffEqBase.promote_u0(u0, p, t0) isa ForwardDiff.Dual
@test DiffEqBase.promote_u0([cis(u0)], p, t0) isa AbstractArray{<:Complex{<:ForwardDiff.Dual}}

if VERSION >= v"1.7"
# v1.6 does not infer `getproperty` mapping
Expand Down Expand Up @@ -118,10 +126,12 @@ p_possibilities_uninferrred = [
for p in p_possibilities_uninferrred
@show p
@test DiffEqBase.anyeltypedual(p) <: ForwardDiff.Dual
u0 = 2.0
local u0 = 2.0
@test DiffEqBase.promote_u0(u0, p, t0) isa ForwardDiff.Dual
@test DiffEqBase.promote_u0([cis(u0)], p, t0) isa AbstractArray{<:Complex{<:ForwardDiff.Dual}}
u0 = ForwardDiff.Dual(2.0)
@test DiffEqBase.promote_u0(u0, p, t0) isa ForwardDiff.Dual
@test DiffEqBase.promote_u0([cis(u0)], p, t0) isa AbstractArray{<:Complex{<:ForwardDiff.Dual}}
end

p_possibilities_missed = [
Expand All @@ -133,10 +143,12 @@ p_possibilities_missed = [
for p in p_possibilities_missed
@show p
@test_broken DiffEqBase.anyeltypedual(p) <: ForwardDiff.Dual
u0 = 2.0
local u0 = 2.0
@test_broken DiffEqBase.promote_u0(u0, p, t0) isa ForwardDiff.Dual
@test_broken DiffEqBase.promote_u0([cis(u0)], p, t0) isa AbstractArray{<:Complex{<:ForwardDiff.Dual}}
u0 = ForwardDiff.Dual(2.0)
@test DiffEqBase.promote_u0(u0, p, t0) isa ForwardDiff.Dual
@test DiffEqBase.promote_u0([cis(u0)], p, t0) isa AbstractArray{<:Complex{<:ForwardDiff.Dual}}
end

p_possibilities_notdual = [
Expand All @@ -146,10 +158,12 @@ p_possibilities_notdual = [
for p in p_possibilities_notdual
@show p
@test !(DiffEqBase.anyeltypedual(p) <: ForwardDiff.Dual)
u0 = 2.0
local u0 = 2.0
@test !(DiffEqBase.promote_u0(u0, p, t0) isa ForwardDiff.Dual)
@test !(DiffEqBase.promote_u0([cis(u0)], p, t0) isa AbstractArray{<:Complex{<:ForwardDiff.Dual}})
u0 = ForwardDiff.Dual(2.0)
@test DiffEqBase.promote_u0(u0, p, t0) isa ForwardDiff.Dual
@test DiffEqBase.promote_u0([cis(u0)], p, t0) isa AbstractArray{<:Complex{<:ForwardDiff.Dual}}
@inferred DiffEqBase.anyeltypedual(p)
end

Expand Down Expand Up @@ -187,10 +201,12 @@ push!(p_possibilities_notdual_uninferred, x)

for p in p_possibilities_notdual_uninferred
@test !(DiffEqBase.anyeltypedual(p) <: ForwardDiff.Dual)
u0 = 2.0
local u0 = 2.0
@test !(DiffEqBase.promote_u0(u0, p, t0) isa ForwardDiff.Dual)
@test !(DiffEqBase.promote_u0([cis(u0)], p, t0) isa AbstractArray{<:Complex{<:ForwardDiff.Dual}})
u0 = ForwardDiff.Dual(2.0)
@test DiffEqBase.promote_u0(u0, p, t0) isa ForwardDiff.Dual
@test DiffEqBase.promote_u0([cis(u0)], p, t0) isa AbstractArray{<:Complex{<:ForwardDiff.Dual}}
end

f(du, u, p, t) = du .= u
Expand All @@ -203,10 +219,12 @@ p_possibilities_configs = [
for p in p_possibilities_configs
@show p
@test !(DiffEqBase.anyeltypedual(p) <: ForwardDiff.Dual)
u0 = 2.0
local u0 = 2.0
@test !(DiffEqBase.promote_u0(u0, p, t0) isa ForwardDiff.Dual)
@test !(DiffEqBase.promote_u0([cis(u0)], p, t0) isa AbstractArray{<:Complex{<:ForwardDiff.Dual}})
u0 = ForwardDiff.Dual(2.0)
@test DiffEqBase.promote_u0(u0, p, t0) isa ForwardDiff.Dual
@test DiffEqBase.promote_u0([cis(u0)], p, t0) isa AbstractArray{<:Complex{<:ForwardDiff.Dual}}
@inferred DiffEqBase.anyeltypedual(p)
end

Expand All @@ -217,10 +235,12 @@ p_possibilities_configs_not_inferred = [
for p in p_possibilities_configs_not_inferred
@show p
@test !(DiffEqBase.anyeltypedual(p) <: ForwardDiff.Dual)
u0 = 2.0
local u0 = 2.0
@test !(DiffEqBase.promote_u0(u0, p, t0) isa ForwardDiff.Dual)
@test !(DiffEqBase.promote_u0([cis(u0)], p, t0) isa AbstractArray{<:Complex{<:ForwardDiff.Dual}})
u0 = ForwardDiff.Dual(2.0)
@test DiffEqBase.promote_u0(u0, p, t0) isa ForwardDiff.Dual
@test DiffEqBase.promote_u0([cis(u0)], p, t0) isa AbstractArray{<:Complex{<:ForwardDiff.Dual}}
end

# use `getfield` on `Pairs`, see https://github.com/JuliaLang/julia/pull/39448
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ end
@time @safetestset "Ensemble AD Tests" begin include("downstream/ensemble_ad.jl") end
@time @safetestset "Community Callback Tests" begin include("downstream/community_callback_tests.jl") end
@time @testset "Distributed Ensemble Tests" begin include("downstream/distributed_ensemble.jl") end
@time @safetestset "AD via ode with complex numbers" begin include("downstream/complex_number_ad.jl") end
end

if !is_APPVEYOR && GROUP == "GPU"
Expand Down

0 comments on commit 81465e1

Please sign in to comment.