Skip to content

Commit

Permalink
Fixed linter errors.
Browse files Browse the repository at this point in the history
  • Loading branch information
pehamTom committed Jun 26, 2024
1 parent d768e3e commit d6d8f93
Show file tree
Hide file tree
Showing 5 changed files with 230 additions and 183 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ exclude = [

[[tool.mypy.overrides]]
module = ["qiskit.*", "qecsim.*","qiskit_aer.*", "matplotlib.*", "scipy.*", "ldpc.*", "pytest_console_scripts.*",
"z3.*", "bposd.*", "numba.*", "pymatching.*"]
"z3.*", "bposd.*", "numba.*", "pymatching.*", "stim.*", "multiprocess.*"]
ignore_missing_imports = true


Expand Down
24 changes: 9 additions & 15 deletions src/mqt/qecc/code.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,21 @@

from __future__ import annotations

import sys
from importlib.resources import files # noqa: TID251
from typing import TYPE_CHECKING

import numpy as np
from ldpc import mod2

try:
from importlib import resources as impresources
except ImportError:
# Try backported to PY<37 `importlib_resources`.
import importlib_resources as impresources # pylint: disable=no-redef

from . import sample_codes # relative-import the *package* containing the templates

if TYPE_CHECKING: # pragma: no cover
import numpy.typing as npt


class CSSCode:
"""A class for representing CSS codes."""

def __init__(self, distance: int, Hx: npt.NDArray[np.int_], Hz: npt.NDArray[np.int_]) -> None:
def __init__(self, distance: int, Hx: npt.NDArray[np.int_], Hz: npt.NDArray[np.int_]) -> None: # noqa: N803
"""Initialize the code."""
self.distance = distance

