Skip to content

Commit

Permalink
Implement macros to skip checks
Browse files Browse the repository at this point in the history
  • Loading branch information
david-pl committed Aug 13, 2018
1 parent e897c16 commit c5f8bd6
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 18 deletions.
6 changes: 3 additions & 3 deletions src/QuantumOptics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ module QuantumOptics
using SparseArrays, LinearAlgebra

export bases, Basis, GenericBasis, CompositeBasis, basis,
tensor, , permutesystems,
tensor, , permutesystems, @ismultiplicable,
states, StateVector, Bra, Ket, basisstate, norm,
dagger, normalize, normalize!,
operators, Operator, expect, variance, identityoperator, ptrace, embed, dense, tr,
Expand Down Expand Up @@ -33,7 +33,7 @@ export bases, Basis, GenericBasis, CompositeBasis, basis,
entropy_vn, fidelity, ptranspose, PPT,
negativity, logarithmic_negativity,
spectralanalysis, eigenstates, eigenenergies, simdiag,
timeevolution, diagonaljumps,
timeevolution, diagonaljumps, @skipchecks,
steadystate,
timecorrelations,
semiclassical,
Expand Down Expand Up @@ -63,7 +63,7 @@ include("transformations.jl")
include("phasespace.jl")
include("metrics.jl")
module timeevolution
export diagonaljumps
export diagonaljumps, @skipchecks
include("timeevolution_base.jl")
include("master.jl")
include("schroedinger.jl")
Expand Down
20 changes: 18 additions & 2 deletions src/bases.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ export Basis, GenericBasis, CompositeBasis, basis,
tensor, , ptrace, permutesystems,
IncompatibleBases,
samebases, multiplicable,
check_samebases, check_multiplicable
check_samebases, check_multiplicable, @ismultiplicable

import Base: ==, ^

Expand Down Expand Up @@ -214,14 +214,30 @@ function multiplicable(b1::CompositeBasis, b2::CompositeBasis)
return true
end


const MULTI_CHECK = Ref(true)
"""
@ismultiplicable
Macro to skip multiplicability checks.
"""
macro ismultiplicable(ex)
return quote
MULTI_CHECK.x = false
local val = $(esc(ex))
MULTI_CHECK.x = true
val
end
end

"""
check_multiplicable(a, b)
Throw an [`IncompatibleBases`](@ref) error if the objects are
not multiplicable.
"""
function check_multiplicable(b1, b2)
if !multiplicable(b1, b2)
if MULTI_CHECK[] && !multiplicable(b1, b2)
throw(IncompatibleBases())
end
end
Expand Down
10 changes: 5 additions & 5 deletions src/master.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module timeevolution_master

export master, master_nh, master_h, master_dynamic, master_nh_dynamic

import ..integrate, ..recast!
import ..integrate, ..recast!, ..QO_CHECKS

using ...bases, ...states, ...operators
using ...operators_dense, ...operators_sparse
Expand Down Expand Up @@ -296,29 +296,29 @@ function dmaster_h_dynamic(t::Float64, rho::DenseOperator, f::Function,
rates::DecayRates,
drho::DenseOperator, tmp::DenseOperator)
result = f(t, rho)
@assert 3 <= length(result) <= 4
QO_CHECKS[] && @assert 3 <= length(result) <= 4
if length(result) == 3
H, J, Jdagger = result
rates_ = rates
else
H, J, Jdagger, rates_ = result
end
check_master(rho, H, J, Jdagger, rates_)
QO_CHECKS[] && check_master(rho, H, J, Jdagger, rates_)
dmaster_h(rho, H, rates_, J, Jdagger, drho, tmp)
end

