Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MCWF jump times and indices #257

Merged
merged 13 commits into from
Sep 25, 2019
182 changes: 96 additions & 86 deletions src/mcwf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ and therefore must not be changed.
operators. If they are not given they are calculated automatically.
* `display_beforeevent=false`: `fout` is called before every jump.
* `display_afterevent=false`: `fout` is called after every jump.
* `display_jumps=false`: If set to true, an additional list of times and indices
is returned. These correspond to the times at which a jump occured and the index
of the jump operators with which the jump occured, respectively.
* `kwargs...`: Further arguments are passed on to the ode solver.
"""
function mcwf(tspan, psi0::T, H::AbstractOperator{B,B}, J::Vector;
Expand Down Expand Up @@ -157,6 +160,9 @@ normalized nor permanent! It is still in use by the ode solve
and therefore must not be changed.
* `display_beforeevent=false`: `fout` is called before every jump.
* `display_afterevent=false`: `fout` is called after every jump.
* `display_jumps=false`: If set to true, an additional list of times and indices
is returned. These correspond to the times at which a jump occured and the index
of the jump operators with which the jump occured, respectively.
* `kwargs...`: Further arguments are passed on to the ode solver.
"""
function mcwf_dynamic(tspan, psi0::T, f::Function;
Expand Down Expand Up @@ -251,112 +257,116 @@ Integrate a single Monte Carlo wave function trajectory.
function integrate_mcwf(dmcwf::Function, jumpfun::Function, tspan,
psi0::T, seed, fout::Function;
display_beforeevent=false, display_afterevent=false,
#TODO: Remove kwargs
display_jumps=false,
save_everystep=false, callback=nothing,
alg=OrdinaryDiffEq.DP5(),
kwargs...) where {B<:Basis,D<:Vector{ComplexF64},T<:Ket{B,D}}

tmp = copy(psi0)
psi_tmp = copy(psi0)
as_vector(psi::T) = psi.data
rng = MersenneTwister(convert(UInt, seed))
jumpnorm = Ref(rand(rng))
djumpnorm(x::D, t::Float64, integrator) = norm(x)^2 - (1-jumpnorm[])

if !display_beforeevent && !display_afterevent
function dojump(integrator)
x = integrator.u
recast!(x, psi_tmp)
t = integrator.t
jumpfun(rng, t, psi_tmp, tmp)
x .= tmp.data
jumpnorm[] = rand(rng)
kwargs...) where T

# Display before or after events
function save_func!(affect!,integrator)
affect!.saveiter += 1
copyat_or_push!(affect!.saved_values.t, affect!.saveiter, integrator.t)
copyat_or_push!(affect!.saved_values.saveval, affect!.saveiter,
affect!.save_func(integrator.u, integrator.t, integrator),Val{false})
return nothing
end
save_before! = display_beforeevent ? save_func! : (affect!,integrator)->nothing
save_after! = display_afterevent ? save_func! : (affect!,integrator)->nothing

# Display jump operator index and times
jump_t = Float64[]
jump_index = Int[]
save_t_index = if display_jumps
function(t,i)
push!(jump_t,t)
push!(jump_index,i)
return nothing
end
cb = OrdinaryDiffEq.ContinuousCallback(djumpnorm,dojump,
save_positions = (display_beforeevent,display_afterevent))


timeevolution.integrate(float(tspan), dmcwf, as_vector(psi0),
copy(psi0), copy(psi0), fout;
callback = cb,
kwargs...)
else
# Temporary workaround until proper tooling for saving
# TODO: Replace by proper call to timeevolution.integrate
function fout_(x::D, t::Float64, integrator)
recast!(x, state)
fout(t, state)
end

state = copy(psi0)
dstate = copy(psi0)
out_type = pure_inference(fout, Tuple{eltype(tspan),typeof(state)})
out = DiffEqCallbacks.SavedValues(Float64,out_type)
scb = DiffEqCallbacks.SavingCallback(fout_,out,saveat=tspan,
save_everystep=save_everystep,
save_start = false)

function dojump_display(integrator)
x = integrator.u
t = integrator.t

affect! = scb.affect!
if display_beforeevent
affect!.saveiter += 1
copyat_or_push!(affect!.saved_values.t, affect!.saveiter, integrator.t)
copyat_or_push!(affect!.saved_values.saveval, affect!.saveiter,
affect!.save_func(integrator.u, integrator.t, integrator),Val{false})
end

recast!(x, psi_tmp)
jumpfun(rng, t, psi_tmp, tmp)
x .= tmp.data
(t,i)->nothing
end

if display_afterevent
affect!.saveiter += 1
copyat_or_push!(affect!.saved_values.t, affect!.saveiter, integrator.t)
copyat_or_push!(affect!.saved_values.saveval, affect!.saveiter,
affect!.save_func(integrator.u, integrator.t, integrator),Val{false})
end
jumpnorm[] = rand(rng)
end
function fout_(x::Vector{ComplexF64}, t::Float64, integrator)
recast!(x, state)
fout(t, state)
end

cb = OrdinaryDiffEq.ContinuousCallback(djumpnorm,dojump_display,
save_positions = (false,false))
full_cb = OrdinaryDiffEq.CallbackSet(callback,cb,scb)
state = copy(psi0)
dstate = copy(psi0)
out_type = pure_inference(fout, Tuple{eltype(tspan),typeof(state)})
out = DiffEqCallbacks.SavedValues(Float64,out_type)
scb = DiffEqCallbacks.SavingCallback(fout_,out,saveat=tspan,
save_everystep=save_everystep,
save_start = false)

cb = jump_callback(jumpfun, seed, scb, save_before!, save_after!, save_t_index, psi0)
full_cb = OrdinaryDiffEq.CallbackSet(callback,cb,scb)

function df_(dx::D, x::D, p, t) where D<:Vector{ComplexF64}
recast!(x, state)
recast!(dx, dstate)
dmcwf(t, state, dstate)
recast!(dstate, dx)
end

function df_(dx::D, x::D, p, t)
recast!(x, state)
recast!(dx, dstate)
dmcwf(t, state, dstate)
recast!(dstate, dx)
end
prob = OrdinaryDiffEq.ODEProblem{true}(df_, as_vector(psi0),(tspan[1],tspan[end]))

prob = OrdinaryDiffEq.ODEProblem{true}(df_, as_vector(psi0),(tspan[1],tspan[end]))
sol = OrdinaryDiffEq.solve(
prob,
alg;
reltol = 1.0e-6,
abstol = 1.0e-8,
save_everystep = false, save_start = false,
save_end = false,
callback=full_cb, kwargs...)

sol = OrdinaryDiffEq.solve(
prob,
alg;
reltol = 1.0e-6,
abstol = 1.0e-8,
save_everystep = false, save_start = false,
save_end = false,
callback=full_cb, kwargs...)
if display_jumps
return out.t, out.saveval, jump_t, jump_index
else
return out.t, out.saveval
end
end

function integrate_mcwf(dmcwf::Function, jumpfun::Function, tspan,
psi0::T, seed, fout::Nothing;
kwargs...) where {T<:Ket}
kwargs...) where T
function fout_(t::Float64, x::T)
psi = copy(x)
psi /= norm(psi)
return psi
return normalize(x)
end
integrate_mcwf(dmcwf, jumpfun, tspan, psi0, seed, fout_; kwargs...)
end

function jump_callback(jumpfun::Function, seed, scb, save_before!::Function,
save_after!::Function, save_t_index::Function, psi0::Ket)

tmp = copy(psi0)
psi_tmp = copy(psi0)

rng = MersenneTwister(convert(UInt, seed))
jumpnorm = Ref(rand(rng))
djumpnorm(x::Vector{ComplexF64}, t::Float64, integrator) = norm(x)^2 - (1-jumpnorm[])

function dojump(integrator)
x = integrator.u
t = integrator.t

affect! = scb.affect!
save_before!(affect!,integrator)
recast!(x, psi_tmp)
i = jumpfun(rng, t, psi_tmp, tmp)
x .= tmp.data
save_after!(affect!,integrator)
save_t_index(t,i)

jumpnorm[] = rand(rng)
return nothing
end

return OrdinaryDiffEq.ContinuousCallback(djumpnorm,dojump,
save_positions = (false,false))
end
as_vector(psi::StateVector) = psi.data

"""
jump(rng, t, psi, J, psi_new)

Expand Down
132 changes: 31 additions & 101 deletions src/semiclassical.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ module semiclassical
import Base: ==
import ..bases, ..operators, ..operators_dense
import ..timeevolution: integrate, recast!, QO_CHECKS
import ..timeevolution.timeevolution_mcwf: jump
import LinearAlgebra: normalize!
import ..timeevolution.timeevolution_mcwf: jump, integrate_mcwf, jump_callback, as_vector
import LinearAlgebra: normalize, normalize!

using Random, LinearAlgebra
import OrdinaryDiffEq
Expand Down Expand Up @@ -36,6 +36,7 @@ end
Base.length(state::State) = length(state.quantum) + length(state.classical)
Base.copy(state::State) = State(copy(state.quantum), copy(state.classical))
normalize!(state::State{B,T}) where {B,T<:Ket} = normalize!(state.quantum)
normalize(state::T) where {B,K<:Ket,T<:State{B,K}} = State(normalize(state.quantum),copy(state.classical))

function ==(a::State, b::State)
samebases(a.quantum, b.quantum) &&
Expand Down Expand Up @@ -140,6 +141,13 @@ Calculate MCWF trajectories coupled to a classical system.
* `fout=nothing`: If given, this function `fout(t, state)` is called every time
an output should be displayed. ATTENTION: The given state is not
permanent!
* `display_beforeevent`: Choose whether or not an additional point should be saved
before a jump occurs. Default is false.
* `display_afterevent`: Choose whether or not an additional point should be saved
after a jump occurs. Default is false.
* `display_jumps=false`: If set to true, an additional list of times and indices
is returned. These correspond to the times at which a jump occured and
the index of the jump operators with which the jump occured, respectively.
* `kwargs...`: Further arguments are passed on to the ode solver.
"""
function mcwf_dynamic(tspan, psi0::State{B,T}, fquantum, fclassical, fjump_classical;
Expand Down Expand Up @@ -206,116 +214,38 @@ function jump_dynamic(rng, t::Float64, psi::T, fquantum::Function, fclassical::F
i = jump(rng, t, psi.quantum, J, psi_new.quantum, rates_)
fjump_classical(t, psi_new.quantum, psi.classical, i)
psi_new.classical .= psi.classical
return i
end

function integrate_mcwf(dmcwf::Function, jumpfun::Function, tspan,
psi0::T, seed, fout::Function;
display_beforeevent=false, display_afterevent=false,
#TODO: Remove kwargs
save_everystep=false, callback=nothing,
alg=OrdinaryDiffEq.DP5(),
kwargs...) where {B<:Basis,T<:State}

function jump_callback(jumpfun::Function, seed, scb, save_before!::Function,
save_after!::Function, save_t_index::Function, psi0::State)
tmp = copy(psi0)
psi_tmp = copy(psi0)
x0 = [psi0.quantum.data; psi0.classical]
rng = MersenneTwister(convert(UInt, seed))
jumpnorm = Ref(rand(rng))
n = length(psi0.quantum)
djumpnorm(x::Vector{ComplexF64}, t::Float64, integrator) = norm(x[1:n])^2 - (1-jumpnorm[])

if !display_beforeevent && !display_afterevent
function dojump(integrator)
x = integrator.u
recast!(x, psi_tmp)
t = integrator.t
jumpfun(rng, t, psi_tmp, tmp)
recast!(tmp, x)
jumpnorm[] = rand(rng)
end
cb = OrdinaryDiffEq.ContinuousCallback(djumpnorm,dojump,
save_positions = (display_beforeevent,display_afterevent))


timeevolution.integrate(float(tspan), dmcwf, x0,
copy(psi0), copy(psi0), fout;
callback = cb,
kwargs...)
else
# Temporary workaround until proper tooling for saving
# TODO: Replace by proper call to timeevolution.integrate
function fout_(x::Vector{ComplexF64}, t::Float64, integrator)
recast!(x, state)
fout(t, state)
end

state = copy(psi0)
dstate = copy(psi0)
out_type = pure_inference(fout, Tuple{eltype(tspan),typeof(state)})
out = DiffEqCallbacks.SavedValues(Float64,out_type)
scb = DiffEqCallbacks.SavingCallback(fout_,out,saveat=tspan,
save_everystep=save_everystep,
save_start = false)

function dojump_display(integrator)
x = integrator.u
t = integrator.t

affect! = scb.affect!
if display_beforeevent
affect!.saveiter += 1
copyat_or_push!(affect!.saved_values.t, affect!.saveiter, integrator.t)
copyat_or_push!(affect!.saved_values.saveval, affect!.saveiter,
affect!.save_func(integrator.u, integrator.t, integrator),Val{false})
end

recast!(x, psi_tmp)
jumpfun(rng, t, psi_tmp, tmp)
recast!(tmp, x)

if display_afterevent
affect!.saveiter += 1
copyat_or_push!(affect!.saved_values.t, affect!.saveiter, integrator.t)
copyat_or_push!(affect!.saved_values.saveval, affect!.saveiter,
affect!.save_func(integrator.u, integrator.t, integrator),Val{false})
end
jumpnorm[] = rand(rng)
end

cb = OrdinaryDiffEq.ContinuousCallback(djumpnorm,dojump_display,
save_positions = (false,false))
full_cb = OrdinaryDiffEq.CallbackSet(callback,cb,scb)

function df_(dx::Vector{ComplexF64}, x::Vector{ComplexF64}, p, t)
recast!(x, state)
recast!(dx, dstate)
dmcwf(t, state, dstate)
recast!(dstate, dx)
end

prob = OrdinaryDiffEq.ODEProblem{true}(df_, x0,(tspan[1],tspan[end]))

sol = OrdinaryDiffEq.solve(
prob,
alg;
reltol = 1.0e-6,
abstol = 1.0e-8,
save_everystep = false, save_start = false,
save_end = false,
callback=full_cb, kwargs...)
return out.t, out.saveval
end
end
function dojump(integrator)
x = integrator.u
t = integrator.t

function integrate_mcwf(dmcwf::Function, jumpfun::Function, tspan,
psi0::T, seed, fout::Nothing;
kwargs...) where {T<:State}
function fout_(t::Float64, x::T)
psi = copy(x)
normalize!(psi)
return psi
affect! = scb.affect!
save_before!(affect!,integrator)
recast!(x, psi_tmp)
i = jumpfun(rng, t, psi_tmp, tmp)
recast!(tmp, x)
save_after!(affect!,integrator)
save_t_index(t,i)

jumpnorm[] = rand(rng)
return nothing
end
integrate_mcwf(dmcwf, jumpfun, tspan, psi0, seed, fout_; kwargs...)

return OrdinaryDiffEq.ContinuousCallback(djumpnorm,dojump,
save_positions = (false,false))
end
as_vector(psi::State{B,K}) where {B,K<:Ket} = [psi.quantum.data; psi.classical]


end # module
Loading