Skip to content

Commit

Permalink
Implement ForwardDiff in master solvers and general DiffEq problems o…
Browse files Browse the repository at this point in the history
…n QO types (#409)
  • Loading branch information
apkille authored Sep 14, 2024
1 parent beb0f37 commit 2939a6a
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 20 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,11 @@ WignerSymbols = "1, 2"
julia = "1.10"

[extras]
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["LinearAlgebra", "SparseArrays", "Random", "Test"]
test = ["FiniteDiff", "LinearAlgebra", "SparseArrays", "Random", "Test"]
4 changes: 4 additions & 0 deletions src/master.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ function master_h(tspan, rho0::Operator, H::AbstractOperator, J;
_check_const.(J)
_check_const.(Jdagger)
check_master(rho0, H, J, Jdagger, rates)
tspan, rho0 = _promote_time_and_state(rho0, H, J, tspan)
tmp = copy(rho0)
dmaster_(t, rho, drho) = dmaster_h!(drho, H, J, Jdagger, rates, rho, tmp)
integrate_master(tspan, dmaster_, rho0, fout; kwargs...)
Expand Down Expand Up @@ -41,6 +42,7 @@ function master_nh(tspan, rho0::Operator, Hnh::AbstractOperator, J;
_check_const.(J)
_check_const.(Jdagger)
check_master(rho0, Hnh, J, Jdagger, rates)
tspan, rho0 = _promote_time_and_state(rho0, Hnh, J, tspan)
tmp = copy(rho0)
dmaster_(t, rho, drho) = dmaster_nh!(drho, Hnh, Hnhdagger, J, Jdagger, rates, rho, tmp)
integrate_master(tspan, dmaster_, rho0, fout; kwargs...)
Expand Down Expand Up @@ -86,6 +88,7 @@ function master(tspan, rho0::Operator, H::AbstractOperator, J;
_check_const(H)
_check_const.(J)
_check_const.(Jdagger)
tspan, rho0 = _promote_time_and_state(rho0, H, J, tspan)
isreducible = check_master(rho0, H, J, Jdagger, rates)
if !isreducible
tmp = copy(rho0)
Expand Down Expand Up @@ -124,6 +127,7 @@ function master(tspan, rho0::Operator, L::SuperOperator; fout=nothing, kwargs...
b = GenericBasis(dim)
rho_ = Ket(b,reshape(rho0.data, dim))
L_ = Operator(b,b,L.data)
tspan, rho_ = _promote_time_and_state(rho_, L_, tspan)
dmaster_(t,rho,drho) = dmaster_liouville!(drho,L_,rho)

# Rewrite into density matrix when saving
Expand Down
19 changes: 0 additions & 19 deletions src/schroedinger.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,22 +122,3 @@ function check_schroedinger(psi::Bra, H)
check_multiplicable(psi, H)
check_samebases(H)
end


function _promote_time_and_state(u0, H::AbstractOperator, tspan)
Ts = eltype(H)
Tt = real(Ts)
p = Vector{Tt}(undef,0)
u0data_promote = DiffEqBase.promote_u0(u0.data, p, tspan[1])
tspan_promote = DiffEqBase.promote_tspan(u0data_promote, p, tspan, nothing, Dict{Symbol, Any}())
if u0data_promote !== u0.data
u0_promote = rebuild(u0, u0data_promote)
return tspan_promote, u0_promote
end
return tspan_promote, u0
end
_promote_time_and_state(u0, f, tspan) = _promote_time_and_state(u0, f(first(tspan), u0), tspan)

rebuild(op::Operator, new_data) = Operator(op.basis_l, op.basis_r, new_data)
rebuild(state::Ket, new_data) = Ket(state.basis, new_data)
rebuild(state::Bra, new_data) = Bra(state.basis, new_data)
44 changes: 44 additions & 0 deletions src/timeevolution_base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,47 @@ macro skiptimechecks(ex)
end

Base.@pure pure_inference(fout,T) = Core.Compiler.return_type(fout, T)

function _promote_time_and_state(u0, H::AbstractOperator, tspan)
Ts = eltype(H)
Tt = real(Ts)
p = Vector{Tt}(undef,0)
u0_promote = DiffEqBase.promote_u0(u0, p, tspan[1])
tspan_promote = DiffEqBase.promote_tspan(u0_promote.data, p, tspan, nothing, Dict{Symbol, Any}())
return tspan_promote, u0_promote
end
function _promote_time_and_state(u0, H::AbstractOperator, J, tspan)
Ts = DiffEqBase.promote_dual(eltype(H), DiffEqBase.anyeltypedual(J))
Tt = real(Ts)
p = Vector{Tt}(undef,0)
u0_promote = DiffEqBase.promote_u0(u0, p, tspan[1])
tspan_promote = DiffEqBase.promote_tspan(u0_promote.data, p, tspan, nothing, Dict{Symbol, Any}())
return tspan_promote, u0_promote
end

_promote_time_and_state(u0, f, tspan) = _promote_time_and_state(u0, f(first(tspan)..., u0), tspan)

@inline function DiffEqBase.promote_u0(u0::Ket, p, t0)
u0data_promote = DiffEqBase.promote_u0(u0.data, p, t0)
if u0data_promote !== u0.data
u0_promote = Ket(u0.basis, u0data_promote)
return u0_promote
end
return u0
end
@inline function DiffEqBase.promote_u0(u0::Bra, p, t0)
u0data_promote = DiffEqBase.promote_u0(u0.data, p, t0)
if u0data_promote !== u0.data
u0_promote = Bra(u0.basis, u0data_promote)
return u0_promote
end
return u0
end
@inline function DiffEqBase.promote_u0(u0::Operator, p, t0)
u0data_promote = DiffEqBase.promote_u0(u0.data, p, t0)
if u0data_promote !== u0.data
u0_promote = Operator(u0.basis_l, u0.basis_r, u0data_promote)
return u0_promote
end
return u0
end
56 changes: 56 additions & 0 deletions test/test_ForwardDiff.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Test
using OrdinaryDiffEq, QuantumOptics
import ForwardDiff
import FiniteDiff

# for some caese ForwardDiff.jl returns NaN due to issue with DiffEq.jl. see https://github.com/SciML/DiffEqBase.jl/issues/861
# Here we test;
Expand All @@ -12,6 +13,27 @@ import ForwardDiff
# here we partially control the gradient error by limiting step size (dtmax)


@testset "ForwardDiff on ODE Problems" begin

# schroedinger equation
b = SpinBasis(10//1)
psi0 = spindown(b)
H(p) = p[1]*sigmax(b) + p[2]*sigmam(b)
f_schrod!(dpsi, psi, p, t) = timeevolution.dschroedinger!(dpsi, H(p), psi)
function cost_schrod(p)
prob = ODEProblem(f_schrod!, psi0, (0.0, pi), p)
sol = solve(prob, DP5(); save_everystep=false)
return 1 - norm(sol[end])
end

p = [rand(), rand()]
fordiff_schrod = ForwardDiff.gradient(cost_schrod, p)
findiff_schrod = FiniteDiff.finite_difference_gradient(cost_schrod, p)

@test isapprox(fordiff_schrod, findiff_schrod; atol=1e-2)

end

@testset "ForwardDiff with schroedinger" begin

# system
Expand Down Expand Up @@ -73,3 +95,37 @@ Ftdop(1.0)
@test ForwardDiff.derivative(Ftdop, 1.0) isa Any

end # testset


@testset "ForwardDiff with master" begin

b = SpinBasis(1//2)
psi0 = spindown(b)
rho0 = dm(psi0)
params = [rand(), rand()]

for f in (:(timeevolution.master), :(timeevolution.master_h), :(timeevolution.master_nh))
# test to see if parameter propagates through Hamiltonian
H(p) = p[1]*sigmax(b) + p[2]*sigmam(b) # Hamiltonian
function cost_H(p) #
tf, psif = eval(f)((0.0, pi), rho0, H(p), [sigmax(b)])
return 1 - norm(psif)
end

forwarddiff_H = ForwardDiff.gradient(cost_H, params)
finitediff_H = FiniteDiff.finite_difference_gradient(cost_H, params)
@test isapprox(forwarddiff_H, finitediff_H; atol=1e-2)

# test to see if parameter propagates through Jump operator
J(p) = p[1]*sigmax(b) + p[2]*sigmam(b) # jump operator
function cost_J(p)
tf, psif = eval(f)((0.0, pi), rho0, sigmax(b), [J(p)])
return 1 - norm(psif)
end

forwarddiff_J = ForwardDiff.gradient(cost_J, params)
finitediff_J = FiniteDiff.finite_difference_gradient(cost_J, params)
@test isapprox(forwarddiff_J, finitediff_J; atol=1e-2)
end

end

0 comments on commit 2939a6a

Please sign in to comment.