Skip to content

Commit

Permalink
Parametrize basis dimensions
Browse files Browse the repository at this point in the history
  • Loading branch information
david-pl committed Oct 20, 2018
1 parent 28fadbb commit b451987
Show file tree
Hide file tree
Showing 9 changed files with 109 additions and 49 deletions.
47 changes: 35 additions & 12 deletions src/bases.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ export Basis, GenericBasis, CompositeBasis, basis,
samebases, multiplicable,
check_samebases, check_multiplicable, @samebases

import Base: ==, ^
import Base: ==, ^, eltype

"""
Abstract base class for all specialized bases.
Expand All @@ -21,10 +21,13 @@ shape vector `Int[2 2]`.
Composite systems can be defined with help of the [`CompositeBasis`](@ref)
class.
"""
abstract type Basis end
abstract type Basis{S} end


==(b1::Basis, b2::Basis) = false
==(b1::T, b2::T) where T<:Basis = true

eltype(a::Basis{S}) where S = S

"""
length(b::Basis)
Expand Down Expand Up @@ -53,13 +56,17 @@ Should only be used rarely since it defeats the purpose of checking that the
bases of state vectors and operators are correct for algebraic operations.
The preferred way is to specify special bases for different systems.
"""
mutable struct GenericBasis <: Basis
mutable struct GenericBasis{S} <: Basis{S}
shape::Vector{Int}
function GenericBasis{S}(shape::Vector{Int}) where S
check_bases_parameter(S,length(shape))
new(shape)
end
end

GenericBasis(shape::Vector{Int}) = GenericBasis{(shape...,)}(shape)
GenericBasis(N::Int) = GenericBasis(Int[N])

==(b1::GenericBasis, b2::GenericBasis) = equal_shape(b1.shape, b2.shape)
# ==(b1::GenericBasis, b2::GenericBasis) = equal_shape(b1.shape, b2.shape)


"""
Expand All @@ -71,14 +78,19 @@ Stores the subbases in a vector and creates the shape vector directly
from the shape vectors of these subbases. Instead of creating a CompositeBasis
directly `tensor(b1, b2...)` or `b1 ⊗ b2 ⊗ …` can be used.
"""
mutable struct CompositeBasis{B<:Vector{Basis}} <: Basis
mutable struct CompositeBasis{B<:Tuple{Vararg{Basis}},S} <: Basis{S}
shape::Vector{Int}
bases::B
function CompositeBasis{B,S}(shape::Vector{Int}, bases::B) where {B<:Tuple{Vararg{Basis}},S}
check_bases_parameter(S,length(shape))
new(shape,bases)
end
end
CompositeBasis(bases::B) where B<:Vector{Basis} = CompositeBasis{B}(Int[prod(b.shape) for b in bases], bases)
CompositeBasis(bases::Basis...) = CompositeBasis(Basis[bases...])

==(b1::CompositeBasis, b2::CompositeBasis) = equal_shape(b1.shape, b2.shape) && equal_bases(b1.bases, b2.bases)
CompositeBasis(shape::Vector{Int}, bases::B) where {B<:Tuple{Vararg{Basis}}} = CompositeBasis{B,(shape...,)}(shape,bases)
CompositeBasis(shape::Vector{Int}, bases::Vector{B}) where {B<:Basis} = CompositeBasis(shape, (bases...,))
CompositeBasis(bases::Tuple{Vararg{Basis}}) = CompositeBasis([length(b) for b=bases], bases)
CompositeBasis(bases::Vector{B}) where {B<:Basis} = CompositeBasis((bases...,))
CompositeBasis(bases::Basis...) = CompositeBasis((bases...,))

"""
tensor(x, y, z...)
Expand All @@ -97,8 +109,8 @@ Create a [`CompositeBasis`](@ref) from the given bases.
Any given CompositeBasis is expanded so that the resulting CompositeBasis never
contains another CompositeBasis.
"""
tensor(b1::Basis, b2::Basis) = CompositeBasis(Int[prod(b1.shape); prod(b2.shape)], Basis[b1, b2])
tensor(b1::CompositeBasis, b2::CompositeBasis) = CompositeBasis(Int[b1.shape; b2.shape], Basis[b1.bases; b2.bases])
tensor(b1::Basis, b2::Basis) = CompositeBasis(Int[prod(b1.shape); prod(b2.shape)], (b1, b2))
tensor(b1::CompositeBasis, b2::CompositeBasis) = CompositeBasis(Int[b1.shape; b2.shape], (b1.bases..., b2.bases...,))
function tensor(b1::CompositeBasis, b2::Basis)
N = length(b1.bases)
shape = Vector{Int}(undef, N+1)
Expand Down Expand Up @@ -280,4 +292,15 @@ function permutesystems(b::CompositeBasis, perm::Vector{Int})
CompositeBasis(b.shape[perm], b.bases[perm])
end

