Skip to content

Commit

Permalink
Separate stochastic base functionality from timeevolution
Browse files Browse the repository at this point in the history
  • Loading branch information
david-pl committed Oct 4, 2019
1 parent d73f8a4 commit 5b3f573
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 106 deletions.
8 changes: 1 addition & 7 deletions src/QuantumOptics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,8 @@ export qfunc, wigner, coherentspinstate, qfuncsu2, wignersu2, ylm,

include("phasespace.jl")
module timeevolution
using QuantumOpticsBase
using QuantumOpticsBase: check_samebases, check_multiplicable
export diagonaljumps, @skiptimechecks

function recast! end

include("timeevolution_base.jl")
include("master.jl")
include("schroedinger.jl")
Expand All @@ -32,9 +28,7 @@ include("timecorrelations.jl")
include("spectralanalysis.jl")
include("semiclassical.jl")
module stochastic
using QuantumOpticsBase
import ..timeevolution: recast!
include("timeevolution_base.jl")
include("stochastic_base.jl")
include("stochastic_definitions.jl")
include("stochastic_schroedinger.jl")
include("stochastic_master.jl")
Expand Down
3 changes: 1 addition & 2 deletions src/mcwf.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
using Random, LinearAlgebra
import OrdinaryDiffEq

# TODO: Remove imports
import DiffEqCallbacks, RecursiveArrayTools.copyat_or_push!
import RecursiveArrayTools.copyat_or_push!

"""
mcwf_h(tspan, rho0, Hnh, J; <keyword arguments>)
Expand Down
1 change: 0 additions & 1 deletion src/stochastic_master.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import ...timeevolution: dmaster_h, dmaster_nh, dmaster_h_dynamic, check_master

const DecayRates = Union{Vector{Float64}, Matrix{Float64}, Nothing}
const DiffArray = Union{Vector{ComplexF64}, Array{ComplexF64, 2}}

"""
stochastic.master(tspan, rho0, H, J, C; <keyword arguments>)
Expand Down
2 changes: 0 additions & 2 deletions src/stochastic_schroedinger.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import ...timeevolution: dschroedinger, dschroedinger_dynamic, check_schroedinger

import DiffEqCallbacks

"""
stochastic.schroedinger(tspan, state0, H, Hs[; fout, ...])
Expand Down
99 changes: 5 additions & 94 deletions src/timeevolution_base.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import OrdinaryDiffEq, DiffEqCallbacks, StochasticDiffEq
using QuantumOpticsBase
using QuantumOpticsBase: check_samebases, check_multiplicable

export @skiptimechecks
import OrdinaryDiffEq, DiffEqCallbacks

const DiffArray = Union{Vector{ComplexF64}, Array{ComplexF64, 2}}

function recast! end

"""
integrate(tspan::Vector{Float64}, df::Function, x0::Vector{ComplexF64},
state::T, dstate::T, fout::Function; kwargs...)
Expand Down Expand Up @@ -86,98 +89,6 @@ function (c::SteadyStateCondtion)(rho,t,integrator)
end



"""
integrate_stoch(tspan::Vector{Float64}, df::Function, dg::Vector{Function}, x0::Vector{ComplexF64},
state::T, dstate::T, fout::Function; kwargs...)
Integrate using StochasticDiffEq
"""
function integrate_stoch(tspan::Vector{Float64}, df::Function, dg::Function, x0::Vector{ComplexF64},
state::T, dstate::T, fout::Function, n::Int;
save_everystep = false, callback=nothing,
alg::StochasticDiffEq.StochasticDiffEqAlgorithm=StochasticDiffEq.EM(),
noise_rate_prototype = nothing,
noise_prototype_classical = nothing,
noise=nothing,
ncb=nothing,
kwargs...) where T

function df_(dx::Vector{ComplexF64}, x::Vector{ComplexF64}, p, t)
recast!(x, state)
recast!(dx, dstate)
df(t, state, dstate)
recast!(dstate, dx)
end

function dg_(dx::Union{Vector{ComplexF64}, Array{ComplexF64, 2}},
x::Vector{ComplexF64}, p, t)
recast!(x, state)
dg(dx, t, state, dstate, n)
end

function fout_(x::Vector{ComplexF64}, t::Float64, integrator)
recast!(x, state)
fout(t, state)
end

nc = isa(noise_prototype_classical, Nothing) ? 0 : size(noise_prototype_classical)[2]
if isa(noise, Nothing) && n > 0
if n + nc == 1
noise_ = StochasticDiffEq.RealWienerProcess(0.0, 0.0)
else
noise_ = StochasticDiffEq.RealWienerProcess!(0.0, zeros(n + nc))
end
else
noise_ = noise
end
if isa(noise_rate_prototype, Nothing)
if n > 1 || nc > 1 || (n > 0 && nc > 0)
noise_rate_prototype = zeros(ComplexF64, length(x0), n + nc)
end
end

out_type = pure_inference(fout, Tuple{eltype(tspan),typeof(state)})

out = DiffEqCallbacks.SavedValues(Float64,out_type)

scb = DiffEqCallbacks.SavingCallback(fout_,out,saveat=tspan,
save_everystep=save_everystep,
save_start = false)

full_cb = OrdinaryDiffEq.CallbackSet(callback, ncb, scb)

prob = StochasticDiffEq.SDEProblem{true}(df_, dg_, x0,(tspan[1],tspan[end]),
noise=noise_,
noise_rate_prototype=noise_rate_prototype)

sol = StochasticDiffEq.solve(
prob,
alg;
reltol = 1.0e-3,
abstol = 1.0e-3,
save_everystep = false, save_start = false,
save_end = false,
callback=full_cb, kwargs...)

out.t,out.saveval
end

"""
integrate_stoch
Define fout if it was omitted.
"""
function integrate_stoch(tspan::Vector{Float64}, df::Function, dg::Function, x0::Vector{ComplexF64},
state::T, dstate::T, ::Nothing, n::Int; kwargs...) where T
function fout(t::Float64, state::T)
copy(state)
end
integrate_stoch(tspan, df, dg, x0, state, dstate, fout, n; kwargs...)
end



const QO_CHECKS = Ref(true)
"""
@skiptimechecks
Expand Down

0 comments on commit 5b3f573

Please sign in to comment.