Skip to content

Commit

Permalink
Parametric typing for operators
Browse files Browse the repository at this point in the history
  • Loading branch information
david-pl committed Sep 29, 2018
1 parent dec97eb commit 28fadbb
Show file tree
Hide file tree
Showing 15 changed files with 132 additions and 100 deletions.
2 changes: 1 addition & 1 deletion src/manybody.jl
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ where ``X`` is the N-particle operator, ``x`` is the one-body operator and
``|u⟩`` are the one-body states associated to the
different modes of the N-particle basis.
"""
function manybodyoperator(basis::ManyBodyBasis, op::T)::T where T<:AbstractOperator
function manybodyoperator(basis::ManyBodyBasis, op::T) where T<:AbstractOperator
@assert op.basis_l == op.basis_r
if op.basis_l == basis.onebodybasis
result = manybodyoperator_1(basis, op)
Expand Down
2 changes: 1 addition & 1 deletion src/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ For fast time evolution also at least the function
implemented. Many other generic multiplication functions can be defined in
terms of this function and are provided automatically.
"""
abstract type AbstractOperator end
abstract type AbstractOperator{BL<:Basis,BR<:Basis} end


# Common error messages
Expand Down
17 changes: 12 additions & 5 deletions src/operators_dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,20 @@ Dense array implementation of Operator.
The matrix consisting of complex floats is stored in the `data` field.
"""
mutable struct DenseOperator <: AbstractOperator
basis_l::Basis
basis_r::Basis
mutable struct DenseOperator{BL<:Basis,BR<:Basis,T<:Matrix{ComplexF64}} <: AbstractOperator{BL,BR}
basis_l::BL
basis_r::BR
data::Matrix{ComplexF64}
DenseOperator(b1::Basis, b2::Basis, data) = length(b1) == size(data, 1) && length(b2) == size(data, 2) ? new(b1, b2, data) : throw(DimensionMismatch())
function DenseOperator{BL,BR,T}(b1::BL, b2::BR, data::T) where {BL<:Basis,BR<:Basis,T<:Matrix{ComplexF64}}
if !(length(b1) == size(data, 1) && length(b2) == size(data, 2))
throw(DimensionMismatch())
end
new(b1, b2, data)
end
end

DenseOperator(b1::BL, b2::BR, data::T) where {BL<:Basis,BR<:Basis,T<:Matrix{ComplexF64}} = DenseOperator{BL,BR,T}(b1, b2, data)
DenseOperator(b1::Basis, b2::Basis, data) = DenseOperator(b1, b2, convert(Matrix{ComplexF64}, data))
DenseOperator(b::Basis, data) = DenseOperator(b, b, data)
DenseOperator(b1::Basis, b2::Basis) = DenseOperator(b1, b2, zeros(ComplexF64, length(b1), length(b2)))
DenseOperator(b::Basis) = DenseOperator(b, b)
Expand Down Expand Up @@ -154,7 +161,7 @@ function operators.permutesystems(a::DenseOperator, perm::Vector{Int})
DenseOperator(permutesystems(a.basis_l, perm), permutesystems(a.basis_r, perm), data)
end

operators.identityoperator(::Type{DenseOperator}, b1::Basis, b2::Basis) = DenseOperator(b1, b2, Matrix{ComplexF64}(I, length(b1), length(b2)))
operators.identityoperator(::Type{T}, b1::Basis, b2::Basis) where {T<:DenseOperator} = DenseOperator(b1, b2, Matrix{ComplexF64}(I, length(b1), length(b2)))

"""
projector(a::Ket, b::Bra)
Expand Down
10 changes: 6 additions & 4 deletions src/operators_lazyproduct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ The factors of the product are stored in the `operators` field. Additionally a
complex factor is stored in the `factor` field which allows for fast
multiplication with numbers.
"""
mutable struct LazyProduct <: AbstractOperator
basis_l::Basis
basis_r::Basis
mutable struct LazyProduct{BL<:Basis,BR<:Basis} <: AbstractOperator{BL,BR}
basis_l::BL
basis_r::BR
factor::ComplexF64
operators::Vector{AbstractOperator}

Expand All @@ -32,7 +32,9 @@ mutable struct LazyProduct <: AbstractOperator
for i = 2:length(operators)
check_multiplicable(operators[i-1], operators[i])
end
new(operators[1].basis_l, operators[end].basis_r, factor, operators)
BL = typeof(operators[1].basis_l)
BR = typeof(operators[end].basis_r)
new{BL,BR}(operators[1].basis_l, operators[end].basis_r, factor, operators)
end
end
LazyProduct(operators::Vector, factor::Number=1) = LazyProduct(convert(Vector{AbstractOperator}, operators), factor)
Expand Down
10 changes: 6 additions & 4 deletions src/operators_lazysum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ All operators have to be given in respect to the same bases. The field
`factors` accounts for an additional multiplicative factor for each operator
stored in the `operators` field.
"""
mutable struct LazySum <: AbstractOperator
basis_l::Basis
basis_r::Basis
mutable struct LazySum{BL<:Basis,BR<:Basis} <: AbstractOperator{BL,BR}
basis_l::BL
basis_r::BR
factors::Vector{ComplexF64}
operators::Vector{AbstractOperator}