function dmaster_nh_dynamic(t::Float64, rho::DenseOperator, f::Function,
rates::DecayRates,
drho::DenseOperator, tmp::DenseOperator)
result = f(t, rho)
@assert 4 <= length(result) <= 5
QO_CHECKS[] && @assert 4 <= length(result) <= 5
if length(result) == 4
Hnh, Hnh_dagger, J, Jdagger = result
rates_ = rates
else
Hnh, Hnh_dagger, J, Jdagger, rates_ = result
end
check_master(rho, Hnh, J, Jdagger, rates_)
QO_CHECKS[] && check_master(rho, Hnh, J, Jdagger, rates_)
dmaster_nh(rho, Hnh, Hnh_dagger, rates_, J, Jdagger, drho, tmp)
end

Expand Down
12 changes: 6 additions & 6 deletions src/mcwf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import OrdinaryDiffEq

# TODO: Remove imports
import DiffEqCallbacks, RecursiveArrayTools.copyat_or_push!
import ..recast!
import ..recast!, ..QO_CHECKS
Base.@pure pure_inference(fout,T) = Core.Compiler.return_type(fout, T)

const DecayRates = Union{Vector{Float64}, Matrix{Float64}, Nothing}
Expand Down Expand Up @@ -196,28 +196,28 @@ end
function dmcwf_h_dynamic(t::Float64, psi::Ket, f::Function, rates::DecayRates,
dpsi::Ket, tmp::Ket)
result = f(t, psi)
@assert 3 <= length(result) <= 4
QO_CHECKS[] && @assert 3 <= length(result) <= 4
if length(result) == 3
H, J, Jdagger = result
rates_ = rates
else
H, J, Jdagger, rates_ = result
end
check_mcwf(psi, H, J, Jdagger, rates_)
QO_CHECKS[] && check_mcwf(psi, H, J, Jdagger, rates_)
dmcwf_h(psi, H, J, Jdagger, dpsi, tmp, rates)
end

function dmcwf_nh_dynamic(t::Float64, psi::Ket, f::Function, dpsi::Ket)
result = f(t, psi)
@assert 3 <= length(result) <= 4
QO_CHECKS[] && @assert 3 <= length(result) <= 4
H, J, Jdagger = result[1:3]
check_mcwf(psi, H, J, Jdagger, nothing)
QO_CHECKS[] && check_mcwf(psi, H, J, Jdagger, nothing)
dmcwf_nh(psi, H, dpsi)
end

function jump_dynamic(rng, t::Float64, psi::Ket, f::Function, psi_new::Ket, rates::DecayRates)
result = f(t, psi)
@assert 3 <= length(result) <= 4
QO_CHECKS[] && @assert 3 <= length(result) <= 4
J = result[2]
if length(result) == 3
rates_ = rates
Expand Down
4 changes: 2 additions & 2 deletions src/schroedinger.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module timeevolution_schroedinger

export schroedinger, schroedinger_dynamic

import ..integrate, ..recast!
import ..integrate, ..recast!, ..QO_CHECKS

using ...bases, ...states, ...operators

Expand Down Expand Up @@ -77,7 +77,7 @@ end

function dschroedinger_dynamic(t::Float64, psi0::T, f::Function, dpsi::T) where T<:StateVector
H = f(t, psi0)
check_schroedinger(psi0, H)
QO_CHECKS[] && check_schroedinger(psi0, H)
dschroedinger(psi0, H, dpsi)
end

Expand Down
18 changes: 18 additions & 0 deletions src/timeevolution_base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ using ..metrics

import OrdinaryDiffEq, DiffEqCallbacks, StochasticDiffEq

export @skipchecks

const DiffArray = Union{Vector{ComplexF64}, Array{ComplexF64, 2}}

function recast! end
Expand Down Expand Up @@ -175,4 +177,20 @@ function integrate_stoch(tspan::Vector{Float64}, df::Function, dg::Function, x0:
end



const QO_CHECKS = Ref(true)
"""
@skipchecks
Macro to skip checks during time evolution.
"""
macro skipchecks(ex)
return quote
QO_CHECKS.x = false
local val = $(esc(ex))
QO_CHECKS.x = true
val
end
end

Base.@pure pure_inference(fout,T) = Core.Compiler.return_type(fout, T)

0 comments on commit c5f8bd6

Please sign in to comment.