From 4f7e1a82e575fd80c622312d67858dc81c03f3a1 Mon Sep 17 00:00:00 2001 From: david-pl Date: Fri, 17 Aug 2018 11:41:02 +0200 Subject: [PATCH] Implement macros to skip checks commit f221430c27aba777ec45262aa97a2bff00e1cdcc Author: david-pl Date: Fri Aug 17 11:16:22 2018 +0200 Update macro docstrings commit 62f10e8d2e0ab5dc533936111b8b04f74140930e Author: David Plankensteiner Date: Tue Aug 14 21:17:58 2018 +0200 Fix stochastic checks commit 5c9eff541abe1791d607827f0ffa1d111efe46d6 Author: David Plankensteiner Date: Tue Aug 14 20:45:51 2018 +0200 Rename macros commit f78cf336f1f5d2bc00239c5cc13851a0d9791c77 Author: david-pl Date: Tue Aug 14 16:01:36 2018 +0200 Start renaming stuff commit c5f8bd6b75136e9de9ecbee01828ebc3058a2176 Author: David Plankensteiner Date: Mon Aug 13 20:25:58 2018 +0200 Implement macros to skip checks --- src/QuantumOptics.jl | 6 +++--- src/bases.jl | 22 +++++++++++++++++++--- src/master.jl | 10 +++++----- src/mcwf.jl | 12 ++++++------ src/schroedinger.jl | 4 ++-- src/stochastic_master.jl | 28 +++++++++++++++++++++++++--- src/stochastic_schroedinger.jl | 13 ++++++++++--- src/stochastic_semiclassical.jl | 13 +++++++++---- src/timeevolution_base.jl | 19 +++++++++++++++++++ 9 files changed, 98 insertions(+), 29 deletions(-) diff --git a/src/QuantumOptics.jl b/src/QuantumOptics.jl index 2b399875..a60160b4 100644 --- a/src/QuantumOptics.jl +++ b/src/QuantumOptics.jl @@ -3,7 +3,7 @@ module QuantumOptics using SparseArrays, LinearAlgebra export bases, Basis, GenericBasis, CompositeBasis, basis, - tensor, ⊗, permutesystems, + tensor, ⊗, permutesystems, @samebases, states, StateVector, Bra, Ket, basisstate, norm, dagger, normalize, normalize!, operators, Operator, expect, variance, identityoperator, ptrace, embed, dense, tr, @@ -31,7 +31,7 @@ export bases, Basis, GenericBasis, CompositeBasis, basis, entropy_vn, fidelity, ptranspose, PPT, negativity, logarithmic_negativity, spectralanalysis, eigenstates, eigenenergies, simdiag, - timeevolution, diagonaljumps, + timeevolution, diagonaljumps, @skiptimecheck, steadystate, timecorrelations, semiclassical, @@ -61,7 +61,7 @@ include("transformations.jl") include("phasespace.jl") include("metrics.jl") module timeevolution - export diagonaljumps + export diagonaljumps, @skiptimechecks include("timeevolution_base.jl") include("master.jl") include("schroedinger.jl") diff --git a/src/bases.jl b/src/bases.jl index 710fefb5..d3aaa164 100644 --- a/src/bases.jl +++ b/src/bases.jl @@ -4,7 +4,7 @@ export Basis, GenericBasis, CompositeBasis, basis, tensor, ⊗, ptrace, permutesystems, IncompatibleBases, samebases, multiplicable, - check_samebases, check_multiplicable + check_samebases, check_multiplicable, @samebases import Base: ==, ^ @@ -175,6 +175,22 @@ Exception that should be raised for an illegal algebraic operation. """ mutable struct IncompatibleBases <: Exception end +const BASES_CHECK = Ref(true) +""" + @samebases + +Macro to skip checks for same bases. Useful for `*`, `expect` and similar +functions. +""" +macro samebases(ex) + return quote + BASES_CHECK.x = false + local val = $(esc(ex)) + BASES_CHECK.x = true + val + end +end + """ samebases(a, b) @@ -189,7 +205,7 @@ Throw an [`IncompatibleBases`](@ref) error if the objects don't have the same bases. """ function check_samebases(b1, b2) - if !samebases(b1, b2) + if BASES_CHECK[] && !samebases(b1, b2) throw(IncompatibleBases()) end end @@ -221,7 +237,7 @@ Throw an [`IncompatibleBases`](@ref) error if the objects are not multiplicable. """ function check_multiplicable(b1, b2) - if !multiplicable(b1, b2) + if BASES_CHECK[] && !multiplicable(b1, b2) throw(IncompatibleBases()) end end diff --git a/src/master.jl b/src/master.jl index 799523fe..9336c39b 100644 --- a/src/master.jl +++ b/src/master.jl @@ -2,7 +2,7 @@ module timeevolution_master export master, master_nh, master_h, master_dynamic, master_nh_dynamic -import ..integrate, ..recast! +import ..integrate, ..recast!, ..QO_CHECKS using ...bases, ...states, ...operators using ...operators_dense, ...operators_sparse @@ -296,14 +296,14 @@ function dmaster_h_dynamic(t::Float64, rho::DenseOperator, f::Function, rates::DecayRates, drho::DenseOperator, tmp::DenseOperator) result = f(t, rho) - @assert 3 <= length(result) <= 4 + QO_CHECKS[] && @assert 3 <= length(result) <= 4 if length(result) == 3 H, J, Jdagger = result rates_ = rates else H, J, Jdagger, rates_ = result end - check_master(rho, H, J, Jdagger, rates_) + QO_CHECKS[] && check_master(rho, H, J, Jdagger, rates_) dmaster_h(rho, H, rates_, J, Jdagger, drho, tmp) end @@ -311,14 +311,14 @@ function dmaster_nh_dynamic(t::Float64, rho::DenseOperator, f::Function, rates::DecayRates, drho::DenseOperator, tmp::DenseOperator) result = f(t, rho) - @assert 4 <= length(result) <= 5 + QO_CHECKS[] && @assert 4 <= length(result) <= 5 if length(result) == 4 Hnh, Hnh_dagger, J, Jdagger = result rates_ = rates else Hnh, Hnh_dagger, J, Jdagger, rates_ = result end - check_master(rho, Hnh, J, Jdagger, rates_) + QO_CHECKS[] && check_master(rho, Hnh, J, Jdagger, rates_) dmaster_nh(rho, Hnh, Hnh_dagger, rates_, J, Jdagger, drho, tmp) end diff --git a/src/mcwf.jl b/src/mcwf.jl index 4ee168c4..f4185291 100644 --- a/src/mcwf.jl +++ b/src/mcwf.jl @@ -11,7 +11,7 @@ import OrdinaryDiffEq # TODO: Remove imports import DiffEqCallbacks, RecursiveArrayTools.copyat_or_push! -import ..recast! +import ..recast!, ..QO_CHECKS Base.@pure pure_inference(fout,T) = Core.Compiler.return_type(fout, T) const DecayRates = Union{Vector{Float64}, Matrix{Float64}, Nothing} @@ -196,28 +196,28 @@ end function dmcwf_h_dynamic(t::Float64, psi::Ket, f::Function, rates::DecayRates, dpsi::Ket, tmp::Ket) result = f(t, psi) - @assert 3 <= length(result) <= 4 + QO_CHECKS[] && @assert 3 <= length(result) <= 4 if length(result) == 3 H, J, Jdagger = result rates_ = rates else H, J, Jdagger, rates_ = result end - check_mcwf(psi, H, J, Jdagger, rates_) + QO_CHECKS[] && check_mcwf(psi, H, J, Jdagger, rates_) dmcwf_h(psi, H, J, Jdagger, dpsi, tmp, rates) end function dmcwf_nh_dynamic(t::Float64, psi::Ket, f::Function, dpsi::Ket) result = f(t, psi) - @assert 3 <= length(result) <= 4 + QO_CHECKS[] && @assert 3 <= length(result) <= 4 H, J, Jdagger = result[1:3] - check_mcwf(psi, H, J, Jdagger, nothing) + QO_CHECKS[] && check_mcwf(psi, H, J, Jdagger, nothing) dmcwf_nh(psi, H, dpsi) end function jump_dynamic(rng, t::Float64, psi::Ket, f::Function, psi_new::Ket, rates::DecayRates) result = f(t, psi) - @assert 3 <= length(result) <= 4 + QO_CHECKS[] && @assert 3 <= length(result) <= 4 J = result[2] if length(result) == 3 rates_ = rates diff --git a/src/schroedinger.jl b/src/schroedinger.jl index afb063b4..45e25bec 100644 --- a/src/schroedinger.jl +++ b/src/schroedinger.jl @@ -2,7 +2,7 @@ module timeevolution_schroedinger export schroedinger, schroedinger_dynamic -import ..integrate, ..recast! +import ..integrate, ..recast!, ..QO_CHECKS using ...bases, ...states, ...operators @@ -77,7 +77,7 @@ end function dschroedinger_dynamic(t::Float64, psi0::T, f::Function, dpsi::T) where T<:StateVector H = f(t, psi0) - check_schroedinger(psi0, H) + QO_CHECKS[] && check_schroedinger(psi0, H) dschroedinger(psi0, H, dpsi) end diff --git a/src/stochastic_master.jl b/src/stochastic_master.jl index f74508a0..459fd754 100644 --- a/src/stochastic_master.jl +++ b/src/stochastic_master.jl @@ -6,7 +6,7 @@ using ...bases, ...states, ...operators using ...operators_dense, ...operators_sparse using ...timeevolution using LinearAlgebra -import ...timeevolution: integrate_stoch, recast! +import ...timeevolution: integrate_stoch, recast!, QO_CHECKS import ...timeevolution.timeevolution_master: dmaster_h, dmaster_nh, dmaster_h_dynamic, check_master const DecayRates = Union{Vector{Float64}, Matrix{Float64}, Nothing} @@ -55,7 +55,7 @@ function master(tspan, rho0::DenseOperator, H::Operator, dmaster_stoch(dx::DiffArray, t::Float64, rho::DenseOperator, drho::DenseOperator, n::Int) = dmaster_stochastic(dx, rho, C, Cdagger, drho, n) - isreducible = check_master(rho0, H, J, Jdagger, rates) + isreducible = check_master(rho0, H, J, Jdagger, rates) && check_master_stoch(rho0, C, Cdagger) if !isreducible dmaster_h_determ(t::Float64, rho::DenseOperator, drho::DenseOperator) = dmaster_h(rho, H, rates, J, Jdagger, drho, tmp) @@ -158,8 +158,9 @@ end function dmaster_stoch_dynamic(dx::DiffArray, t::Float64, rho::DenseOperator, f::Function, drho::DenseOperator, n::Int) result = f(t, rho) - @assert 2 == length(result) + QO_CHECKS[] && @assert 2 == length(result) C, Cdagger = result + QO_CHECKS[] && check_master_stoch(rho, C, Cdagger) dmaster_stochastic(dx, rho, C, Cdagger, drho, n) end @@ -174,6 +175,27 @@ function integrate_master_stoch(tspan, df::Function, dg::Function, integrate_stoch(tspan_, df, dg, x0, state, dstate, fout, n; kwargs...) end +function check_master_stoch(rho0::DenseOperator, C::Vector, Cdagger::Vector) + @assert length(C) == length(Cdagger) + isreducible = true + for c=C + @assert isa(c, Operator) + if !(isa(c, DenseOperator) || isa(c, SparseOperator)) + isreducible = false + end + check_samebases(rho0, c) + end + for c=Cdagger + @assert isa(c, Operator) + if !(isa(c, DenseOperator) || isa(c, SparseOperator)) + isreducible = false + end + check_samebases(rho0, c) + end + isreducible +end + + # TODO: Speed up by recasting to n-d arrays, remove vector methods function recast!(x::Union{Vector{ComplexF64}, SubArray{ComplexF64, 1}}, rho::DenseOperator) rho.data = reshape(x, size(rho.data)) diff --git a/src/stochastic_schroedinger.jl b/src/stochastic_schroedinger.jl index dc0824e8..f2723855 100644 --- a/src/stochastic_schroedinger.jl +++ b/src/stochastic_schroedinger.jl @@ -6,7 +6,7 @@ using ...bases, ...states, ...operators using ...operators_dense, ...operators_sparse using ...timeevolution using LinearAlgebra -import ...timeevolution: integrate_stoch, recast! +import ...timeevolution: integrate_stoch, recast!, QO_CHECKS import ...timeevolution.timeevolution_schroedinger: dschroedinger, dschroedinger_dynamic, check_schroedinger import DiffEqCallbacks @@ -45,6 +45,10 @@ function schroedinger(tspan, psi0::Ket, H::Operator, Hs::Vector; state = copy(psi0) check_schroedinger(psi0, H) + for h=Hs + check_schroedinger(psi0, h) + end + dschroedinger_determ(t::Float64, psi::Ket, dpsi::Ket) = dschroedinger(psi, H, dpsi) dschroedinger_stoch(dx::DiffArray, t::Float64, psi::Ket, dpsi::Ket, n::Int) = dschroedinger_stochastic(dx, psi, Hs, dpsi, n) @@ -130,14 +134,12 @@ end function dschroedinger_stochastic(dx::Vector{ComplexF64}, psi::Ket, Hs::Vector{T}, dpsi::Ket, index::Int) where T <: Operator - check_schroedinger(psi, Hs[index]) recast!(dx, dpsi) dschroedinger(psi, Hs[index], dpsi) end function dschroedinger_stochastic(dx::Array{ComplexF64, 2}, psi::Ket, Hs::Vector{T}, dpsi::Ket, n::Int) where T <: Operator for i=1:n - check_schroedinger(psi, Hs[i]) dx_i = @view dx[:, i] recast!(dx_i, dpsi) dschroedinger(psi, Hs[i], dpsi) @@ -147,6 +149,11 @@ end function dschroedinger_stochastic(dx::DiffArray, t::Float64, psi::Ket, f::Function, dpsi::Ket, n::Int) ops = f(t, psi) + if QO_CHECKS[] + @inbounds for h=ops + check_schroedinger(psi, h) + end + end dschroedinger_stochastic(dx, psi, ops, dpsi, n) end diff --git a/src/stochastic_semiclassical.jl b/src/stochastic_semiclassical.jl index 86cd719c..bfd662b9 100644 --- a/src/stochastic_semiclassical.jl +++ b/src/stochastic_semiclassical.jl @@ -7,8 +7,9 @@ using ...operators_dense, ...operators_sparse using ...semiclassical import ...semiclassical: recast!, State, dmaster_h_dynamic using ...timeevolution -import ...timeevolution: integrate_stoch -import ...timeevolution.timeevolution_schroedinger: dschroedinger, dschroedinger_dynamic +import ...timeevolution: integrate_stoch, QO_CHECKS +import ...timeevolution.timeevolution_schroedinger: dschroedinger, dschroedinger_dynamic, check_schroedinger +import ...stochastic.stochastic_master: check_master_stoch using ...stochastic using LinearAlgebra @@ -207,6 +208,7 @@ function dschroedinger_stochastic(dx::Vector{ComplexF64}, t::Float64, dstate::State{Ket}, ::Int) H = fstoch_quantum(t, state.quantum, state.classical) recast!(dx, dstate) + QO_CHECKS[] && check_schroedinger(state.quantum, H[1]) dschroedinger(state.quantum, H[1], dstate.quantum) recast!(dstate, dx) end @@ -217,6 +219,7 @@ function dschroedinger_stochastic(dx::Array{ComplexF64, 2}, for i=1:n dx_i = @view dx[:, i] recast!(dx_i, dstate) + QO_CHECKS[] && check_schroedinger(state.quantum, H[i]) dschroedinger(state.quantum, H[i], dstate.quantum) recast!(dstate, dx_i) end @@ -239,8 +242,9 @@ function dmaster_stoch_dynamic(dx::Vector{ComplexF64}, t::Float64, state::State{DenseOperator}, fstoch_quantum::Function, fstoch_classical::Nothing, dstate::State{DenseOperator}, ::Int) result = fstoch_quantum(t, state.quantum, state.classical) - @assert length(result) == 2 + QO_CHECKS[] && @assert length(result) == 2 C, Cdagger = result + QO_CHECKS[] && check_master_stoch(state.quantum, C, Cdagger) recast!(dx, dstate) operators.gemm!(1, C[1], state.quantum, 0, dstate.quantum) operators.gemm!(1, state.quantum, Cdagger[1], 1, dstate.quantum) @@ -251,8 +255,9 @@ function dmaster_stoch_dynamic(dx::Array{ComplexF64, 2}, t::Float64, state::State{DenseOperator}, fstoch_quantum::Function, fstoch_classical::Nothing, dstate::State{DenseOperator}, n::Int) result = fstoch_quantum(t, state.quantum, state.classical) - @assert length(result) == 2 + QO_CHECKS[] && @assert length(result) == 2 C, Cdagger = result + QO_CHECKS[] && check_master_stoch(state.quantum, C, Cdagger) for i=1:n dx_i = @view dx[:, i] recast!(dx_i, dstate) diff --git a/src/timeevolution_base.jl b/src/timeevolution_base.jl index 4f55a304..fc71d604 100644 --- a/src/timeevolution_base.jl +++ b/src/timeevolution_base.jl @@ -2,6 +2,8 @@ using ..metrics import OrdinaryDiffEq, DiffEqCallbacks, StochasticDiffEq +export @skiptimechecks + const DiffArray = Union{Vector{ComplexF64}, Array{ComplexF64, 2}} function recast! end @@ -175,4 +177,21 @@ function integrate_stoch(tspan::Vector{Float64}, df::Function, dg::Function, x0: end + +const QO_CHECKS = Ref(true) +""" + @skiptimechecks + +Macro to skip checks during time-dependent problems. +Useful for `timeevolution.master_dynamic` and similar functions. +""" +macro skiptimechecks(ex) + return quote + QO_CHECKS.x = false + local val = $(esc(ex)) + QO_CHECKS.x = true + val + end +end + Base.@pure pure_inference(fout,T) = Core.Compiler.return_type(fout, T)