Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add representation for general stabilizer codes to the python side. #278

Merged
merged 8 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions src/mqt/qecc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ._version import version as __version__
from .analog_information_decoding.simulators.analog_tannergraph_decoding import AnalogTannergraphDecoder, AtdSimulator
from .analog_information_decoding.simulators.quasi_single_shot_v2 import QssSimulator
from .codes import CSSCode, InvalidCSSCodeError
from .codes import CSSCode, StabilizerCode
from .pyqecc import (
Code,
Decoder,
Expand All @@ -33,10 +33,8 @@
"DecodingResultStatus",
"DecodingRunInformation",
"GrowthVariant",
"InvalidCSSCodeError",
"InvalidCSSCodeError",
# "SoftInfoDecoder",
"QssSimulator",
"StabilizerCode",
"UFDecoder",
"UFHeuristic",
"__version__",
Expand Down
3 changes: 3 additions & 0 deletions src/mqt/qecc/codes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@
from .css_code import CSSCode, InvalidCSSCodeError
from .hexagonal_color_code import HexagonalColorCode
from .square_octagon_color_code import SquareOctagonColorCode
from .stabilizer_code import InvalidStabilizerCodeError, StabilizerCode

__all__ = [
"CSSCode",
"ColorCode",
"HexagonalColorCode",
"InvalidCSSCodeError",
"InvalidStabilizerCodeError",
"LatticeType",
"SquareOctagonColorCode",
"StabilizerCode",
"construct_bb_code",
]
128 changes: 55 additions & 73 deletions src/mqt/qecc/codes/css_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,19 @@

from __future__ import annotations

import sys
from pathlib import Path
from typing import TYPE_CHECKING

import numpy as np
from ldpc import mod2

from .stabilizer_code import StabilizerCode

if TYPE_CHECKING: # pragma: no cover
import numpy.typing as npt


class CSSCode:
class CSSCode(StabilizerCode):
Fixed Show fixed Hide fixed
"""A class for representing CSS codes."""

def __init__(
Expand All @@ -25,79 +26,59 @@
z_distance: int | None = None,
) -> None:
"""Initialize the code."""
self._check_valid_check_matrices(Hx, Hz)

if Hx is None:
assert Hz is not None
self.n = Hz.shape[1]
self.Hx = np.zeros((0, self.n), dtype=np.int8)
else:
self.Hx = Hx
if Hz is None:
assert Hx is not None
self.n = Hx.shape[1]
self.Hz = np.zeros((0, self.n), dtype=np.int8)
else:
self.Hz = Hz

z_padding = np.zeros(self.Hx.shape, dtype=np.int8)
x_padding = np.zeros(self.Hz.shape, dtype=np.int8)

x_padded = np.hstack([self.Hx, z_padding])
z_padded = np.hstack([x_padding, self.Hz])
phases = np.zeros((x_padded.shape[0] + z_padded.shape[0], 1), dtype=np.int8)
super().__init__(np.hstack((np.vstack((x_padded, z_padded)), phases)), distance)

self.distance = distance

Check warning

Code scanning / CodeQL

Overwriting attribute in super-class or sub-class Warning

Assignment overwrites attribute distance, which was previously defined in subclass ColorCode.
Assignment overwrites attribute distance, which was previously defined in superclass StabilizerCode.
self.x_distance = x_distance if x_distance is not None else distance
self.z_distance = z_distance if z_distance is not None else distance

if self.distance < 0:
msg = "The distance must be a non-negative integer"
raise InvalidCSSCodeError(msg)
if Hx is None and Hz is None:
msg = "At least one of the check matrices must be provided"
raise InvalidCSSCodeError(msg)
if self.x_distance < self.distance or self.z_distance < self.distance:
msg = "The x and z distances must be greater than or equal to the distance"
raise InvalidCSSCodeError(msg)
if Hx is not None and Hz is not None:
if Hx.shape[1] != Hz.shape[1]:
msg = "Check matrices must have the same number of columns"
raise InvalidCSSCodeError(msg)
if np.any(Hx @ Hz.T % 2 != 0):
msg = "The check matrices must be orthogonal"
raise InvalidCSSCodeError(msg)

self.Hx = Hx
self.Hz = Hz
self.n = Hx.shape[1] if Hx is not None else Hz.shape[1] # type: ignore[union-attr]
self.k = self.n - (Hx.shape[0] if Hx is not None else 0) - (Hz.shape[0] if Hz is not None else 0)
self.Lx = CSSCode._compute_logical(self.Hz, self.Hx)

Check warning

Code scanning / CodeQL

Overwriting attribute in super-class or sub-class Warning

Assignment overwrites attribute Lx, which was previously defined in superclass StabilizerCode.
Assignment overwrites attribute Lx, which was previously defined in superclass StabilizerCode.
self.Lz = CSSCode._compute_logical(self.Hx, self.Hz)

Check warning

Code scanning / CodeQL

Overwriting attribute in super-class or sub-class Warning

Assignment overwrites attribute Lz, which was previously defined in superclass StabilizerCode.
Assignment overwrites attribute Lz, which was previously defined in superclass StabilizerCode.

def __hash__(self) -> int:
"""Compute a hash for the CSS code."""
x_hash = int.from_bytes(self.Hx.tobytes(), sys.byteorder) if self.Hx is not None else 0
z_hash = int.from_bytes(self.Hz.tobytes(), sys.byteorder) if self.Hz is not None else 0
return hash(x_hash ^ z_hash)

def __eq__(self, other: object) -> bool:
"""Check if two CSS codes are equal."""
if not isinstance(other, CSSCode):
return NotImplemented
if self.Hx is None and other.Hx is None:
assert self.Hz is not None
assert other.Hz is not None
return np.array_equal(self.Hz, other.Hz)
if self.Hz is None and other.Hz is None:
assert self.Hx is not None
assert other.Hx is not None
return np.array_equal(self.Hx, other.Hx)
if (self.Hx is None and other.Hx is not None) or (self.Hx is not None and other.Hx is None):
return False
if (self.Hz is None and other.Hz is not None) or (self.Hz is not None and other.Hz is None):
return False
assert self.Hx is not None
assert other.Hx is not None
assert self.Hz is not None
assert other.Hz is not None
return bool(
mod2.rank(self.Hx) == mod2.rank(np.vstack([self.Hx, other.Hx]))
and mod2.rank(self.Hz) == mod2.rank(np.vstack([self.Hz, other.Hz]))
)
def x_checks_as_pauli_strings(self) -> list[str]:
"""Return the x checks as Pauli strings."""
return ["".join("X" if bit == 1 else "I" for bit in row) for row in self.Hx]

@staticmethod
def _compute_logical(m1: npt.NDArray[np.int8] | None, m2: npt.NDArray[np.int8] | None) -> npt.NDArray[np.int8]:
"""Compute the logical matrix L."""
if m1 is None:
ker_m2 = mod2.nullspace(m2) # compute the kernel basis of m2
pivots = mod2.row_echelon(ker_m2)[-1]
logs = np.zeros_like(ker_m2, dtype=np.int8) # type: npt.NDArray[np.int8]
for i, pivot in enumerate(pivots):
logs[i, pivot] = 1
return logs
def z_checks_as_pauli_strings(self) -> list[str]:
"""Return the z checks as Pauli strings."""
return ["".join("Z" if bit == 1 else "I" for bit in row) for row in self.Hz]

if m2 is None:
return mod2.nullspace(m1).astype(np.int8)
def x_logicals_as_pauli_strings(self) -> list[str]:
"""Return the x logicals as a Pauli strings."""
return ["".join("X" if bit == 1 else "I" for bit in row) for row in self.Lx]

def z_logicals_as_pauli_strings(self) -> list[str]:
"""Return the z logicals as Pauli strings."""
return ["".join("Z" if bit == 1 else "I" for bit in row) for row in self.Lz]

@staticmethod
def _compute_logical(m1: npt.NDArray[np.int8], m2: npt.NDArray[np.int8]) -> npt.NDArray[np.int8]:
"""Compute the logical matrix L."""
ker_m1 = mod2.nullspace(m1) # compute the kernel basis of m1
im_m2_transp = mod2.row_basis(m2) # compute the image basis of m2
log_stack = np.vstack([im_m2_transp, ker_m1])
pehamTom marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -161,19 +142,20 @@
self.Hx.shape[0] == self.Hz.shape[0] and mod2.rank(self.Hx) == mod2.rank(np.vstack([self.Hx, self.Hz]))
)

