Skip to content

Commit

Permalink
Implement macros to skip checks
Browse files Browse the repository at this point in the history
commit f221430
Author: david-pl <[email protected]>
Date:   Fri Aug 17 11:16:22 2018 +0200

    Update macro docstrings

commit 62f10e8
Author: David Plankensteiner <[email protected]>
Date:   Tue Aug 14 21:17:58 2018 +0200

    Fix stochastic checks

commit 5c9eff5
Author: David Plankensteiner <[email protected]>
Date:   Tue Aug 14 20:45:51 2018 +0200

    Rename macros

commit f78cf33
Author: david-pl <[email protected]>
Date:   Tue Aug 14 16:01:36 2018 +0200

    Start renaming stuff

commit c5f8bd6
Author: David Plankensteiner <[email protected]>
Date:   Mon Aug 13 20:25:58 2018 +0200

    Implement macros to skip checks
  • Loading branch information
david-pl committed Aug 17, 2018
1 parent c2ad518 commit 4f7e1a8
Show file tree
Hide file tree
Showing 9 changed files with 98 additions and 29 deletions.
6 changes: 3 additions & 3 deletions src/QuantumOptics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module QuantumOptics
using SparseArrays, LinearAlgebra

export bases, Basis, GenericBasis, CompositeBasis, basis,
tensor, , permutesystems,
tensor, , permutesystems, @samebases,
states, StateVector, Bra, Ket, basisstate, norm,
dagger, normalize, normalize!,
operators, Operator, expect, variance, identityoperator, ptrace, embed, dense, tr,
Expand Down Expand Up @@ -31,7 +31,7 @@ export bases, Basis, GenericBasis, CompositeBasis, basis,
entropy_vn, fidelity, ptranspose, PPT,
negativity, logarithmic_negativity,
spectralanalysis, eigenstates, eigenenergies, simdiag,
timeevolution, diagonaljumps,
timeevolution, diagonaljumps, @skiptimecheck,
steadystate,
timecorrelations,
semiclassical,
Expand Down Expand Up @@ -61,7 +61,7 @@ include("transformations.jl")
include("phasespace.jl")
include("metrics.jl")
module timeevolution
export diagonaljumps
export diagonaljumps, @skiptimechecks
include("timeevolution_base.jl")
include("master.jl")
include("schroedinger.jl")
Expand Down
22 changes: 19 additions & 3 deletions src/bases.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ export Basis, GenericBasis, CompositeBasis, basis,
tensor, , ptrace, permutesystems,
IncompatibleBases,
samebases, multiplicable,
check_samebases, check_multiplicable
check_samebases, check_multiplicable, @samebases

import Base: ==, ^

Expand Down Expand Up @@ -175,6 +175,22 @@ Exception that should be raised for an illegal algebraic operation.
"""
mutable struct IncompatibleBases <: Exception end

const BASES_CHECK = Ref(true)
"""
@samebases
Macro to skip checks for same bases. Useful for `*`, `expect` and similar
functions.
"""
macro samebases(ex)
return quote
BASES_CHECK.x = false
local val = $(esc(ex))
BASES_CHECK.x = true
val
end
end

"""
samebases(a, b)
Expand All @@ -189,7 +205,7 @@ Throw an [`IncompatibleBases`](@ref) error if the objects don't have
the same bases.
"""
function check_samebases(b1, b2)
if !samebases(b1, b2)
if BASES_CHECK[] && !samebases(b1, b2)
throw(IncompatibleBases())
end
end
Expand Down Expand Up @@ -221,7 +237,7 @@ Throw an [`IncompatibleBases`](@ref) error if the objects are
not multiplicable.
"""
function check_multiplicable(b1, b2)
if !multiplicable(b1, b2)
if BASES_CHECK[] && !multiplicable(b1, b2)
throw(IncompatibleBases())
end
end
Expand Down
10 changes: 5 additions & 5 deletions src/master.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module timeevolution_master

export master, master_nh, master_h, master_dynamic, master_nh_dynamic

import ..integrate, ..recast!
import ..integrate, ..recast!, ..QO_CHECKS