Expand All @@ -31,7 +31,9 @@ mutable struct LazySum <: AbstractOperator
@assert operators[1].basis_l == operators[i].basis_l
@assert operators[1].basis_r == operators[i].basis_r
end
new(operators[1].basis_l, operators[1].basis_r, factors, operators)
BL = typeof(operators[1].basis_l)
BR = typeof(operators[1].basis_r)
new{BL,BR}(operators[1].basis_l, operators[1].basis_r, factors, operators)
end
end
LazySum(factors::Vector{T}, operators::Vector) where {T<:Number} = LazySum(complex(factors), AbstractOperator[op for op in operators])
Expand Down
10 changes: 5 additions & 5 deletions src/operators_lazytensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@ specifies in which subsystem the corresponding operator lives. Additionally,
a complex factor is stored in the `factor` field which allows for fast
multiplication with numbers.
"""
mutable struct LazyTensor <: AbstractOperator
basis_l::CompositeBasis
basis_r::CompositeBasis
mutable struct LazyTensor{BL<:CompositeBasis,BR<:CompositeBasis} <: AbstractOperator{BL,BR}
basis_l::BL
basis_r::BR
factor::ComplexF64
indices::Vector{Int}
operators::Vector{AbstractOperator}

function LazyTensor(op::LazyTensor, factor::Number)
new(op.basis_l, op.basis_r, factor, op.indices, op.operators)
new{typeof(op.basis_l),typeof(op.basis_r)}(op.basis_l, op.basis_r, factor, op.indices, op.operators)
end

function LazyTensor(basis_l::Basis, basis_r::Basis,
Expand All @@ -54,7 +54,7 @@ mutable struct LazyTensor <: AbstractOperator
indices = indices[perm]
ops = ops[perm]
end
new(basis_l, basis_r, complex(factor), indices, ops)
new{typeof(basis_l),typeof(basis_r)}(basis_l, basis_r, complex(factor), indices, ops)
end
end

Expand Down
14 changes: 8 additions & 6 deletions src/operators_sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,20 @@ Sparse array implementation of Operator.
The matrix is stored as the julia built-in type `SparseMatrixCSC`
in the `data` field.
"""
mutable struct SparseOperator <: AbstractOperator
basis_l::Basis
basis_r::Basis
data::SparseMatrixCSC{ComplexF64, Int}
function SparseOperator(b1::Basis, b2::Basis, data)
mutable struct SparseOperator{BL<:Basis,BR<:Basis,T<:SparseMatrixCSC{ComplexF64,Int}} <: AbstractOperator{BL,BR}
basis_l::BL
basis_r::BR
data::T
function SparseOperator{BL,BR,T}(b1::Basis, b2::Basis, data::T) where {BL<:Basis,BR<:Basis,T<:SparseMatrixCSC{ComplexF64,Int}}
if length(b1) != size(data, 1) || length(b2) != size(data, 2)
throw(DimensionMismatch())
end
new(b1, b2, data)
end
end

SparseOperator(b1::BL, b2::BR, data::T) where {BL<:Basis,BR<:Basis,T<:SparseMatrixCSC{ComplexF64,Int}} = SparseOperator{BL,BR,T}(b1, b2, data)
SparseOperator(b1::Basis, b2::Basis, data) = SparseOperator(b1, b2, convert(SparseMatrixCSC{ComplexF64,Int}, data))
SparseOperator(b::Basis, data::SparseMatrixCSC{ComplexF64, Int}) = SparseOperator(b, b, data)
SparseOperator(b::Basis, data::Matrix{ComplexF64}) = SparseOperator(b, sparse(data))
SparseOperator(op::DenseOperator) = SparseOperator(op.basis_l, op.basis_r, sparse(op.data))
Expand Down Expand Up @@ -116,7 +118,7 @@ function operators.permutesystems(rho::SparseOperator, perm::Vector{Int})
SparseOperator(permutesystems(rho.basis_l, perm), permutesystems(rho.basis_r, perm), data)
end

