Skip to content

Commit

Permalink
Broadcasting for sparse and dense operators
Browse files Browse the repository at this point in the history
  • Loading branch information
david-pl committed Dec 7, 2018
1 parent f910a11 commit bedef6e
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 8 deletions.
4 changes: 2 additions & 2 deletions src/QuantumOptics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ export bases, Basis, GenericBasis, CompositeBasis, basis,
tensor, , permutesystems, @samebases,
states, StateVector, Bra, Ket, basisstate, norm,
dagger, normalize, normalize!,
operators, AbstractOperator, expect, variance, identityoperator, ptrace, embed, dense, tr,
sparse,
operators, AbstractOperator, DataOperator, expect, variance,
identityoperator, ptrace, embed, dense, tr, sparse,
operators_dense, DenseOperator, projector, dm,
operators_sparse, SparseOperator, diagonaloperator,
operators_lazysum, LazySum,
Expand Down
15 changes: 14 additions & 1 deletion src/operators.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module operators

export AbstractOperator, length, basis, dagger, ishermitian, tensor, embed,
export AbstractOperator, DataOperator, length, basis, dagger, ishermitian, tensor, embed,
tr, ptrace, normalize, normalize!, expect, variance,
exp, permutesystems, identityoperator, dense

Expand All @@ -26,6 +26,16 @@ terms of this function and are provided automatically.
"""
abstract type AbstractOperator{BL<:Basis,BR<:Basis} end

"""
Abstract type for operators with a data field.
This is an abstract type for operators that have a direct matrix representation
stored in their `.data` field.
Type hierarchy: `DataOperator <: AbstractOperator`
"""
abstract type DataOperator{BL<:Basis,BR<:Basis} <: AbstractOperator{BL,BR} end


# Common error messages
arithmetic_unary_error(funcname, x::AbstractOperator) = throw(ArgumentError("$funcname is not defined for this type of operator: $(typeof(x)).\nTry to convert to another operator type first with e.g. dense() or sparse()."))
Expand All @@ -35,6 +45,9 @@ addnumbererror() = throw(ArgumentError("Can't add or subtract a number and an op
length(a::AbstractOperator) = length(a.basis_l)::Int*length(a.basis_r)::Int
basis(a::AbstractOperator) = (check_samebases(a); a.basis_l)

# Ensure scalar broadcasting
Base.broadcastable(x::AbstractOperator) = Ref(x)

# Arithmetic operations
+(a::AbstractOperator, b::AbstractOperator) = arithmetic_binary_error("Addition", a, b)
+(a::Number, b::AbstractOperator) = addnumbererror()
Expand Down
27 changes: 24 additions & 3 deletions src/operators_dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module operators_dense

export DenseOperator, dense, projector, dm

import Base: ==, +, -, *, /
import Base: ==, +, -, *, /, Broadcast
import ..operators

using LinearAlgebra, Base.Cartesian
Expand All @@ -16,7 +16,7 @@ Dense array implementation of Operator.
The matrix consisting of complex floats is stored in the `data` field.
"""
mutable struct DenseOperator{BL<:Basis,BR<:Basis,T<:Matrix{ComplexF64}} <: AbstractOperator{BL,BR}
mutable struct DenseOperator{BL<:Basis,BR<:Basis,T<:Matrix{ComplexF64}} <: DataOperator{BL,BR}
basis_l::BL
basis_r::BR
data::Matrix{ComplexF64}
Expand All @@ -37,7 +37,28 @@ DenseOperator{B1,B2}(b1::B1, b2::B2) where {B1<:Basis,B2<:Basis} = DenseOperator
DenseOperator(b::Basis) = DenseOperator(b, b)
DenseOperator(op::AbstractOperator) = dense(op)

Base.copy(x::DenseOperator) = DenseOperator(x.basis_l, x.basis_r, copy(x.data))
Base.copy(x::T) where T<:DataOperator = T(x.basis_l, x.basis_r, copy(x.data))

# Broadcasting
Base.size(A::DataOperator) = size(A.data)
Base.axes(A::DataOperator) = axes(A.data)

function Base.copyto!(dest::DataOperator{BL,BR}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {BL<:Basis,BR<:Basis,Style,Axes,F,Args<:Tuple{Vararg{Base.RefValue{<:DataOperator{BL,BR}}}}}
axes(dest) == axes(bc) || Base.Broadcast.throwdm(axes(dest), axes(bc))
# Performance optimization: broadcast!(identity, dest, A) is equivalent to copyto!(dest, A) if indices match
if bc.f === identity && isa(bc.args, Tuple{Base.RefValue{<:DataOperator{BL,BR}}}) # only a single input argument to broadcast!
A = bc.args[1][]
if axes(dest) == axes(A)
return copyto!(dest, A)
end
end
args_ = Tuple(a[].data for a=bc.args)
bc_ = Broadcast.Broadcasted(bc.f, args_, axes(bc))
copyto!(dest.data, bc_)
return dest
end

Base.copyto!(A::DataOperator{BL,BR},B::DataOperator{BL,BR}) where {BL<:Basis,BR<:Basis} = (copyto!(A.data,B.data); A)

"""
dense(op::AbstractOperator)
Expand Down
4 changes: 2 additions & 2 deletions src/operators_sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module operators_sparse

export SparseOperator, diagonaloperator

import Base: ==, *, /, +, -
import Base: ==, *, /, +, -, Broadcast
import ..operators
import SparseArrays: sparse

Expand All @@ -18,7 +18,7 @@ Sparse array implementation of Operator.
The matrix is stored as the julia built-in type `SparseMatrixCSC`
in the `data` field.
"""
mutable struct SparseOperator{BL<:Basis,BR<:Basis,T<:SparseMatrixCSC{ComplexF64,Int}} <: AbstractOperator{BL,BR}
mutable struct SparseOperator{BL<:Basis,BR<:Basis,T<:SparseMatrixCSC{ComplexF64,Int}} <: DataOperator{BL,BR}
basis_l::BL
basis_r::BR
data::T
Expand Down

0 comments on commit bedef6e

Please sign in to comment.