diff --git a/CHANGELOG.md b/CHANGELOG.md index 4c30de4..52ab459 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,11 @@ # News +## v0.4.0 - 2024-11-26 + +- Add `OperatorBasis` and `SuperOperatorBasis` abstract types along with corresponding `fullbasis` function to obtain these from instances of subtypes of `AbstractOperator` and `AbstractSuperOperator`. +- Change type parameters for `StateVector`, `AbstractKet` `AbstractBra` `AbstractOperator` `AbstractSuperOperator` to elimitate all type parameters. + + ## v0.3.6 - 2024-09-08 - Add `coherentstate`, `thermalstate`, `displace`, `squeeze`, `wigner`, previously from QuantumOptics. diff --git a/Project.toml b/Project.toml index d8b70c0..9f743ed 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "QuantumInterface" uuid = "5717a53b-5d69-4fa3-b976-0bf2f97ca1e5" authors = ["QuantumInterface.jl contributors"] -version = "0.3.6" +version = "0.4.0" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/QuantumInterface.jl b/src/QuantumInterface.jl index 34efa04..c23e652 100644 --- a/src/QuantumInterface.jl +++ b/src/QuantumInterface.jl @@ -1,14 +1,144 @@ module QuantumInterface -import Base: ==, +, -, *, /, ^, length, one, exp, conj, conj!, transpose, copy -import LinearAlgebra: tr, ishermitian, norm, normalize, normalize! -import Base: show, summary -import SparseArrays: sparse, spzeros, AbstractSparseMatrix # TODO move to an extension +## +# Basis specific +## + +""" + basis(a) + +Return the basis of an object. + +If it's ambiguous, e.g. if an operator or superoperator has a different +left and right basis, an [`IncompatibleBases`](@ref) error is thrown. +""" +function basis end + +""" + fullbasis(a) + +Return the full basis of an object. + +Returns subtype of `Basis` when a is a subtype of `StateVector`. +Returns a subtype of `OperatorBasis` a is a subtype of `AbstractOperator`. +Returns a subtype of `SuperOperatorBasis` when a is a subtype of `AbstractSuperOperator`. +""" +function fullbasis end + +""" + length(b::Basis) + +Total dimension of the Hilbert space. +""" +function length end + +function bases end + +function spinnumber end + +function cutoff end + +function offset end + +## +# Standard methods +## + +""" + multiplicable(a, b) + +Check if any two subtypes of `StateVector`, `AbstractOperator`, +or `AbstractSuperOperator` can be multiplied in the given order. + +Spcefically this checks whether the right basis of a is equal +to the left basis of b +""" +function multiplicable end + +""" + check_multiplicable(a, b) + +Throw an [`IncompatibleBases`](@ref) error if the objects are +not multiplicable as determined by `multiplicable(a, b)`. + +If the macro `@compatiblebases` is used anywhere up the call stack, +this check is disabled. +""" +function check_multiplicable end + +""" + addible(a, b) + +Check if any two subtypes of `StateVector`, `AbstractOperator`, +or `AbstractSuperOperator` can be added together. + +Spcefically this checks whether the left basis of a is equal +to the left basis of b and whether the right basis of a is equal +to the right basis of b. +""" +function addible end + +""" + check_addible(a, b) + +Throw an [`IncompatibleBases`](@ref) error if the objects are +not addible as determined by `addible(a, b)`. + +If the macro `@compatiblebases` is used anywhere up the call stack, +this check is disabled. +""" +function check_addible end + +""" + issquare(a) + +Check if any two subtypes of `StateVector`, `AbstractOperator`, +or `AbstractSuperOperator` are square. + +Spcefically this checks whether the left basis of a is equal +to the right basis of a. +For subtypes of `StateVector` this is always false. +""" +function addible end + +""" + check_issquare(a, b) + +Throw an [`IncompatibleBases`](@ref) error if the objects are +not addible as determined by `addible(a, b)`. + +If the macro `@compatiblebases` is used anywhere up the call stack, +this check is disabled. +""" +function check_addible end + +const BASES_CHECK = Ref(true) + +""" + @compatiblebases + +Macro to skip checks for compatible bases. Useful for `*`, `expect` and similar +functions. +""" +macro compatiblebases(ex) + return quote + BASES_CHECK.x = false + local val = $(esc(ex)) + BASES_CHECK.x = true + val + end +end function apply! end function dagger end +""" + directsum(x, y, z...) + +Direct sum of the given objects. Alternatively, the unicode +symbol ⊕ (\\oplus) can be used. +""" function directsum end const ⊕ = directsum directsum() = GenericBasis(0) @@ -86,8 +216,9 @@ function squeeze end function wigner end -include("bases.jl") include("abstract_types.jl") +include("bases.jl") +include("show.jl") include("linalg.jl") include("tensor.jl") diff --git a/src/abstract_types.jl b/src/abstract_types.jl index f8667c9..70a2ce6 100644 --- a/src/abstract_types.jl +++ b/src/abstract_types.jl @@ -1,3 +1,18 @@ +""" +Abstract base class for all specialized bases. + +The Basis class is meant to specify a basis of the Hilbert space of the +studied system. Besides basis specific information all subclasses must +implement a shape variable which indicates the dimension of the used +Hilbert space. For a spin-1/2 Hilbert space this would be the +vector `[2]`. A system composed of two spins would then have a +shape vector `[2 2]`. + +Composite systems can be defined with help of the [`CompositeBasis`](@ref) +class. +""" +abstract type Basis end + """ Abstract base class for `Bra` and `Ket` states. @@ -6,9 +21,9 @@ in respect to a certain basis. These coefficients are stored in the `data` field and the basis is defined in the `basis` field. """ -abstract type StateVector{B,T} end -abstract type AbstractKet{B,T} <: StateVector{B,T} end -abstract type AbstractBra{B,T} <: StateVector{B,T} end +abstract type StateVector end +abstract type AbstractKet <: StateVector end +abstract type AbstractBra <: StateVector end """ Abstract base class for all operators. @@ -21,7 +36,7 @@ For fast time evolution also at least the function implemented. Many other generic multiplication functions can be defined in terms of this function and are provided automatically. """ -abstract type AbstractOperator{BL,BR} end +abstract type AbstractOperator end """ Base class for all super operator classes. @@ -37,21 +52,4 @@ A_{bl_1,bl_2} = S_{(bl_1,bl_2) ↔ (br_1,br_2)} B_{br_1,br_2} A_{br_1,br_2} = B_{bl_1,bl_2} S_{(bl_1,bl_2) ↔ (br_1,br_2)} ``` """ -abstract type AbstractSuperOperator{B1,B2} end - -function summary(stream::IO, x::AbstractOperator) - print(stream, "$(typeof(x).name.name)(dim=$(length(x.basis_l))x$(length(x.basis_r)))\n") - if samebases(x) - print(stream, " basis: ") - show(stream, basis(x)) - else - print(stream, " basis left: ") - show(stream, x.basis_l) - print(stream, "\n basis right: ") - show(stream, x.basis_r) - end -end - -show(stream::IO, x::AbstractOperator) = summary(stream, x) - -traceout!(s::StateVector, i) = ptrace(s,i) +abstract type AbstractSuperOperator end diff --git a/src/bases.jl b/src/bases.jl index 6e4b077..380d248 100644 --- a/src/bases.jl +++ b/src/bases.jl @@ -1,18 +1,3 @@ -""" -Abstract base class for all specialized bases. - -The Basis class is meant to specify a basis of the Hilbert space of the -studied system. Besides basis specific information all subclasses must -implement a shape variable which indicates the dimension of the used -Hilbert space. For a spin-1/2 Hilbert space this would be the -vector `[2]`. A system composed of two spins would then have a -shape vector `[2 2]`. - -Composite systems can be defined with help of the [`CompositeBasis`](@ref) -class. -""" -abstract type Basis end - """ length(b::Basis) @@ -20,16 +5,6 @@ Total dimension of the Hilbert space. """ Base.length(b::Basis) = prod(b.shape) -""" - basis(a) - -Return the basis of an object. - -If it's ambiguous, e.g. if an operator has a different left and right basis, -an [`IncompatibleBases`](@ref) error is thrown. -""" -function basis end - """ GenericBasis(N) @@ -64,6 +39,7 @@ end CompositeBasis(bases) = CompositeBasis([length(b) for b ∈ bases], bases) CompositeBasis(bases::Basis...) = CompositeBasis((bases...,)) CompositeBasis(bases::Vector) = CompositeBasis((bases...,)) +bases(b::CompositeBasis) = b.bases Base.:(==)(b1::T, b2::T) where T<:CompositeBasis = equal_shape(b1.shape, b2.shape) @@ -142,8 +118,6 @@ Exception that should be raised for an illegal algebraic operation. """ mutable struct IncompatibleBases <: Exception end -const BASES_CHECK = Ref(true) - """ @samebases @@ -283,6 +257,8 @@ struct FockBasis{T} <: Basis new{T}([N-offset+1], N, offset) end end +cutoff(b::FockBasis) = b.N +offset(b::FockBasis) = b.offset Base.:(==)(b1::FockBasis, b2::FockBasis) = (b1.N==b2.N && b1.offset==b2.offset) @@ -348,6 +324,7 @@ struct SpinBasis{S,T} <: Basis end SpinBasis(spinnumber::Rational) = SpinBasis{spinnumber}(spinnumber) SpinBasis(spinnumber) = SpinBasis(convert(Rational{Int}, spinnumber)) +spinnumber(b::SpinBasis) = b.spinnumber Base.:(==)(b1::SpinBasis, b2::SpinBasis) = b1.spinnumber==b2.spinnumber @@ -366,9 +343,9 @@ SumBasis(shape, bases::Vector) = (tmp = (bases...,); SumBasis(shape, tmp)) SumBasis(bases::Vector) = SumBasis((bases...,)) SumBasis(bases::Basis...) = SumBasis((bases...,)) -==(b1::T, b2::T) where T<:SumBasis = equal_shape(b1.shape, b2.shape) -==(b1::SumBasis, b2::SumBasis) = false -length(b::SumBasis) = sum(b.shape) +Base.:(==)(b1::T, b2::T) where T<:SumBasis = equal_shape(b1.shape, b2.shape) +Base.:(==)(b1::SumBasis, b2::SumBasis) = false +Base.length(b::SumBasis) = sum(b.shape) """ directsum(b1::Basis, b2::Basis) @@ -393,62 +370,3 @@ function directsum(b1::SumBasis, b2::SumBasis) bases = [b1.bases...;b2.bases...] return SumBasis(shape, (bases...,)) end - -embed(b::SumBasis, indices, ops) = embed(b, b, indices, ops) - -## -# show methods -## - -function show(stream::IO, x::GenericBasis) - if length(x.shape) == 1 - write(stream, "Basis(dim=$(x.shape[1]))") - else - s = replace(string(x.shape), " " => "") - write(stream, "Basis(shape=$s)") - end -end - -function show(stream::IO, x::CompositeBasis) - write(stream, "[") - for i in 1:length(x.bases) - show(stream, x.bases[i]) - if i != length(x.bases) - write(stream, " ⊗ ") - end - end - write(stream, "]") -end - -function show(stream::IO, x::SpinBasis) - d = denominator(x.spinnumber) - n = numerator(x.spinnumber) - if d == 1 - write(stream, "Spin($n)") - else - write(stream, "Spin($n/$d)") - end -end - -function show(stream::IO, x::FockBasis) - if iszero(x.offset) - write(stream, "Fock(cutoff=$(x.N))") - else - write(stream, "Fock(cutoff=$(x.N), offset=$(x.offset))") - end -end - -function show(stream::IO, x::NLevelBasis) - write(stream, "NLevel(N=$(x.N))") -end - -function show(stream::IO, x::SumBasis) - write(stream, "[") - for i in 1:length(x.bases) - show(stream, x.bases[i]) - if i != length(x.bases) - write(stream, " ⊕ ") - end - end - write(stream, "]") -end diff --git a/src/embed_permute.jl b/src/embed_permute.jl index c2cc4ca..ac63bc4 100644 --- a/src/embed_permute.jl +++ b/src/embed_permute.jl @@ -67,8 +67,8 @@ function embed(basis_l::CompositeBasis, basis_r::CompositeBasis, ops_sb = [x[2] for x in idxop_sb] for (idxsb, opsb) in zip(indices_sb, ops_sb) - (opsb.basis_l == basis_l.bases[idxsb]) || throw(IncompatibleBases()) - (opsb.basis_r == basis_r.bases[idxsb]) || throw(IncompatibleBases()) + (opsb.basis_l == basis_l.bases[idxsb]) || throw(IncompatibleBases()) # FIXME issue #12 + (opsb.basis_r == basis_r.bases[idxsb]) || throw(IncompatibleBases()) # FIXME issue #12 end S = length(operators) > 0 ? mapreduce(eltype, promote_type, operators) : Any @@ -83,6 +83,8 @@ function embed(basis_l::CompositeBasis, basis_r::CompositeBasis, return embed_op end +embed(b::SumBasis, indices, ops) = embed(b, b, indices, ops) + permutesystems(a::AbstractOperator, perm) = arithmetic_unary_error("Permutations of subsystems", a) nsubsystems(s::AbstractKet) = nsubsystems(basis(s)) diff --git a/src/identityoperator.jl b/src/identityoperator.jl index 5959882..aa031ff 100644 --- a/src/identityoperator.jl +++ b/src/identityoperator.jl @@ -1,4 +1,4 @@ -one(x::Union{<:Basis,<:AbstractOperator}) = identityoperator(x) +Base.one(x::Union{<:Basis,<:AbstractOperator}) = identityoperator(x) """ identityoperator(a::Basis[, b::Basis]) @@ -22,4 +22,4 @@ identityoperator(::Type{T}, ::Type{Any}, b1::Basis, b2::Basis) where T<:Abstract identityoperator(b1::Basis, b2::Basis) = identityoperator(ComplexF64, b1, b2) """Prepare the identity superoperator over a given space.""" -function identitysuperoperator end \ No newline at end of file +function identitysuperoperator end diff --git a/src/julia_base.jl b/src/julia_base.jl index 9a0532d..126694c 100644 --- a/src/julia_base.jl +++ b/src/julia_base.jl @@ -1,3 +1,5 @@ +import Base: +, -, *, /, ^, length, exp, conj, conj!, adjoint, transpose, copy + # Common error messages arithmetic_unary_error(funcname, x::AbstractOperator) = throw(ArgumentError("$funcname is not defined for this type of operator: $(typeof(x)).\nTry to convert to another operator type first with e.g. dense() or sparse().")) arithmetic_binary_error(funcname, a::AbstractOperator, b::AbstractOperator) = throw(ArgumentError("$funcname is not defined for this combination of types of operators: $(typeof(a)), $(typeof(b)).\nTry to convert to a common operator type first with e.g. dense() or sparse().")) @@ -8,33 +10,33 @@ addnumbererror() = throw(ArgumentError("Can't add or subtract a number and an op # States ## --(a::T) where {T<:StateVector} = T(a.basis, -a.data) +-(a::T) where {T<:StateVector} = T(a.basis, -a.data) # FIXME issue #12 *(a::StateVector, b::Number) = b*a -copy(a::T) where {T<:StateVector} = T(a.basis, copy(a.data)) -length(a::StateVector) = length(a.basis)::Int -basis(a::StateVector) = a.basis +copy(a::T) where {T<:StateVector} = T(a.basis, copy(a.data)) # FIXME issue #12 +length(a::StateVector) = length(a.basis)::Int # FIXME issue #12 +basis(a::StateVector) = a.basis # FIXME issue #12 directsum(x::StateVector...) = reduce(directsum, x) +adjoint(a::StateVector) = dagger(a) + + # Array-like functions -Base.size(x::StateVector) = size(x.data) -@inline Base.axes(x::StateVector) = axes(x.data) +Base.size(x::StateVector) = size(x.data) # FIXME issue #12 +@inline Base.axes(x::StateVector) = axes(x.data) # FIXME issue #12 Base.ndims(x::StateVector) = 1 Base.ndims(::Type{<:StateVector}) = 1 -Base.eltype(x::StateVector) = eltype(x.data) +Base.eltype(x::StateVector) = eltype(x.data) # FIXME issue #12 # Broadcasting Base.broadcastable(x::StateVector) = x -Base.adjoint(a::StateVector) = dagger(a) - - ## # Operators ## -length(a::AbstractOperator) = length(a.basis_l)::Int*length(a.basis_r)::Int -basis(a::AbstractOperator) = (check_samebases(a); a.basis_l) -basis(a::AbstractSuperOperator) = (check_samebases(a); a.basis_l[1]) +length(a::AbstractOperator) = length(a.basis_l)::Int*length(a.basis_r)::Int # FIXME issue #12 +basis(a::AbstractOperator) = (check_samebases(a); a.basis_l) # FIXME issue #12 +basis(a::AbstractSuperOperator) = (check_samebases(a); a.basis_l[1]) # FIXME issue #12 # Ensure scalar broadcasting Base.broadcastable(x::AbstractOperator) = Ref(x) @@ -60,14 +62,17 @@ Operator exponential. """ exp(op::AbstractOperator) = throw(ArgumentError("exp() is not defined for this type of operator: $(typeof(op)).\nTry to convert to dense operator first with dense().")) -Base.size(op::AbstractOperator) = (length(op.basis_l),length(op.basis_r)) +Base.size(op::AbstractOperator) = (length(op.basis_l),length(op.basis_r)) # FIXME issue #12 function Base.size(op::AbstractOperator, i::Int) i < 1 && throw(ErrorException("dimension index is < 1")) i > 2 && return 1 - i==1 ? length(op.basis_l) : length(op.basis_r) + i==1 ? length(op.basis_l) : length(op.basis_r) # FIXME issue #12 end -Base.adjoint(a::AbstractOperator) = dagger(a) +adjoint(a::AbstractOperator) = dagger(a) + +transpose(a::AbstractOperator) = arithmetic_unary_error("Transpose", a) + conj(a::AbstractOperator) = arithmetic_unary_error("Complex conjugate", a) conj!(a::AbstractOperator) = conj(a::AbstractOperator) diff --git a/src/julia_linalg.jl b/src/julia_linalg.jl index d2f4d3d..3087d0a 100644 --- a/src/julia_linalg.jl +++ b/src/julia_linalg.jl @@ -1,3 +1,5 @@ +import LinearAlgebra: tr, ishermitian, norm, normalize, normalize! + """ ishermitian(op::AbstractOperator) @@ -17,7 +19,7 @@ tr(x::AbstractOperator) = arithmetic_unary_error("Trace", x) Norm of the given bra or ket state. """ -norm(x::StateVector) = norm(x.data) +norm(x::StateVector) = norm(x.data) # FIXME issue #12 """ normalize(x::StateVector) @@ -31,7 +33,7 @@ normalize(x::StateVector) = x/norm(x) In-place normalization of the given bra or ket so that `norm(x)` is one. """ -normalize!(x::StateVector) = (normalize!(x.data); x) +normalize!(x::StateVector) = (normalize!(x.data); x) # FIXME issue #12 """ normalize(op) diff --git a/src/linalg.jl b/src/linalg.jl index 8bb47cd..fb0779d 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -1,10 +1,10 @@ -samebases(a::AbstractOperator) = samebases(a.basis_l, a.basis_r)::Bool -samebases(a::AbstractOperator, b::AbstractOperator) = samebases(a.basis_l, b.basis_l)::Bool && samebases(a.basis_r, b.basis_r)::Bool -check_samebases(a::Union{AbstractOperator, AbstractSuperOperator}) = check_samebases(a.basis_l, a.basis_r) -multiplicable(a::AbstractOperator, b::AbstractOperator) = multiplicable(a.basis_r, b.basis_l) +samebases(a::AbstractOperator) = samebases(a.basis_l, a.basis_r)::Bool # FIXME issue #12 +samebases(a::AbstractOperator, b::AbstractOperator) = samebases(a.basis_l, b.basis_l)::Bool && samebases(a.basis_r, b.basis_r)::Bool # FIXME issue #12 +check_samebases(a::Union{AbstractOperator, AbstractSuperOperator}) = check_samebases(a.basis_l, a.basis_r) # FIXME issue #12 +multiplicable(a::AbstractOperator, b::AbstractOperator) = multiplicable(a.basis_r, b.basis_l) # FIXME issue #12 dagger(a::AbstractOperator) = arithmetic_unary_error("Hermitian conjugate", a) -transpose(a::AbstractOperator) = arithmetic_unary_error("Transpose", a) directsum(a::AbstractOperator...) = reduce(directsum, a) ptrace(a::AbstractOperator, index) = arithmetic_unary_error("Partial trace", a) _index_complement(b::CompositeBasis, indices) = complement(length(b.bases), indices) reduced(a, indices) = ptrace(a, _index_complement(basis(a), indices)) +traceout!(s::StateVector, i) = ptrace(s,i) diff --git a/src/show.jl b/src/show.jl new file mode 100644 index 0000000..38607b0 --- /dev/null +++ b/src/show.jl @@ -0,0 +1,69 @@ +import Base: show, summary + +function summary(stream::IO, x::AbstractOperator) + print(stream, "$(typeof(x).name.name)(dim=$(length(x.basis_l))x$(length(x.basis_r)))\n") + if samebases(x) + print(stream, " basis: ") + show(stream, basis(x)) + else + print(stream, " basis left: ") + show(stream, x.basis_l) + print(stream, "\n basis right: ") + show(stream, x.basis_r) + end +end + +show(stream::IO, x::AbstractOperator) = summary(stream, x) + +function show(stream::IO, x::GenericBasis) + if length(x.shape) == 1 + write(stream, "Basis(dim=$(x.shape[1]))") + else + s = replace(string(x.shape), " " => "") + write(stream, "Basis(shape=$s)") + end +end + +function show(stream::IO, x::CompositeBasis) + write(stream, "[") + for i in 1:length(x.bases) + show(stream, x.bases[i]) + if i != length(x.bases) + write(stream, " ⊗ ") + end + end + write(stream, "]") +end + +function show(stream::IO, x::SpinBasis) + d = denominator(x.spinnumber) + n = numerator(x.spinnumber) + if d == 1 + write(stream, "Spin($n)") + else + write(stream, "Spin($n/$d)") + end +end + +function show(stream::IO, x::FockBasis) + if iszero(x.offset) + write(stream, "Fock(cutoff=$(x.N))") + else + write(stream, "Fock(cutoff=$(x.N), offset=$(x.offset))") + end +end + +function show(stream::IO, x::NLevelBasis) + write(stream, "NLevel(N=$(x.N))") +end + +function show(stream::IO, x::SumBasis) + write(stream, "[") + for i in 1:length(x.bases) + show(stream, x.bases[i]) + if i != length(x.bases) + write(stream, " ⊕ ") + end + end + write(stream, "]") +end diff --git a/src/sparse.jl b/src/sparse.jl index 2ba8f5f..d6b301c 100644 --- a/src/sparse.jl +++ b/src/sparse.jl @@ -1,4 +1,4 @@ -# TODO make an extension? +import SparseArrays: sparse, spzeros, AbstractSparseMatrix # TODO move to an extension # dense(a::AbstractOperator) = arithmetic_unary_error("Conversion to dense", a) diff --git a/test/runtests.jl b/test/runtests.jl index 0bccf25..826fe33 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -26,6 +26,7 @@ end println("Starting tests with $(Threads.nthreads()) threads out of `Sys.CPU_THREADS = $(Sys.CPU_THREADS)`...") @doset "sortedindices" +@doset "bases" #VERSION >= v"1.9" && @doset "doctests" get(ENV,"JET_TEST","")=="true" && @doset "jet" VERSION >= v"1.9" && @doset "aqua" diff --git a/test/test_bases.jl b/test/test_bases.jl new file mode 100644 index 0000000..1d91673 --- /dev/null +++ b/test/test_bases.jl @@ -0,0 +1,55 @@ +using Test +using QuantumInterface: tensor, ⊗, ptrace, reduced, permutesystems, equal_bases, multiplicable +using QuantumInterface: GenericBasis, CompositeBasis, NLevelBasis, FockBasis + +@testset "basis" begin + +shape1 = [5] +shape2 = [2, 3] +shape3 = [6] + +b1 = GenericBasis(shape1) +b2 = GenericBasis(shape2) +b3 = GenericBasis(shape3) + +@test b1.shape == shape1 +@test b2.shape == shape2 +@test b1 != b2 +@test b1 != FockBasis(2) +@test b1 == b1 + +@test tensor(b1) == b1 +comp_b1 = tensor(b1, b2) +comp_uni = b1 ⊗ b2 +comp_b2 = tensor(b1, b1, b2) +@test comp_b1.shape == [prod(shape1), prod(shape2)] +@test comp_uni.shape == [prod(shape1), prod(shape2)] +@test comp_b2.shape == [prod(shape1), prod(shape1), prod(shape2)] + +@test b1^3 == CompositeBasis(b1, b1, b1) +@test (b1⊗b2)^2 == CompositeBasis(b1, b2, b1, b2) +@test_throws ArgumentError b1^(0) + +comp_b1_b2 = tensor(comp_b1, comp_b2) +@test comp_b1_b2.shape == [prod(shape1), prod(shape2), prod(shape1), prod(shape1), prod(shape2)] +@test comp_b1_b2 == CompositeBasis(b1, b2, b1, b1, b2) + +@test_throws ArgumentError tensor() +@test comp_b2.shape == tensor(b1, comp_b1).shape +@test comp_b2 == tensor(b1, comp_b1) + +@test_throws ArgumentError ptrace(comp_b1, [1, 2]) +@test ptrace(comp_b2, [1]) == ptrace(comp_b2, [2]) == comp_b1 == ptrace(comp_b2, 1) +@test ptrace(comp_b2, [1, 2]) == ptrace(comp_b1, [1]) +@test ptrace(comp_b2, [2, 3]) == ptrace(comp_b1, [2]) +@test ptrace(comp_b2, [2, 3]) == reduced(comp_b2, [1]) +@test_throws ArgumentError reduced(comp_b1, []) + +comp1 = tensor(b1, b2, b3) +comp2 = tensor(b2, b1, b3) +@test permutesystems(comp1, [2,1,3]) == comp2 + +@test !equal_bases([b1, b2], [b1, b3]) +@test !multiplicable(comp1, b1 ⊗ b2 ⊗ NLevelBasis(prod(b3.shape))) + +end # testset