Expand All @@ -37,7 +31,7 @@ def __init__(self, distance: int, Hx: npt.NDArray[np.int_], Hz: npt.NDArray[np.i

def __hash__(self) -> int:
"""Compute a hash for the CSS code."""
return hash(self.Hx.tobytes() ^ self.Hz.tobytes())
return hash(int.from_bytes(self.Hx.tobytes(), sys.byteorder) ^ int.from_bytes(self.Hz.tobytes(), sys.byteorder))

def __eq__(self, other: object) -> bool:
"""Check if two CSS codes are equal."""
Expand Down Expand Up @@ -113,7 +107,7 @@ def from_code_name(code_name: str, distance: int = 3) -> CSSCode:
code_name: The name of the code.
distance: The distance of the code.
"""
prefix = impresources.files(sample_codes)
prefix = files("sample_codes")
paths = {
"steane": prefix / "steane/",
"tetrahedral": prefix / "tetrahedral/",
Expand All @@ -138,7 +132,7 @@ def from_code_name(code_name: str, distance: int = 3) -> CSSCode:

if code_name in distances:
distance = distances[code_name]
elif distance is None:
else:
msg = f"Distance is not specified for {code_name}"
raise ValueError(msg)
return CSSCode(distance, hx, hz)
Expand All @@ -149,7 +143,7 @@ def from_code_name(code_name: str, distance: int = 3) -> CSSCode:
class ClassicalCode:
"""A class for representing classical codes."""

def __init__(self, distance: int, H: npt.NDArray[np.int_]) -> None:
def __init__(self, distance: int, H: npt.NDArray[np.int_]) -> None: # noqa: N803
"""Initialize the code."""
self.distance = distance
self.H = H
Expand All @@ -162,6 +156,6 @@ class HyperGraphProductCode(CSSCode):

def __init__(self, c1: ClassicalCode, c2: ClassicalCode) -> None:
"""Initialize the code."""
Hx = np.hstack((np.kron(c1.H.T, np.eye(c2.H.shape[0])), np.kron(np.eye(c1.n), c2.H)))
Hz = np.hstack((np.kron(np.eye(c1.H.shape[0]), c2.H.T), np.kron(c1.H, np.eye(c2.n))))
Hx = np.hstack((np.kron(c1.H.T, np.eye(c2.H.shape[0])), np.kron(np.eye(c1.n), c2.H))) # noqa: N806
Hz = np.hstack((np.kron(np.eye(c1.H.shape[0]), c2.H.T), np.kron(c1.H, np.eye(c2.n)))) # noqa: N806
super().__init__(np.min(c1.distance, c2.distance), Hx, Hz)
5 changes: 3 additions & 2 deletions src/mqt/qecc/ft_stateprep/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from .simulation import LUTDecoder, NoisyNDFTStatePrepSimulator
from .simulation import LutDecoder, NoisyNDFTStatePrepSimulator
from .state_prep import (
StatePrepCircuit,
depth_optimal_prep_circuit,
Expand All @@ -15,7 +15,7 @@
)

__all__ = [
"LUTDecoder",
"LutDecoder",
"NoisyNDFTStatePrepSimulator",
"StatePrepCircuit",
"depth_optimal_prep_circuit",
Expand All @@ -25,4 +25,5 @@
"heuristic_prep_circuit",
"heuristic_verification_circuit",
"heuristic_verification_stabilizers",
"lutDecoder",
]
63 changes: 34 additions & 29 deletions src/mqt/qecc/ft_stateprep/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import logging
from collections import defaultdict
from typing import TYPE_CHECKING

Expand All @@ -26,7 +27,7 @@ def __init__(self, state_prep_circ: QuantumCircuit, code: CSSCode, p: float, zer
state_prep_circ: The state preparation circuit.
code: The code to simulate.
p: The error rate.
zero_state: Whether thezero state is prepared or nor.
zero_state: Whether thezero state is prepared or nor.
"""
self.circ = state_prep_circ
self.num_qubits = state_prep_circ.num_qubits
Expand All @@ -41,7 +42,7 @@ def __init__(self, state_prep_circ: QuantumCircuit, code: CSSCode, p: float, zer
self.data_measurements = [] # type: list[int]
self.n_measurements = 0
self.stim_circ = stim.Circuit()
self.decoder = LUTDecoder(code)
self.decoder = LutDecoder(code)
self.set_p(p)

def set_p(self, p: float) -> None:
Expand Down Expand Up @@ -192,18 +193,22 @@ def logical_error_rate(
num_logical_errors = 0

if self.zero_state:
self.decoder.generate_x_LUT()
self.decoder.generate_x_lut()
else:
self.decoder.generate_z_LUT()
self.decoder.generate_z_lut()

i = 1
while i <= int(np.ceil(shots / batch)) or at_least_min_errors:
num_logical_errors_batch, discarded_batch = self._simulate_batch(batch)

logging.log(
logging.INFO,
f"Batch {i}: {num_logical_errors_batch} logical errors and {discarded_batch} discarded shots. {batch - discarded_batch} shots used.",
)
if discarded_batch != batch:
p_l_batch = num_logical_errors_batch / (batch - discarded_batch)
p_l = ((i - 1) * p_l + p_l_batch) / i

r_a_batch = 1 - discarded_batch / batch

# Update statistics
Expand Down Expand Up @@ -247,22 +252,22 @@ def _simulate_batch(self, shots: int = 1024) -> tuple[int, int]:
return num_logical_errors, num_discarded


class LUTDecoder:
class LutDecoder:
"""Lookup table decoder for a CSSState."""

def __init__(self, code: CSSCode, init_LUTs: bool = True) -> None:
def __init__(self, code: CSSCode, init_luts: bool = True) -> None:
"""Initialize the decoder.
Args:
code: The code to decode.
init_LUTs: Whether to initialize the lookup tables at object creation.
init_luts: Whether to initialize the lookup tables at object creation.
"""
self.code = code
self.x_LUT = {} # type: dict[bytes, npt.NDArray[np.int8]]
self.z_LUT = {} # type: dict[bytes, npt.NDArray[np.int8]]
if init_LUTs:
self.generate_x_LUT()
self.generate_z_LUT()
self.x_lut = {} # type: dict[bytes, npt.NDArray[np.int8]]
self.z_lut = {} # type: dict[bytes, npt.NDArray[np.int8]]
if init_luts:
self.generate_x_lut()
self.generate_z_lut()

def batch_decode_x(self, syndromes: npt.NDArray[np.int8]) -> npt.NDArray[np.int8]:
"""Decode the X errors given a batch of syndromes."""
Expand All @@ -274,35 +279,35 @@ def batch_decode_z(self, syndromes: npt.NDArray[np.int8]) -> npt.NDArray[np.int8

def decode_x(self, syndrome: npt.NDArray[np.int8]) -> npt.NDArray[np.int8]:
"""Decode the X errors given a syndrome."""
if len(self.x_LUT) == 0:
self.generate_x_LUT()
return self.x_LUT[syndrome.tobytes()]
if len(self.x_lut) == 0:
self.generate_x_lut()
return self.x_lut[syndrome.tobytes()]

def decode_z(self, syndrome: npt.NDArray[np.int8]) -> npt.NDArray[np.int8]:
"""Decode the Z errors given a syndrome."""
if len(self.z_LUT) == 0:
self.generate_z_LUT()
return self.z_LUT[syndrome.tobytes()]
if len(self.z_lut) == 0:
self.generate_z_lut()
return self.z_lut[syndrome.tobytes()]

def generate_x_LUT(self) -> None:
def generate_x_lut(self) -> None:
"""Generate the lookup table for the X errors."""
if len(self.x_LUT) != 0:
if len(self.x_lut) != 0:
return

self.x_LUT = LUTDecoder._generate_LUT(self.code.Hz)
self.x_lut = LutDecoder._generate_lut(self.code.Hz)
if self.code.is_self_dual():
self.z_LUT = self.x_LUT
self.z_lut = self.x_lut

def generate_z_LUT(self) -> None:
def generate_z_lut(self) -> None:
"""Generate the lookup table for the Z errors."""
if len(self.z_LUT) != 0:
if len(self.z_lut) != 0:
return
self.z_LUT = LUTDecoder._generate_LUT(self.code.Hx)
self.z_lut = LutDecoder._generate_lut(self.code.Hx)
if self.code.is_self_dual():
self.z_LUT = self.x_LUT
self.z_lut = self.x_lut

@staticmethod
def _generate_LUT(checks: npt.NDArray[np.int_]) -> dict[bytes, npt.NDArray[np.int_]]:
def _generate_lut(checks: npt.NDArray[np.int_]) -> dict[bytes, npt.NDArray[np.int_]]:
"""Generate a lookup table for the stabilizer state.
The lookup table maps error syndromes to their errors.
Expand All @@ -317,7 +322,7 @@ def _generate_LUT(checks: npt.NDArray[np.int_]) -> dict[bytes, npt.NDArray[np.in
syndromes[syndrome.astype(np.int8).tobytes()].append(state)

# Sort according to weight
for key, v in lut.items():
for key, v in syndromes.items():
lut[key] = np.array(min(v, key=np.sum))

return lut
Expand Down
Loading

0 comments on commit d6d8f93

Please sign in to comment.