using ...bases, ...states, ...operators
using ...operators_dense, ...operators_sparse
Expand Down Expand Up @@ -296,29 +296,29 @@ function dmaster_h_dynamic(t::Float64, rho::DenseOperator, f::Function,
rates::DecayRates,
drho::DenseOperator, tmp::DenseOperator)
result = f(t, rho)
@assert 3 <= length(result) <= 4
QO_CHECKS[] && @assert 3 <= length(result) <= 4
if length(result) == 3
H, J, Jdagger = result
rates_ = rates
else
H, J, Jdagger, rates_ = result
end
check_master(rho, H, J, Jdagger, rates_)
QO_CHECKS[] && check_master(rho, H, J, Jdagger, rates_)
dmaster_h(rho, H, rates_, J, Jdagger, drho, tmp)
end

function dmaster_nh_dynamic(t::Float64, rho::DenseOperator, f::Function,
rates::DecayRates,
drho::DenseOperator, tmp::DenseOperator)
result = f(t, rho)
@assert 4 <= length(result) <= 5
QO_CHECKS[] && @assert 4 <= length(result) <= 5
if length(result) == 4
Hnh, Hnh_dagger, J, Jdagger = result
rates_ = rates
else
Hnh, Hnh_dagger, J, Jdagger, rates_ = result
end
check_master(rho, Hnh, J, Jdagger, rates_)
QO_CHECKS[] && check_master(rho, Hnh, J, Jdagger, rates_)
dmaster_nh(rho, Hnh, Hnh_dagger, rates_, J, Jdagger, drho, tmp)
end

Expand Down
12 changes: 6 additions & 6 deletions src/mcwf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import OrdinaryDiffEq

# TODO: Remove imports
import DiffEqCallbacks, RecursiveArrayTools.copyat_or_push!
import ..recast!
import ..recast!, ..QO_CHECKS
Base.@pure pure_inference(fout,T) = Core.Compiler.return_type(fout, T)

const DecayRates = Union{Vector{Float64}, Matrix{Float64}, Nothing}
Expand Down Expand Up @@ -196,28 +196,28 @@ end
function dmcwf_h_dynamic(t::Float64, psi::Ket, f::Function, rates::DecayRates,
dpsi::Ket, tmp::Ket)
result = f(t, psi)
@assert 3 <= length(result) <= 4
QO_CHECKS[] && @assert 3 <= length(result) <= 4
if length(result) == 3
H, J, Jdagger = result
rates_ = rates
else
H, J, Jdagger, rates_ = result
end
check_mcwf(psi, H, J, Jdagger, rates_)
QO_CHECKS[] && check_mcwf(psi, H, J, Jdagger, rates_)
dmcwf_h(psi, H, J, Jdagger, dpsi, tmp, rates)
end

function dmcwf_nh_dynamic(t::Float64, psi::Ket, f::Function, dpsi::Ket)
result = f(t, psi)
@assert 3 <= length(result) <= 4
QO_CHECKS[] && @assert 3 <= length(result) <= 4
H, J, Jdagger = result[1:3]
check_mcwf(psi, H, J, Jdagger, nothing)
QO_CHECKS[] && check_mcwf(psi, H, J, Jdagger, nothing)
dmcwf_nh(psi, H, dpsi)
end

function jump_dynamic(rng, t::Float64, psi::Ket, f::Function, psi_new::Ket, rates::DecayRates)
result = f(t, psi)
@assert 3 <= length(result) <= 4
QO_CHECKS[] && @assert 3 <= length(result) <= 4
J = result[2]
if length(result) == 3
rates_ = rates
Expand Down
4 changes: 2 additions & 2 deletions src/schroedinger.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module timeevolution_schroedinger

export schroedinger, schroedinger_dynamic

import ..integrate, ..recast!
import ..integrate, ..recast!, ..QO_CHECKS

using ...bases, ...states, ...operators

Expand Down Expand Up @@ -77,7 +77,7 @@ end

function dschroedinger_dynamic(t::Float64, psi0::T, f::Function, dpsi::T) where T<:StateVector
H = f(t, psi0)
check_schroedinger(psi0, H)
QO_CHECKS[] && check_schroedinger(psi0, H)
dschroedinger(psi0, H, dpsi)
end

Expand Down
28 changes: 25 additions & 3 deletions src/stochastic_master.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using ...bases, ...states, ...operators
using ...operators_dense, ...operators_sparse
using ...timeevolution
using LinearAlgebra
import ...timeevolution: integrate_stoch, recast!
import ...timeevolution: integrate_stoch, recast!, QO_CHECKS
import ...timeevolution.timeevolution_master: dmaster_h, dmaster_nh, dmaster_h_dynamic, check_master

const DecayRates = Union{Vector{Float64}, Matrix{Float64}, Nothing}
Expand Down Expand Up @@ -55,7 +55,7 @@ function master(tspan, rho0::DenseOperator, H::Operator,
dmaster_stoch(dx::DiffArray, t::Float64, rho::DenseOperator, drho::DenseOperator, n::Int) =
dmaster_stochastic(dx, rho, C, Cdagger, drho, n)

isreducible = check_master(rho0, H, J, Jdagger, rates)
isreducible = check_master(rho0, H, J, Jdagger, rates) && check_master_stoch(rho0, C, Cdagger)
if !isreducible
dmaster_h_determ(t::Float64, rho::DenseOperator, drho::DenseOperator) =
dmaster_h(rho, H, rates, J, Jdagger, drho, tmp)
Expand Down Expand Up @@ -158,8 +158,9 @@ end
function dmaster_stoch_dynamic(dx::DiffArray, t::Float64, rho::DenseOperator,
f::Function, drho::DenseOperator, n::Int)
result = f(t, rho)
@assert 2 == length(result)
QO_CHECKS[] && @assert 2 == length(result)
C, Cdagger = result
QO_CHECKS[] && check_master_stoch(rho, C, Cdagger)
dmaster_stochastic(dx, rho, C, Cdagger, drho, n)
end

Expand All @@ -174,6 +175,27 @@ function integrate_master_stoch(tspan, df::Function, dg::Function,
integrate_stoch(tspan_, df, dg, x0, state, dstate, fout, n; kwargs...)
end

function check_master_stoch(rho0::DenseOperator, C::Vector, Cdagger::Vector)
@assert length(C) == length(Cdagger)
isreducible = true
for c=C
@assert isa(c, Operator)
if !(isa(c, DenseOperator) || isa(c, SparseOperator))
isreducible = false
end
check_samebases(rho0, c)
end
for c=Cdagger
@assert isa(c, Operator)
if !(isa(c, DenseOperator) || isa(c, SparseOperator))
isreducible = false
end
check_samebases(rho0, c)
end
isreducible
end


# TODO: Speed up by recasting to n-d arrays, remove vector methods
function recast!(x::Union{Vector{ComplexF64}, SubArray{ComplexF64, 1}}, rho::DenseOperator)
rho.data = reshape(x, size(rho.data))
Expand Down
13 changes: 10 additions & 3 deletions src/stochastic_schroedinger.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using ...bases, ...states, ...operators
using ...operators_dense, ...operators_sparse
using ...timeevolution
using LinearAlgebra
import ...timeevolution: integrate_stoch, recast!
import ...timeevolution: integrate_stoch, recast!, QO_CHECKS
import ...timeevolution.timeevolution_schroedinger: dschroedinger, dschroedinger_dynamic, check_schroedinger

import DiffEqCallbacks
Expand Down Expand Up @@ -45,6 +45,10 @@ function schroedinger(tspan, psi0::Ket, H::Operator, Hs::Vector;
state = copy(psi0)

check_schroedinger(psi0, H)
for h=Hs
check_schroedinger(psi0, h)
end

dschroedinger_determ(t::Float64, psi::Ket, dpsi::Ket) = dschroedinger(psi, H, dpsi)
dschroedinger_stoch(dx::DiffArray,
t::Float64, psi::Ket, dpsi::Ket, n::Int) = dschroedinger_stochastic(dx, psi, Hs, dpsi, n)
Expand Down Expand Up @@ -130,14 +134,12 @@ end

function dschroedinger_stochastic(dx::Vector{ComplexF64}, psi::Ket, Hs::Vector{T},
dpsi::Ket, index::Int) where T <: Operator
check_schroedinger(psi, Hs[index])
recast!(dx, dpsi)
dschroedinger(psi, Hs[index], dpsi)
end
function dschroedinger_stochastic(dx::Array{ComplexF64, 2}, psi::Ket, Hs::Vector{T},
dpsi::Ket, n::Int) where T <: Operator
for i=1:n
check_schroedinger(psi, Hs[i])
dx_i = @view dx[:, i]
recast!(dx_i, dpsi)
dschroedinger(psi, Hs[i], dpsi)
Expand All @@ -147,6 +149,11 @@ end
function dschroedinger_stochastic(dx::DiffArray,
t::Float64, psi::Ket, f::Function, dpsi::Ket, n::Int)
ops = f(t, psi)
if QO_CHECKS[]
@inbounds for h=ops
check_schroedinger(psi, h)
end
end
dschroedinger_stochastic(dx, psi, ops, dpsi, n)
end

Expand Down
13 changes: 9 additions & 4 deletions src/stochastic_semiclassical.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ using ...operators_dense, ...operators_sparse
using ...semiclassical
import ...semiclassical: recast!, State, dmaster_h_dynamic
using ...timeevolution
import ...timeevolution: integrate_stoch
import ...timeevolution.timeevolution_schroedinger: dschroedinger, dschroedinger_dynamic
import ...timeevolution: integrate_stoch, QO_CHECKS
import ...timeevolution.timeevolution_schroedinger: dschroedinger, dschroedinger_dynamic, check_schroedinger
import ...stochastic.stochastic_master: check_master_stoch
using ...stochastic
using LinearAlgebra

Expand Down Expand Up @@ -207,6 +208,7 @@ function dschroedinger_stochastic(dx::Vector{ComplexF64}, t::Float64,
dstate::State{Ket}, ::Int)
H = fstoch_quantum(t, state.quantum, state.classical)
recast!(dx, dstate)
QO_CHECKS[] && check_schroedinger(state.quantum, H[1])
dschroedinger(state.quantum, H[1], dstate.quantum)
recast!(dstate, dx)
end
Expand All @@ -217,6 +219,7 @@ function dschroedinger_stochastic(dx::Array{ComplexF64, 2},
for i=1:n
dx_i = @view dx[:, i]
recast!(dx_i, dstate)
QO_CHECKS[] && check_schroedinger(state.quantum, H[i])
dschroedinger(state.quantum, H[i], dstate.quantum)
recast!(dstate, dx_i)
end
Expand All @@ -239,8 +242,9 @@ function dmaster_stoch_dynamic(dx::Vector{ComplexF64}, t::Float64,
state::State{DenseOperator}, fstoch_quantum::Function,
fstoch_classical::Nothing, dstate::State{DenseOperator}, ::Int)
result = fstoch_quantum(t, state.quantum, state.classical)
@assert length(result) == 2
QO_CHECKS[] && @assert length(result) == 2
C, Cdagger = result
QO_CHECKS[] && check_master_stoch(state.quantum, C, Cdagger)
recast!(dx, dstate)
operators.gemm!(1, C[1], state.quantum, 0, dstate.quantum)
operators.gemm!(1, state.quantum, Cdagger[1], 1, dstate.quantum)
Expand All @@ -251,8 +255,9 @@ function dmaster_stoch_dynamic(dx::Array{ComplexF64, 2}, t::Float64,
state::State{DenseOperator}, fstoch_quantum::Function,
fstoch_classical::Nothing, dstate::State{DenseOperator}, n::Int)
result = fstoch_quantum(t, state.quantum, state.classical)
@assert length(result) == 2
QO_CHECKS[] && @assert length(result) == 2
C, Cdagger = result
QO_CHECKS[] && check_master_stoch(state.quantum, C, Cdagger)
for i=1:n
dx_i = @view dx[:, i]
recast!(dx_i, dstate)
Expand Down
19 changes: 19 additions & 0 deletions src/timeevolution_base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ using ..metrics

import OrdinaryDiffEq, DiffEqCallbacks, StochasticDiffEq

export @skiptimechecks

const DiffArray = Union{Vector{ComplexF64}, Array{ComplexF64, 2}}

function recast! end
Expand Down Expand Up @@ -175,4 +177,21 @@ function integrate_stoch(tspan::Vector{Float64}, df::Function, dg::Function, x0:
end



const QO_CHECKS = Ref(true)
"""
@skiptimechecks
Macro to skip checks during time-dependent problems.
Useful for `timeevolution.master_dynamic` and similar functions.
"""
macro skiptimechecks(ex)
return quote
QO_CHECKS.x = false
local val = $(esc(ex))
QO_CHECKS.x = true
val
end
end

Base.@pure pure_inference(fout,T) = Core.Compiler.return_type(fout, T)

0 comments on commit 4f7e1a8

Please sign in to comment.