operators.identityoperator(::Type{SparseOperator}, b1::Basis, b2::Basis) = SparseOperator(b1, b2, sparse(ComplexF64(1)*I, length(b1), length(b2)))
operators.identityoperator(::Type{T}, b1::Basis, b2::Basis) where {T<:SparseOperator} = SparseOperator(b1, b2, sparse(ComplexF64(1)*I, length(b1), length(b2)))
operators.identityoperator(b1::Basis, b2::Basis) = identityoperator(SparseOperator, b1, b2)
operators.identityoperator(b::Basis) = identityoperator(b, b)

Expand Down
30 changes: 23 additions & 7 deletions src/particle.jl
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ end
Abstract type for all implementations of FFT operators.
"""
abstract type FFTOperator <: AbstractOperator end
abstract type FFTOperator{BL<:Basis,BR<:Basis} <: AbstractOperator{BL,BR} end

const PlanFFT = FFTW.cFFTWPlan

Expand All @@ -274,15 +274,24 @@ const PlanFFT = FFTW.cFFTWPlan
Operator performing a fast fourier transformation when multiplied with a state
that is a Ket or an Operator.
"""
mutable struct FFTOperators <: FFTOperator
basis_l::Basis
basis_r::Basis
mutable struct FFTOperators{BL<:Basis,BR<:Basis} <: FFTOperator{BL,BR}
basis_l::BL
basis_r::BR
fft_l!::PlanFFT
fft_r!::PlanFFT
fft_l2!::PlanFFT
fft_r2!::PlanFFT
mul_before::Array{ComplexF64}
mul_after::Array{ComplexF64}
function FFTOperators(b1::BL, b2::BR,
fft_l!::PlanFFT,
fft_r!::PlanFFT,
fft_l2!::PlanFFT,
fft_r2!::PlanFFT,
mul_before::Array{ComplexF64},
mul_after::Array{ComplexF64}) where {BL<:Basis,BR<:Basis}
new{BL,BR}(b1, b2, fft_l!, fft_r!, fft_l2!, fft_r2!, mul_before, mul_after)
end
end

"""
Expand All @@ -291,13 +300,20 @@ end
Operator that can only perform fast fourier transformations on Kets.
This is much more memory efficient when only working with Kets.
"""
mutable struct FFTKets <: FFTOperator
basis_l::Basis
basis_r::Basis
mutable struct FFTKets{BL<:Basis,BR<:Basis} <: FFTOperator{BL,BR}
basis_l::BL
basis_r::BR
fft_l!::PlanFFT
fft_r!::PlanFFT
mul_before::Array{ComplexF64}
mul_after::Array{ComplexF64}
function FFTKets(b1::BL, b2::BR,
fft_l!::PlanFFT,
fft_r!::PlanFFT,
mul_before::Array{ComplexF64},
mul_after::Array{ComplexF64}) where {BL<:Basis,BR<:Basis}
new{BL,BR}(b1, b2, fft_l!, fft_r!, mul_before, mul_after)
end
end