function check_bases_parameter(S,N::Int)
if !isa(S,Tuple{Vararg{Int,N}})
if isa(S,Tuple{Vararg{Int}})
throw(ArgumentError("Bases shape parameter has the wrong length!"))
else
throw(ArgumentError("Cannot create basis with parametric field $(typeof(S))!
Needs to be of type `Tuple{Int,N}`."))
end
end
end

end # module
9 changes: 5 additions & 4 deletions src/fock.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,20 @@ using SparseArrays
Basis for a Fock space where `N` specifies a cutoff, i.e. what the highest
included fock state is. Note that the dimension of this basis then is N+1.
"""
mutable struct FockBasis <: Basis
mutable struct FockBasis{S} <: Basis{S}
shape::Vector{Int}
N::Int
function FockBasis(N::Int)
function FockBasis{S}(N::Int) where S
bases.check_bases_parameter(S,1)
if N < 0
throw(DimensionMismatch())
end
new([N+1], N)
end
end
FockBasis(N::Int) = FockBasis{(N+1,)}(N)


==(b1::FockBasis, b2::FockBasis) = b1.N==b2.N
# ==(b1::FockBasis, b2::FockBasis) = b1.N==b2.N

"""
number(b::FockBasis)
Expand Down
6 changes: 4 additions & 2 deletions src/manybody.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,18 @@ The basis has to know the associated one-body basis `b` and which occupation sta
should be included. The occupations_hash is used to speed up checking if two
many-body bases are equal.
"""
mutable struct ManyBodyBasis <: Basis
mutable struct ManyBodyBasis{S} <: Basis{S}
shape::Vector{Int}
onebodybasis::Basis
occupations::Vector{Vector{Int}}
occupations_hash::UInt

function ManyBodyBasis(onebodybasis::Basis, occupations::Vector{Vector{Int}})
function ManyBodyBasis{S}(onebodybasis::Basis, occupations::Vector{Vector{Int}}) where S
bases.check_bases_parameter(S,1)
new([length(occupations)], onebodybasis, occupations, hash(hash.(occupations)))
end
end
ManyBodyBasis(onebodybasis::Basis, occupations::Vector{Vector{Int}}) = ManyBodyBasis{(length(occupations),)}(onebodybasis::Basis, occupations::Vector{Vector{Int}})

"""
fermionstates(Nmodes, Nparticles)
Expand Down
11 changes: 6 additions & 5 deletions src/nlevel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,19 @@ using ..bases, ..states, ..operators, ..operators_sparse
Basis for a system consisting of N states.
"""
mutable struct NLevelBasis <: Basis
mutable struct NLevelBasis{S} <: Basis{S}
shape::Vector{Int}
N::Int
function NLevelBasis(N::Int)
function NLevelBasis{S}(N::Int) where S
if N < 1
throw(DimensionMismatch())
end
bases.check_bases_parameter(S,1)
new([N], N)
end
end

==(b1::NLevelBasis, b2::NLevelBasis) = b1.N == b2.N
NLevelBasis(N::Int) = NLevelBasis{(N,)}(N)
# ==(b1::NLevelBasis, b2::NLevelBasis) = b1.N == b2.N


"""
Expand Down Expand Up @@ -56,4 +57,4 @@ function nlevelstate(b::NLevelBasis, n::Int)
basisstate(b, n)
end

end # module
end # module
38 changes: 26 additions & 12 deletions src/particle.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,20 @@ of ``x_{min}`` and ``x_{max}`` are due to the periodic boundary conditions
more or less arbitrary and are chosen to be
``-\\pi/dp`` and ``\\pi/dp`` with ``dp=(p_{max}-p_{min})/N``.
"""
mutable struct PositionBasis <: Basis
mutable struct PositionBasis{X1,X2,S} <: Basis{S}
shape::Vector{Int}
xmin::Float64
xmax::Float64
N::Int
PositionBasis(xmin::Real, xmax::Real, N::Int) = new([N], xmin, xmax, N)
function PositionBasis{X1,X2,S}(xmin::Float64, xmax::Float64, N::Int) where {X1,X2,S}
bases.check_bases_parameter(S,1)
@assert isa(X1,Float64)
@assert isa(X2,Float64)
new([N], xmin, xmax, N)
end
end
PositionBasis(xmin::T, xmax::T, N::Int) where T<:Float64 = PositionBasis{xmin,xmax,(N,)}(xmin, xmax, N)
PositionBasis(xmin::Real, xmax::Real, N::Int) = PositionBasis(convert(Float64, xmin), convert(Float64, xmax), N)

"""
MomentumBasis(pmin, pmax, Npoints)
Expand All @@ -49,19 +56,26 @@ of ``p_{min}`` and ``p_{max}`` are due to the periodic boundary conditions
more or less arbitrary and are chosen to be
``-\\pi/dx`` and ``\\pi/dx`` with ``dx=(x_{max}-x_{min})/N``.
"""
mutable struct MomentumBasis <: Basis
mutable struct MomentumBasis{P1,P2,S} <: Basis{S}
shape::Vector{Int}
pmin::Float64
pmax::Float64
N::Int
MomentumBasis(pmin::Real, pmax::Real, N::Int) = new([N], pmin, pmax, N)
function MomentumBasis{P1,P2,S}(pmin::Float64, pmax::Float64, N::Int) where {P1,P2,S}
bases.check_bases_parameter(S,1)
@assert isa(P1,Float64)
@assert isa(P2,Float64)
new([N], pmin, pmax, N)
end
end
MomentumBasis(pmin::Float64, pmax::Float64, N::Int) = MomentumBasis{pmin,pmax,(N,)}(pmin, pmax, N)
MomentumBasis(pmin::Real, pmax::Real, N::Int) = MomentumBasis(convert(Float64,pmin), convert(Float64,pmax), N)

PositionBasis(b::MomentumBasis) = (dp = (b.pmax - b.pmin)/b.N; PositionBasis(-pi/dp, pi/dp, b.N))
MomentumBasis(b::PositionBasis) = (dx = (b.xmax - b.xmin)/b.N; MomentumBasis(-pi/dx, pi/dx, b.N))

==(b1::PositionBasis, b2::PositionBasis) = b1.xmin==b2.xmin && b1.xmax==b2.xmax && b1.N==b2.N
==(b1::MomentumBasis, b2::MomentumBasis) = b1.pmin==b2.pmin && b1.pmax==b2.pmax && b1.N==b2.N
# ==(b1::PositionBasis, b2::PositionBasis) = b1.xmin==b2.xmin && b1.xmax==b2.xmax && b1.N==b2.N
# ==(b1::MomentumBasis, b2::MomentumBasis) = b1.pmin==b2.pmin && b1.pmax==b2.pmax && b1.N==b2.N


"""
Expand Down Expand Up @@ -368,8 +382,8 @@ end
function transform(basis_l::CompositeBasis, basis_r::CompositeBasis; ket_only::Bool=false, index::Vector{Int}=Int[])
@assert length(basis_l.bases) == length(basis_r.bases)
if length(index) == 0
check_pos = typeof.(basis_l.bases) .== PositionBasis
check_mom = typeof.(basis_l.bases) .== MomentumBasis
check_pos = [isa.(basis_l.bases, PositionBasis)...]
check_mom = [isa.(basis_l.bases, MomentumBasis)...]
if any(check_pos) && !any(check_mom)
index = [1:length(basis_l.bases);][check_pos]
elseif any(check_mom) && !any(check_pos)
Expand All @@ -378,11 +392,11 @@ function transform(basis_l::CompositeBasis, basis_r::CompositeBasis; ket_only::B
throw(IncompatibleBases())
end
end
if all(typeof.(basis_l.bases[index]) .== PositionBasis)
@assert all(typeof.(basis_r.bases[index]) .== MomentumBasis)
if all(isa.(basis_l.bases[index], PositionBasis))
@assert all(isa.(basis_r.bases[index], MomentumBasis))
transform_xp(basis_l, basis_r, index; ket_only=ket_only)
elseif all(typeof.(basis_l.bases[index]) .== MomentumBasis)
@assert all(typeof.(basis_r.bases[index]) .== PositionBasis)
elseif all(isa.(basis_l.bases[index], MomentumBasis))
@assert all(isa.(basis_r.bases[index], PositionBasis))
transform_px(basis_l, basis_r, index; ket_only=ket_only)
else
throw(IncompatibleBases())
Expand Down
14 changes: 12 additions & 2 deletions src/spin.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,28 @@ The basis can be created for arbitrary spinnumbers by using a rational number,
e.g. `SpinBasis(3//2)`. The Pauli operators are defined for all possible
spin numbers.
"""
mutable struct SpinBasis <: Basis
mutable struct SpinBasis{SN,S} <: Basis{S}
shape::Vector{Int}
spinnumber::Rational{Int}
function SpinBasis(spinnumber::Rational{Int})
function SpinBasis{SN,S}(spinnumber::Rational{Int}) where {SN,S}
n = numerator(spinnumber)
d = denominator(spinnumber)
@assert d==2 || d==1
@assert n > 0
N = numerator(spinnumber*2 + 1)
bases.check_bases_parameter(S,1)
@assert isa(SN,Rational{Int})
new([N], spinnumber)
end
end
function SpinBasis(spinnumber::Rational{Int})
n = numerator(spinnumber)
d = denominator(spinnumber)
@assert d==2 || d==1
@assert n > 0
N = numerator(spinnumber*2 + 1)
SpinBasis{spinnumber,(N,)}(spinnumber)
end
SpinBasis(spinnumber::Int) = SpinBasis(convert(Rational{Int}, spinnumber))

==(b1::SpinBasis, b2::SpinBasis) = b1.spinnumber==b2.spinnumber
Expand Down
12 changes: 10 additions & 2 deletions src/states.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,21 @@ In-place normalization of the given bra or ket so that `norm(x)` is one.
"""
normalize!(x::StateVector) = (rmul!(x.data, 1.0/norm(x)); nothing)

function permutesystems(state::T, perm::Vector{Int}) where T<:StateVector
function permutesystems(state::T, perm::Vector{Int}) where T<:Ket
@assert length(state.basis.bases) == length(perm)
@assert isperm(perm)
data = reshape(state.data, state.basis.shape...)
data = permutedims(data, perm)
data = reshape(data, length(data))
T(permutesystems(state.basis, perm), data)
Ket(permutesystems(state.basis, perm), data)
end
function permutesystems(state::T, perm::Vector{Int}) where T<:Bra
@assert length(state.basis.bases) == length(perm)
@assert isperm(perm)
data = reshape(state.data, state.basis.shape...)
data = permutedims(data, perm)
data = reshape(data, length(data))
Bra(permutesystems(state.basis, perm), data)
end

# Creation of basis states.
Expand Down
7 changes: 4 additions & 3 deletions src/subspace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,24 @@ using ..bases, ..states, ..operators, ..operators_dense
A basis describing a subspace embedded a higher dimensional Hilbert space.
"""
mutable struct SubspaceBasis{B<:Basis,T<:Ket} <: Basis
mutable struct SubspaceBasis{B<:Basis,T<:Ket,S} <: Basis{S}
shape::Vector{Int}
superbasis::B
basisstates::Vector{T}
basisstates_hash::UInt

function SubspaceBasis{B,T}(superbasis::B, basisstates::Vector{T}) where {B<:Basis,T<:Ket}
function SubspaceBasis{B,T,S}(superbasis::B, basisstates::Vector{T}) where {B<:Basis,T<:Ket,S}
for state = basisstates
if state.basis != superbasis
throw(ArgumentError("The basis of the basisstates has to be the superbasis."))
end
end
bases.check_bases_parameter(S,1)
basisstates_hash = hash(hash.([hash.(x.data) for x=basisstates]))
new(Int[length(basisstates)], superbasis, basisstates, basisstates_hash)
end
end

SubspaceBasis{B,T}(superbasis::B, basisstates::Vector{T}) where {B<:Basis,T<:Ket} = SubspaceBasis{B,T,(length(basisstates),)}(superbasis, basisstates)
SubspaceBasis(superbasis::B, basisstates::Vector{T}) where {B<:Basis,T<:Ket} = SubspaceBasis{B,T}(superbasis, basisstates)
SubspaceBasis(basisstates::Vector{T}) where T<:Ket = SubspaceBasis(basisstates[1].basis, basisstates)

Expand Down
14 changes: 7 additions & 7 deletions test/test_states.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,15 @@ ket_b2 = randstate(b2)
ket_b3 = randstate(b3)

# Addition
@test_throws bases.IncompatibleBases bra_b1 + bra_b2
@test_throws bases.IncompatibleBases ket_b1 + ket_b2
@test_throws MethodError bra_b1 + bra_b2
@test_throws MethodError ket_b1 + ket_b2
@test 1e-14 > D(bra_b1 + Bra(b1), bra_b1)
@test 1e-14 > D(ket_b1 + Ket(b1), ket_b1)
@test 1e-14 > D(bra_b1 + dagger(ket_b1), dagger(ket_b1) + bra_b1)

# Subtraction
@test_throws bases.IncompatibleBases bra_b1 - bra_b2
@test_throws bases.IncompatibleBases ket_b1 - ket_b2
@test_throws MethodError bra_b1 - bra_b2
@test_throws MethodError ket_b1 - ket_b2
@test 1e-14 > D(bra_b1 - Bra(b1), bra_b1)
@test 1e-14 > D(ket_b1 - Ket(b1), ket_b1)
@test 1e-14 > D(bra_b1 - dagger(ket_b1), -dagger(ket_b1) + bra_b1)
Expand Down Expand Up @@ -86,9 +86,9 @@ idx = LinearIndices(shape)[1, 4, 3]


# Norm
basis = FockBasis(1)
bra = Bra(basis, [3im, -4])
ket = Ket(basis, [-4im, 3])
bf = FockBasis(1)
bra = Bra(bf, [3im, -4])
ket = Ket(bf, [-4im, 3])
@test 5 norm(bra)
@test 5 norm(ket)

Expand Down

0 comments on commit b451987

Please sign in to comment.