def stabs_as_pauli_strings(self) -> tuple[list[str] | None, list[str] | None]:
"""Return the stabilizers as Pauli strings."""
x_str = None if self.Hx is None else ["".join(["I" if x == 0 else "X" for x in row]) for row in self.Hx]
z_str = None if self.Hz is None else ["".join(["I" if z == 0 else "Z" for z in row]) for row in self.Hz]
return x_str, z_str

def z_logicals_as_pauli_string(self) -> str:
"""Return the logical Z operator as a Pauli string."""
return "".join(["I" if z == 0 else "Z" for z in self.Lx[0]])
@staticmethod
def _check_valid_check_matrices(Hx: npt.NDArray[np.int8] | None, Hz: npt.NDArray[np.int8] | None) -> None: # noqa: N803
pehamTom marked this conversation as resolved.
Show resolved Hide resolved
"""Check if the code is a valid CSS code."""
if Hx is None and Hz is None:
msg = "At least one of the check matrices must be provided"
raise InvalidCSSCodeError(msg)

def x_logicals_as_pauli_string(self) -> str:
"""Return the logical X operator as a Pauli string."""
return "".join(["I" if x == 0 else "X" for x in self.Lz[0]])
if Hx is not None and Hz is not None:
if Hx.shape[1] != Hz.shape[1]:
msg = "Check matrices must have the same number of columns"
raise InvalidCSSCodeError(msg)
if np.any(Hx @ Hz.T % 2 != 0):
msg = "The check matrices must be orthogonal"
raise InvalidCSSCodeError(msg)

@staticmethod
def from_code_name(code_name: str, distance: int | None = None) -> CSSCode:
Expand Down
3 changes: 2 additions & 1 deletion src/mqt/qecc/codes/hexagonal_color_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@ class HexagonalColorCode(ColorCode):

def __init__(self, distance: int) -> None:
"""Hexagonal Color Code initialization from base class."""
ColorCode.__init__(self, distance=distance, lattice_type=LatticeType.HEXAGON)
super().__init__(distance=distance, lattice_type=LatticeType.HEXAGON)

def add_qubits(self) -> None:
"""Add qubits to the code."""
colour = ["r", "b", "g"]
y = 0

x_max = self.distance + self.distance // 2
while x_max > 0:
ancilla_colour = colour[y % 3]
Expand Down
2 changes: 1 addition & 1 deletion src/mqt/qecc/codes/square_octagon_color_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(self, distance: int) -> None:
# additionally to ancilla_qubits (on squares) we have the ones on octagons
self.octagon_ancilla_qubits: set[tuple[int, int]] = set()
self.square_ancilla_qubits: set[tuple[int, int]] = set()
ColorCode.__init__(self, distance=distance, lattice_type=LatticeType.SQUARE_OCTAGON)
super().__init__(distance=distance, lattice_type=LatticeType.SQUARE_OCTAGON)

def add_qubits(self) -> None:
"""Add qubits to the code."""
Expand Down
Loading
Loading