Skip to content

Commit

Permalink
fix mcwf
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Jan 11, 2018
1 parent 8c4c3e9 commit b5c73c8
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 27 deletions.
45 changes: 23 additions & 22 deletions src/mcwf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ export mcwf, mcwf_h, mcwf_nh, diagonaljumps

using ...bases, ...states, ...operators, ...ode_dopri
using ...operators_dense, ...operators_sparse

using ..timeevolution
import OrdinaryDiffEq

const DecayRates = Union{Vector{Float64}, Matrix{Float64}, Void}

Expand All @@ -29,40 +30,40 @@ Integrate a single Monte Carlo wave function trajectory.
and therefore must not be changed.
* `kwargs`: Further arguments are passed on to the ode solver.
"""
function integrate_mcwf(dmcwf::Function, jumpfun::Function, tspan, psi0::Ket, seed;
fout=nothing,
kwargs...)
function integrate_mcwf(dmcwf::Function, jumpfun::Function, tspan,
psi0::Ket, seed; fout=nothing,
display_beforeevent=false, display_afterevent=false,
kwargs...)
tmp = copy(psi0)
as_ket(x::Vector{Complex128}) = Ket(psi0.basis, x)
as_vector(psi::Ket) = psi.data
rng = MersenneTwister(convert(UInt, seed))
jumpnorm = Float64[rand(rng)]
djumpnorm(t, x::Vector{Complex128}) = norm(as_ket(x))^2 - (1-jumpnorm[1])
function dojump(t, x::Vector{Complex128})
jumpnorm = Ref(rand(rng))
djumpnorm(t, x::Vector{Complex128},integrator) = norm(as_ket(x))^2 - (1-jumpnorm[])
function dojump(integrator)
x = integrator.u
t = integrator.t
jumpfun(rng, t, as_ket(x), tmp)
x .= tmp.data
jumpnorm[1] = rand(rng)
return ode_dopri.jump
jumpnorm[] = rand(rng)
end
cb = OrdinaryDiffEq.ContinuousCallback(djumpnorm,dojump,
save_positions = (display_beforeevent,display_afterevent))

tout = Float64[]
xout = Ket[]
function fout_(t, x::Vector{Complex128})
function fout_(t, x::Ket)
if fout==nothing
psi = copy(as_ket(x))
psi = copy(x)
psi /= norm(psi)
push!(tout, t)
push!(xout, psi)
return nothing
return psi
else
return fout(t, as_ket(x))
return fout(t, x)
end
end
dmcwf_(t, x::Vector{Complex128}, dx::Vector{Complex128}) = dmcwf(t, as_ket(x), as_ket(dx))
ode_event(dmcwf_, float(tspan), as_vector(psi0), fout_,
djumpnorm, dojump;
kwargs...)
return fout==nothing ? (tout, xout) : nothing

timeevolution.integrate(float(tspan), dmcwf, as_vector(psi0),
copy(psi0), copy(psi0), fout_;
callback = cb,
kwargs...)
end

"""
Expand Down
12 changes: 7 additions & 5 deletions src/timeevolution_base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ function recast! end
df(t, state::T, dstate::T)
"""
function integrate{T}(tspan::Vector{Float64}, df::Function, x0::Vector{Complex128},
state::T, dstate::T, fout::Function,
alg::OrdinaryDiffEq.OrdinaryDiffEqAlgorithm = OrdinaryDiffEq.DP5();
state::T, dstate::T, fout::Function;
alg::OrdinaryDiffEq.OrdinaryDiffEqAlgorithm = OrdinaryDiffEq.DP5(),
steady_state = false, eps = 1e-3, save_everystep = false,
kwargs...)
callback = nothing, kwargs...)

function df_(t, x::Vector{Complex128}, dx::Vector{Complex128})
recast!(x, state)
Expand Down Expand Up @@ -51,6 +51,8 @@ function integrate{T}(tspan::Vector{Float64}, df::Function, x0::Vector{Complex12
cb = scb
end

full_cb = OrdinaryDiffEq.CallbackSet(callback,cb)

# TODO: Expose algorithm choice
sol = OrdinaryDiffEq.solve(
prob,
Expand All @@ -59,12 +61,12 @@ function integrate{T}(tspan::Vector{Float64}, df::Function, x0::Vector{Complex12
abstol = 1.0e-8,
save_everystep = false, save_start = false,
save_end = false,
callback=cb, kwargs...)
callback=full_cb, kwargs...)
out.t,out.saveval
end

function integrate{T}(tspan::Vector{Float64}, df::Function, x0::Vector{Complex128},
state::T, dstate::T, ::Void, args...; kwargs...)
state::T, dstate::T, ::Void; kwargs...)
function fout(t::Float64, state::T)
copy(state)
end
Expand Down

0 comments on commit b5c73c8

Please sign in to comment.