"""
Expand Down
41 changes: 22 additions & 19 deletions src/semiclassical.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import ..timeevolution: integrate, recast!
using ..bases, ..states, ..operators, ..operators_dense, ..timeevolution


const QuantumState = Union{Ket, DenseOperator}
const QuantumState{B} = Union{Ket{B}, DenseOperator{B,B}}
const DecayRates = Union{Nothing, Vector{Float64}, Matrix{Float64}}

"""
Expand All @@ -16,15 +16,18 @@ Semi-classical state.
It consists of a quantum part, which is either a `Ket` or a `DenseOperator` and
a classical part that is specified as a complex vector of arbitrary length.
"""
mutable struct State{T<:QuantumState}
mutable struct State{B<:Basis,T<:QuantumState{B},C<:Vector{ComplexF64}}
quantum::T
classical::Vector{ComplexF64}
classical::C
function State(quantum::T, classical::C) where {B<:Basis,T<:QuantumState{B},C<:Vector{ComplexF64}}
new{B,T,C}(quantum, classical)
end
end

Base.length(state::State) = length(state.quantum) + length(state.classical)
Base.copy(state::State) = State(copy(state.quantum), copy(state.classical))

function ==(a::State{T}, b::State{T}) where T<:QuantumState
function ==(a::State, b::State)
samebases(a.quantum, b.quantum) &&
length(a.classical)==length(b.classical) &&
(a.classical==b.classical) &&
Expand All @@ -33,9 +36,9 @@ end

operators.expect(op, state::State) = expect(op, state.quantum)
operators.variance(op, state::State) = variance(op, state.quantum)
operators.ptrace(state::State, indices::Vector{Int}) = State{DenseOperator}(ptrace(state.quantum, indices), state.classical)
operators.ptrace(state::State, indices::Vector{Int}) = State(ptrace(state.quantum, indices), state.classical)

operators_dense.dm(x::State{T}) where T<:Ket = State{DenseOperator}(dm(x.quantum), x.classical)
operators_dense.dm(x::State{B,T}) where {B<:Basis,T<:Ket{B}} = State(dm(x.quantum), x.classical)


"""
Expand All @@ -57,11 +60,11 @@ Integrate time-dependent Schrödinger equation coupled to a classical system.
normalized nor permanent!
* `kwargs...`: Further arguments are passed on to the ode solver.
"""
function schroedinger_dynamic(tspan, state0::State{T}, fquantum::Function, fclassical::Function;
function schroedinger_dynamic(tspan, state0::State{B,T}, fquantum::Function, fclassical::Function;
fout::Union{Function,Nothing}=nothing,
kwargs...) where T <: Ket
kwargs...) where {B<:Basis,T<:Ket{B}}
tspan_ = convert(Vector{Float64}, tspan)
dschroedinger_(t, state::State{T}, dstate::State{T}) = dschroedinger_dynamic(t, state, fquantum, fclassical, dstate)
dschroedinger_(t, state::State{B,T}, dstate::State{B,T}) = dschroedinger_dynamic(t, state, fquantum, fclassical, dstate)
x0 = Vector{ComplexF64}(undef, length(state0))
recast!(state0, x0)
state = copy(state0)
Expand All @@ -88,13 +91,13 @@ Integrate time-dependent master equation coupled to a classical system.
permanent!
* `kwargs...`: Further arguments are passed on to the ode solver.
"""
function master_dynamic(tspan, state0::State{DenseOperator}, fquantum, fclassical;
rates::Union{Vector{Float64}, Matrix{Float64}, Nothing}=nothing,
function master_dynamic(tspan, state0::State{B,T}, fquantum, fclassical;
rates::DecayRates=nothing,
fout::Union{Function,Nothing}=nothing,
tmp::DenseOperator=copy(state0.quantum),
kwargs...)
tmp::T=copy(state0.quantum),
kwargs...) where {B<:Basis,T<:DenseOperator{B,B}}
tspan_ = convert(Vector{Float64}, tspan)
function dmaster_(t, state::State{DenseOperator}, dstate::State{DenseOperator})
function dmaster_(t, state::State{B,T}, dstate::State{B,T})
dmaster_h_dynamic(t, state, fquantum, fclassical, rates, dstate, tmp)
end
x0 = Vector{ComplexF64}(undef, length(state0))
Expand All @@ -104,7 +107,7 @@ function master_dynamic(tspan, state0::State{DenseOperator}, fquantum, fclassica
integrate(tspan_, dmaster_, x0, state, dstate, fout; kwargs...)
end

function master_dynamic(tspan, state0::State{T}, fquantum, fclassical; kwargs...) where T<:Ket
function master_dynamic(tspan, state0::State{B,T}, fquantum, fclassical; kwargs...) where {B<:Basis,T<:Ket{B}}
master_dynamic(tspan, dm(state0), fquantum, fclassical; kwargs...)
end

Expand All @@ -122,15 +125,15 @@ function recast!(x::Vector{ComplexF64}, state::State)
copyto!(state.classical, 1, x, N+1, length(state.classical))
end

function dschroedinger_dynamic(t::Float64, state::State{T}, fquantum::Function,
fclassical::Function, dstate::State{T}) where T<:Ket
function dschroedinger_dynamic(t::Float64, state::State{B,T}, fquantum::Function,
fclassical::Function, dstate::State{B,T}) where {B<:Basis,T<:Ket{B}}
fquantum_(t, psi) = fquantum(t, state.quantum, state.classical)
timeevolution.timeevolution_schroedinger.dschroedinger_dynamic(t, state.quantum, fquantum_, dstate.quantum)
fclassical(t, state.quantum, state.classical, dstate.classical)
end

function dmaster_h_dynamic(t::Float64, state::State{DenseOperator}, fquantum::Function,
fclassical::Function, rates::DecayRates, dstate::State{DenseOperator}, tmp::DenseOperator)
function dmaster_h_dynamic(t::Float64, state::State{B,T}, fquantum::Function,
fclassical::Function, rates::DecayRates, dstate::State{B,T}, tmp::T) where {B<:Basis,T<:DenseOperator{B,B}}
fquantum_(t, rho) = fquantum(t, state.quantum, state.classical)
timeevolution.timeevolution_master.dmaster_h_dynamic(t, state.quantum, fquantum_, rates, dstate.quantum, tmp)
fclassical(t, state.quantum, state.classical, dstate.classical)
Expand Down
Loading

0 comments on commit 28fadbb

Please sign in to comment.