From 2c4cb47f459d8c36d47794be712004d38c08ef43 Mon Sep 17 00:00:00 2001 From: Benzillaist Date: Tue, 31 Oct 2023 19:40:50 -0400 Subject: [PATCH] Added code_evaluation and code_generation methods Added multiple methods: code_generation: - Methods for creating Bicycle and Unicycle codes - Methods for assembling CSS codes - Utility methods for working with codes code_evaluation: - Methods for evaluating codes with a belief propagation decoder --- src/ecc/ECC.jl | 15 +- src/ecc/code_evaluation.jl | 178 ++++++++++++++++++++++ src/ecc/code_generation.jl | 298 +++++++++++++++++++++++++++++++++++++ 3 files changed, 488 insertions(+), 3 deletions(-) create mode 100644 src/ecc/code_evaluation.jl create mode 100644 src/ecc/code_generation.jl diff --git a/src/ecc/ECC.jl b/src/ecc/ECC.jl index b83f4e793..897b00748 100644 --- a/src/ecc/ECC.jl +++ b/src/ecc/ECC.jl @@ -1,18 +1,25 @@ module ECC using LinearAlgebra -using QuantumClifford -using QuantumClifford: AbstractOperation, AbstractStabilizer +using QuantumClifford, CairoMakie, SparseArrays, LDPCDecoders +using QuantumClifford: AbstractOperation, AbstractStabilizer, Stabilizer import QuantumClifford: Stabilizer, MixedDestabilizer +import QuantumClifford.ECC: parity_checks using DocStringExtensions using Combinatorics: combinations +using Statistics: std +using LinearAlgebra: rank abstract type AbstractECC end export Shor9, Steane7, Cleve8, Perfect5, Bitflip3, parity_checks, naive_syndrome_circuit, shor_syndrome_circuit, naive_encoding_circuit, code_n, code_s, code_k, rate, distance, - isdegenerate, faults_matrix + isdegenerate, faults_matrix, CSS_Code, Bicycle_Code, Unicycle_Code, Circ2BicycleH0, + Circ2UnicycleH0, AssembleCSS, BicycleSetGen, BicycleSetGenRand, GetCodeTableau, + GetXTableau, GetZTableau, parity_checks, ReduceBicycle, ReduceUnicycle, + create_lookup_table, evaluate_code_decoder_w_ecirc_pf, plot_code_performance, pf_encoding_plot + """Parity check tableau of a code.""" function parity_checks end @@ -289,6 +296,8 @@ function isdegenerate(H::Stabilizer, d::Int=1) end include("circuits.jl") +include("code_generation.jl") +include("code_evaluation.jl") include("codes/bitflipcode.jl") include("codes/fivequbit.jl") diff --git a/src/ecc/code_evaluation.jl b/src/ecc/code_evaluation.jl new file mode 100644 index 000000000..206d55934 --- /dev/null +++ b/src/ecc/code_evaluation.jl @@ -0,0 +1,178 @@ +# using QuantumClifford.ECC: faults_matrix, naive_syndrome_circuit, parity_checks, AbstractECC, naive_encoding_circuit, Cleve8, Steane7, Shor9, Perfect5 + +"""Generate a lookup table for decoding single qubit errors. Maps s⃗ → e⃗.""" +function create_lookup_table(code::Stabilizer) + lookup_table = Dict() + constraints, qubits = size(code) + # In the case of no errors + lookup_table[ zeros(UInt8, constraints) ] = zero(PauliOperator, qubits) + # In the case of single bit errors + for bit_to_be_flipped in 1:qubits + for error_type in [single_x, single_y, single_z] + # Generate e⃗ + error = error_type(qubits, bit_to_be_flipped) + # Calculate s⃗ + # (check which stabilizer rows do not commute with the Pauli error) + syndrome = comm(error, code) + # Store s⃗ → e⃗ + lookup_table[syndrome] = error + end + end + lookup_table +end + + """Currently uses a 1 qubit lookup table decoder. Assumes scirc is generated with naive_syndrome_circuit - I have a different function for fault tolerant syndromes circuits""" +function evaluate_code_decoder_w_ecirc_pf(checks::Stabilizer, ecirc, scirc, p_error; nframes=10_000, encoding_locs=nothing) + s, n = size(checks) + k = n-s + + pcm = stab_to_gf2(checks) + pcm_X = pcm[1:Int(s/2), 1:n] + pcm_Z = pcm[Int(s/2) + 1:end, n + 1:end] + + O = faults_matrix(checks) + circuit_Z = Base.copy(scirc) + circuit_X = Base.copy(scirc) + + # This is a where the bits to be encoded are, ecircs genereated by naive_encoding_circuit() will put those at the bottom + # Thus, the default is to apply this to the bottom k qubits + if isnothing(encoding_locs) + pre_X = [sHadamard(i) for i in n-k+1:n] + else + pre_X = [sHadamard(i) for i in encoding_locs] + end + + md = MixedDestabilizer(checks) + logview_Z = [ logicalzview(md);] + logcirc_Z, numLogBits, _ = naive_syndrome_circuit(logview_Z) # numLogBits shoudl equal k + + logview_X = [ logicalxview(md);] + logcirc_X, _ = naive_syndrome_circuit(logview_X) + + # Z logic circuit + for gate in logcirc_Z + type = typeof(gate) + if type == sMRZ + push!(circuit_Z, sMRZ(gate.qubit+s, gate.bit+s)) + else + push!(circuit_Z, type(gate.q1, gate.q2+s)) + end + end + + # X logic circuit + for gate in logcirc_X + type = typeof(gate) + if type == sMRZ + push!(circuit_X, sMRZ(gate.qubit+s, gate.bit+s)) + else + push!(circuit_X, type(gate.q1, gate.q2+s)) + end + end + + # Z simulation + errors = [PauliError(i,p_error) for i in 1:n] + + fullcircuit_Z = vcat(ecirc, errors, circuit_Z) + + frames = PauliFrame(nframes, n+s+k, s+k) + pftrajectories(frames, fullcircuit_Z) + syndromes = pfmeasurements(frames)[:, 1:s] + logicalSyndromes = pfmeasurements(frames)[:, s+1: s+k] + + decoded = 0 + for i in 1:nframes + row = syndromes[i,:] + + guess, success = syndrome_decode(sparse(pcm_Z), sparse(pcm_Z'), row[Int(s/2)+1:end], 50, fill(p_error, n), zeros(Int(s/2), n), zeros(Int(s/2), n), zeros(n), zeros(n)) + + guess = vcat(convert(Vector{Bool}, fill(0, n)), convert(Vector{Bool}, guess)) + + if isnothing(guess) + continue + else + result_Z = (O * guess)[k+1:2k] + if result_Z == logicalSyndromes[i,:] + if(i == 1) + end + decoded += 1 + end + end + end + z_error = 1 - decoded / nframes + + # X simulation + fullcircuit_X = vcat(pre_X, ecirc, errors, circuit_X) + frames = PauliFrame(nframes, n+s+k, s+k) + pftrajectories(frames, fullcircuit_X) + syndromes = pfmeasurements(frames)[:, 1:s] + logicalSyndromes = pfmeasurements(frames)[:, s+1: s+k] + + decoded = 0 + for i in 1:nframes + row = syndromes[i,:] + + guess, success = syndrome_decode(sparse(pcm_X), sparse(pcm_X'), row[1:Int(s/2)], 50, fill(p_error, n), zeros(Int(s/2), n), zeros(Int(s/2), n), zeros(n), zeros(n)) + + guess = vcat(convert(Vector{Bool}, guess), convert(Vector{Bool}, fill(0, n))) + + if isnothing(guess) + continue + else + result_X = (O * guess)[1:k] + if result_X == logicalSyndromes[i,:] + decoded += 1 + end + end + end + x_error = 1 - decoded / nframes + + return x_error, z_error +end + +"""Taken from the QEC Seminar notebook for plotting logical vs physical error""" +function plot_code_performance(error_rates, post_ec_error_rates; title="") + f = Figure(resolution=(500,300)) + ax = f[1,1] = Axis(f, xlabel="single (qu)bit error rate", ylabel="Logical error rate",title=title) + ax.aspect = DataAspect() + lim = max(error_rates[end],post_ec_error_rates[end]) + lines!([0,lim], [0,lim], label="single bit", color=:black) + plot!(error_rates, post_ec_error_rates, label="after decoding", color=:black) + xlims!(0,lim) + ylims!(0,lim) + f[1,2] = Legend(f, ax, "Error Rates") + f +end + +function pf_encoding_plot(code::AbstractECC, name=string(typeof(code))) + checks = parity_checks(code) + pf_encoding_plot(checks, name) +end + +function pf_encoding_plot(checks, name="") + (scirc, _), time1, _ = @timed naive_syndrome_circuit(checks) + # a = @timed naive_syndrome_circuit(checks) + # println(a) + # println(scirc) + # println(time1) + ecirc, time2, _ = @timed naive_encoding_circuit(checks) + # a = @timed naive_encoding_circuit(checks) + # println(a) + # println(ecirc) + # println(time2) + + error_rates = 0.000:0.0025:0.2 + post_ec_error_rates, time3, _ = @timed [evaluate_code_decoder_w_ecirc_pf(checks, ecirc, scirc, p) for p in error_rates] + # println(time3) + + total_time = round(time1 + time2 + time3, sigdigits=4) + + x_error = [post_ec_error_rates[i][1] for i in eachindex(post_ec_error_rates)] + z_error = [post_ec_error_rates[i][2] for i in eachindex(post_ec_error_rates)] + a_error = (x_error + z_error) / 2 + + f_x = plot_code_performance(error_rates, x_error,title=""*name*": Belief Decoder X @$total_time"*"s") + f_z = plot_code_performance(error_rates, z_error,title=""*name*": Belief Decoder Z @$total_time"*"s") + f_a = plot_code_performance(error_rates, a_error,title=""*name*": Belief Decoder @$total_time"*"s") + + return f_x, f_z, f_a, total_time +end \ No newline at end of file diff --git a/src/ecc/code_generation.jl b/src/ecc/code_generation.jl new file mode 100644 index 000000000..b0797714a --- /dev/null +++ b/src/ecc/code_generation.jl @@ -0,0 +1,298 @@ +# using QuantumClifford: Stabilizer +# using QuantumClifford.ECC: AbstractECC +# import QuantumClifford.ECC: parity_checks +# using Statistics:std +# using Nemo: residue_ring, matrix +# using LinearAlgebra: rank + +"""Struct for arbitrary CSS error correcting codes. + +This struct holds: + - tab: Boolean matrix with the X part taking up the left side and the Z part taking up the right side + - stab: Stabilizer of the parity check matrix + - n: Block length + - d: Code distance""" +struct CSS <: AbstractECC + tab::Matrix{Bool} + stab::Stabilizer + n::Int +end + +"""Takes an untrimmed bicycle matrix and removes the row which keeps the spread of the column weights minimal. + +Required before the bicycle code can be used. + +Typical usage: +ReduceBicycle(Circ2BicycleH0(array_indices, (block length / 2) ) )""" +function ReduceBicycle(H0::Matrix{Bool}) + m, n = size(H0) + r_i = 0 + std_min = Inf + for i in 1:m + t_H0 = vcat(H0[1:i-1, :], H0[i+1:end, :]) + std_temp = std(convert(Array, sum(t_H0, dims = 1))) + if std_temp < std_min + std_min = std_temp + r_i = i + end + end + return vcat(H0[1:r_i-1, :], H0[r_i+1:end, :]) +end + +"""Takes a list of indices and creates the base of the bicycle matrix. + +For example: +Circ2BicycleH0([1, 2, 4], 7) + +See https://arxiv.org/abs/quant-ph/0304161 for more details""" +function Circ2BicycleH0(circ_indices::Array{Int}, n::Int) + circ_arr = Array{Bool}(undef, n) + circ_matrix = Matrix{Bool}(undef, n, n) + comp_matrix = Matrix{Bool}(undef, n, 2*n) + for i = 1:n + if Int(i-1) in circ_indices + circ_arr[i] = true + else + circ_arr[i] = false + end + end + for i = 1:n + circ_matrix[i,1:n] = circ_arr + li = circ_arr[end] + circ_arr[2:end] = circ_arr[1:end-1] + circ_arr[1] = li + end + comp_matrix[1:n,1:n] = circ_matrix + comp_matrix[1:n,n+1:2*n] = transpose(circ_matrix) + return comp_matrix +end + +"""Takes an untrimmed unicycle matrix and removes linearly dependent rows. + +Required before the unicycle code can be used. + +Typical usage: +ReduceUnicycle(Circ2UnicycleH0(array_indices, block length) )""" +function ReduceUnicycle(m::Matrix{Bool}) + r = LinearAlgebra.rank(nm7) + rrzz = Nemo.residue_ring(Nemo.ZZ, 2) + for i in 1:size(u7)[1] + tm = vcat(m[1:i-1,:], m[i+1:end,:]) + tr = LinearAlgebra.rank(Nemo.matrix(rrzz, tm)) + if(tr == r) + m = tm + i -= 1 + if(size(m)[1] == r) + break + end + end + end + return m +end + +"""Takes a list of indices and creates the base of the unicycle matrix. + +For example: +Circ2UnicycleH0([1, 2, 4], 7) + +See https://arxiv.org/abs/quant-ph/0304161 for more details""" +function Circ2UnicycleH0(circ_indices::Array{Int}, n::Int) + circ_arr = fill(false, n) + one_col = transpose(fill(true, n)) + circ_matrix = Matrix{Bool}(undef, n, n) + comp_matrix = Matrix{Bool}(undef, n, n+1) + for i = 1:n + if i in circ_indices + circ_arr[i] = true + else + circ_arr[i] = false + end + end + for i = 1:n + circ_matrix[i,1:n] = circ_arr + li = circ_arr[end] + circ_arr[2:end] = circ_arr[1:end-1] + circ_arr[1] = li + end + comp_matrix[1:n,1:n] = circ_matrix + comp_matrix[1:n,n+1] = one_col + return comp_matrix +end + +function AssembleCSS end + +"""Creates a CSS code using the two provided matrices where H contains the X checks and G contains the Z checks.""" +function AssembleCSS(H::Matrix{Bool}, G::Matrix{Bool})::CSS + Hy, Hx = size(H) + Gy, Gx = size(G) + comp_matrix = fill(false, (Hy + Gy, Hx + Gx)) + # comp_matrix = Matrix{Bool}(undef, Hy + Gy, Hx + Gx) + comp_matrix[1:Hy, 1:Hx] = H + comp_matrix[Hy+1:end, Hx+1:end] = G + pcm_stab = Stabilizer(fill(0x0, Hy+Gy), GetXTableau(comp_matrix), GetZTableau(comp_matrix)) + return CSS(comp_matrix, pcm_stab, Hx) + # return comp_matrix +end + +"""Creates a CSS code using the provided matrix for the X and Z checks.""" +function AssembleCSS(H::Matrix{Bool})::CSS + return AssembleCSS(H, H) +end + +"""Attempts to generate a list of indices to be used in a bicycle code using a search method""" +function BicycleSetGen(N::Int) + circ_arr::Array{Int} = [0] + diff_arr::Array{Int} = [] + circ_arr[1] = 0 + # test new elements + for add_i = (circ_arr[end] + 1):N - 1 + valid = true + temp_circ_arr = copy(circ_arr) + temp_diff_arr::Array{Int} = [] + push!(temp_circ_arr, add_i) + for j = 1:size(temp_circ_arr)[1] + temp_arr = copy(temp_circ_arr) + # add lesser elements + N to temp_arr + for k = 1:size(temp_circ_arr)[1] + if k < j + push!(temp_arr, temp_circ_arr[k] + N) + else + break + end + end + # test if new index is valid + for k = 1:(size(temp_circ_arr)[1] - 2) + t_diff = (temp_arr[j + k] - temp_arr[j]) % N + if ((t_diff) in temp_diff_arr) + valid = false + break + else + push!(temp_diff_arr, t_diff) + end + end + if !valid + break + end + end + if valid + circ_arr = copy(temp_circ_arr) + diff_arr = copy(temp_diff_arr) + end + end + return circ_arr +end + +"""Attempts to generate a list of indices to be used in a bicycle code using a randomized check method + +Note: This is very slow for large N""" +function BicycleSetGenRand(N::Int, d::Int) + circ_arr::Array{Int} = [0] + diff_arr::Array{Int} = [] + atmp_add::Array{Int} = [0] + circ_arr[1] = 0 + # test new elements + for i = (circ_arr[end] + 1):(N^2) + valid = true + temp_circ_arr = copy(circ_arr) + temp_diff_arr::Array{Int} = [] + add_i = rand(1: N-1) + atmp_add = push!(atmp_add, add_i) + if add_i in circ_arr + continue + end + push!(temp_circ_arr, add_i) + for j = 1:size(temp_circ_arr)[1] + temp_arr = copy(temp_circ_arr) + # add lesser elements + N to temp_arr + for k = 1:size(temp_circ_arr)[1] + if k < j + push!(temp_arr, temp_circ_arr[k] + N) + else + break + end + end + # test if new index is valid + for k = 1:(size(temp_circ_arr)[1] - 2) + t_diff = (temp_arr[j + k] - temp_arr[j]) % N + if ((t_diff) in temp_diff_arr) + valid = false + break + else + push!(temp_diff_arr, t_diff) + end + end + if !valid + break + end + end + if valid + circ_arr = copy(temp_circ_arr) + diff_arr = copy(temp_diff_arr) + if (size(atmp_add)[1] == N) || (size(circ_arr)[1] == d) + break + end + end + end + return circ_arr +end + +"""Takes in a boolean Matrix and returns the parity check tableau as a string of characters. + +Note: Only works when the block length for the X and Z checks are the same!""" +function GetCodeTableau(ecc::Matrix{Bool}) + eccx = size(ecc)[2] + eccy = size(ecc)[1] + ps::String = "" + for i = 1:size(ecc)[1] + for j = 1:(Int(size(ecc)[2]/2)) + if (ecc[i, j] == 0) && (ecc[i, j + Int(eccx / 2)] == 0) + ps = string(ps, "I") + elseif (ecc[i, j] == 1) && (ecc[i, j + Int(eccx / 2)] == 0) + ps = string(ps, "X") + elseif (ecc[i, j] == 1) && (ecc[i, j + Int(eccx / 2)] == 1) + ps = string(ps, "Y") + else + ps = string(ps, "Z") + end + end + ps = string(ps,"\n") + end + return ps +end + +"""Takes in a matrix and returns just the X checks portion while keeping the full height of the matrix. + +Note: Only works when the block length for the X and Z checks are the same!""" +function GetXTableau(ecc::Matrix{Bool}) + return ecc[1:size(ecc)[1], 1:Int(size(ecc)[2]/2)] +end + +"""Takes in a matrix and returns just the Z checks portion while keeping the full height of the matrix. + +Note: Only works when the block length for the X and Z checks are the same!""" +function GetZTableau(ecc::Matrix{Bool}) + return ecc[1:size(ecc)[1], Int(size(ecc)[2]/2) + 1:end] +end + +"""Takes in a matrix and returns just the X checks portion while keeping the full height of the matrix. + +Note: Only works when the block length for the X and Z checks are the same!""" +function GetXTableau(ecc::CSS) + return GetXTableau(ecc.tab) +end + +"""Takes in a matrix and returns just the Z checks portion while keeping the full height of the matrix. + +Note: Only works when the block length for the X and Z checks are the same!""" +function GetZTableau(ecc::CSS) + return GetZTableau(ecc.tab) +end + +"""Returns the matrix form of the X and Z checks.""" +tableau(c::CSS) = c.tab + +"""Returns the stabilizer making up the parity check tableau.""" +parity_checks(c::CSS) = c.stab + +"""Returns the block length of the code.""" +code_n(c::CSS) = c.n \ No newline at end of file