Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Semiclassical mcwf #255

Merged
merged 5 commits into from
Sep 23, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions src/mcwf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ function dmcwf_h_dynamic(t::Float64, psi::T, f::Function, rates::DecayRates,
H, J, Jdagger, rates_ = result
end
QO_CHECKS[] && check_mcwf(psi, H, J, Jdagger, rates_)
dmcwf_h(psi, H, J, Jdagger, dpsi, tmp, rates)
dmcwf_h(psi, H, J, Jdagger, dpsi, tmp, rates_)
end

function dmcwf_nh_dynamic(t::Float64, psi::T, f::Function, dpsi::T) where T<:Ket
Expand All @@ -224,7 +224,7 @@ function jump_dynamic(rng, t::Float64, psi::T, f::Function, psi_new::T, rates::D
else
rates_ = result[4]
end
jump(rng, t, psi, J, psi_new, rates)
jump(rng, t, psi, J, psi_new, rates_)
end

"""
Expand Down Expand Up @@ -373,6 +373,7 @@ function jump(rng, t::Float64, psi::T, J::Vector, psi_new::T, rates::Nothing) wh
if length(J)==1
operators.gemv!(complex(1.), J[1], psi, complex(0.), psi_new)
psi_new.data ./= norm(psi_new)
i=1
else
probs = zeros(Float64, length(J))
for i=1:length(J)
Expand All @@ -384,13 +385,14 @@ function jump(rng, t::Float64, psi::T, J::Vector, psi_new::T, rates::Nothing) wh
i = findfirst(cumprobs.>r)
operators.gemv!(complex(1.)/sqrt(probs[i]), J[i], psi, complex(0.), psi_new)
end
return nothing
return i
end

function jump(rng, t::Float64, psi::T, J::Vector, psi_new::T, rates::Vector{Float64}) where T<:Ket
if length(J)==1
operators.gemv!(complex(sqrt(rates[1])), J[1], psi, complex(0.), psi_new)
psi_new.data ./= norm(psi_new)
i=1
else
probs = zeros(Float64, length(J))
for i=1:length(J)
Expand All @@ -402,7 +404,7 @@ function jump(rng, t::Float64, psi::T, J::Vector, psi_new::T, rates::Vector{Floa
i = findfirst(cumprobs.>r)
operators.gemv!(complex(sqrt(rates[i]/probs[i])), J[i], psi, complex(0.), psi_new)
end
return nothing
return i
end

"""
Expand Down
181 changes: 180 additions & 1 deletion src/semiclassical.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,16 @@ module semiclassical

import Base: ==
import ..bases, ..operators, ..operators_dense
import ..timeevolution: integrate, recast!
import ..timeevolution: integrate, recast!, QO_CHECKS
import ..timeevolution.timeevolution_mcwf: jump
import LinearAlgebra: normalize!

using Random, LinearAlgebra
import OrdinaryDiffEq

# TODO: Remove imports
import DiffEqCallbacks, RecursiveArrayTools.copyat_or_push!
Base.@pure pure_inference(fout,T) = Core.Compiler.return_type(fout, T)

using ..bases, ..states, ..operators, ..operators_dense, ..timeevolution

Expand All @@ -26,6 +35,7 @@ end

Base.length(state::State) = length(state.quantum) + length(state.classical)
Base.copy(state::State) = State(copy(state.quantum), copy(state.classical))
normalize!(state::State{B,T}) where {B,T<:Ket} = normalize!(state.quantum)

function ==(a::State, b::State)
samebases(a.quantum, b.quantum) &&
Expand Down Expand Up @@ -111,6 +121,44 @@ function master_dynamic(tspan, state0::State{B,T}, fquantum, fclassical; kwargs.
master_dynamic(tspan, dm(state0), fquantum, fclassical; kwargs...)
end

"""
semiclassical.mcwf_dynamic(tspan, psi0, fquantum, fclassical, fjump_classical; <keyword arguments>)

Calculate MCWF trajectories coupled to a classical system.

# Arguments
* `tspan`: Vector specifying the points of time for which output should
be displayed.
* `rho0`: Initial semi-classical state [`semiclassical.State`](@ref).
* `fquantum`: Function `f(t, rho, u) -> (H, J, Jdagger)` returning the time
and/or state dependent Hamiltonian and Jump operators.
* `fclassical`: Function `f(t, rho, u, du)` calculating the possibly time and
state dependent derivative of the classical equations and storing it
in the complex vector `du`.
* `fjump_classical`: Function `f(t, rho, u, i)` making a classical jump when a
quantum jump of the i-th jump operator occurs.
* `fout=nothing`: If given, this function `fout(t, state)` is called every time
an output should be displayed. ATTENTION: The given state is not
permanent!
* `kwargs...`: Further arguments are passed on to the ode solver.
"""
function mcwf_dynamic(tspan, psi0::State{B,T}, fquantum, fclassical, fjump_classical;
seed=rand(UInt),
rates::DecayRates=nothing,
fout::Union{Function,Nothing}=nothing,
kwargs...) where {B<:Basis,T<:Ket{B}}
tspan_ = convert(Vector{Float64}, tspan)
tmp=copy(psi0.quantum)
function dmcwf_(t::Float64, psi::S, dpsi::S) where {B<:Basis,T<:Ket{B},S<:State{B,T}}
dmcwf_h_dynamic(t, psi, fquantum, fclassical, rates, dpsi, tmp)
end
j_(rng, t::Float64, psi, psi_new) = jump_dynamic(rng, t, psi, fquantum, fclassical, fjump_classical, psi_new, rates)
x0 = Vector{ComplexF64}(undef, length(psi0))
recast!(psi0, x0)
psi = copy(psi0)
dpsi = copy(psi0)
integrate_mcwf(dmcwf_, j_, tspan_, psi, seed, fout; kwargs...)
end

function recast!(state::State{B,T,C}, x::C) where {B<:Basis,T<:QuantumState{B},C<:Vector{ComplexF64}}
N = length(state.quantum)
Expand Down Expand Up @@ -139,4 +187,135 @@ function dmaster_h_dynamic(t::Float64, state::State{B,T}, fquantum::Function,
fclassical(t, state.quantum, state.classical, dstate.classical)
end

function dmcwf_h_dynamic(t::Float64, psi::T, fquantum::Function, fclassical::Function, rates::DecayRates,
dpsi::T, tmp::K) where {T,K}
fquantum_(t, rho) = fquantum(t, psi.quantum, psi.classical)
timeevolution.timeevolution_mcwf.dmcwf_h_dynamic(t, psi.quantum, fquantum_, rates, dpsi.quantum, tmp)
fclassical(t, psi.quantum, psi.classical, dpsi.classical)
end

function jump_dynamic(rng, t::Float64, psi::T, fquantum::Function, fclassical::Function, fjump_classical::Function, psi_new::T, rates::DecayRates) where T<:State
result = fquantum(t, psi.quantum, psi.classical)
QO_CHECKS[] && @assert 3 <= length(result) <= 4
J = result[2]
if length(result) == 3
rates_ = rates
else
rates_ = result[4]
end
i = jump(rng, t, psi.quantum, J, psi_new.quantum, rates_)
fjump_classical(t, psi_new.quantum, psi.classical, i)
psi_new.classical .= psi.classical
end

function integrate_mcwf(dmcwf::Function, jumpfun::Function, tspan,
psi0::T, seed, fout::Function;
display_beforeevent=false, display_afterevent=false,
#TODO: Remove kwargs
save_everystep=false, callback=nothing,
alg=OrdinaryDiffEq.DP5(),
kwargs...) where {B<:Basis,T<:State}

tmp = copy(psi0)
psi_tmp = copy(psi0)
x0 = [psi0.quantum.data; psi0.classical]
rng = MersenneTwister(convert(UInt, seed))
jumpnorm = Ref(rand(rng))
n = length(psi0.quantum)
djumpnorm(x::Vector{ComplexF64}, t::Float64, integrator) = norm(x[1:n])^2 - (1-jumpnorm[])

if !display_beforeevent && !display_afterevent
function dojump(integrator)
x = integrator.u
recast!(x, psi_tmp)
t = integrator.t
jumpfun(rng, t, psi_tmp, tmp)
recast!(tmp, x)
jumpnorm[] = rand(rng)
end
cb = OrdinaryDiffEq.ContinuousCallback(djumpnorm,dojump,
save_positions = (display_beforeevent,display_afterevent))


timeevolution.integrate(float(tspan), dmcwf, x0,
copy(psi0), copy(psi0), fout;
callback = cb,
kwargs...)
else
# Temporary workaround until proper tooling for saving
# TODO: Replace by proper call to timeevolution.integrate
function fout_(x::Vector{ComplexF64}, t::Float64, integrator)
recast!(x, state)
fout(t, state)
end

state = copy(psi0)
dstate = copy(psi0)
out_type = pure_inference(fout, Tuple{eltype(tspan),typeof(state)})
out = DiffEqCallbacks.SavedValues(Float64,out_type)
scb = DiffEqCallbacks.SavingCallback(fout_,out,saveat=tspan,
save_everystep=save_everystep,
save_start = false)

function dojump_display(integrator)
x = integrator.u
t = integrator.t

affect! = scb.affect!
if display_beforeevent
affect!.saveiter += 1
copyat_or_push!(affect!.saved_values.t, affect!.saveiter, integrator.t)
copyat_or_push!(affect!.saved_values.saveval, affect!.saveiter,
affect!.save_func(integrator.u, integrator.t, integrator),Val{false})
end

recast!(x, psi_tmp)
jumpfun(rng, t, psi_tmp, tmp)
recast!(tmp, x)

if display_afterevent
affect!.saveiter += 1
copyat_or_push!(affect!.saved_values.t, affect!.saveiter, integrator.t)
copyat_or_push!(affect!.saved_values.saveval, affect!.saveiter,
affect!.save_func(integrator.u, integrator.t, integrator),Val{false})
end
jumpnorm[] = rand(rng)
end

cb = OrdinaryDiffEq.ContinuousCallback(djumpnorm,dojump_display,
save_positions = (false,false))
full_cb = OrdinaryDiffEq.CallbackSet(callback,cb,scb)

function df_(dx::Vector{ComplexF64}, x::Vector{ComplexF64}, p, t)
recast!(x, state)
recast!(dx, dstate)
dmcwf(t, state, dstate)
recast!(dstate, dx)
end

prob = OrdinaryDiffEq.ODEProblem{true}(df_, x0,(tspan[1],tspan[end]))

sol = OrdinaryDiffEq.solve(
prob,
alg;
reltol = 1.0e-6,
abstol = 1.0e-8,
save_everystep = false, save_start = false,
save_end = false,
callback=full_cb, kwargs...)
return out.t, out.saveval
end
end

function integrate_mcwf(dmcwf::Function, jumpfun::Function, tspan,
psi0::T, seed, fout::Nothing;
kwargs...) where {T<:State}
function fout_(t::Float64, x::T)
psi = copy(x)
normalize!(psi)
return psi
end
integrate_mcwf(dmcwf, jumpfun, tspan, psi0, seed, fout_; kwargs...)
end

end # module
52 changes: 51 additions & 1 deletion test/test_semiclassical.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,4 +125,54 @@ semiclassical.master_dynamic(T, state0, fquantum_master, fclassical; fout=f)
tout, state_t = semiclassical.master_dynamic(T, state0, fquantum_master, fclassical)
f(T[end], state_t[end])

end # testset
# Test mcwf
# Set up system where only atom can jump once
ba = SpinBasis(1//2)
bf = FockBasis(5)
sm = sigmam(ba)⊗one(bf)
a = one(ba)⊗destroy(bf)
H = 0*sm
J = [0*a,sm]
Jdagger = dagger.(J)
function fquantum(t,psi,u)
return H, J, Jdagger
end
function fclassical(t,psi,u,du)
du[1] = u[2] # dx
du[2] = 0.0
end
njumps = [0]
function fjump_classical(t,psi,u,i)
@test i==2
njumps .+= 1
u[2] += 1.0
end
u0 = rand(2) .+ 0.0im
ψ0 = semiclassical.State(spinup(ba)⊗fockstate(bf,0),u0)

tout1, ψt1 = semiclassical.mcwf_dynamic(T,ψ0,fquantum,fclassical,fjump_classical,seed=1)
@test njumps == [1]
tout2, ψt2 = semiclassical.mcwf_dynamic(T,ψ0,fquantum,fclassical,fjump_classical,seed=1)
@test ψt2 == ψt1
tout3, ψt3 = semiclassical.mcwf_dynamic(T,ψ0,fquantum,fclassical,fjump_classical;display_beforeevent=true,seed=1)
@test length(ψt3) == length(ψt1)+1
tout4, ψt4 = semiclassical.mcwf_dynamic(T,ψ0,fquantum,fclassical,fjump_classical;display_beforeevent=true,display_afterevent=true,seed=1)
@test length(ψt4) == length(ψt1)+2
tout5, ut = semiclassical.mcwf_dynamic(T,ψ0,fquantum,fclassical,fjump_classical;display_beforeevent=true,display_afterevent=true,seed=1,fout=(t,psi)->copy(psi.classical))

@test ψt1[end].classical[2] == u0[2] + 1.0

# Test classical jump behavior
before_jump = findfirst(t -> !(t∈T), tout3)
after_jump = findlast(t-> !(t∈T), tout4)
@test after_jump == before_jump+1
@test ψt3[before_jump].classical[2] == u0[2]
@test ψt4[after_jump].classical[2] == u0[2] + 1.0
@test ut == [ψ.classical for ψ=ψt4]

# Test quantum jumps
@test ψt1[end].quantum == spindown(ba)⊗fockstate(bf,0)
@test ψt4[before_jump].quantum == ψ0.quantum
@test ψt4[after_jump].quantum == spindown(ba)⊗fockstate(bf,0)

end # testsets