diff --git a/src/operators_dense.jl b/src/operators_dense.jl index 4dbe81c9..3ef4e185 100644 --- a/src/operators_dense.jl +++ b/src/operators_dense.jl @@ -302,7 +302,7 @@ end # Broadcasting Base.size(A::DataOperator) = size(A.data) -Base.axes(A::DataOperator) = axes(A.data) +@inline Base.axes(A::DataOperator) = axes(A.data) Base.broadcastable(A::DataOperator) = A # Custom broadcasting styles @@ -311,26 +311,21 @@ struct DenseOperatorStyle{BL<:Basis,BR<:Basis} <: DataOperatorStyle{BL,BR} end # Style precedence rules Broadcast.BroadcastStyle(::Type{<:DenseOperator{BL,BR}}) where {BL<:Basis,BR<:Basis} = DenseOperatorStyle{BL,BR}() -Broadcast.BroadcastStyle(::T, ::Broadcast.AbstractArrayStyle) where T<:DataOperatorStyle = T() Broadcast.BroadcastStyle(::DenseOperatorStyle{B1,B2}, ::DenseOperatorStyle{B3,B4}) where {B1<:Basis,B2<:Basis,B3<:Basis,B4<:Basis} = throw(bases.IncompatibleBases()) # Out-of-place broadcasting -function Base.copy(bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {BL<:Basis,BR<:Basis,Style<:DenseOperatorStyle{BL,BR},Axes,F,Args<:Tuple} +@inline function Base.copy(bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {BL<:Basis,BR<:Basis,Style<:DenseOperatorStyle{BL,BR},Axes,F,Args<:Tuple} bcf = Broadcast.flatten(bc) - args_ = Tuple(isa(a, DataOperator{BL,BR}) ? a.data : a for a=bcf.args) - bl,br = find_bases(bcf.args) + args_ = Tuple(a.data for a=bcf.args) + bl,br = states.find_basis(bcf.args) bc_ = Broadcast.Broadcasted(bcf.f, args_, axes(bcf)) # TODO: remove convert return DenseOperator{BL,BR}(bl, br, convert(Matrix{ComplexF64}, copy(bc_))) end -find_bases(bc::Broadcast.Broadcasted) = find_bases(bc.args) -find_bases(args::Tuple) = find_bases(find_bases(args[1]), Base.tail(args)) -find_bases(x) = x -find_bases(a::DataOperator, rest) = (a.basis_l, a.basis_r) -find_bases(::Any, rest) = find_bases(rest) +states.find_basis(a::DataOperator, rest) = (a.basis_l, a.basis_r) # In-place broadcasting -function Base.copyto!(dest::DataOperator{BL,BR}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {BL<:Basis,BR<:Basis,Style<:DataOperatorStyle{BL,BR},Axes,F,Args} +@inline function Base.copyto!(dest::DataOperator{BL,BR}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {BL<:Basis,BR<:Basis,Style<:DataOperatorStyle{BL,BR},Axes,F,Args} axes(dest) == axes(bc) || Base.Broadcast.throwdm(axes(dest), axes(bc)) # Performance optimization: broadcast!(identity, dest, A) is equivalent to copyto!(dest, A) if indices match if bc.f === identity && isa(bc.args, Tuple{<:DataOperator{BL,BR}}) # only a single input argument to broadcast! @@ -341,11 +336,13 @@ function Base.copyto!(dest::DataOperator{BL,BR}, bc::Broadcast.Broadcasted{Style end # Get the underlying data fields of operators and broadcast them as arrays bcf = Broadcast.flatten(bc) - args_ = Tuple(isa(a, DataOperator{BL,BR}) ? a.data : a for a=bcf.args) + args_ = Tuple(a.data for a=bcf.args) bc_ = Broadcast.Broadcasted(bcf.f, args_, axes(bcf)) copyto!(dest.data, bc_) return dest end -Base.copyto!(A::DataOperator{BL,BR},B::DataOperator{BL,BR}) where {BL<:Basis,BR<:Basis} = (copyto!(A.data,B.data); A) +@inline Base.copyto!(A::DataOperator{BL,BR},B::DataOperator{BL,BR}) where {BL<:Basis,BR<:Basis} = (copyto!(A.data,B.data); A) +@inline Base.copyto!(dest::DataOperator{BL,BR}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {BL<:Basis,BR<:Basis,Style<:DataOperatorStyle,Axes,F,Args} = + throw(bases.IncompatibleBases()) end # module diff --git a/src/operators_sparse.jl b/src/operators_sparse.jl index 2fa14ebe..7e2153e5 100644 --- a/src/operators_sparse.jl +++ b/src/operators_sparse.jl @@ -4,7 +4,7 @@ export SparseOperator, diagonaloperator import Base: ==, *, /, +, -, Broadcast import ..operators -import ..operators_dense: DataOperatorStyle, DenseOperatorStyle, find_bases +import ..operators_dense: DataOperatorStyle, DenseOperatorStyle import SparseArrays: sparse using ..bases, ..states, ..operators, ..operators_dense, ..sparsematrix @@ -157,10 +157,10 @@ Broadcast.BroadcastStyle(::DenseOperatorStyle{B1,B2}, ::SparseOperatorStyle{B1,B Broadcast.BroadcastStyle(::DenseOperatorStyle{B1,B2}, ::SparseOperatorStyle{B3,B4}) where {B1<:Basis,B2<:Basis,B3<:Basis,B4<:Basis} = throw(bases.IncompatibleBases()) Broadcast.BroadcastStyle(::SparseOperatorStyle{B1,B2}, ::SparseOperatorStyle{B3,B4}) where {B1<:Basis,B2<:Basis,B3<:Basis,B4<:Basis} = throw(bases.IncompatibleBases()) -function Base.copy(bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {BL<:Basis,BR<:Basis,Style<:SparseOperatorStyle{BL,BR},Axes,F,Args<:Tuple} +@inline function Base.copy(bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {BL<:Basis,BR<:Basis,Style<:SparseOperatorStyle{BL,BR},Axes,F,Args<:Tuple} bcf = Broadcast.flatten(bc) - args_ = Tuple(isa(a, DataOperator{BL,BR}) ? a.data : a for a=bcf.args) - bl,br = find_bases(bcf.args) + args_ = Tuple(a.data for a=bcf.args) + bl,br = states.find_basis(bcf.args) bc_ = Broadcast.Broadcasted(bcf.f, args_, axes(bcf)) return SparseOperator{BL,BR}(bl, br, copy(bc_)) end diff --git a/src/states.jl b/src/states.jl index 2432895c..1b618f28 100644 --- a/src/states.jl +++ b/src/states.jl @@ -186,4 +186,87 @@ end samebases(a::T, b::T) where {T<:StateVector} = samebases(a.basis, b.basis)::Bool +# Array-like functions +Base.size(x::StateVector) = size(x.data) +@inline Base.axes(x::StateVector) = axes(x.data) +Base.ndims(x::StateVector) = 1 +Base.ndims(::Type{<:StateVector}) = 1 + +# Broadcasting +Base.broadcastable(x::StateVector) = x + +# Custom broadcasting style +abstract type StateVectorStyle{B<:Basis} <: Broadcast.BroadcastStyle end +struct KetStyle{B<:Basis} <: StateVectorStyle{B} end +struct BraStyle{B<:Basis} <: StateVectorStyle{B} end + +# Style precedence rules +Broadcast.BroadcastStyle(::Type{<:Ket{B}}) where {B<:Basis} = KetStyle{B}() +Broadcast.BroadcastStyle(::Type{<:Bra{B}}) where {B<:Basis} = BraStyle{B}() +Broadcast.BroadcastStyle(::KetStyle{B1}, ::KetStyle{B2}) where {B1<:Basis,B2<:Basis} = throw(bases.IncompatibleBases()) +Broadcast.BroadcastStyle(::BraStyle{B1}, ::BraStyle{B2}) where {B1<:Basis,B2<:Basis} = throw(bases.IncompatibleBases()) + +# Out-of-place broadcasting +@inline function Base.copy(bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B<:Basis,Style<:KetStyle{B},Axes,F,Args<:Tuple} + bcf = Broadcast.flatten(bc) + args_ = Tuple(a.data for a=bcf.args) + bc_ = Broadcast.Broadcasted(bcf.f, args_, axes(bcf)) + b = find_basis(bcf) + return Ket{B}(b, copy(bc_)) +end +@inline function Base.copy(bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B<:Basis,Style<:BraStyle{B},Axes,F,Args<:Tuple} + bcf = Broadcast.flatten(bc) + args_ = Tuple(a.data for a=bcf.args) + bc_ = Broadcast.Broadcasted(bcf.f, args_, axes(bcf)) + b = find_basis(bcf) + return Bra{B}(b, copy(bc_)) +end +find_basis(bc::Broadcast.Broadcasted) = find_basis(bc.args) +find_basis(args::Tuple) = find_basis(find_basis(args[1]), Base.tail(args)) +find_basis(x) = x +find_basis(a::StateVector, rest) = a.basis +find_basis(::Any, rest) = find_basis(rest) + +# In-place broadcasting for Kets +@inline function Base.copyto!(dest::Ket{B}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B<:Basis,Style<:KetStyle{B},Axes,F,Args} + axes(dest) == axes(bc) || Base.Broadcast.throwdm(axes(dest), axes(bc)) + # Performance optimization: broadcast!(identity, dest, A) is equivalent to copyto!(dest, A) if indices match + if bc.f === identity && isa(bc.args, Tuple{<:Ket{B}}) # only a single input argument to broadcast! + A = bc.args[1] + if axes(dest) == axes(A) + return copyto!(dest, A) + end + end + # Get the underlying data fields of kets and broadcast them as arrays + bcf = Broadcast.flatten(bc) + args_ = Tuple(a.data for a=bcf.args) + bc_ = Broadcast.Broadcasted(bcf.f, args_, axes(bcf)) + copyto!(dest.data, bc_) + return dest +end +@inline Base.copyto!(dest::Ket{B1}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B1<:Basis,B2<:Basis,Style<:KetStyle{B2},Axes,F,Args} = + throw(bases.IncompatibleBases()) + +# In-place broadcasting for Bras +@inline function Base.copyto!(dest::Bra{B}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B<:Basis,Style<:BraStyle{B},Axes,F,Args} + axes(dest) == axes(bc) || Base.Broadcast.throwdm(axes(dest), axes(bc)) + # Performance optimization: broadcast!(identity, dest, A) is equivalent to copyto!(dest, A) if indices match + if bc.f === identity && isa(bc.args, Tuple{<:Bra{B}}) # only a single input argument to broadcast! + A = bc.args[1] + if axes(dest) == axes(A) + return copyto!(dest, A) + end + end + # Get the underlying data fields of bras and broadcast them as arrays + bcf = Broadcast.flatten(bc) + args_ = Tuple(a.data for a=bcf.args) + bc_ = Broadcast.Broadcasted(bcf.f, args_, axes(bcf)) + copyto!(dest.data, bc_) + return dest +end +@inline Base.copyto!(dest::Bra{B1}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B1<:Basis,B2<:Basis,Style<:BraStyle{B2},Axes,F,Args} = + throw(bases.IncompatibleBases()) + +@inline Base.copyto!(A::T,B::T) where T<:StateVector = (copyto!(A.data,B.data); A) + end # module diff --git a/src/superoperators.jl b/src/superoperators.jl index 928a8fb5..b5ca853a 100644 --- a/src/superoperators.jl +++ b/src/superoperators.jl @@ -5,6 +5,7 @@ export SuperOperator, DenseSuperOperator, SparseSuperOperator, import Base: ==, *, /, +, - import ..bases +import ..states import SparseArrays: sparse using ..bases, ..operators, ..operators_dense, ..operators_sparse @@ -44,6 +45,7 @@ mutable struct DenseSuperOperator{B1<:Tuple{Basis,Basis},B2<:Tuple{Basis,Basis}, new(basis_l, basis_r, data) end end +DenseSuperOperator{BL,BR}(basis_l::BL, basis_r::BR, data::T) where {BL<:Tuple{Basis,Basis},BR<:Tuple{Basis,Basis},T<:Matrix{ComplexF64}} = DenseSuperOperator{BL,BR,T}(basis_l, basis_r, data) DenseSuperOperator(basis_l::BL, basis_r::BR, data::T) where {BL<:Tuple{Basis,Basis},BR<:Tuple{Basis,Basis},T<:Matrix{ComplexF64}} = DenseSuperOperator{BL,BR,T}(basis_l, basis_r, data) function DenseSuperOperator(basis_l::Tuple{Basis, Basis}, basis_r::Tuple{Basis, Basis}) @@ -222,4 +224,69 @@ Operator exponential which can for example used to calculate time evolutions. """ Base.exp(op::DenseSuperOperator) = DenseSuperOperator(op.basis_l, op.basis_r, exp(op.data)) +# Array-like functions +Base.size(A::SuperOperator) = size(A.data) +@inline Base.axes(A::SuperOperator) = axes(A.data) +Base.ndims(A::SuperOperator) = 2 +Base.ndims(::Type{<:SuperOperator}) = 2 + +# Broadcasting +Base.broadcastable(A::SuperOperator) = A + +# Custom broadcasting styles +abstract type SuperOperatorStyle{BL<:Tuple{Basis,Basis},BR<:Tuple{Basis,Basis}} <: Broadcast.BroadcastStyle end +struct DenseSuperOperatorStyle{BL<:Tuple{Basis,Basis},BR<:Tuple{Basis,Basis}} <: SuperOperatorStyle{BL,BR} end +struct SparseSuperOperatorStyle{BL<:Tuple{Basis,Basis},BR<:Tuple{Basis,Basis}} <: SuperOperatorStyle{BL,BR} end + +# Style precedence rules +Broadcast.BroadcastStyle(::Type{<:DenseSuperOperator{BL,BR}}) where {BL<:Tuple{Basis,Basis},BR<:Tuple{Basis,Basis}} = DenseSuperOperatorStyle{BL,BR}() +Broadcast.BroadcastStyle(::Type{<:SparseSuperOperator{BL,BR}}) where {BL<:Tuple{Basis,Basis},BR<:Tuple{Basis,Basis}} = SparseSuperOperatorStyle{BL,BR}() +Broadcast.BroadcastStyle(::DenseSuperOperatorStyle{B1,B2}, ::SparseSuperOperatorStyle{B1,B2}) where {B1<:Tuple{Basis,Basis},B2<:Tuple{Basis,Basis}} = DenseSuperOperatorStyle{B1,B2}() +Broadcast.BroadcastStyle(::DenseSuperOperatorStyle{B1,B2}, ::DenseSuperOperatorStyle{B3,B4}) where {B1<:Tuple{Basis,Basis},B2<:Tuple{Basis,Basis},B3<:Tuple{Basis,Basis},B4<:Tuple{Basis,Basis}} = throw(bases.IncompatibleBases()) +Broadcast.BroadcastStyle(::SparseSuperOperatorStyle{B1,B2}, ::SparseSuperOperatorStyle{B3,B4}) where {B1<:Tuple{Basis,Basis},B2<:Tuple{Basis,Basis},B3<:Tuple{Basis,Basis},B4<:Tuple{Basis,Basis}} = throw(bases.IncompatibleBases()) +Broadcast.BroadcastStyle(::DenseSuperOperatorStyle{B1,B2}, ::SparseSuperOperatorStyle{B3,B4}) where {B1<:Tuple{Basis,Basis},B2<:Tuple{Basis,Basis},B3<:Tuple{Basis,Basis},B4<:Tuple{Basis,Basis}} = throw(bases.IncompatibleBases()) + +# Out-of-place broadcasting +@inline function Base.copy(bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {BL<:Tuple{Basis,Basis},BR<:Tuple{Basis,Basis},Style<:DenseSuperOperatorStyle{BL,BR},Axes,F,Args<:Tuple} + bcf = Broadcast.flatten(bc) + args_ = Tuple(a.data for a=bcf.args) + bl,br = states.find_basis(bcf.args) + bc_ = Broadcast.Broadcasted(bcf.f, args_, axes(bcf)) + # TODO: remove convert + return DenseSuperOperator{BL,BR}(bl, br, convert(Matrix{ComplexF64}, copy(bc_))) +end +@inline function Base.copy(bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {BL<:Tuple{Basis,Basis},BR<:Tuple{Basis,Basis},Style<:SparseSuperOperatorStyle{BL,BR},Axes,F,Args<:Tuple} + bcf = Broadcast.flatten(bc) + args_ = Tuple(a.data for a=bcf.args) + bl,br = states.find_basis(bcf.args) + bc_ = Broadcast.Broadcasted(bcf.f, args_, axes(bcf)) + return SuperOperator{BL,BR}(bl, br, copy(bc_)) +end +states.find_basis(a::SuperOperator, rest) = (a.basis_l, a.basis_r) + +# In-place broadcasting +@inline function Base.copyto!(dest::SuperOperator{BL,BR}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {BL<:Tuple{Basis,Basis},BR<:Tuple{Basis,Basis},Style<:SuperOperatorStyle{BL,BR},Axes,F,Args} + axes(dest) == axes(bc) || Base.Broadcast.throwdm(axes(dest), axes(bc)) + # Performance optimization: broadcast!(identity, dest, A) is equivalent to copyto!(dest, A) if indices match + if bc.f === identity && isa(bc.args, Tuple{<:SuperOperator{BL,BR}}) # only a single input argument to broadcast! + A = bc.args[1] + if axes(dest) == axes(A) + return copyto!(dest, A) + end + end + # Get the underlying data fields of operators and broadcast them as arrays + bcf = Broadcast.flatten(bc) + args_ = Tuple(a.data for a=bcf.args) + bc_ = Broadcast.Broadcasted(bcf.f, args_, axes(bcf)) + copyto!(dest.data, bc_) + return dest +end +@inline Base.copyto!(A::SuperOperator{BL,BR},B::SuperOperator{BL,BR}) where {BL<:Tuple{Basis,Basis},BR<:Tuple{Basis,Basis}} = (copyto!(A.data,B.data); A) +@inline function Base.copyto!(dest::SuperOperator{B1,B2}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where { + B1<:Tuple{Basis,Basis},B2<:Tuple{Basis,Basis},B3<:Tuple{Basis,Basis}, + B4<:Tuple{Basis,Basis},Style<:SuperOperatorStyle{B3,B4},Axes,F,Args + } + throw(bases.IncompatibleBases()) +end + end # module diff --git a/test/test_operators_dense.jl b/test/test_operators_dense.jl index f797554c..5a3496c0 100644 --- a/test/test_operators_dense.jl +++ b/test/test_operators_dense.jl @@ -345,4 +345,17 @@ bnlevel = NLevelBasis(2) @test ishermitian(DenseOperator(bspin, bspin, [1.0 im; -im 2.0])) == true @test ishermitian(DenseOperator(bspin, bnlevel, [1.0 im; -im 2.0])) == false +# Test broadcasting +op1_ = copy(op1) +op1 .= 2*op1 +@test op1 == op1_ .+ op1_ +op1 .= op1_ +@test op1 == op1_ +op1 .= op1_ .+ 3 * op1_ +@test op1 == 4*op1_ +@test_throws DimensionMismatch op1 .= op2 +bf = FockBasis(3) +op3 = randoperator(bf) +@test_throws bases.IncompatibleBases op1 .+ op3 + end # testset diff --git a/test/test_operators_sparse.jl b/test/test_operators_sparse.jl index 2e117126..eb26389d 100644 --- a/test/test_operators_sparse.jl +++ b/test/test_operators_sparse.jl @@ -346,4 +346,17 @@ bnlevel = NLevelBasis(2) @test ishermitian(SparseOperator(bspin, bspin, sparse([1.0 im; -im 2.0]))) == true @test ishermitian(SparseOperator(bspin, bnlevel, sparse([1.0 im; -im 2.0]))) == false +# Test broadcasting +@test_throws DimensionMismatch op1 .+ op2 +@test op1 .+ op1 == op1 + op1 +op1 .= DenseOperator(op1) +@test isa(op1, SparseOperator) +@test isa(op1 .+ DenseOperator(op1), DenseOperator) +op3 = sprandop(FockBasis(1),FockBasis(2)) +@test_throws bases.IncompatibleBases op1 .+ op3 +@test_throws bases.IncompatibleBases op1 .= op3 +op_ = copy(op1) +op_ .+= op1 +@test op_ == 2*op1 + end # testset diff --git a/test/test_states.jl b/test/test_states.jl index ab37ebf5..a92072ce 100644 --- a/test/test_states.jl +++ b/test/test_states.jl @@ -147,4 +147,19 @@ psi321 = psi3 ⊗ psi2 ⊗ psi1 @test 1e-14 > D(dagger(psi312), permutesystems(dagger(psi123), [3, 1, 2])) @test 1e-14 > D(dagger(psi321), permutesystems(dagger(psi123), [3, 2, 1])) +# Test Broadcasting +@test_throws bases.IncompatibleBases psi123 .= psi132 +@test_throws bases.IncompatibleBases psi123 .+ psi132 +bra123 = dagger(psi123) +bra132 = dagger(psi132) +@test_throws ArgumentError psi123 .+ bra123 +@test_throws bases.IncompatibleBases bra123 .= bra132 +@test_throws bases.IncompatibleBases bra123 .+ bra132 +psi_ = copy(psi123) +psi_ .+= psi123 +@test psi_ == 2*psi123 +bra_ = copy(bra123) +bra_ .= 3*bra123 +@test bra_ == 3*dagger(psi123) + end # testset diff --git a/test/test_superoperators.jl b/test/test_superoperators.jl index e4530678..543e25d4 100644 --- a/test/test_superoperators.jl +++ b/test/test_superoperators.jl @@ -1,6 +1,6 @@ using Test using QuantumOptics -using SparseArrays +using SparseArrays, LinearAlgebra @testset "superoperators" begin @@ -163,7 +163,7 @@ op2 = DenseOperator(spinbasis, [0.2+0.1im 0.1+2.3im; 0.8+4.0im 0.3+1.4im]) L = liouvillian(H, J) ρ = -1im*(H*ρ₀ - ρ₀*H) for j=J - ρ += j*ρ₀*dagger(j) - 0.5*dagger(j)*j*ρ₀ - 0.5*ρ₀*dagger(j)*j + ρ .+= j*ρ₀*dagger(j) - 0.5*dagger(j)*j*ρ₀ - 0.5*ρ₀*dagger(j)*j end @test tracedistance(L*ρ₀, ρ) < 1e-10 @@ -180,4 +180,20 @@ tout, ρt = timeevolution.master([0.,1.], ρ₀, H, J; reltol=1e-7) rates = diagm(0 => [1.0, 1.0]) @test liouvillian(H, J; rates=rates) == L +# Test broadcasting +@test L .+ L == L + L +Ldense = dense(L) +@test isa(L .+ Ldense, DenseSuperOperator) +L_ = copy(L) +L .+= L +@test L == 2*L_ +L .+= Ldense +@test L == 3*L_ +Ldense_ = dense(L_) +Ldense .+= Ldense +@test Ldense == 2*Ldense_ +Ldense .+= L +@test isa(Ldense, DenseSuperOperator) +@test isapprox(Ldense.data, 5*Ldense_.data) + end # testset