Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pl/channelbases #72

Merged
merged 7 commits into from
Feb 21, 2020
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 71 additions & 4 deletions src/matrixbases.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import .Base: length, iterate
export AbstractMatrixBasisIterator, HermitianBasisIterator, AbstractBasis,
AbstractMatrixBasis, HermitianBasis, hermitianbasis, represent, combine
ChannelBasisIterator, AbstractChannelBasis, ChannelBasis, AbstractMatrixBasis,
HermitianBasis, hermitianbasis, represent, combine, channelbasis
abstract type AbstractMatrixBasisIterator{T<:AbstractMatrix} end
struct HermitianBasisIterator{T} <: AbstractMatrixBasisIterator{T}
dim::Int
Expand All @@ -17,17 +18,17 @@ struct HermitianBasis{T} <: AbstractMatrixBasis{T}
end
end



"""
$(SIGNATURES)
- `dim`: dimensions of the matrix.

Returns elementary hermitian matrices of dimension `dim` x `dim`.
"""
hermitianbasis(T::Type{<:AbstractMatrix{<:Number}}, dim::Int) = HermitianBasisIterator{T}(dim)

hermitianbasis(dim::Int) = hermitianbasis(Matrix{ComplexF64}, dim)


function iterate(itr::HermitianBasisIterator{T}, state=(1,1)) where T<:AbstractMatrix{<:Number}
dim = itr.dim
(a, b) = state
Expand Down Expand Up @@ -55,6 +56,72 @@ function represent(basis::Type{T}, m::Matrix{<:Number}) where T<:AbstractMatrixB
represent(basis{typeof(m)}(d), m)
end

function combine(basis::T, v::Vector{<:Number}) where T<:AbstractMatrixBasis
function combine(basis::AbstractMatrixBasis{T}, v::Vector{<:Number}) where T<:AbstractMatrix{<:Number}
sum(basis.iterator .* v)
end

"""
$(SIGNATURES)

"""
abstract type AbstractMatrixBasisIterator{T<:AbstractMatrix} end
plewandowska777 marked this conversation as resolved.
Show resolved Hide resolved
struct ChannelBasisIterator{T} <: AbstractMatrixBasisIterator{T}
idim::Int
odim::Int
hitr::HermitianBasisIterator{T}
function ChannelBasisIterator{T}(idim::Int, odim::Int) where T<:AbstractMatrix{<:Number}
new(idim, odim, HermitianBasisIterator{T}(idim))
end
end

abstract type AbstractChannelBasis{T} <: AbstractMatrixBasis{T} end
struct ChannelBasis{T} <: AbstractMatrixBasis{T}
plewandowska777 marked this conversation as resolved.
Show resolved Hide resolved
iterator::ChannelBasisIterator{T}
function ChannelBasis{T}(idim::Integer, odim::Integer) where T<:AbstractMatrix{<:Number}
new(ChannelBasisIterator{T}(idim, odim))
end
end
channelbasis(T::Type{<:AbstractMatrix{<:Number}}, idim::Int, odim::Int=idim) = ChannelBasis{T}(idim, odim)

channelbasis(idim::Int, odim::Int=idim) = channelbasis(Matrix{ComplexF64}, idim, odim)

function iterate(itr::ChannelBasisIterator{T}, state=(1,1,1,1)) where T<:AbstractMatrix{<:Number}
idim = itr.idim
plewandowska777 marked this conversation as resolved.
Show resolved Hide resolved
odim = itr.odim
hitr = itr.hitr
(a, c, b, d) = state
(a == odim && c == odim && d == 2) && return nothing

Tn = eltype(T)
if a > c
x = (ketbra(T, a, c, odim) ⊗ ketbra(T, b, d, idim) + ketbra(T, c, a, odim) ⊗ ketbra(T, d, b, idim)) / sqrt(Tn(2))
plewandowska777 marked this conversation as resolved.
Show resolved Hide resolved
elseif a < c
x = (im * ketbra(T, a, c, odim) ⊗ ketbra(T, b, d, idim) - im * ketbra(T, c, a, odim) ⊗ ketbra(T, d, b, idim)) / sqrt(Tn(2))
elseif a < odim
H = iterate(hitr, (b, d))[1]
x = (diagm(0 => vcat(ones(Tn, a), Tn[-a], zeros(Tn, odim - a-1))) ⊗ H) / sqrt(Tn(a + a^2))
else
x = Matrix{Tn}(I, idim * odim, idim * odim) / sqrt(Tn(idim * odim))
end
if d < idim
newstate = (a, c, b, d+1)
elseif d == idim && b < idim
newstate = (a, c, b+1, 1)
elseif d == idim && b == idim && c < odim
newstate = (a, c+1, 1, 1)
else
newstate = (a+1, 1, 1, 1)
end
return x, newstate
end
length(itr::ChannelBasisIterator) = itr.idim^2 * itr.odim^2 - itr.idim^2 + 1
plewandowska777 marked this conversation as resolved.
Show resolved Hide resolved

