Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement ForwardDiff in master solvers and general DiffEq problems on QO types #409

Merged
merged 6 commits into from
Sep 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,11 @@ WignerSymbols = "1, 2"
julia = "1.3"

[extras]
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["LinearAlgebra", "SparseArrays", "Random", "Test"]
test = ["FiniteDiff", "LinearAlgebra", "SparseArrays", "Random", "Test"]
4 changes: 4 additions & 0 deletions src/master.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ function master_h(tspan, rho0::Operator, H::AbstractOperator, J;
_check_const.(J)
_check_const.(Jdagger)
check_master(rho0, H, J, Jdagger, rates)
tspan, rho0 = _promote_time_and_state(rho0, H, J, tspan)
tmp = copy(rho0)
dmaster_(t, rho, drho) = dmaster_h!(drho, H, J, Jdagger, rates, rho, tmp)
integrate_master(tspan, dmaster_, rho0, fout; kwargs...)
Expand Down Expand Up @@ -41,6 +42,7 @@ function master_nh(tspan, rho0::Operator, Hnh::AbstractOperator, J;
_check_const.(J)
_check_const.(Jdagger)
check_master(rho0, Hnh, J, Jdagger, rates)
tspan, rho0 = _promote_time_and_state(rho0, Hnh, J, tspan)
tmp = copy(rho0)
dmaster_(t, rho, drho) = dmaster_nh!(drho, Hnh, Hnhdagger, J, Jdagger, rates, rho, tmp)
integrate_master(tspan, dmaster_, rho0, fout; kwargs...)
Expand Down Expand Up @@ -86,6 +88,7 @@ function master(tspan, rho0::Operator, H::AbstractOperator, J;
_check_const(H)
_check_const.(J)
_check_const.(Jdagger)
tspan, rho0 = _promote_time_and_state(rho0, H, J, tspan)
isreducible = check_master(rho0, H, J, Jdagger, rates)
if !isreducible
tmp = copy(rho0)
Expand Down Expand Up @@ -124,6 +127,7 @@ function master(tspan, rho0::Operator, L::SuperOperator; fout=nothing, kwargs...
b = GenericBasis(dim)
rho_ = Ket(b,reshape(rho0.data, dim))
L_ = Operator(b,b,L.data)
tspan, rho_ = _promote_time_and_state(rho_, L_, tspan)
dmaster_(t,rho,drho) = dmaster_liouville!(drho,L_,rho)

# Rewrite into density matrix when saving
Expand Down
19 changes: 0 additions & 19 deletions src/schroedinger.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,22 +122,3 @@ function check_schroedinger(psi::Bra, H)
check_multiplicable(psi, H)
check_samebases(H)
end


function _promote_time_and_state(u0, H::AbstractOperator, tspan)
Ts = eltype(H)
Tt = real(Ts)
p = Vector{Tt}(undef,0)
u0data_promote = DiffEqBase.promote_u0(u0.data, p, tspan[1])
tspan_promote = DiffEqBase.promote_tspan(u0data_promote, p, tspan, nothing, Dict{Symbol, Any}())
if u0data_promote !== u0.data
u0_promote = rebuild(u0, u0data_promote)
return tspan_promote, u0_promote
end
return tspan_promote, u0
end
_promote_time_and_state(u0, f, tspan) = _promote_time_and_state(u0, f(first(tspan), u0), tspan)

rebuild(op::Operator, new_data) = Operator(op.basis_l, op.basis_r, new_data)
rebuild(state::Ket, new_data) = Ket(state.basis, new_data)
rebuild(state::Bra, new_data) = Bra(state.basis, new_data)
44 changes: 44 additions & 0 deletions src/timeevolution_base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,47 @@ macro skiptimechecks(ex)
end

Base.@pure pure_inference(fout,T) = Core.Compiler.return_type(fout, T)

function _promote_time_and_state(u0, H::AbstractOperator, tspan)
Ts = eltype(H)
Tt = real(Ts)
p = Vector{Tt}(undef,0)
u0_promote = DiffEqBase.promote_u0(u0, p, tspan[1])
tspan_promote = DiffEqBase.promote_tspan(u0_promote.data, p, tspan, nothing, Dict{Symbol, Any}())
return tspan_promote, u0_promote
end
function _promote_time_and_state(u0, H::AbstractOperator, J, tspan)
Ts = DiffEqBase.promote_dual(eltype(H), DiffEqBase.anyeltypedual(J))
Tt = real(Ts)
p = Vector{Tt}(undef,0)
u0_promote = DiffEqBase.promote_u0(u0, p, tspan[1])
tspan_promote = DiffEqBase.promote_tspan(u0_promote.data, p, tspan, nothing, Dict{Symbol, Any}())
return tspan_promote, u0_promote
end

_promote_time_and_state(u0, f, tspan) = _promote_time_and_state(u0, f(first(tspan)..., u0), tspan)

@inline function DiffEqBase.promote_u0(u0::Ket, p, t0)
u0data_promote = DiffEqBase.promote_u0(u0.data, p, t0)
if u0data_promote !== u0.data
u0_promote = Ket(u0.basis, u0data_promote)
return u0_promote
end
return u0
end
@inline function DiffEqBase.promote_u0(u0::Bra, p, t0)
u0data_promote = DiffEqBase.promote_u0(u0.data, p, t0)
if u0data_promote !== u0.data
u0_promote = Bra(u0.basis, u0data_promote)
return u0_promote
end
return u0
end
@inline function DiffEqBase.promote_u0(u0::Operator, p, t0)
u0data_promote = DiffEqBase.promote_u0(u0.data, p, t0)
if u0data_promote !== u0.data
u0_promote = Operator(u0.basis_l, u0.basis_r, u0data_promote)
return u0_promote
end
return u0
end
56 changes: 56 additions & 0 deletions test/test_ForwardDiff.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Test
using OrdinaryDiffEq, QuantumOptics
import ForwardDiff
import FiniteDiff

# for some caese ForwardDiff.jl returns NaN due to issue with DiffEq.jl. see https://github.com/SciML/DiffEqBase.jl/issues/861
# Here we test;
Expand All @@ -12,6 +13,27 @@ import ForwardDiff
# here we partially control the gradient error by limiting step size (dtmax)


@testset "ForwardDiff on ODE Problems" begin

# schroedinger equation
b = SpinBasis(10//1)
psi0 = spindown(b)
H(p) = p[1]*sigmax(b) + p[2]*sigmam(b)
f_schrod!(dpsi, psi, p, t) = timeevolution.dschroedinger!(dpsi, H(p), psi)
function cost_schrod(p)
prob = ODEProblem(f_schrod!, psi0, (0.0, pi), p)
sol = solve(prob, DP5(); save_everystep=false)
return 1 - norm(sol[end])
end

p = [rand(), rand()]
fordiff_schrod = ForwardDiff.gradient(cost_schrod, p)
findiff_schrod = FiniteDiff.finite_difference_gradient(cost_schrod, p)

@test isapprox(fordiff_schrod, findiff_schrod; atol=1e-2)

end

@testset "ForwardDiff with schroedinger" begin

# system
Expand Down Expand Up @@ -73,3 +95,37 @@ Ftdop(1.0)
@test ForwardDiff.derivative(Ftdop, 1.0) isa Any

end # testset


@testset "ForwardDiff with master" begin

b = SpinBasis(1//2)
psi0 = spindown(b)
rho0 = dm(psi0)
params = [rand(), rand()]

for f in (:(timeevolution.master), :(timeevolution.master_h), :(timeevolution.master_nh))
# test to see if parameter propagates through Hamiltonian
H(p) = p[1]*sigmax(b) + p[2]*sigmam(b) # Hamiltonian
function cost_H(p) #
tf, psif = eval(f)((0.0, pi), rho0, H(p), [sigmax(b)])
return 1 - norm(psif)
end

forwarddiff_H = ForwardDiff.gradient(cost_H, params)
finitediff_H = FiniteDiff.finite_difference_gradient(cost_H, params)
@test isapprox(forwarddiff_H, finitediff_H; atol=1e-2)

# test to see if parameter propagates through Jump operator
J(p) = p[1]*sigmax(b) + p[2]*sigmam(b) # jump operator
function cost_J(p)
tf, psif = eval(f)((0.0, pi), rho0, sigmax(b), [J(p)])
return 1 - norm(psif)
end

forwarddiff_J = ForwardDiff.gradient(cost_J, params)
finitediff_J = FiniteDiff.finite_difference_gradient(cost_J, params)
@test isapprox(forwarddiff_J, finitediff_J; atol=1e-2)
end

end