Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ForwardDiff.jl fail on ode with complex numbers #858

Closed
AmitRotem opened this issue Dec 29, 2022 · 6 comments
Closed

ForwardDiff.jl fail on ode with complex numbers #858

AmitRotem opened this issue Dec 29, 2022 · 6 comments

Comments

@AmitRotem
Copy link
Contributor

Differentiating some (real valued) cost function by real parameters fails when the ode in the cost function has complex numbers.
This happens in promote_u0

@inline function promote_u0(u0, p, t0)
if !(eltype(u0) <: ForwardDiff.Dual)
T = anyeltypedual(p)
T === Any && return u0
if T <: ForwardDiff.Dual
return T.(u0)
end
end
u0
end

When trying to convert the complex u0 to ForwardDiff.Dual.
A workaround is to convert u0 to Complex{ForwardDiff.Dual{...}}.
Is there a reason not to switch between them?

Example;

using LinearAlgebra, DifferentialEquations
import ForwardDiff as FD

H0 = randn(ComplexF64,5,5)
H0-=H0'
A = randn(ComplexF64,5,5)
A-=A'
u0 = hcat(normalize(randn(ComplexF64,5)), normalize(randn(5)))

function Ht!(du,u,p,t)
    a,b,c = p
    du .= (A*u)*(a*cos(b*t))
    du.+= (H0*u)*c
    return nothing
end

function loss_fail(p)
    prob = ODEProblem(Ht!,u0,(0.0,1.0),p)
    sol = solve(prob)
    abs2(tr(first(sol.u)'last(sol.u)))
end

FD.gradient(loss_fail, rand(3))

errors with ForwardDiff.Dual{ForwardDiff.Tag{typeof(loss_fail), Float64}, Float64, 3}(::ComplexF64) is ambiguous

Workaround;

function loss_workaround(p)
    Tp = eltype(p)
    Ts = promote_type(Tp, eltype(u0))
    function f(du,u,qwe,t)
        Ht!(du,u,p,t)
    end
    prob = ODEProblem(f,Ts.(u0),Tp.((0.0,1.0)))
    sol = solve(prob)
    abs2(tr(first(sol.u)'last(sol.u)))
end

FD.gradient(loss_workaround, rand(3))
@ChrisRackauckas
Copy link
Member

Changing to:

@inline function promote_u0(u0, p, t0)
    if !(eltype(u0) <: ForwardDiff.Dual)
        T = anyeltypedual(p)
        T === Any && return u0
        if T <: ForwardDiff.Dual
            Ts = promote_type(T, eltype(u0))
            return Ts.(u0)
        end
    end
    u0
end

doesn't do anything here. So I'm inclined to say that this is an upstream issue with how ForwardDiff + Complex functions and I'm not sure what kind of nicety we can add here except maybe an FAQ entry explaining what to do with complex.

@AmitRotem
Copy link
Contributor Author

For me it seems to solve the ambiguity by making Ts a Complex{Dual{...}}.
Then it fails when getting the initial dt here
https://github.com/SciML/OrdinaryDiffEq.jl/blob/891622a76db14337bea1648abb6c6d103b93540c/src/initdt.jl#L130-L134
trying to convert Dual to Float.
This could be solved by promoting tspan with a promote_tspan method Complex{Dual{...}} type

function promote_tspan(u0::AbstractArray{<:Complex{<:ForwardDiff.Dual}}, p, tspan, prob, kwargs)
    return _promote_tspan(real(eltype(u0)).(tspan), kwargs)
end

Not sure about the switch on the having a cllback as in

function promote_tspan(u0::AbstractArray{<:ForwardDiff.Dual}, p, tspan, prob, kwargs)
if (haskey(kwargs, :callback) && has_continuous_callback(kwargs[:callback])) ||
(haskey(prob.kwargs, :callback) && has_continuous_callback(prob.kwargs[:callback]))
return _promote_tspan(eltype(u0).(tspan), kwargs)
else
return _promote_tspan(tspan, kwargs)
end
end

@ChrisRackauckas
Copy link
Member

But t shouldn't be complex-valued, that wouldn't make sense. d₀ and d₁ should be real-valued because it's the norms. What's making that complex?

@AmitRotem
Copy link
Contributor Author

I think that d₀ and d₁ return as (real) Dual but tspan is not Dual.
The promotion I've suggested is with real(eltype(u0)) which for u0 of type Complex{Dual{...}} return type Dual{...}, keeping tspan real.

@ChrisRackauckas
Copy link
Member

PR it with an example that's fixed. I don't think I'm following the whole train, but it sounds like a reasonable fix at a high level.

@ChrisRackauckas
Copy link
Member

Handled in #860

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants