Skip to content

Commit

Permalink
MCWF jump times and indices (#257)
Browse files Browse the repository at this point in the history
* semicl mcwf

* semicl mcwf läuft

* semiclassical mcfw

* Change MCWF interface to display jumps

* Semiclassical mcwf with display event

* Add display_which and display_t to semiclassical mcwf

* Clean up integrate_mcwf

* Clean up semiclassical mcwf

* Add docstrings

* Add docstrings in semiclassical.mcwf

* Fix tests
  • Loading branch information
david-pl authored Sep 25, 2019
1 parent b9aa28d commit f510fef
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 187 deletions.
182 changes: 96 additions & 86 deletions src/mcwf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ and therefore must not be changed.
operators. If they are not given they are calculated automatically.
* `display_beforeevent=false`: `fout` is called before every jump.
* `display_afterevent=false`: `fout` is called after every jump.
* `display_jumps=false`: If set to true, an additional list of times and indices
is returned. These correspond to the times at which a jump occured and the index
of the jump operators with which the jump occured, respectively.
* `kwargs...`: Further arguments are passed on to the ode solver.
"""
function mcwf(tspan, psi0::T, H::AbstractOperator{B,B}, J::Vector;
Expand Down Expand Up @@ -157,6 +160,9 @@ normalized nor permanent! It is still in use by the ode solve
and therefore must not be changed.
* `display_beforeevent=false`: `fout` is called before every jump.
* `display_afterevent=false`: `fout` is called after every jump.
* `display_jumps=false`: If set to true, an additional list of times and indices
is returned. These correspond to the times at which a jump occured and the index
of the jump operators with which the jump occured, respectively.
* `kwargs...`: Further arguments are passed on to the ode solver.
"""
function mcwf_dynamic(tspan, psi0::T, f::Function;
Expand Down Expand Up @@ -251,112 +257,116 @@ Integrate a single Monte Carlo wave function trajectory.
function integrate_mcwf(dmcwf::Function, jumpfun::Function, tspan,
psi0::T, seed, fout::Function;
display_beforeevent=false, display_afterevent=false,
#TODO: Remove kwargs
display_jumps=false,
save_everystep=false, callback=nothing,
alg=OrdinaryDiffEq.DP5(),
kwargs...) where {B<:Basis,D<:Vector{ComplexF64},T<:Ket{B,D}}

tmp = copy(psi0)
psi_tmp = copy(psi0)
as_vector(psi::T) = psi.data
rng = MersenneTwister(convert(UInt, seed))
jumpnorm = Ref(rand(rng))
djumpnorm(x::D, t::Float64, integrator) = norm(x)^2 - (1-jumpnorm[])

if !display_beforeevent && !display_afterevent
function dojump(integrator)
x = integrator.u
recast!(x, psi_tmp)
t = integrator.t
jumpfun(rng, t, psi_tmp, tmp)
x .= tmp.data
jumpnorm[] = rand(rng)
kwargs...) where T

# Display before or after events
function save_func!(affect!,integrator)
affect!.saveiter += 1
copyat_or_push!(affect!.saved_values.t, affect!.saveiter, integrator.t)
copyat_or_push!(affect!.saved_values.saveval, affect!.saveiter,
affect!.save_func(integrator.u, integrator.t, integrator),Val{false})
return nothing
end
save_before! = display_beforeevent ? save_func! : (affect!,integrator)->nothing
save_after! = display_afterevent ? save_func! : (affect!,integrator)->nothing

# Display jump operator index and times
jump_t = Float64[]
jump_index = Int[]
save_t_index = if display_jumps
function(t,i)
push!(jump_t,t)
push!(jump_index,i)
return nothing
end
cb = OrdinaryDiffEq.ContinuousCallback(djumpnorm,dojump,
save_positions = (display_beforeevent,display_afterevent))


timeevolution.integrate(float(tspan), dmcwf, as_vector(psi0),
copy(psi0), copy(psi0), fout;
callback = cb,
kwargs...)
else
# Temporary workaround until proper tooling for saving
# TODO: Replace by proper call to timeevolution.integrate
function fout_(x::D, t::Float64, integrator)
recast!(x, state)
fout(t, state)
end

state = copy(psi0)
dstate = copy(psi0)
out_type = pure_inference(fout, Tuple{eltype(tspan),typeof(state)})
out = DiffEqCallbacks.SavedValues(Float64,out_type)
scb = DiffEqCallbacks.SavingCallback(fout_,out,saveat=tspan,
save_everystep=save_everystep,
save_start = false)

function dojump_display(integrator)
x = integrator.u
t = integrator.t

affect! = scb.affect!
if display_beforeevent
affect!.saveiter += 1
copyat_or_push!(affect!.saved_values.t, affect!.saveiter, integrator.t)
copyat_or_push!(affect!.saved_values.saveval, affect!.saveiter,
affect!.save_func(integrator.u, integrator.t, integrator),Val{false})
end

recast!(x, psi_tmp)
jumpfun(rng, t, psi_tmp, tmp)
x .= tmp.data
(t,i)->nothing
end

if display_afterevent
affect!.saveiter += 1
copyat_or_push!(affect!.saved_values.t, affect!.saveiter, integrator.t)
copyat_or_push!(affect!.saved_values.saveval, affect!.saveiter,
affect!.save_func(integrator.u, integrator.t, integrator),Val{false})
end
jumpnorm[] = rand(rng)
end
function fout_(x::Vector{ComplexF64}, t::Float64, integrator)
recast!(x, state)
fout(t, state)
end

cb = OrdinaryDiffEq.ContinuousCallback(djumpnorm,dojump_display,
save_positions = (false,false))
full_cb = OrdinaryDiffEq.CallbackSet(callback,cb,scb)
state = copy(psi0)
dstate = copy(psi0)
out_type = pure_inference(fout, Tuple{eltype(tspan),typeof(state)})
out = DiffEqCallbacks.SavedValues(Float64,out_type)
scb = DiffEqCallbacks.SavingCallback(fout_,out,saveat=tspan,
save_everystep=save_everystep,
save_start = false)

cb = jump_callback(jumpfun, seed, scb, save_before!, save_after!, save_t_index, psi0)
full_cb = OrdinaryDiffEq.CallbackSet(callback,cb,scb)

function df_(dx::D, x::D, p, t) where D<:Vector{ComplexF64}
recast!(x, state)
recast!(dx, dstate)
dmcwf(t, state, dstate)
recast!(dstate, dx)
end

function df_(dx::D, x::D, p, t)
recast!(x, state)
recast!(dx, dstate)
dmcwf(t, state, dstate)
recast!(dstate, dx)
end
prob = OrdinaryDiffEq.ODEProblem{true}(df_, as_vector(psi0),(tspan[1],tspan[end]))

prob = OrdinaryDiffEq.ODEProblem{true}(df_, as_vector(psi0),(tspan[1],tspan[end]))
sol = OrdinaryDiffEq.solve(
prob,
alg;
reltol = 1.0e-6,
abstol = 1.0e-8,
save_everystep = false, save_start = false,
save_end = false,
callback=full_cb, kwargs...)

sol = OrdinaryDiffEq.solve(
prob,
alg;
reltol = 1.0e-6,
abstol = 1.0e-8,
save_everystep = false, save_start = false,
save_end = false,
callback=full_cb, kwargs...)
if display_jumps
return out.t, out.saveval, jump_t, jump_index
else
return out.t, out.saveval
end
end

function integrate_mcwf(dmcwf::Function, jumpfun::Function, tspan,
psi0::T, seed, fout::Nothing;
kwargs...) where {T<:Ket}
kwargs...) where T
function fout_(t::Float64, x::T)
psi = copy(x)
psi /= norm(psi)
return psi
return normalize(x)
end
integrate_mcwf(dmcwf, jumpfun, tspan, psi0, seed, fout_; kwargs...)
end

function jump_callback(jumpfun::Function, seed, scb, save_before!::Function,
save_after!::Function, save_t_index::Function, psi0::Ket)

tmp = copy(psi0)
psi_tmp = copy(psi0)

rng = MersenneTwister(convert(UInt, seed))
jumpnorm = Ref(rand(rng))
djumpnorm(x::Vector{ComplexF64}, t::Float64, integrator) = norm(x)^2 - (1-jumpnorm[])

function dojump(integrator)
x = integrator.u
t = integrator.t

affect! = scb.affect!
save_before!(affect!,integrator)
recast!(x, psi_tmp)
i = jumpfun(rng, t, psi_tmp, tmp)
x .= tmp.data
save_after!(affect!,integrator)
save_t_index(t,i)

jumpnorm[] = rand(rng)
return nothing
end

return OrdinaryDiffEq.ContinuousCallback(djumpnorm,dojump,
save_positions = (false,false))
end
as_vector(psi::StateVector) = psi.data

"""
jump(rng, t, psi, J, psi_new)
Expand Down
132 changes: 31 additions & 101 deletions src/semiclassical.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ module semiclassical
import Base: ==
import ..bases, ..operators, ..operators_dense
import ..timeevolution: integrate, recast!, QO_CHECKS
import ..timeevolution.timeevolution_mcwf: jump
import LinearAlgebra: normalize!
import ..timeevolution.timeevolution_mcwf: jump, integrate_mcwf, jump_callback, as_vector
import LinearAlgebra: normalize, normalize!

using Random, LinearAlgebra
import OrdinaryDiffEq
Expand Down Expand Up @@ -36,6 +36,7 @@ end
Base.length(state::State) = length(state.quantum) + length(state.classical)
Base.copy(state::State) = State(copy(state.quantum), copy(state.classical))
normalize!(state::State{B,T}) where {B,T<:Ket} = normalize!(state.quantum)
normalize(state::T) where {B,K<:Ket,T<:State{B,K}} = State(normalize(state.quantum),copy(state.classical))

function ==(a::State, b::State)
samebases(a.quantum, b.quantum) &&
Expand Down Expand Up @@ -140,6 +141,13 @@ Calculate MCWF trajectories coupled to a classical system.
* `fout=nothing`: If given, this function `fout(t, state)` is called every time
an output should be displayed. ATTENTION: The given state is not
permanent!
* `display_beforeevent`: Choose whether or not an additional point should be saved
before a jump occurs. Default is false.
* `display_afterevent`: Choose whether or not an additional point should be saved
after a jump occurs. Default is false.
* `display_jumps=false`: If set to true, an additional list of times and indices
is returned. These correspond to the times at which a jump occured and
the index of the jump operators with which the jump occured, respectively.
* `kwargs...`: Further arguments are passed on to the ode solver.
"""
function mcwf_dynamic(tspan, psi0::State{B,T}, fquantum, fclassical, fjump_classical;
Expand Down Expand Up @@ -206,116 +214,38 @@ function jump_dynamic(rng, t::Float64, psi::T, fquantum::Function, fclassical::F
i = jump(rng, t, psi.quantum, J, psi_new.quantum, rates_)
fjump_classical(t, psi_new.quantum, psi.classical, i)
psi_new.classical .= psi.classical
return i
end

function integrate_mcwf(dmcwf::Function, jumpfun::Function, tspan,
psi0::T, seed, fout::Function;
display_beforeevent=false, display_afterevent=false,
#TODO: Remove kwargs
save_everystep=false, callback=nothing,
alg=OrdinaryDiffEq.DP5(),
kwargs...) where {B<:Basis,T<:State}

function jump_callback(jumpfun::Function, seed, scb, save_before!::Function,
save_after!::Function, save_t_index::Function, psi0::State)
tmp = copy(psi0)
psi_tmp = copy(psi0)
x0 = [psi0.quantum.data; psi0.classical]
rng = MersenneTwister(convert(UInt, seed))
jumpnorm = Ref(rand(rng))
n = length(psi0.quantum)
djumpnorm(x::Vector{ComplexF64}, t::Float64, integrator) = norm(x[1:n])^2 - (1-jumpnorm[])

if !display_beforeevent && !display_afterevent
function dojump(integrator)
x = integrator.u
recast!(x, psi_tmp)
t = integrator.t
jumpfun(rng, t, psi_tmp, tmp)
recast!(tmp, x)
jumpnorm[] = rand(rng)
end
cb = OrdinaryDiffEq.ContinuousCallback(djumpnorm,dojump,
save_positions = (display_beforeevent,display_afterevent))


timeevolution.integrate(float(tspan), dmcwf, x0,
copy(psi0), copy(psi0), fout;
callback = cb,
kwargs...)
else
# Temporary workaround until proper tooling for saving
# TODO: Replace by proper call to timeevolution.integrate
function fout_(x::Vector{ComplexF64}, t::Float64, integrator)
recast!(x, state)
fout(t, state)
end

state = copy(psi0)
dstate = copy(psi0)
out_type = pure_inference(fout, Tuple{eltype(tspan),typeof(state)})
out = DiffEqCallbacks.SavedValues(Float64,out_type)
scb = DiffEqCallbacks.SavingCallback(fout_,out,saveat=tspan,
save_everystep=save_everystep,
save_start = false)

function dojump_display(integrator)
x = integrator.u
t = integrator.t

affect! = scb.affect!
if display_beforeevent
affect!.saveiter += 1
copyat_or_push!(affect!.saved_values.t, affect!.saveiter, integrator.t)
copyat_or_push!(affect!.saved_values.saveval, affect!.saveiter,
affect!.save_func(integrator.u, integrator.t, integrator),Val{false})
end

recast!(x, psi_tmp)
jumpfun(rng, t, psi_tmp, tmp)
recast!(tmp, x)

if display_afterevent
affect!.saveiter += 1
copyat_or_push!(affect!.saved_values.t, affect!.saveiter, integrator.t)
copyat_or_push!(affect!.saved_values.saveval, affect!.saveiter,
affect!.save_func(integrator.u, integrator.t, integrator),Val{false})
end
jumpnorm[] = rand(rng)
end

cb = OrdinaryDiffEq.ContinuousCallback(djumpnorm,dojump_display,
save_positions = (false,false))
full_cb = OrdinaryDiffEq.CallbackSet(callback,cb,scb)

function df_(dx::Vector{ComplexF64}, x::Vector{ComplexF64}, p, t)
recast!(x, state)
recast!(dx, dstate)
dmcwf(t, state, dstate)
recast!(dstate, dx)
end

prob = OrdinaryDiffEq.ODEProblem{true}(df_, x0,(tspan[1],tspan[end]))

sol = OrdinaryDiffEq.solve(
prob,
alg;
reltol = 1.0e-6,
abstol = 1.0e-8,
save_everystep = false, save_start = false,
save_end = false,
callback=full_cb, kwargs...)
return out.t, out.saveval
end
end
function dojump(integrator)
x = integrator.u
t = integrator.t

function integrate_mcwf(dmcwf::Function, jumpfun::Function, tspan,
psi0::T, seed, fout::Nothing;
kwargs...) where {T<:State}
function fout_(t::Float64, x::T)
psi = copy(x)
normalize!(psi)
return psi
affect! = scb.affect!
save_before!(affect!,integrator)
recast!(x, psi_tmp)
i = jumpfun(rng, t, psi_tmp, tmp)
recast!(tmp, x)
save_after!(affect!,integrator)
save_t_index(t,i)

jumpnorm[] = rand(rng)
return nothing
end
integrate_mcwf(dmcwf, jumpfun, tspan, psi0, seed, fout_; kwargs...)

return OrdinaryDiffEq.ContinuousCallback(djumpnorm,dojump,
save_positions = (false,false))
end
as_vector(psi::State{B,K}) where {B,K<:Ket} = [psi.quantum.data; psi.classical]


end # module
Loading

0 comments on commit f510fef

Please sign in to comment.