function represent(basis::ChannelBasis{T1}, Φ::AbstractQuantumOperation{T2}) where T1<:AbstractMatrix{<:Number} where T2<:AbstractMatrix{<:Number}
plewandowska777 marked this conversation as resolved.
Show resolved Hide resolved
J = convert(DynamicalMatrix{T2}, Φ)
plewandowska777 marked this conversation as resolved.
Show resolved Hide resolved
represent(basis, J.matrix)
end

function combine(basis::ChannelBasis{T}, v::Vector{<:Number}) where T<:AbstractMatrix
plewandowska777 marked this conversation as resolved.
Show resolved Hide resolved
m = sum(basis.iterator .* v)
plewandowska777 marked this conversation as resolved.
Show resolved Hide resolved
DynamicalMatrix{T}(m, basis.iterator.idim, basis.iterator.odim)
end
40 changes: 40 additions & 0 deletions test/matrixbases.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,14 @@
@test [tr(m[i]' * m[j]) for i=1:d, j=1:d] ≈ Matrix{Float64}(I, d, d)
end

@testset "ChannelBasisIterator" begin
d1 = 2
d2 = 2
plewandowska777 marked this conversation as resolved.
Show resolved Hide resolved
d = d1^2 * d2^2 - d1^2 + 1
m = collect(ChannelBasisIterator{Matrix{ComplexF64}}(d1,d2))
@test [tr(m[i]' * m[j]) for i=1:d, j=1:d] ≈ Matrix{Float64}(I, d, d)
end

@testset "represent, combine" begin
d = 4
A = reshape(collect(1:16), d, d) + reshape(collect(1:16), d, d)'
Expand All @@ -27,9 +35,41 @@ end
@test length(vC) == prod(size(C))
end

@testset "represent, combine" begin
plewandowska777 marked this conversation as resolved.
Show resolved Hide resolved
d1 = 2
plewandowska777 marked this conversation as resolved.
Show resolved Hide resolved
d2 = 4
A = reshape(collect(1:16), d1 * d2, d1) * reshape(collect(1:16), d1 * d2, d1)'
B = Matrix{Float64}(I, d2, d2) ⊗ (ptrace(A, [d2, d1], 1))^(-1/2)
plewandowska777 marked this conversation as resolved.
Show resolved Hide resolved
A = B * A * B'
vA = represent(ChannelBasis{Matrix{ComplexF64}}(d1, d2), A)
Ap = combine(ChannelBasis{Matrix{ComplexF64}}(d1, d2), vA)
@test A ≈ Ap.matrix

A = reshape(collect(1:64), d1 * d2, d1 * d2) * reshape(collect(1:64), d1 * d2, d1 * d2)' + Matrix{Float64}(I, d1 * d2, d1 * d2)
B = Matrix{Float64}(I, d1, d1) ⊗ (ptrace(A, [d1, d2], 1))^(-1/2)
B = B * A * B'
vB = represent(ChannelBasis{Matrix{ComplexF64}}(d2, d1), B)
Bp = combine(ChannelBasis{Matrix{ComplexF64}}(d2, d1), vB)
@test B ≈ Bp.matrix

#vB = represent(ChannelBasis{Matrix{ComplexF32}}(d1,d2), B)
plewandowska777 marked this conversation as resolved.
Show resolved Hide resolved
#@test eltype(vB) == Float32

# C = Float16[1 2; 3 4]
# C += C'
# vC = represent(ChannelBasis, C)
# @test eltype(vC) == eltype(C)
# @test length(vC) == prod(size(C))
end

@testset "hermitainbasis" begin
@test hermitianbasis(Matrix{Float32}, 2) == HermitianBasisIterator{Matrix{Float32}}(2)
@test hermitianbasis(2) == HermitianBasisIterator{Matrix{ComplexF64}}(2)
end

# @testset "channelbasis" begin
plewandowska777 marked this conversation as resolved.
Show resolved Hide resolved
# @test channelbasis(Matrix{Float32}, 2,2) == ChannelBasisIterator{Matrix{Float32}}(2,2)
# @test channelbasis(2,2) == ChannelBasisIterator{Matrix{ComplexF64}}(2,2)
# end

end
1 change: 0 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ using Test
my_tests = ["utils.jl", "base.jl", "ptrace.jl", "ptranspose.jl", "reshuffle.jl",
"channels.jl", "functionals.jl", "gates.jl", "matrixbases.jl",
"permute_systems.jl", "randomqobjects.jl", "convex.jl"]

for my_test in my_tests
include(my_test)
end