diff --git a/docs/source/StatePrep.ipynb b/docs/source/StatePrep.ipynb
new file mode 100644
index 00000000..85c8f41d
--- /dev/null
+++ b/docs/source/StatePrep.ipynb
@@ -0,0 +1,272 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "26db7b98-afc0-48d5-846b-e4fb1fe9e6f4",
+ "metadata": {},
+ "source": [
+ "# Fault tolerant state preparation of Pauli eigenstates for CSS codes\n",
+ "\n",
+ "The QECC package contains functionality for synthesizing and simulating fault tolerant and non-fault tolerant state preparation circuits for Pauli eigenstates of CSS codes.\n",
+ "Currently it supports synthesizing circuits for preparing the $|0\\rangle_L^k$ and $|+\\rangle_L^k$ states of arbitrary $[[n,k,d]]$ CSS codes.\n",
+ "\n",
+ "## Synthesizing non-FT state preparation circuits\n",
+ "\n",
+ "A non-fault tolerant preparation circuit can be generated directly from a CSS code. Let's consider the [Steane code](https://errorcorrectionzoo.org/c/steane) which is a $[[7, 1, 3]]$ color code."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "f556173a-1657-403f-8226-bdb565c9b8ee",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from mqt.qecc import CSSCode\n",
+ "\n",
+ "steane_code = CSSCode.from_code_name(\"Steane\")\n",
+ "print(steane_code.stabs_as_pauli_strings())"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "b8e3bf10-7253-40f6-9b5d-ad6b85df6876",
+ "metadata": {},
+ "source": [
+ "A state preparation circuit for the logical $|0\\rangle_L$ of this code is a circuit that generates a state that is stabilized by all of the above Pauli operators and the logical $Z_L$ operator of the Steane code. \n",
+ "\n",
+ "The code is small enough that we can generate a CNOT-optimal state preparation circuit for it:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "35cc5bff-74cf-4ebc-be24-bc614b08a15f",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from mqt.qecc.ft_stateprep import gate_optimal_prep_circuit\n",
+ "\n",
+ "non_ft_sp = gate_optimal_prep_circuit(steane_code, zero_state=True, max_timeout=2)\n",
+ "\n",
+ "non_ft_sp.circ.draw(output=\"mpl\", initial_state=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "a648f3c7-1002-423d-99b9-f4d15b0d4cc3",
+ "metadata": {},
+ "source": [
+ "We see that the minimal number of CNOTs required to prepare the logical $|0\\rangle_L$ circuit of the Steane code is $7$.\n",
+ "\n",
+ "## Synthesizing FT state preparation circuits\n",
+ "The circuit above is not fault-tolerant. For example, an $X$ error on qubit $q_1$ before the last CNOT propagates to a weight $2$ X error on $q_1$ and $q_2$. This is to be expected since we apply two-qubit gates between the qubits of a single logical qubit. \n",
+ "\n",
+ "A common method to turn a non-FT protocol into a fault tolerant one is through post-selection. We can try to detect whether an error was propagated through the circuit and restart the preparation in case of a detection event. A circuit performing such measurements is called a *verification circuit*. \n",
+ "\n",
+ "Verification circuits need to be carefully constructed such that only stabilizers of the code are measured and no more measurements are performed than necessary. Finding good verification circuits is NP-complete.\n",
+ "\n",
+ "QECC can automatically generate optimal verification circuits."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "edb629ed-50d2-4be0-9345-d938c316a9ac",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from mqt.qecc.ft_stateprep import gate_optimal_verification_circuit\n",
+ "\n",
+ "ft_sp = gate_optimal_verification_circuit(non_ft_sp)\n",
+ "\n",
+ "ft_sp.draw(output=\"mpl\", initial_state=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "9fd98702-1505-4f7e-a5d9-8510d36ec78b",
+ "metadata": {},
+ "source": [
+ "We have just automatically generated the known FT state preparation circuit for the Steane\n",
+ "code: [^1]. We see that if an X error happens on qubit $q_1$ before the last CNOT causes the verification circuit to measure a $-1$. \n",
+ "\n",
+ "## Simulating state preparation circuits\n",
+ "\n",
+ "If we want to see the probability of a logical error happening after post-selecting, QECC provides simulation utilities that can quickly generate results. We can simulate the non-FT and FT circuits and compare the results.\n",
+ "\n",
+ "[^1]: https://www.nature.com/articles/srep19578"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "9ccb7239-c669-432e-b195-b8f0887c83a3",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from mqt.qecc.ft_stateprep import NoisyNDFTStatePrepSimulator\n",
+ "\n",
+ "p = 0.05\n",
+ "\n",
+ "non_ft_simulator = NoisyNDFTStatePrepSimulator(non_ft_sp.circ, code=steane_code, zero_state=True, p=p)\n",
+ "ft_simulator = NoisyNDFTStatePrepSimulator(ft_sp, code=steane_code, zero_state=True, p=p)\n",
+ "\n",
+ "\n",
+ "pl_non_ft, ra_non_ft, _, _ = non_ft_simulator.logical_error_rate(min_errors=10)\n",
+ "pl_ft, ra_ft, _, _ = ft_simulator.logical_error_rate(min_errors=10)\n",
+ "\n",
+ "print(f\"Logical error rate for non-FT state preparation: {pl_non_ft}\")\n",
+ "print(f\"Logical error rate for FT state preparation: {pl_ft}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "56258f47-3a31-4b95-a374-9b4d0725551a",
+ "metadata": {},
+ "source": [
+ "The error rates seem quite close to each other. To properly judge the fault tolerance of the circuits we want to look at how the logical error rate scales with the physical error rate."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "3514dc4d-9346-4a60-b685-45241068585c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "ps = [0.1, 0.05, 0.01, 0.008, 0.006, 0.004, 0.002, 0.001]\n",
+ "\n",
+ "non_ft_simulator.plot_state_prep(ps, min_errors=50, name=\"non-FT\")\n",
+ "ft_simulator.plot_state_prep(ps, min_errors=50, name=\"FT\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d0e7d899-aa16-4252-8186-485465483df6",
+ "metadata": {},
+ "source": [
+ "Indeed we observe a quadratic scaling for the fault tolerant state preparation circuit while the logical error rate scales linearly for the non-fault tolerant state preparation.\n",
+ "\n",
+ "## Beyond distance 3\n",
+ "\n",
+ "Distance 3 circuits are particularly simple for fault tolerant state preparation because for the $|0\\rangle_L$ we can completely ignore Z errors. Due to error degeneracy any Z error is equivalent to a weight 1 or 0 error. \n",
+ "\n",
+ "Additionally one has to pay special attention to the order of measurements in the verification circuits when more than one independent error in the state preparation circuit is considered. \n",
+ "\n",
+ "Because both error types are considered, the verification circuit now measures both X- and Z-stabilizers. Unfortunately a Z error in an X measurement can propagate to the data qubits and vice versa for Z measurements. Therefore, if we check for Z errors after we have checked for X errors the measurements might introduce more X errors on the data qubits. We can check those again but that would just turn the situation around; now Z errors can propagate to the data qubits.\n",
+ "\n",
+ "Detecting such *hook errors* can be achieved via flag-fault tolerant stabilizer measurements [^2]. Usually, information from such hook errors is used to augment an error correction scheme. But we can also use these flags as additional measurements on which we post-select. If one of the flags triggers, this indicates that a hook error happened and we reset.\n",
+ "\n",
+ "By default QECC automatically performs such additional measurements when necessary. Let's consider a larger code to illustrate the point. The [square-octagon color code](https://errorcorrectionzoo.org/c/488_color) is defined on the following lattice:\n",
+ "\n",
+ "\n",
+ "\n",
+ "The distance 5 code uses 17 qubits from this lattice, i.e., we have a $[[17, 1, 5]]$ CSS code. Given the size of the code, synthesizing an optimal state preparation circuit might take a long time. QECC also has a fast heuristic state preparation circuit synthesis.\n",
+ "\n",
+ "[^2]: https://arxiv.org/abs/1708.02246"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "e77aa2f2-4920-4e1c-a84f-8511053a07b5",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from mqt.qecc.ft_stateprep import heuristic_prep_circuit\n",
+ "\n",
+ "cc = CSSCode.from_code_name(\"cc_4_8_8\")\n",
+ "cc_non_ft_sp = heuristic_prep_circuit(cc, zero_state=True, optimize_depth=True)\n",
+ "\n",
+ "cc_non_ft_sp.circ.draw(output=\"mpl\", initial_state=True, scale=0.7)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "6f29ff66-836e-4d11-ab8e-860325bdd7a9",
+ "metadata": {},
+ "source": [
+ "Even though optimal state preparation circuit synthesis seems out of range we can still synthesize good verification circuits in a short time if we give an initial guess on how many measurements we will need."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "6c43c729-d303-4d99-b519-aa87a9fb2342",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "cc_ft_sp = gate_optimal_verification_circuit(cc_non_ft_sp, max_timeout=2, max_ancillas=3)\n",
+ "\n",
+ "cc_ft_sp.draw(output=\"mpl\", initial_state=True, fold=-1, scale=0.2)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "2af33324-99db-48b0-aef8-4ef93546d3b1",
+ "metadata": {},
+ "source": [
+ "We see that the overhead for the verification overshadows the state preparation by a large margin. But this verification circuit is still much smaller than the naive variant of post-selecting on the stabilizer generators of the code."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "b16d84d5-3b82-4f0c-8aef-3a35f54d547d",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from mqt.qecc.ft_stateprep import naive_verification_circuit\n",
+ "\n",
+ "cc_ft_naive = naive_verification_circuit(cc_non_ft_sp)\n",
+ "\n",
+ "print(f\"CNOTs required for naive FT state preparation: {cc_ft_naive.num_nonlocal_gates()}\")\n",
+ "print(f\"CNOTs required for optimized FT state preparation: {cc_ft_sp.num_nonlocal_gates()}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "7be705a4-7732-43d5-b4e2-c13619e99703",
+ "metadata": {},
+ "source": [
+ "We expect that the distance 5 color code should be prepared with a better logical error rate than the Steane code. And this is indeed the case:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "c81d9f32-af86-4388-9915-6d5be01efce3",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "cc_simulator = NoisyNDFTStatePrepSimulator(cc_ft_sp, code=cc, zero_state=True, p=p)\n",
+ "\n",
+ "ps = [0.1, 0.05, 0.01, 0.008, 0.006, 0.004, 0.002, 0.001, 0.0008, 0.0006, 0.0004, 0.0003]\n",
+ "\n",
+ "ft_simulator.plot_state_prep(ps, min_errors=50, name=\"Distance 3\") # simulate Steane code as comparison\n",
+ "cc_simulator.plot_state_prep(ps, min_errors=50, name=\"Distance 5\")"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/docs/source/images/488_color_code.svg b/docs/source/images/488_color_code.svg
new file mode 100644
index 00000000..807bb288
--- /dev/null
+++ b/docs/source/images/488_color_code.svg
@@ -0,0 +1 @@
+
diff --git a/docs/source/index.rst b/docs/source/index.rst
index db3ceca5..63e080c1 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -27,6 +27,7 @@ please let us know at our :doc:`Support ` page or by reaching out to us
Installation
LightsOutDecoder
EccFramework
+ StatePrep
Publications
.. toctree::
diff --git a/docs/source/library/Library.rst b/docs/source/library/Library.rst
index dbc2ddbb..3c811190 100644
--- a/docs/source/library/Library.rst
+++ b/docs/source/library/Library.rst
@@ -15,3 +15,4 @@ Library
DecodingRunInformation
SamplePauliError
LightsOutDecoder
+ StatePrep
diff --git a/docs/source/library/StatePrep.rst b/docs/source/library/StatePrep.rst
new file mode 100644
index 00000000..c1c09e4a
--- /dev/null
+++ b/docs/source/library/StatePrep.rst
@@ -0,0 +1,29 @@
+Fault tolerant state preparation
+================================
+
+QECC provides functionality to synthesize and simulate state preparation circuits for logical basis states for arbitrary :math:`[[n, k, d]]` quantum CSS codes.
+
+ .. currentmodule:: mqt.qecc.ft_stateprep
+
+Non-fault tolerant state preparation circuits can be synthesized using :func:`depth_optimal_prep_circuit`, :func:`gate_optimal_prep_circuit` and :func:`heuristic_prep_circuit`.
+
+ .. autofunction:: depth_optimal_prep_circuit
+
+ .. autofunction:: gate_optimal_prep_circuit
+
+ .. autofunction:: heuristic_prep_circuit
+
+These methods return a :class:`StatePrepCircuit` from which the circuit can be obtained as a qiskit :code:`QuantumCircuit` object via the :code:`circ` member. The :class:`StatePrepCircuit` class contains methods for computing the state preparation circuit's fault set. :class:`StatePrepCircuit` are given as input to the verification circuit synthesis methods which add verification measurements to the circuit such that postselection on +1 measurement results of these circuit outputs a state with a logical error rate on the order of :math:`O(p^{\frac{d}{2}})`.
+
+Gate-optimal verification circuits can be generated using the :func:`gate_optimal_verification_circuit` method. This method uses the SMT solver `Z3 `_ for iteratively searching for better verification circuits. The search is guided by a :code:`min_timeout` and :code:`max_timeout` parameter. Initially, the search is only allowed to continue for the minimal amount of time. This time budget is successively increased until a solution is found. At this point the maximum number of CNOTs is reduced until the synthesis takes :code:`max_timeout` time or if the SAT solver manages to prove that no smaller circuit exists.
+
+ .. autofunction:: gate_optimal_verification_circuit
+
+If the optimal synthesis takes too long, the :func:`heuristic_verification_circuit` method can be used. This method reduces the synthesis of state preparation circuits to a set cover problem. Quality of solution can be traded with performance via the :code:`max_covering_sets` parameter. The smaller this parameter is set the lower the number of sets from which a covering is obtained.
+
+ .. autofunction:: heuristic_verification_circuit
+
+State preparation circuits can be simulated using the :class:`NoisyNDFTStatePrepSimulator` class.
+
+ .. autoclass:: NoisyNDFTStatePrepSimulator
+ :members:
diff --git a/pyproject.toml b/pyproject.toml
index 16aa4ab2..50012d92 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -37,13 +37,15 @@ classifiers = [
]
requires-python = ">=3.8"
dependencies = [
- "z3-solver>=4.11,<4.14",
+ "z3-solver>=4.12,<4.14",
"qecsim",
"ldpc>=0.1.51",
"numpy>=1.26,<2; python_version > '3.11'", # some of our dependencies are not yet compatible with numpy 2.0
"numpy>=1.24,<2; python_version <= '3.11'", # some of our dependencies are not yet compatible with numpy 2.0
"qiskit[qasm3-import]>=1.0.0",
"qiskit-aer>=0.14.0",
+ "stim >= 1.13.0",
+ "multiprocess >= 0.70.16",
"bposd>=1.6",
"numba>=0.59; python_version > '3.11'",
"numpy>=0.57; python_version <= '3.11'",
@@ -176,7 +178,7 @@ exclude = [
[[tool.mypy.overrides]]
module = ["qiskit.*", "qecsim.*","qiskit_aer.*", "matplotlib.*", "scipy.*", "ldpc.*", "pytest_console_scripts.*",
- "z3.*", "bposd.*", "numba.*", "pymatching.*"]
+ "z3.*", "bposd.*", "numba.*", "pymatching.*", "stim.*", "multiprocess.*"]
ignore_missing_imports = true
@@ -275,6 +277,7 @@ convention = "google"
wille = "wille"
ser = "ser"
aer = "aer"
+anc = "anc"
[tool.repo-review]
diff --git a/src/mqt/qecc/__init__.py b/src/mqt/qecc/__init__.py
index 0f55aaf5..bda36cf3 100644
--- a/src/mqt/qecc/__init__.py
+++ b/src/mqt/qecc/__init__.py
@@ -9,6 +9,7 @@
from ._version import version as __version__
from .analog_information_decoding.simulators.analog_tannergraph_decoding import AnalogTannergraphDecoder, AtdSimulator
from .analog_information_decoding.simulators.quasi_single_shot_v2 import QssSimulator
+from .code import CSSCode, InvalidCSSCodeError
from .pyqecc import (
Code,
Decoder,
@@ -25,12 +26,14 @@
__all__ = [
"AnalogTannergraphDecoder",
"AtdSimulator",
+ "CSSCode",
"Code",
"Decoder",
"DecodingResult",
"DecodingResultStatus",
"DecodingRunInformation",
"GrowthVariant",
+ "InvalidCSSCodeError",
# "SoftInfoDecoder",
"QssSimulator",
"UFDecoder",
diff --git a/src/mqt/qecc/cc_decoder/color_code.py b/src/mqt/qecc/cc_decoder/color_code.py
index 722822a5..c7b81231 100644
--- a/src/mqt/qecc/cc_decoder/color_code.py
+++ b/src/mqt/qecc/cc_decoder/color_code.py
@@ -6,7 +6,8 @@
from typing import TYPE_CHECKING
import numpy as np
-from ldpc import mod2
+
+from ..code import CSSCode
if TYPE_CHECKING: # pragma: no cover
import numpy.typing as npt
@@ -19,7 +20,7 @@ class LatticeType(str, Enum):
SQUARE_OCTAGON = "square_octagon"
-class ColorCode:
+class ColorCode(CSSCode):
"""A base class for color codes on a three-valent, three-colourable lattice."""
def __init__(self, distance: int, lattice_type: LatticeType) -> None:
@@ -33,9 +34,8 @@ def __init__(self, distance: int, lattice_type: LatticeType) -> None:
self.add_qubits()
self.H: npt.NDArray[np.int_] = np.zeros((len(self.ancilla_qubits), len(self.data_qubits)), dtype=int)
self.construct_layout()
- self.compute_logical()
- self.n = len(self.qubits_to_faces)
- self.k = self.L.shape[1]
+ CSSCode.__init__(self, distance, self.H, self.H)
+ self.L = self.Lz
def __hash__(self) -> int:
"""Compute a hash for the color code."""
@@ -54,13 +54,8 @@ def construct_layout(self) -> None:
"""Construct the adjacency lists of the code from the qubits lists. Assumes add_qubits was called."""
def compute_logical(self) -> None:
- """Compute the logical matrix L."""
- ker_hx = mod2.nullspace(self.H) # compute the kernel basis of hx
- im_hz_transp = mod2.row_basis(self.H) # compute the image basis of hz.T
- log_stack = np.vstack([im_hz_transp, ker_hx])
- pivots = mod2.row_echelon(log_stack.T)[3]
- log_op_indices = [i for i in range(im_hz_transp.shape[0], log_stack.shape[0]) if i in pivots]
- self.L = log_stack[log_op_indices]
+ """Compute the logical operators of the code."""
+ self.L = self._compute_logical(self.H, self.H)
def get_syndrome(self, error: npt.NDArray[np.int_]) -> npt.NDArray[np.int_]:
"""Compute the syndrome of the error."""
diff --git a/src/mqt/qecc/code.py b/src/mqt/qecc/code.py
new file mode 100644
index 00000000..a201b9fb
--- /dev/null
+++ b/src/mqt/qecc/code.py
@@ -0,0 +1,234 @@
+"""Class for representing quantum error correction codes."""
+
+from __future__ import annotations
+
+import sys
+from pathlib import Path
+from typing import TYPE_CHECKING
+
+import numpy as np
+from ldpc import mod2
+
+if TYPE_CHECKING: # pragma: no cover
+ import numpy.typing as npt
+
+
+class CSSCode:
+ """A class for representing CSS codes."""
+
+ def __init__(
+ self,
+ distance: int,
+ Hx: npt.NDArray[np.int8] | None = None, # noqa: N803
+ Hz: npt.NDArray[np.int8] | None = None, # noqa: N803
+ x_distance: int | None = None,
+ z_distance: int | None = None,
+ ) -> None:
+ """Initialize the code."""
+ self.distance = distance
+ self.x_distance = x_distance if x_distance is not None else distance
+ self.z_distance = z_distance if z_distance is not None else distance
+
+ if self.distance < 0:
+ msg = "The distance must be a non-negative integer"
+ raise InvalidCSSCodeError(msg)
+ if Hx is None and Hz is None:
+ msg = "At least one of the check matrices must be provided"
+ raise InvalidCSSCodeError(msg)
+ if self.x_distance < self.distance or self.z_distance < self.distance:
+ msg = "The x and z distances must be greater than or equal to the distance"
+ raise InvalidCSSCodeError(msg)
+ if Hx is not None and Hz is not None:
+ if Hx.shape[1] != Hz.shape[1]:
+ msg = "Check matrices must have the same number of columns"
+ raise InvalidCSSCodeError(msg)
+ if np.any(Hx @ Hz.T % 2 != 0):
+ msg = "The check matrices must be orthogonal"
+ raise InvalidCSSCodeError(msg)
+
+ self.Hx = Hx
+ self.Hz = Hz
+ self.n = Hx.shape[1] if Hx is not None else Hz.shape[1] # type: ignore[union-attr]
+ self.k = self.n - (Hx.shape[0] if Hx is not None else 0) - (Hz.shape[0] if Hz is not None else 0)
+ self.Lx = CSSCode._compute_logical(self.Hz, self.Hx)
+ self.Lz = CSSCode._compute_logical(self.Hx, self.Hz)
+
+ def __hash__(self) -> int:
+ """Compute a hash for the CSS code."""
+ x_hash = int.from_bytes(self.Hx.tobytes(), sys.byteorder) if self.Hx is not None else 0
+ z_hash = int.from_bytes(self.Hz.tobytes(), sys.byteorder) if self.Hz is not None else 0
+ return hash(x_hash ^ z_hash)
+
+ def __eq__(self, other: object) -> bool:
+ """Check if two CSS codes are equal."""
+ if not isinstance(other, CSSCode):
+ return NotImplemented
+ if self.Hx is None and other.Hx is None:
+ assert self.Hz is not None
+ assert other.Hz is not None
+ return np.array_equal(self.Hz, other.Hz)
+ if self.Hz is None and other.Hz is None:
+ assert self.Hx is not None
+ assert other.Hx is not None
+ return np.array_equal(self.Hx, other.Hx)
+ if (self.Hx is None and other.Hx is not None) or (self.Hx is not None and other.Hx is None):
+ return False
+ if (self.Hz is None and other.Hz is not None) or (self.Hz is not None and other.Hz is None):
+ return False
+ assert self.Hx is not None
+ assert other.Hx is not None
+ assert self.Hz is not None
+ assert other.Hz is not None
+ return bool(
+ mod2.rank(self.Hx) == mod2.rank(np.vstack([self.Hx, other.Hx]))
+ and mod2.rank(self.Hz) == mod2.rank(np.vstack([self.Hz, other.Hz]))
+ )
+
+ @staticmethod
+ def _compute_logical(m1: npt.NDArray[np.int8] | None, m2: npt.NDArray[np.int8] | None) -> npt.NDArray[np.int8]:
+ """Compute the logical matrix L."""
+ if m1 is None:
+ ker_m2 = mod2.nullspace(m2) # compute the kernel basis of m2
+ pivots = mod2.row_echelon(ker_m2)[-1]
+ logs = np.zeros_like(ker_m2, dtype=np.int8) # type: npt.NDArray[np.int8]
+ for i, pivot in enumerate(pivots):
+ logs[i, pivot] = 1
+ return logs
+
+ if m2 is None:
+ return mod2.nullspace(m1).astype(np.int8) # type: ignore[no-any-return]
+
+ ker_m1 = mod2.nullspace(m1) # compute the kernel basis of m1
+ im_m2_transp = mod2.row_basis(m2) # compute the image basis of m2
+ log_stack = np.vstack([im_m2_transp, ker_m1])
+ pivots = mod2.row_echelon(log_stack.T)[3]
+ log_op_indices = [i for i in range(im_m2_transp.shape[0], log_stack.shape[0]) if i in pivots]
+ return log_stack[log_op_indices]
+
+ def get_x_syndrome(self, error: npt.NDArray[np.int8]) -> npt.NDArray[np.int8]:
+ """Compute the x syndrome of the error."""
+ if self.Hx is None:
+ return np.empty((0, error.shape[0]), dtype=np.int8)
+ return self.Hx @ error % 2
+
+ def get_z_syndrome(self, error: npt.NDArray[np.int8]) -> npt.NDArray[np.int8]:
+ """Compute the z syndrome of the error."""
+ if self.Hz is None:
+ return np.empty((0, error.shape[0]), dtype=np.int8)
+ return self.Hz @ error % 2
+
+ def check_if_logical_x_error(self, residual: npt.NDArray[np.int8]) -> bool:
+ """Check if the residual is a logical error."""
+ return bool((self.Lz @ residual % 2 == 1).any())
+
+ def check_if_logical_z_error(self, residual: npt.NDArray[np.int8]) -> bool:
+ """Check if the residual is a logical error."""
+ return bool((self.Lx @ residual % 2 == 1).any())
+
+ def stabilizer_eq_x_error(self, error_1: npt.NDArray[np.int8], error_2: npt.NDArray[np.int8]) -> bool:
+ """Check if two X errors are in the same coset."""
+ if self.Hx is None:
+ return bool(np.array_equal(error_1, error_2))
+ m1 = np.vstack([self.Hx, error_1])
+ m2 = np.vstack([self.Hx, error_2])
+ m3 = np.vstack([self.Hx, error_1, error_2])
+ return bool(mod2.rank(m1) == mod2.rank(m2) == mod2.rank(m3))
+
+ def stabilizer_eq_z_error(self, error_1: npt.NDArray[np.int8], error_2: npt.NDArray[np.int8]) -> bool:
+ """Check if two Z errors are in the same coset."""
+ if self.Hz is None:
+ return bool(np.array_equal(error_1, error_2))
+ m1 = np.vstack([self.Hz, error_1])
+ m2 = np.vstack([self.Hz, error_2])
+ m3 = np.vstack([self.Hz, error_1, error_2])
+ return bool(mod2.rank(m1) == mod2.rank(m2) == mod2.rank(m3))
+
+ def is_self_dual(self) -> bool:
+ """Check if the code is self-dual."""
+ if self.Hx is None or self.Hz is None:
+ return False
+ return bool(
+ self.Hx.shape[0] == self.Hz.shape[0] and mod2.rank(self.Hx) == mod2.rank(np.vstack([self.Hx, self.Hz]))
+ )
+
+ def stabs_as_pauli_strings(self) -> tuple[list[str] | None, list[str] | None]:
+ """Return the stabilizers as Pauli strings."""
+ x_str = None if self.Hx is None else ["".join(["I" if x == 0 else "X" for x in row]) for row in self.Hx]
+ z_str = None if self.Hz is None else ["".join(["I" if z == 0 else "Z" for z in row]) for row in self.Hz]
+ return x_str, z_str
+
+ def z_logicals_as_pauli_string(self) -> str:
+ """Return the logical Z operator as a Pauli string."""
+ return "".join(["I" if z == 0 else "Z" for z in self.Lx[0]])
+
+ def x_logicals_as_pauli_string(self) -> str:
+ """Return the logical X operator as a Pauli string."""
+ return "".join(["I" if x == 0 else "X" for x in self.Lz[0]])
+
+ @staticmethod
+ def from_code_name(code_name: str, distance: int | None = None) -> CSSCode:
+ r"""Return CSSCode object for a known code.
+
+ The following codes are supported:
+ - [[7, 1, 3]] Steane (\"Steane\")
+ - [[15, 1, 3]] tetrahedral code (\"Tetrahedral\")
+ - [[15, 7, 3]] Hamming code (\"Hamming\")
+ - [[9, 1, 3]] Shore code (\"Shor\")
+ - [[9, 1, 3]] rotated surface code (\"Surface, 3\"), also default when no distance is given
+ - [[25, 1, 5]] rotated surface code (\"Surface, 5\")
+ - [[17, 1, 5]] 4,8,8 color code (\"CC_4_8_8\")
+ - [[23, 1, 7]] golay code (\"Golay\")
+ - 6,6,6 color code for arbitrary distances (\"CC_6_6_6, d\")
+
+
+ Args:
+ code_name: The name of the code.
+ distance: The distance of the code.
+ """
+ prefix = (Path(__file__) / "../sample_codes/").resolve()
+ paths = {
+ "steane": prefix / "steane/",
+ "tetrahedral": prefix / "tetrahedral/",
+ "hamming": prefix / "hamming/",
+ "shor": prefix / "shor/",
+ "surface_3": prefix / "rotated_surface_d3/",
+ "surface_5": prefix / "rotated_surface_d5/",
+ "cc_4_8_8": prefix / "cc_4_8_8_d5/",
+ "golay": prefix / "golay/",
+ }
+
+ distances = {
+ "steane": (3, 3),
+ "tetrahedral": (7, 3),
+ "hamming": (3, 3),
+ "shor": (3, 3),
+ "cc_4_8_8": (5, 5),
+ "golay": (7, 7),
+ "surface_3": (3, 3),
+ "surface_5": (5, 5),
+ } # X, Z distances
+
+ code_name = code_name.lower()
+ if code_name == "surface":
+ if distance is None:
+ distance = 3
+ code_name += "_%d" % distance
+
+ if code_name in paths:
+ hx = np.load(paths[code_name] / "hx.npy")
+ hz = np.load(paths[code_name] / "hz.npy")
+
+ if code_name in distances:
+ x_distance, z_distance = distances[code_name]
+ distance = min(x_distance, z_distance)
+ return CSSCode(distance, hx, hz, x_distance=x_distance, z_distance=z_distance)
+
+ if distance is None:
+ msg = f"Distance is not specified for {code_name}"
+ raise InvalidCSSCodeError(msg)
+ msg = f"Unknown code name: {code_name}"
+ raise InvalidCSSCodeError(msg)
+
+
+class InvalidCSSCodeError(ValueError):
+ """Raised when the CSS code is invalid."""
diff --git a/src/mqt/qecc/ft_stateprep/__init__.py b/src/mqt/qecc/ft_stateprep/__init__.py
new file mode 100644
index 00000000..3e597881
--- /dev/null
+++ b/src/mqt/qecc/ft_stateprep/__init__.py
@@ -0,0 +1,30 @@
+"""Methods for synthesizing fault tolerant state preparation circuits."""
+
+from __future__ import annotations
+
+from .simulation import LutDecoder, NoisyNDFTStatePrepSimulator
+from .state_prep import (
+ StatePrepCircuit,
+ depth_optimal_prep_circuit,
+ gate_optimal_prep_circuit,
+ gate_optimal_verification_circuit,
+ gate_optimal_verification_stabilizers,
+ heuristic_prep_circuit,
+ heuristic_verification_circuit,
+ heuristic_verification_stabilizers,
+ naive_verification_circuit,
+)
+
+__all__ = [
+ "LutDecoder",
+ "NoisyNDFTStatePrepSimulator",
+ "StatePrepCircuit",
+ "depth_optimal_prep_circuit",
+ "gate_optimal_prep_circuit",
+ "gate_optimal_verification_circuit",
+ "gate_optimal_verification_stabilizers",
+ "heuristic_prep_circuit",
+ "heuristic_verification_circuit",
+ "heuristic_verification_stabilizers",
+ "naive_verification_circuit",
+]
diff --git a/src/mqt/qecc/ft_stateprep/simulation.py b/src/mqt/qecc/ft_stateprep/simulation.py
new file mode 100644
index 00000000..3add1953
--- /dev/null
+++ b/src/mqt/qecc/ft_stateprep/simulation.py
@@ -0,0 +1,403 @@
+"""Simulation of Non-deterministic fault tolerant state preparation."""
+
+from __future__ import annotations
+
+import logging
+from collections import defaultdict
+from typing import TYPE_CHECKING
+
+import matplotlib.pyplot as plt
+import numpy as np
+import stim
+from qiskit.converters import circuit_to_dag, dag_to_circuit
+
+from ..code import InvalidCSSCodeError
+
+if TYPE_CHECKING: # pragma: no cover
+ import numpy.typing as npt
+ from qiskit import QuantumCircuit
+
+ from ..code import CSSCode
+
+
+class NoisyNDFTStatePrepSimulator:
+ """Class for simulating noisy state preparation circuit using a depolarizing noise model."""
+
+ def __init__(self, state_prep_circ: QuantumCircuit, code: CSSCode, p: float, zero_state: bool = True) -> None:
+ """Initialize the simulator.
+
+ Args:
+ state_prep_circ: The state preparation circuit.
+ code: The code to simulate.
+ p: The error rate.
+ zero_state: Whether thezero state is prepared or nor.
+ """
+ if code.Hx is None or code.Hz is None:
+ msg = "The code must have both X and Z checks."
+ raise InvalidCSSCodeError(msg)
+
+ self.circ = state_prep_circ
+ self.num_qubits = state_prep_circ.num_qubits
+ self.code = code
+ self.p = p
+ self.zero_state = zero_state
+ # Store which measurements are X, Z or data measurements.
+ # The indices refer to the indices of the measurements in the stim circuit.
+ self.x_verification_measurements = [] # type: list[int]
+ self.z_verification_measurements = [] # type: list[int]
+ self.x_measurements = [] # type: list[int]
+ self.z_measurements = [] # type: list[int]
+ self.data_measurements = [] # type: list[int]
+ self.n_measurements = 0
+ self.stim_circ = stim.Circuit()
+ self.decoder = LutDecoder(code)
+ self.set_p(p)
+
+ def set_p(self, p: float) -> None:
+ """Set the error rate.
+
+ This reinitializes the stim circuit.
+
+ Args:
+ p: The error rate.
+ """
+ self.x_verification_measurements = []
+ self.z_verification_measurements = []
+ self.x_measurements = []
+ self.z_measurements = []
+ self.data_measurements = []
+ self.n_measurements = 0
+ self.p = p
+ self._reused_qubits = 0
+ self.stim_circ = self.to_stim_circ()
+ self.num_qubits = (
+ self.stim_circ.num_qubits
+ - (len(self.x_verification_measurements) + len(self.z_verification_measurements))
+ + self._reused_qubits
+ )
+ self.measure_stabilizers()
+ if self.zero_state:
+ self.measure_z()
+ else:
+ self.measure_x()
+
+ def to_stim_circ(self) -> stim.Circuit:
+ """Convert a QuantumCircuit to a noisy STIM circuit.
+
+ A depolarizing error model is used:
+ - Single-qubit gates and idling qubits are followed by a single-qubit Pauli error with probability 2/9 p. This reflects the fact that two-qubit gates are more likely to fail.
+ - Two-qubit gates are followed by a two-qubit Pauli error with probability p/15.
+ - Measurements flip with a probability of 2/3 p.
+ - Qubit are initialized in the -1 Eigenstate with probability 2/3 p.
+
+ Args:
+ circ: The QuantumCircuit to convert.
+ p: The error rate.
+ """
+ initialized = [False for _ in self.circ.qubits]
+ stim_circuit = stim.Circuit()
+ ctrls = []
+
+ def idle_error(used_qubits: list[int]) -> None:
+ for q in self.circ.qubits:
+ qubit = self.circ.find_bit(q)[0]
+ if initialized[qubit] and qubit not in used_qubits:
+ stim_circuit.append_operation("DEPOLARIZE1", [self.circ.find_bit(q)[0]], [2 * self.p / 3])
+
+ dag = circuit_to_dag(self.circ)
+ layers = dag.layers()
+ used_qubits = [] # type: list[int]
+
+ targets = set()
+ measured = defaultdict(int) # type: defaultdict[int, int]
+ for layer in layers:
+ layer_circ = dag_to_circuit(layer["graph"])
+
+ # Apply idling errors to all qubits that were unused in the previous layer
+ if len(used_qubits) > 0:
+ idle_error(used_qubits)
+
+ used_qubits = []
+ for gate in layer_circ.data:
+ if gate[0].name == "h":
+ qubit = self.circ.find_bit(gate[1][0])[0]
+ ctrls.append(qubit)
+ if initialized[qubit]:
+ stim_circuit.append_operation("H", [qubit])
+ stim_circuit.append_operation("DEPOLARIZE1", [qubit], [2 * self.p / 3])
+ used_qubits.append(qubit)
+
+ elif gate[0].name == "cx":
+ ctrl = self.circ.find_bit(gate[1][0])[0]
+ target = self.circ.find_bit(gate[1][1])[0]
+ targets.add(target)
+ if not initialized[ctrl]:
+ if ctrl in ctrls:
+ stim_circuit.append_operation("H", [ctrl])
+ stim_circuit.append_operation("Z_ERROR", [ctrl], [2 * self.p / 3]) # Wrong initialization
+ else:
+ stim_circuit.append_operation("X_ERROR", [ctrl], [2 * self.p / 3]) # Wrong initialization
+ initialized[ctrl] = True
+ if not initialized[target]:
+ stim_circuit.append_operation("X_ERROR", [target], [2 * self.p / 3]) # Wrong initialization
+ initialized[target] = True
+
+ stim_circuit.append_operation("CX", [ctrl, target])
+ stim_circuit.append_operation("DEPOLARIZE2", [ctrl, target], [self.p])
+ used_qubits.extend([ctrl, target])
+
+ elif gate[0].name == "measure":
+ anc = self.circ.find_bit(gate[1][0])[0]
+ stim_circuit.append_operation("X_ERROR", [anc], [2 * self.p / 3])
+ stim_circuit.append_operation("MR", [anc])
+ if anc in targets:
+ self.z_verification_measurements.append(self.n_measurements)
+ else:
+ self.x_verification_measurements.append(self.n_measurements)
+ self.n_measurements += 1
+ used_qubits.extend([anc])
+ initialized[anc] = False
+ measured[anc] += 1
+ if measured[anc] == 2:
+ self._reused_qubits += 1
+
+ return stim_circuit
+
+ def measure_stabilizers(self) -> stim.Circuit:
+ """Measure the stabilizers of the code.
+
+ An ancilla is used for each measurement.
+ """
+ assert self.code.Hx is not None
+ assert self.code.Hz is not None
+
+ for check in self.code.Hx:
+ supp = _support(check)
+ anc = self.stim_circ.num_qubits
+ self.stim_circ.append_operation("H", [anc])
+ for q in supp:
+ self.stim_circ.append_operation("CX", [anc, q])
+ self.stim_circ.append_operation("MRX", [anc])
+ self.x_measurements.append(self.n_measurements)
+ self.n_measurements += 1
+
+ for check in self.code.Hz:
+ supp = _support(check)
+ anc = self.stim_circ.num_qubits
+ for q in supp:
+ self.stim_circ.append_operation("CX", [q, anc])
+ self.stim_circ.append_operation("MRZ", [anc])
+ self.z_measurements.append(self.n_measurements)
+ self.n_measurements += 1
+
+ def measure_z(self) -> None:
+ """Measure all data qubits in the Z basis."""
+ self.data_measurements = [self.n_measurements + i for i in range(self.num_qubits)]
+ self.n_measurements += self.num_qubits
+ self.stim_circ.append_operation("MRZ", list(range(self.num_qubits)))
+
+ def measure_x(self) -> None:
+ """Measure all data qubits in the X basis."""
+ self.data_measurements = [self.n_measurements + i for i in range(self.num_qubits)]
+ self.n_measurements += self.num_qubits
+ self.stim_circ.append_operation("MRX", list(range(self.num_qubits)))
+
+ def logical_error_rate(
+ self,
+ shots: int = 100000,
+ shots_per_batch: int = 100000,
+ at_least_min_errors: bool = True,
+ min_errors: int = 500,
+ ) -> tuple[float, float, int, int]:
+ """Estimate the logical error rate of the code.
+
+ Args:
+ shots: The number of shots to use.
+ shots_per_batch: The number of shots per batch.
+ at_least_min_errors: Whether to continue simulating until at least min_errors are found.
+ min_errors: The minimum number of errors to find before stopping.
+ """
+ batch = min(shots_per_batch, shots)
+ p_l = 0.0
+ r_a = 0.0
+
+ num_logical_errors = 0
+
+ if self.zero_state:
+ self.decoder.generate_x_lut()
+ else:
+ self.decoder.generate_z_lut()
+
+ i = 1
+ while i <= int(np.ceil(shots / batch)) or at_least_min_errors:
+ num_logical_errors_batch, discarded_batch = self._simulate_batch(batch)
+
+ logging.log(
+ logging.INFO,
+ f"Batch {i}: {num_logical_errors_batch} logical errors and {discarded_batch} discarded shots. {batch - discarded_batch} shots used.",
+ )
+ p_l_batch = num_logical_errors_batch / (batch - discarded_batch) if discarded_batch != batch else 0.0
+ p_l = ((i - 1) * p_l + p_l_batch) / i
+
+ r_a_batch = 1 - discarded_batch / batch
+
+ # Update statistics
+ num_logical_errors += num_logical_errors_batch
+ r_a = ((i - 1) * r_a + r_a_batch) / i
+
+ if at_least_min_errors and num_logical_errors >= min_errors:
+ break
+ i += 1
+
+ return p_l, r_a, num_logical_errors, i * batch
+
+ def _simulate_batch(self, shots: int = 1024) -> tuple[int, int]:
+ sampler = self.stim_circ.compile_sampler()
+ detection_events = sampler.sample(shots)
+
+ # Filter events where the verification circuit flagged
+ verification_measurements = self.x_verification_measurements + self.z_verification_measurements
+ index_array = np.where(np.all(detection_events[:, verification_measurements] == 0, axis=1))[0]
+ filtered_events = detection_events[index_array].astype(np.int8)
+
+ if len(filtered_events) == 0: # All events were discarded
+ return 0, shots
+
+ state = filtered_events[:, self.data_measurements]
+
+ if self.zero_state:
+ checks = filtered_events[:, self.z_measurements]
+ observables = self.code.Lz
+ estimates = self.decoder.batch_decode_x(checks)
+ else:
+ checks = filtered_events[:, self.x_measurements]
+ observables = self.code.Lx
+ estimates = self.decoder.batch_decode_z(checks)
+
+ corrected = state + estimates
+
+ num_discarded = detection_events.shape[0] - filtered_events.shape[0]
+ num_logical_errors = np.sum(
+ np.any(corrected @ observables.T % 2 != 0, axis=1)
+ ) # number of non-commuting corrected states
+ return num_logical_errors, num_discarded
+
+ def plot_state_prep(self, ps: list[float], min_errors: int = 500, name: str | None = None) -> None:
+ """Plot the logical error rate and accaptence rate as a function of the physical error rate.
+
+ Args:
+ ps: The physical error rates to plot.
+ min_errors: The minimum number of errors to find before stopping.
+ name: The name of the plot.
+ """
+ p_ls = []
+ r_as = []
+ for p in ps:
+ self.set_p(p)
+ p_l, r_a, _num_logical_errors, _num_shots = self.logical_error_rate(min_errors=min_errors)
+ p_ls.append(p_l)
+ r_as.append(r_a)
+
+ plt.subplot(1, 2, 1)
+ plt.plot(ps, p_ls, marker="o", label=name)
+ plt.xscale("log")
+ plt.yscale("log")
+ plt.xlabel("Physical error rate")
+ plt.ylabel("Logical error rate")
+
+ if name is not None:
+ plt.legend()
+
+ plt.subplot(1, 2, 2)
+ plt.plot(ps, r_as, marker="o", label=name)
+ plt.xscale("log")
+ plt.yscale("log")
+ plt.xlabel("Physical error rate")
+ plt.ylabel("Acceptance rate")
+ if name is not None:
+ plt.legend()
+ plt.tight_layout()
+
+
+class LutDecoder:
+ """Lookup table decoder for a CSSCode."""
+
+ def __init__(self, code: CSSCode, init_luts: bool = True) -> None:
+ """Initialize the decoder.
+
+ Args:
+ code: The code to decode.
+ init_luts: Whether to initialize the lookup tables at object creation.
+ """
+ self.code = code
+ self.x_lut = {} # type: dict[bytes, npt.NDArray[np.int8]]
+ self.z_lut = {} # type: dict[bytes, npt.NDArray[np.int8]]
+ if init_luts:
+ self.generate_x_lut()
+ self.generate_z_lut()
+
+ def batch_decode_x(self, syndromes: npt.NDArray[np.int_]) -> npt.NDArray[np.int8]:
+ """Decode the X errors given a batch of syndromes."""
+ return np.apply_along_axis(self.decode_x, 1, syndromes)
+
+ def batch_decode_z(self, syndromes: npt.NDArray[np.int_]) -> npt.NDArray[np.int8]:
+ """Decode the Z errors given a batch of syndromes."""
+ return np.apply_along_axis(self.decode_z, 1, syndromes)
+
+ def decode_x(self, syndrome: npt.NDArray[np.int8]) -> npt.NDArray[np.int8]:
+ """Decode the X errors given a syndrome."""
+ if len(self.x_lut) == 0:
+ self.generate_x_lut()
+ return self.x_lut[syndrome.tobytes()]
+
+ def decode_z(self, syndrome: npt.NDArray[np.int8]) -> npt.NDArray[np.int8]:
+ """Decode the Z errors given a syndrome."""
+ if len(self.z_lut) == 0:
+ self.generate_z_lut()
+ return self.z_lut[syndrome.tobytes()]
+
+ def generate_x_lut(self) -> None:
+ """Generate the lookup table for the X errors."""
+ if len(self.x_lut) != 0:
+ return
+
+ assert self.code.Hz is not None, "The code does not have a Z stabilizer matrix."
+ self.x_lut = LutDecoder._generate_lut(self.code.Hz)
+ if self.code.is_self_dual():
+ self.z_lut = self.x_lut
+
+ def generate_z_lut(self) -> None:
+ """Generate the lookup table for the Z errors."""
+ if len(self.z_lut) != 0:
+ return
+
+ assert self.code.Hx is not None, "The code does not have an X stabilizer matrix."
+ self.z_lut = LutDecoder._generate_lut(self.code.Hx)
+ if self.code.is_self_dual():
+ self.z_lut = self.x_lut
+
+ @staticmethod
+ def _generate_lut(checks: npt.NDArray[np.int_]) -> dict[bytes, npt.NDArray[np.int_]]:
+ """Generate a lookup table for the stabilizer state.
+
+ The lookup table maps error syndromes to their errors.
+ """
+ n_qubits = checks.shape[1]
+
+ syndromes = defaultdict(list)
+ lut = {} # type: dict[bytes, npt.NDArray[np.int8]]
+ for i in range(2**n_qubits):
+ state = np.array(list(np.binary_repr(i).zfill(n_qubits))).astype(np.int8) # type: npt.NDArray[np.int_]
+ syndrome = checks @ state % 2
+ syndromes[syndrome.astype(np.int8).tobytes()].append(state)
+
+ # Sort according to weight
+ for key, v in syndromes.items():
+ lut[key] = np.array(min(v, key=np.sum))
+
+ return lut
+
+
+def _support(v: npt.NDArray[np.int_]) -> npt.NDArray[np.int_]:
+ """Return the support of a vector."""
+ return np.where(v != 0)[0]
diff --git a/src/mqt/qecc/ft_stateprep/state_prep.py b/src/mqt/qecc/ft_stateprep/state_prep.py
new file mode 100644
index 00000000..4937f66d
--- /dev/null
+++ b/src/mqt/qecc/ft_stateprep/state_prep.py
@@ -0,0 +1,1523 @@
+"""Synthesizing state preparation circuits for CSS codes."""
+
+from __future__ import annotations
+
+import logging
+from collections import defaultdict, deque
+from typing import TYPE_CHECKING, Any
+
+import multiprocess
+import numpy as np
+import z3
+from ldpc import mod2
+from qiskit import AncillaRegister, ClassicalRegister, QuantumCircuit, QuantumRegister
+from qiskit.converters import circuit_to_dag
+from qiskit.dagcircuit import DAGOutNode
+
+from ..code import InvalidCSSCodeError
+
+logger = logging.getLogger(__name__)
+
+
+if TYPE_CHECKING: # pragma: no cover
+ from collections.abc import Callable
+
+ import numpy.typing as npt
+ from qiskit import AncillaQubit, ClBit, DagCircuit, DAGNode, Qubit
+ from qiskit.quantum_info import PauliList
+
+ from ..code import CSSCode
+
+
+class StatePrepCircuit:
+ """Represents a state preparation circuit for a CSS code."""
+
+ def __init__(self, circ: QuantumCircuit, code: CSSCode, zero_state: bool = True) -> None:
+ """Initialize a state preparation circuit.
+
+ Args:
+ circ: The state preparation circuit.
+ code: The CSS code to prepare the state for.
+ zero_state: If True, prepare the +1 eigenstate of the Z basis. If False, prepare the +1 eigenstate of the X basis.
+ """
+ self.circ = circ
+ self.code = code
+ self.zero_state = zero_state
+
+ if code.Hx is None or code.Hz is None:
+ msg = "The CSS code must have both X and Z checks."
+ raise InvalidCSSCodeError(msg)
+
+ self.x_checks = code.Hx.copy() if zero_state else np.vstack((code.Lx.copy(), code.Hx.copy()))
+ self.z_checks = code.Hz.copy() if not zero_state else np.vstack((code.Lz.copy(), code.Hz.copy()))
+
+ self.num_qubits = circ.num_qubits
+ self.max_errors = (code.distance - 1) // 2
+ self.x_fault_sets = [None for _ in range(self.max_errors + 1)] # type: list[npt.NDArray[np.int8] | None]
+ self.z_fault_sets = [None for _ in range(self.max_errors + 1)] # type: list[npt.NDArray[np.int8] | None]
+ self.x_fault_sets_unreduced = [None for _ in range(self.max_errors + 1)] # type: list[npt.NDArray[np.int8] | None]
+ self.z_fault_sets_unreduced = [None for _ in range(self.max_errors + 1)] # type: list[npt.NDArray[np.int8] | None]
+
+ self.max_x_measurements = len(self.x_checks)
+ self.max_z_measurements = len(self.z_checks)
+
+ def compute_fault_sets(self, reduce: bool = True) -> None:
+ """Compute the fault sets for the state preparation circuit."""
+ self.compute_fault_set(self.max_errors, x_errors=True, reduce=reduce)
+ self.compute_fault_set(self.max_errors, x_errors=False, reduce=reduce)
+
+ def compute_fault_set(
+ self, num_errors: int = 1, x_errors: bool = True, reduce: bool = True
+ ) -> npt.NDArray[np.int8]:
+ """Compute the fault set of the state.
+
+ Args:
+ state: The stabilizer state to compute the fault set for.
+ num_errors: The number of independent errors to propagate through the circuit.
+ x_errors: If True, compute the fault set for X errors. If False, compute the fault set for Z errors.
+ reduce: If True, reduce the fault set by the stabilizers of the code to reduce weights.
+
+ Returns:
+ The fault set of the state.
+ """
+ faults = self.x_fault_sets[num_errors] if x_errors else self.z_fault_sets[num_errors] # type: npt.NDArray[np.int8] | None
+ if faults is not None:
+ return faults
+
+ if num_errors == 1:
+ logging.info("Computing fault set for 1 error.")
+ dag = circuit_to_dag(self.circ)
+ for node in dag.front_layer(): # remove hadamards
+ dag.remove_op_node(node)
+ fault_list = []
+ # propagate every error before a control
+ for node in dag.topological_op_nodes():
+ error = _propagate_error(dag, node, x_errors=x_errors)
+ fault_list.append(error)
+ faults = np.array(fault_list, dtype=np.int8)
+ faults = np.unique(faults, axis=0)
+
+ if x_errors and self.x_fault_sets_unreduced[1] is None:
+ non_propagated_single_errors = np.eye(self.num_qubits, dtype=np.int8) # type: npt.NDArray[np.int8]
+ self.x_fault_sets_unreduced[1] = np.vstack((faults, non_propagated_single_errors))
+ elif not x_errors and self.z_fault_sets[1] is None:
+ non_propagated_single_errors = np.eye(self.num_qubits, dtype=np.int8)
+ self.z_fault_sets_unreduced[1] = np.vstack((faults, non_propagated_single_errors))
+ else:
+ logging.info(f"Computing fault set for {num_errors} errors.")
+ self.compute_fault_set(num_errors - 1, x_errors, reduce=reduce)
+ if x_errors:
+ faults = self.x_fault_sets_unreduced[num_errors - 1]
+ single_faults = self.x_fault_sets_unreduced[1]
+ else:
+ faults = self.z_fault_sets_unreduced[num_errors - 1]
+ single_faults = self.z_fault_sets_unreduced[1]
+
+ assert faults is not None
+ assert single_faults is not None
+
+ new_faults = (faults[:, np.newaxis, :] + single_faults).reshape(-1, self.num_qubits) % 2
+ # remove duplicates
+ faults = np.unique(new_faults, axis=0)
+ if x_errors:
+ self.x_fault_sets_unreduced[num_errors] = faults.copy()
+ else:
+ self.z_fault_sets_unreduced[num_errors] = faults.copy()
+
+ # reduce faults by stabilizer
+ stabs = self.x_checks if x_errors else self.z_checks
+ faults = _remove_trivial_faults(faults, stabs, self.code, x_errors, num_errors)
+
+ # remove stabilizer equivalent faults
+ if reduce:
+ logging.info("Removing stabilizer equivalent faults.")
+ faults = _remove_stabilizer_equivalent_faults(faults, stabs)
+ if x_errors:
+ self.x_fault_sets[num_errors] = faults
+ else:
+ self.z_fault_sets[num_errors] = faults
+ return faults
+
+ def combine_faults(
+ self, additional_faults: npt.NDArray[np.int8], x_errors: bool = True
+ ) -> list[npt.NDArray[np.int8] | None]:
+ """Combine fault sets of circuit with additional independent faults.
+
+ Args:
+ additional_faults: The additional faults to combine with the fault set of the circuit.
+ x_errors: If True, combine the fault sets for X errors. If False, combine the fault sets for Z errors.
+ """
+ self.compute_fault_sets()
+ if len(additional_faults) == 0:
+ return self.x_fault_sets if x_errors else self.z_fault_sets
+
+ fault_sets_unreduced = self.x_fault_sets_unreduced.copy() if x_errors else self.z_fault_sets_unreduced.copy()
+ assert fault_sets_unreduced[1] is not None
+ fault_sets_unreduced[1] = np.vstack((fault_sets_unreduced[1], additional_faults))
+
+ for i in range(1, self.max_errors):
+ uncombined = fault_sets_unreduced[i]
+ assert uncombined is not None
+ combined = (uncombined[:, np.newaxis, :] + additional_faults).reshape(-1, self.num_qubits) % 2
+ next_faults = fault_sets_unreduced[i + 1]
+ assert next_faults is not None
+ fault_sets_unreduced[i + 1] = np.vstack((next_faults, combined))
+ fault_sets = [None for _ in range(self.max_errors + 1)] # type: list[None | npt.NDArray[np.int8]]
+ stabs = self.x_checks if x_errors else self.z_checks
+ for num_errors in range(1, self.max_errors + 1):
+ fs = fault_sets_unreduced[num_errors]
+ assert fs is not None
+ fault_sets[num_errors] = _remove_trivial_faults(fs, stabs, self.code, x_errors, num_errors)
+ return fault_sets
+
+
+def heuristic_prep_circuit(code: CSSCode, optimize_depth: bool = True, zero_state: bool = True) -> StatePrepCircuit:
+ """Return a circuit that prepares the +1 eigenstate of the code w.r.t. the Z or X basis.
+
+ Args:
+ code: The CSS code to prepare the state for.
+ optimize_depth: If True, optimize the depth of the circuit. This may lead to a higher number of CNOTs.
+ zero_state: If True, prepare the +1 eigenstate of the Z basis. If False, prepare the +1 eigenstate of the X basis.
+ """
+ logging.info("Starting heuristic state preparation.")
+ if code.Hx is None or code.Hz is None:
+ msg = "The code must have both X and Z stabilizers defined."
+ raise InvalidCSSCodeError(msg)
+ checks = code.Hx.copy() if zero_state else code.Hz.copy()
+ rank = mod2.rank(checks)
+
+ def is_reduced() -> bool:
+ return bool(len(np.where(np.all(checks == 0, axis=0))[0]) == checks.shape[1] - rank)
+
+ costs = np.array([
+ [np.sum((checks[:, i] + checks[:, j]) % 2) for j in range(checks.shape[1])] for i in range(checks.shape[1])
+ ])
+ costs -= np.sum(checks, axis=0)
+ np.fill_diagonal(costs, 1)
+
+ used_qubits = [] # type: list[np.int_]
+ cnots = [] # type: list[tuple[int, int]]
+ while not is_reduced():
+ m = np.zeros((checks.shape[1], checks.shape[1]), dtype=bool) # type: npt.NDArray[np.bool_]
+ m[used_qubits, :] = True
+ m[:, used_qubits] = True
+
+ costs_unused = np.ma.array(costs, mask=m) # type: ignore[no-untyped-call]
+ if np.all(costs_unused >= 0): # no more reductions possible
+ if used_qubits == []: # local minimum => get out by making matrix triangular
+ logging.warning("Local minimum reached. Making matrix triangular.")
+ checks = mod2.reduced_row_echelon(checks)[0]
+ costs = np.array([
+ [np.sum((checks[:, i] + checks[:, j]) % 2) for j in range(checks.shape[1])]
+ for i in range(checks.shape[1])
+ ])
+ costs -= np.sum(checks, axis=0)
+ np.fill_diagonal(costs, 1)
+ else: # try to move onto the next layer
+ used_qubits = []
+ continue
+
+ i, j = np.unravel_index(np.argmin(costs_unused), costs.shape)
+ cnots.append((int(i), int(j)))
+
+ if optimize_depth:
+ used_qubits.append(i)
+ used_qubits.append(j)
+
+ # update checks
+ checks[:, j] = (checks[:, i] + checks[:, j]) % 2
+ # update costs
+ new_weights = np.sum((checks[:, j][:, np.newaxis] + checks) % 2, axis=0)
+ costs[j, :] = new_weights - np.sum(checks, axis=0)
+ costs[:, j] = new_weights - np.sum(checks[:, j])
+ np.fill_diagonal(costs, 1)
+
+ circ = _build_circuit_from_list_and_checks(cnots, checks, zero_state)
+ return StatePrepCircuit(circ, code, zero_state)
+
+
+def _generate_circ_with_bounded_depth(
+ checks: npt.NDArray[np.int8], max_depth: int, zero_state: bool = True
+) -> QuantumCircuit | None:
+ assert max_depth > 0, "max_depth should be greater than 0"
+ columns = np.array([
+ [[z3.Bool(f"x_{d}_{i}_{j}") for j in range(checks.shape[1])] for i in range(checks.shape[0])]
+ for d in range(max_depth + 1)
+ ])
+
+ additions = np.array([
+ [[z3.Bool(f"add_{d}_{i}_{j}") for j in range(checks.shape[1])] for i in range(checks.shape[1])]
+ for d in range(max_depth)
+ ])
+ n_cols = checks.shape[1]
+ s = z3.Solver()
+
+ # create initial matrix
+ columns[0, :, :] = checks.astype(bool)
+
+ s.add(_column_addition_constraint(columns, additions))
+
+ # qubit can be involved in at most one addition at each depth
+ for d in range(max_depth):
+ for col in range(n_cols):
+ s.add(
+ z3.PbLe(
+ [(additions[d, col_1, col], 1) for col_1 in range(n_cols) if col != col_1]
+ + [(additions[d, col, col_2], 1) for col_2 in range(n_cols) if col != col_2],
+ 1,
+ )
+ )
+
+ # if column is not involved in any addition at certain depth, it is the same as the previous column
+ for d in range(1, max_depth + 1):
+ for col in range(n_cols):
+ s.add(
+ z3.Implies(
+ z3.Not(
+ z3.Or(
+ list(np.delete(additions[d - 1, :, col], [col]))
+ + list(np.delete(additions[d - 1, col, :], [col]))
+ )
+ ),
+ _symbolic_vector_eq(columns[d, :, col], columns[d - 1, :, col]),
+ )
+ )
+
+ s.add(_final_matrix_constraint(columns))
+
+ if s.check() == z3.sat:
+ m = s.model()
+ cnots = [
+ (i, j)
+ for d in range(max_depth)
+ for j in range(checks.shape[1])
+ for i in range(checks.shape[1])
+ if m[additions[d, i, j]]
+ ]
+
+ checks = np.array([
+ [bool(m[columns[max_depth, i, j]]) for j in range(checks.shape[1])] for i in range(checks.shape[0])
+ ])
+
+ return _build_circuit_from_list_and_checks(cnots, checks, zero_state=zero_state)
+
+ return None
+
+
+def _generate_circ_with_bounded_gates(
+ checks: npt.NDArray[np.int8], max_cnots: int, zero_state: bool = True
+) -> QuantumCircuit:
+ """Find the gate optimal circuit for a given check matrix and maximum depth."""
+ columns = np.array([
+ [[z3.Bool(f"x_{d}_{i}_{j}") for j in range(checks.shape[1])] for i in range(checks.shape[0])]
+ for d in range(max_cnots + 1)
+ ])
+ n_bits = int(np.ceil(np.log2(checks.shape[1])))
+ targets = [z3.BitVec(f"target_{d}", n_bits) for d in range(max_cnots)]
+ controls = [z3.BitVec(f"control_{d}", n_bits) for d in range(max_cnots)]
+ s = z3.Solver()
+
+ additions = np.array([
+ [
+ [z3.And(controls[d] == col_1, targets[d] == col_2) for col_2 in range(checks.shape[1])]
+ for col_1 in range(checks.shape[1])
+ ]
+ for d in range(max_cnots)
+ ])
+
+ # create initial matrix
+ columns[0, :, :] = checks.astype(bool)
+ s.add(_column_addition_constraint(columns, additions))
+
+ for d in range(1, max_cnots + 1):
+ # qubit cannot be control and target at the same time
+ s.add(controls[d - 1] != targets[d - 1])
+
+ # control and target must be valid qubits
+ if checks.shape[1] and (checks.shape[1] - 1) != 0:
+ s.add(z3.ULT(controls[d - 1], checks.shape[1]))
+ s.add(z3.ULT(targets[d - 1], checks.shape[1]))
+
+ # if column is not involved in any addition at certain depth, it is the same as the previous column
+ for d in range(1, max_cnots + 1):
+ for col in range(checks.shape[1]):
+ s.add(z3.Implies(targets[d - 1] != col, _symbolic_vector_eq(columns[d, :, col], columns[d - 1, :, col])))
+
+ # assert that final check matrix has checks.shape[1]-checks.shape[0] zero columns
+ s.add(_final_matrix_constraint(columns))
+
+ if s.check() == z3.sat:
+ m = s.model()
+ cnots = [(m[controls[d]].as_long(), m[targets[d]].as_long()) for d in range(max_cnots)]
+ checks = np.array([
+ [bool(m[columns[max_cnots][i][j]]) for j in range(checks.shape[1])] for i in range(checks.shape[0])
+ ]).astype(int)
+ return _build_circuit_from_list_and_checks(cnots, checks, zero_state=zero_state)
+
+ return None
+
+
+def _optimal_circuit(
+ code: CSSCode,
+ prep_func: Callable[[npt.NDArray[np.int8], int, bool], QuantumCircuit | None],
+ zero_state: bool = True,
+ min_param: int = 1,
+ max_param: int = 10,
+ min_timeout: int = 1,
+ max_timeout: int = 3600,
+) -> StatePrepCircuit | None:
+ """Synthesize a state preparation circuit for a CSS code that minimizes the circuit w.r.t. some metric param according to prep_func.
+
+ Args:
+ code: The CSS code to prepare the state for.
+ zero_state: If True, prepare the +1 eigenstate of the Z basis. If False, prepare the +1 eigenstate of the X basis.
+ prep_func: The function to optimize the circuit with respect to.
+ min_param: minimum parameter to start with
+ max_param: maximum parameter to reach
+ min_timeout: minimum timeout to start with
+ max_timeout: maximum timeout to reach
+ """
+ if code.Hx is None or code.Hz is None:
+ msg = "Code must have both X and Z stabilizers defined."
+ raise ValueError(msg)
+ checks = code.Hx if zero_state else code.Hz
+
+ def fun(param: int) -> QuantumCircuit | None:
+ return prep_func(checks, param, zero_state)
+
+ res = iterative_search_with_timeout(
+ fun,
+ min_param,
+ max_param,
+ min_timeout,
+ max_timeout,
+ )
+
+ if res is None:
+ return None
+ circ, curr_param = res
+ if circ is None:
+ return None
+
+ logging.info(f"Solution found with param {curr_param}")
+ # Solving a SAT instance is much faster than proving unsat in this case
+ # so we iterate backwards until we find an unsat instance or hit a timeout
+ logging.info("Trying to minimize param")
+ while True:
+ logging.info(f"Trying param {curr_param - 1}")
+ opt_res = _run_with_timeout(fun, curr_param - 1, timeout=max_timeout)
+ if opt_res is None or (isinstance(opt_res, str) and opt_res == "timeout"):
+ break
+ circ = opt_res
+ curr_param -= 1
+
+ logging.info(f"Optimal param: {curr_param}")
+ return StatePrepCircuit(circ, code, zero_state)
+
+
+def depth_optimal_prep_circuit(
+ code: CSSCode,
+ zero_state: bool = True,
+ min_depth: int = 1,
+ max_depth: int = 10,
+ min_timeout: int = 1,
+ max_timeout: int = 3600,
+) -> StatePrepCircuit | None:
+ """Synthesize a state preparation circuit for a CSS code that minimizes the circuit depth.
+
+ Args:
+ code: The CSS code to prepare the state for.
+ zero_state: If True, prepare the +1 eigenstate of the Z basis. If False, prepare the +1 eigenstate of the X basis.
+ min_depth: minimum depth to start with
+ max_depth: maximum depth to reach
+ min_timeout: minimum timeout to start with
+ max_timeout: maximum timeout to reach
+ """
+ return _optimal_circuit(
+ code, _generate_circ_with_bounded_depth, zero_state, min_depth, max_depth, min_timeout, max_timeout
+ )
+
+
+def gate_optimal_prep_circuit(
+ code: CSSCode,
+ zero_state: bool = True,
+ min_gates: int = 1,
+ max_gates: int = 10,
+ min_timeout: int = 1,
+ max_timeout: int = 3600,
+) -> StatePrepCircuit | None:
+ """Synthesize a state preparation circuit for a CSS code that minimizes the number of gates.
+
+ Args:
+ code: The CSS code to prepare the state for.
+ zero_state: If True, prepare the +1 eigenstate of the Z basis. If False, prepare the +1 eigenstate of the X basis.
+ min_gates: minimum number of gates to start with
+ max_gates: maximum number of gates to reach
+ min_timeout: minimum timeout to start with
+ max_timeout: maximum timeout to reach
+ """
+ return _optimal_circuit(
+ code, _generate_circ_with_bounded_gates, zero_state, min_gates, max_gates, min_timeout, max_timeout
+ )
+
+
+def _build_circuit_from_list_and_checks(
+ cnots: list[tuple[int, int]], checks: npt.NDArray[np.int8], zero_state: bool = True
+) -> QuantumCircuit:
+ # Build circuit
+ n = checks.shape[1]
+ circ = QuantumCircuit(n)
+
+ controls = [i for i in range(n) if np.sum(checks[:, i]) >= 1]
+ if zero_state:
+ for control in controls:
+ circ.h(control)
+ else:
+ for i in range(n):
+ if i not in controls:
+ circ.h(i)
+
+ for i, j in reversed(cnots):
+ if zero_state:
+ ctrl, tar = i, j
+ else:
+ ctrl, tar = j, i
+ circ.cx(ctrl, tar)
+ return circ
+
+
+def _run_with_timeout(func: Callable[[Any], Any], *args: Any, timeout: int = 10) -> Any | str | None: # noqa: ANN401
+ """Run a function with a timeout.
+
+ If the function does not complete within the timeout, return None.
+
+ Args:
+ func: The function to run.
+ args: The arguments to pass to the function.
+ timeout: The maximum time to allow the function to run for in seconds.
+ """
+ manager = multiprocess.Manager()
+ return_list = manager.list()
+ p = multiprocess.Process(target=lambda: return_list.append(func(*args)))
+ p.start()
+ p.join(timeout)
+ if p.is_alive():
+ p.terminate()
+ return "timeout"
+ return return_list[0]
+
+
+def iterative_search_with_timeout(
+ fun: Callable[[int], QuantumCircuit],
+ min_param: int,
+ max_param: int,
+ min_timeout: int,
+ max_timeout: int,
+ param_factor: float = 2,
+ timeout_factor: float = 2,
+) -> None | tuple[None | QuantumCircuit, int]:
+ """Geometrically increases the parameter and timeout until a result is found or the maximum timeout is reached.
+
+ Args:
+ fun: function to run with increasing parameters and timeouts
+ min_param: minimum parameter to start with
+ max_param: maximum parameter to reach
+ min_timeout: minimum timeout to start with
+ max_timeout: maximum timeout to reach
+ param_factor: factor to increase the parameter by at each iteration
+ timeout_factor: factor to increase the timeout by at each iteration
+ """
+ curr_timeout = min_timeout
+ curr_param = min_param
+ while curr_timeout <= max_timeout:
+ while curr_param <= max_param:
+ logging.info(f"Running iterative search with param={curr_param} and timeout={curr_timeout}")
+ res = _run_with_timeout(fun, curr_param, timeout=curr_timeout)
+ if res is not None and (not isinstance(res, str) or res != "timeout"):
+ return res, curr_param
+ if curr_param == max_param:
+ break
+
+ curr_param = int(curr_param * param_factor)
+ curr_param = min(curr_param, max_param)
+
+ curr_timeout = int(curr_timeout * timeout_factor)
+ curr_param = min_param
+ return None, max_param
+
+
+def gate_optimal_verification_stabilizers(
+ sp_circ: StatePrepCircuit,
+ x_errors: bool = True,
+ min_timeout: int = 1,
+ max_timeout: int = 3600,
+ max_ancillas: int | None = None,
+ additional_faults: npt.NDArray[np.int8] | None = None,
+) -> list[list[npt.NDArray[np.int8]]]:
+ """Return verification stabilizers for the state preparation circuit.
+
+ The method uses an iterative search to find the optimal set of stabilizers by repeatedly computing the optimal circuit for each number of ancillas and cnots. This is repeated for each number of independent correctable errors in the state preparation circuit. Thus the verification circuit is constructed of multiple "layers" of stabilizers, each layer corresponding to a fault set it verifies.
+
+ Args:
+ sp_circ: The state preparation circuit to verify.
+ x_errors: If True, verify the X errors. If False, verify the Z errors.
+ min_timeout: The minimum time to allow each search to run for.
+ max_timeout: The maximum time to allow each search to run for.
+ max_ancillas: The maximum number of ancillas to allow in each layer verification circuit.
+ additional_faults: Faults to verify in addition to the faults propagating in the state preparation circuit.
+
+ Returns:
+ A list of stabilizers to verify the state preparation circuit.
+ """
+ max_errors = (sp_circ.code.distance - 1) // 2
+ layers = [[] for _ in range(max_errors)] # type: list[list[npt.NDArray[np.int8]]]
+ if max_ancillas is None:
+ max_ancillas = sp_circ.max_z_measurements if x_errors else sp_circ.max_x_measurements
+
+ sp_circ.compute_fault_sets()
+ fault_sets = (
+ sp_circ.combine_faults(additional_faults, x_errors)
+ if additional_faults is not None
+ else sp_circ.x_fault_sets
+ if x_errors
+ else sp_circ.z_fault_sets
+ )
+
+ # Find the optimal circuit for every number of errors in the preparation circuit
+ for num_errors in range(1, max_errors + 1):
+ logging.info(f"Finding verification stabilizers for {num_errors} errors")
+ faults = fault_sets[num_errors]
+ assert faults is not None
+
+ if len(faults) == 0:
+ logging.info(f"No non-trivial faults for {num_errors} errors")
+ layers[num_errors - 1] = []
+ continue
+ # Start with maximal number of ancillas
+ # Minimal CNOT solution must be achievable with these
+ num_anc = max_ancillas
+ checks = sp_circ.z_checks if x_errors else sp_circ.x_checks
+ min_cnots = np.min(np.sum(checks, axis=1))
+ max_cnots = np.sum(checks)
+
+ logging.info(
+ f"Finding verification stabilizers for {num_errors} errors with {min_cnots} to {max_cnots} CNOTs using {num_anc} ancillas"
+ )
+
+ def fun(num_cnots: int) -> list[npt.NDArray[np.int8]] | None:
+ return verification_stabilizers(sp_circ, faults, num_anc, num_cnots, x_errors=x_errors) # noqa: B023
+
+ res = iterative_search_with_timeout(
+ fun,
+ min_cnots,
+ max_cnots,
+ min_timeout,
+ max_timeout,
+ )
+
+ if res is None:
+ logging.info(f"No verification stabilizers found for {num_errors} errors")
+ layers[num_errors - 1] = []
+ continue
+ measurements, num_cnots = res
+ if measurements is None or (isinstance(measurements, str) and measurements == "timeout"):
+ logging.info(f"No verification stabilizers found for {num_errors} errors")
+ return [] # No solution found
+
+ logging.info(f"Found verification stabilizers for {num_errors} errors with {num_cnots} CNOTs")
+ # If any measurements are unused we can reduce the number of ancillas at least by that
+ measurements = [m for m in measurements if np.any(m)]
+ num_anc = len(measurements)
+ # Iterate backwards to find the minimal number of cnots
+ logging.info(f"Finding minimal number of CNOTs for {num_errors} errors")
+
+ def search_cnots(num_cnots: int) -> list[npt.NDArray[np.int8]] | None:
+ return verification_stabilizers(sp_circ, faults, num_anc, num_cnots, x_errors=x_errors) # noqa: B023
+
+ while num_cnots - 1 > 0:
+ logging.info(f"Trying {num_cnots - 1} CNOTs")
+
+ cnot_opt = _run_with_timeout(
+ search_cnots,
+ num_cnots - 1,
+ timeout=max_timeout,
+ )
+ if cnot_opt is None or (isinstance(cnot_opt, str) and cnot_opt == "timeout"):
+ break
+ num_cnots -= 1
+ measurements = cnot_opt
+ logging.info(f"Minimal number of CNOTs for {num_errors} errors is: {num_cnots}")
+
+ # If the number of CNOTs is minimal, we can reduce the number of ancillas
+ logging.info(f"Finding minimal number of ancillas for {num_errors} errors")
+ while num_anc - 1 > 0:
+ logging.info(f"Trying {num_anc - 1} ancillas")
+
+ def search_anc(num_anc: int) -> list[npt.NDArray[np.int8]] | None:
+ return verification_stabilizers(sp_circ, faults, num_anc, num_cnots, x_errors=x_errors) # noqa: B023
+
+ anc_opt = _run_with_timeout(
+ search_anc,
+ num_anc - 1,
+ timeout=max_timeout,
+ )
+ if anc_opt is None or (isinstance(anc_opt, str) and anc_opt == "timeout"):
+ break
+ num_anc -= 1
+ measurements = anc_opt
+ logging.info(f"Minimal number of ancillas for {num_errors} errors is: {num_anc}")
+ layers[num_errors - 1] = measurements
+
+ return layers
+
+
+def _verification_circuit(
+ sp_circ: StatePrepCircuit,
+ verification_stabs_fun: Callable[
+ [StatePrepCircuit, bool, npt.NDArray[np.int8] | None], list[list[npt.NDArray[np.int8]]]
+ ],
+ full_fault_tolerance: bool = True,
+) -> QuantumCircuit:
+ logging.info("Finding verification stabilizers for the state preparation circuit")
+ layers_1 = verification_stabs_fun(sp_circ, sp_circ.zero_state) # type: ignore[call-arg]
+ measurements_1 = [measurement for layer in layers_1 for measurement in layer]
+
+ if full_fault_tolerance:
+ additional_errors = _hook_errors(measurements_1)
+ layers_2 = verification_stabs_fun(sp_circ, not sp_circ.zero_state, additional_errors)
+ measurements_2 = [measurement for layer in layers_2 for measurement in layer]
+ else:
+ measurements_2 = []
+
+ if sp_circ.zero_state:
+ return _measure_ft_stabs(sp_circ, measurements_2, measurements_1, full_fault_tolerance=full_fault_tolerance)
+ return _measure_ft_stabs(sp_circ, measurements_1, measurements_2, full_fault_tolerance=full_fault_tolerance)
+
+
+def gate_optimal_verification_circuit(
+ sp_circ: StatePrepCircuit,
+ min_timeout: int = 1,
+ max_timeout: int = 3600,
+ max_ancillas: int | None = None,
+ full_fault_tolerance: bool = True,
+) -> QuantumCircuit:
+ """Return a verified state preparation circuit.
+
+ The verification circuit is a set of stabilizers such that each propagated error in sp_circ anticommutes with some verification stabilizer.
+
+ The method uses an iterative search to find the optimal set of stabilizers by repeatedly computing the optimal circuit for each number of ancillas and cnots. This is repeated for each number of independent correctable errors in the state preparation circuit. Thus the verification circuit is constructed of multiple "layers" of stabilizers, each layer corresponding to a fault set it verifies.
+
+ Args:
+ sp_circ: The state preparation circuit to verify.
+ min_timeout: The minimum time to allow each search to run for.
+ max_timeout: The maximum time to allow each search to run for.
+ max_ancillas: The maximum number of ancillas to allow in each layer verification circuit.
+ full_fault_tolerance: If True, the verification circuit will be constructed to be fault tolerant to all errors in the state preparation circuit. If False, the verification circuit will be constructed to be fault tolerant only to the type of errors that can cause a logical error. For a logical |0> state preparation circuit, this means the verification circuit will be fault tolerant to X errors but not for Z errors. For a logical |+> state preparation circuit, this means the verification circuit will be fault tolerant to Z errors but not for X errors.
+ """
+
+ def verification_stabs_fun(
+ sp_circ: StatePrepCircuit,
+ zero_state: bool,
+ additional_errors: npt.NDArray[np.int8] | None = None,
+ ) -> list[list[npt.NDArray[np.int8]]]:
+ return gate_optimal_verification_stabilizers(
+ sp_circ, zero_state, min_timeout, max_timeout, max_ancillas, additional_errors
+ )
+
+ return _verification_circuit(sp_circ, verification_stabs_fun, full_fault_tolerance=full_fault_tolerance)
+
+
+def heuristic_verification_circuit(
+ sp_circ: StatePrepCircuit,
+ max_covering_sets: int = 10000,
+ find_coset_leaders: bool = True,
+ full_fault_tolerance: bool = True,
+) -> QuantumCircuit:
+ """Return a verified state preparation circuit.
+
+ The method uses a greedy set covering heuristic to find a small set of stabilizers that verifies the state preparation circuit. The heuristic is not guaranteed to find the optimal set of stabilizers.
+
+ Args:
+ sp_circ: The state preparation circuit to verify.
+ max_covering_sets: The maximum number of covering sets to consider.
+ find_coset_leaders: Whether to find coset leaders for the found measurements. This is done using SAT solvers so it can be slow.
+ full_fault_tolerance: If True, the verification circuit will be constructed to be fault tolerant to all errors in the state preparation circuit. If False, the verification circuit will be constructed to be fault tolerant only to the type of errors that can cause a logical error. For a logical |0> state preparation circuit, this means the verification circuit will be fault tolerant to X errors but not for Z errors. For a logical |+> state preparation circuit, this means the verification circuit will be fault tolerant to Z errors but not for X errors.
+ """
+
+ def verification_stabs_fun(
+ sp_circ: StatePrepCircuit, zero_state: bool, additional_errors: npt.NDArray[np.int8] | None = None
+ ) -> list[list[npt.NDArray[np.int8]]]:
+ return heuristic_verification_stabilizers(
+ sp_circ, zero_state, max_covering_sets, find_coset_leaders, additional_errors
+ )
+
+ return _verification_circuit(sp_circ, verification_stabs_fun, full_fault_tolerance=full_fault_tolerance)
+
+
+def heuristic_verification_stabilizers(
+ sp_circ: StatePrepCircuit,
+ x_errors: bool = True,
+ max_covering_sets: int = 10000,
+ find_coset_leaders: bool = True,
+ additional_faults: npt.NDArray[np.int8] | None = None,
+) -> list[list[npt.NDArray[np.int8]]]:
+ """Return verification stabilizers for the preparation circuit.
+
+ Args:
+ sp_circ: The state preparation circuit to verify.
+ x_errors: Whether to find verification stabilizers for X errors. If False, find for Z errors.
+ max_covering_sets: The maximum number of covering sets to consider.
+ find_coset_leaders: Whether to find coset leaders for the found measurements. This is done using SAT solvers so it can be slow.
+ additional_faults: Faults to verify in addition to the faults propagating in the state preparation circuit.
+ """
+ logging.info("Finding verification stabilizers using heuristic method")
+ max_errors = (sp_circ.code.distance - 1) // 2
+ layers = [[] for _ in range(max_errors)] # type: list[list[npt.NDArray[np.int8]]]
+ sp_circ.compute_fault_sets()
+ fault_sets = (
+ sp_circ.combine_faults(additional_faults, x_errors)
+ if additional_faults is not None
+ else sp_circ.x_fault_sets
+ if x_errors
+ else sp_circ.z_fault_sets
+ )
+ orthogonal_checks = sp_circ.z_checks if x_errors else sp_circ.x_checks
+ for num_errors in range(1, max_errors + 1):
+ logging.info(f"Finding verification stabilizers for {num_errors} errors")
+ faults = fault_sets[num_errors]
+ assert faults is not None
+ logging.info(f"There are {len(faults)} faults")
+ if len(faults) == 0:
+ layers[num_errors - 1] = []
+ continue
+
+ layers[num_errors - 1] = _heuristic_layer(faults, orthogonal_checks, find_coset_leaders, max_covering_sets)
+
+ return layers
+
+
+def _covers(s: npt.NDArray[np.int8], faults: npt.NDArray[np.int8]) -> frozenset[int]:
+ return frozenset(np.where(s @ faults.T % 2 != 0)[0])
+
+
+def _set_cover(
+ n: int, cands: set[frozenset[int]], mapping: dict[frozenset[int], list[npt.NDArray[np.int8]]]
+) -> set[frozenset[int]]:
+ universe = set(range(n))
+ cover = set() # type: set[frozenset[int]]
+
+ def sort_key(stab: frozenset[int], universe: set[int] = universe) -> tuple[int, np.int_]:
+ return (len(stab & universe), -np.sum(mapping[stab]))
+
+ while universe:
+ best = max(cands, key=sort_key)
+ cover.add(best)
+ universe -= best
+ return cover
+
+
+def _extend_covering_sets(
+ candidate_sets: set[frozenset[int]], size_limit: int, mapping: dict[frozenset[int], list[npt.NDArray[np.int8]]]
+) -> set[frozenset[int]]:
+ to_remove = set() # type: set[frozenset[int]]
+ to_add = set() # type: set[frozenset[int]]
+ for c1 in candidate_sets:
+ for c2 in candidate_sets:
+ if len(to_add) >= size_limit:
+ break
+
+ comb = c1 ^ c2
+ if c1 == c2 or comb in candidate_sets or comb in to_add or comb in to_remove:
+ continue
+
+ mapping[comb].extend([(s1 + s2) % 2 for s1 in mapping[c1] for s2 in mapping[c2]])
+
+ if len(c1 & c2) == 0:
+ to_remove.add(c1)
+ to_remove.add(c2)
+ to_add.add(c1 ^ c2)
+
+ return candidate_sets.union(to_add)
+
+
+def _heuristic_layer(
+ faults: npt.NDArray[np.int8], checks: npt.NDArray[np.int8], find_coset_leaders: bool, max_covering_sets: int
+) -> list[npt.NDArray[np.int8]]:
+ syndromes = checks @ faults.T % 2
+ candidates = np.where(np.any(syndromes != 0, axis=1))[0]
+ non_candidates = np.where(np.all(syndromes == 0, axis=1))[0]
+ candidate_checks = checks[candidates]
+ non_candidate_checks = checks[non_candidates]
+
+ logging.info("Converting Stabilizer Checks to covering sets")
+ candidate_sets_ordered = [(_covers(s, faults), s, i) for i, s in enumerate(candidate_checks)]
+ mapping = defaultdict(list)
+ for cand, _, i in candidate_sets_ordered:
+ mapping[cand].append(candidate_checks[i])
+ candidate_sets = {cand for cand, _, _ in candidate_sets_ordered}
+
+ logging.info("Finding initial set cover")
+ cover = _set_cover(len(faults), candidate_sets, mapping)
+ logging.info(f"Initial set cover has {len(cover)} sets")
+
+ def cost(cover: set[frozenset[int]]) -> tuple[int, int]:
+ cost1 = len(cover)
+ cost2 = sum(np.sum(mapping[stab]) for stab in cover)
+ return cost1, cost2
+
+ cost1, cost2 = cost(cover)
+ prev_candidates = candidate_sets.copy()
+
+ # find good cover
+ improved = True
+ while improved and len(candidate_sets) < max_covering_sets:
+ improved = False
+ # add all symmetric differences to candidates
+ candidate_sets = _extend_covering_sets(candidate_sets, max_covering_sets, mapping)
+ new_cover = _set_cover(len(faults), candidate_sets, mapping)
+ logging.info(f"New Covering set has {len(new_cover)} sets")
+ new_cost1 = len(new_cover)
+ new_cost2 = sum(np.sum(mapping[stab]) for stab in new_cover)
+ if new_cost1 < cost1 or (new_cost1 == cost1 and new_cost2 < cost2):
+ cover = new_cover
+ cost1 = new_cost1
+ cost2 = new_cost2
+ improved = True
+ elif candidate_sets == prev_candidates:
+ break
+ prev_candidates = candidate_sets
+
+ # reduce stabilizers in cover
+ logging.info(f"Found covering set of size {len(cover)}.")
+ if find_coset_leaders and len(non_candidates) > 0:
+ logging.info("Finding coset leaders.")
+ measurements = []
+ for c in cover:
+ leaders = [_coset_leader(m, non_candidate_checks) for m in mapping[c]]
+ leaders.sort(key=np.sum)
+ measurements.append(leaders[0])
+ else:
+ measurements = [min(mapping[c], key=np.sum) for c in cover]
+
+ return measurements
+
+
+def _measure_ft_x(qc: QuantumCircuit, x_measurements: list[npt.NDArray[np.int8]], flags: bool = False) -> None:
+ if len(x_measurements) == 0:
+ return
+ num_x_anc = len(x_measurements)
+ x_anc = AncillaRegister(num_x_anc, "x_anc")
+ x_c = ClassicalRegister(num_x_anc, "x_c")
+ qc.add_register(x_anc)
+ qc.add_register(x_c)
+
+ for i, m in enumerate(x_measurements):
+ stab = np.where(m != 0)[0]
+ if flags:
+ measure_flagged(qc, stab, x_anc[i], x_c[i], z_measurement=False)
+ else:
+ qc.h(x_anc)
+ qc.cx([x_anc[i]] * len(stab), stab)
+ qc.h(x_anc)
+ qc.measure(x_anc, x_c)
+
+
+def _measure_ft_z(qc: QuantumCircuit, z_measurements: list[npt.NDArray[np.int8]], flags: bool = False) -> None:
+ if len(z_measurements) == 0:
+ return
+ num_z_anc = len(z_measurements)
+ z_anc = AncillaRegister(num_z_anc, "z_anc")
+ z_c = ClassicalRegister(num_z_anc, "z_c")
+ qc.add_register(z_anc)
+ qc.add_register(z_c)
+
+ for i, m in enumerate(z_measurements):
+ stab = np.where(m != 0)[0]
+ if flags:
+ measure_flagged(qc, stab, z_anc[i], z_c[i], z_measurement=True)
+ else:
+ qc.cx(stab, [z_anc[i]] * len(stab))
+ qc.measure(z_anc, z_c)
+
+
+def _measure_ft_stabs(
+ sp_circ: StatePrepCircuit,
+ x_measurements: list[npt.NDArray[np.int8]],
+ z_measurements: list[npt.NDArray[np.int8]],
+ full_fault_tolerance: bool = True,
+) -> QuantumCircuit:
+ # Create the verification circuit
+ q = QuantumRegister(sp_circ.num_qubits, "q")
+ measured_circ = QuantumCircuit(q)
+ measured_circ.compose(sp_circ.circ, inplace=True)
+
+ if sp_circ.zero_state:
+ _measure_ft_z(measured_circ, z_measurements)
+ if full_fault_tolerance:
+ _measure_ft_x(measured_circ, x_measurements, flags=True)
+ else:
+ _measure_ft_x(measured_circ, x_measurements)
+ if full_fault_tolerance:
+ _measure_ft_z(measured_circ, z_measurements, flags=True)
+
+ return measured_circ
+
+
+def _vars_to_stab(
+ measurement: list[z3.BoolRef | bool], generators: npt.NDArray[np.int8]
+) -> npt.NDArray[z3.BoolRef | bool]: # type: ignore[type-var]
+ measurement_stab = _symbolic_scalar_mult(generators[0], measurement[0])
+ for i, scalar in enumerate(measurement[1:]):
+ measurement_stab = _symbolic_vector_add(measurement_stab, _symbolic_scalar_mult(generators[i + 1], scalar))
+ return measurement_stab
+
+
+def verification_stabilizers(
+ sp_circ: StatePrepCircuit,
+ fault_set: npt.NDArray[np.int8],
+ num_anc: int,
+ num_cnots: int,
+ x_errors: bool = True,
+) -> list[npt.NDArray[np.int8]] | None:
+ """Return verification stabilizers for num_errors independent errors in the state preparation circuit using z3.
+
+ Args:
+ sp_circ: The state preparation circuit.
+ fault_set: The set of errors to verify.
+ num_anc: The maximum number of ancilla qubits to use.
+ num_cnots: The maximumg number of CNOT gates to use.
+ num_errors: The number of errors occur in the state prep circuit.
+ x_errors: If True, the errors are X errors. Otherwise, the errors are Z errors.
+ """
+ # Measurements are written as sums of generators
+ # The variables indicate which generators are non-zero in the sum
+ gens = sp_circ.z_checks if x_errors else sp_circ.x_checks
+ n_gens = gens.shape[0]
+
+ measurement_vars = [[z3.Bool(f"m_{anc}_{i}") for i in range(n_gens)] for anc in range(num_anc)]
+ solver = z3.Solver()
+
+ measurement_stabs = [_vars_to_stab(vars_, gens) for vars_ in measurement_vars]
+
+ # assert that each error is detected
+ solver.add(
+ z3.And([
+ z3.PbGe([(_odd_overlap(measurement, error), 1) for measurement in measurement_stabs], 1)
+ for error in fault_set
+ ])
+ )
+
+ # assert that not too many CNOTs are used
+ solver.add(
+ z3.PbLe(
+ [(measurement[q], 1) for measurement in measurement_stabs for q in range(sp_circ.num_qubits)], num_cnots
+ )
+ )
+
+ if solver.check() == z3.sat:
+ model = solver.model()
+ # Extract stabilizer measurements from model
+ actual_measurements = []
+ for m in measurement_vars:
+ v = np.zeros(sp_circ.num_qubits, dtype=np.int8) # type: npt.NDArray[np.int8]
+ for g in range(n_gens):
+ if model[m[g]]:
+ v += gens[g]
+ actual_measurements.append(v % 2)
+
+ return actual_measurements
+ return None
+
+
+def _coset_leader(error: npt.NDArray[np.int8], generators: npt.NDArray[np.int8]) -> npt.NDArray[np.int8]:
+ if len(generators) == 0:
+ return error
+ s = z3.Optimize()
+ leader = [z3.Bool(f"e_{i}") for i in range(len(error))]
+ coeff = [z3.Bool(f"c_{i}") for i in range(len(generators))]
+
+ g = _vars_to_stab(coeff, generators)
+
+ s.add(_symbolic_vector_eq(np.array(leader), _symbolic_vector_add(error.astype(bool), g)))
+ s.minimize(z3.Sum(leader))
+
+ s.check() # always SAT
+ m = s.model()
+ return np.array([bool(m[leader[i]]) for i in range(len(error))]).astype(int)
+
+
+def _symbolic_scalar_mult(v: npt.NDArray[np.int8], a: z3.BoolRef | bool) -> npt.NDArray[z3.BoolRef]:
+ """Multiply a concrete vector by a symbolic scalar."""
+ return np.array([a if s == 1 else False for s in v])
+
+
+def _symbolic_vector_add(
+ v1: npt.NDArray[z3.BoolRef | bool], v2: npt.NDArray[z3.BoolRef | bool]
+) -> npt.NDArray[z3.BoolRef | bool]:
+ """Add two symbolic vectors."""
+ v_new = [False for _ in range(len(v1))]
+ for i in range(len(v1)):
+ # If one of the elements is a bool, we can simplify the expression
+ v1_i_is_bool = isinstance(v1[i], (bool, np.bool_))
+ v2_i_is_bool = isinstance(v2[i], (bool, np.bool_))
+ if v1_i_is_bool:
+ v1[i] = bool(v1[i])
+ if v1[i]:
+ v_new[i] = z3.Not(v2[i]) if not v2_i_is_bool else not v2[i]
+ else:
+ v_new[i] = v2[i]
+
+ elif v2_i_is_bool:
+ v2[i] = bool(v2[i])
+ if v2[i]:
+ v_new[i] = z3.Not(v1[i])
+ else:
+ v_new[i] = v1[i]
+
+ else:
+ v_new[i] = z3.Xor(v1[i], v2[i])
+
+ return np.array(v_new)
+
+
+def _odd_overlap(v_sym: npt.NDArray[z3.BoolRef | bool], v_con: npt.NDArray[np.int8]) -> z3.BoolRef:
+ """Return True if the overlap of symbolic vector with constant vector is odd."""
+ return z3.PbEq([(v_sym[i], 1) for i, c in enumerate(v_con) if c == 1], 1)
+
+
+def _symbolic_vector_eq(v1: npt.NDArray[z3.BoolRef | bool], v2: npt.NDArray[z3.BoolRef | bool]) -> z3.BoolRef:
+ """Return assertion that two symbolic vectors should be equal."""
+ constraints = [False for _ in v1]
+ for i in range(len(v1)):
+ # If one of the elements is a bool, we can simplify the expression
+ v1_i_is_bool = isinstance(v1[i], (bool, np.bool_))
+ v2_i_is_bool = isinstance(v2[i], (bool, np.bool_))
+ if v1_i_is_bool:
+ v1[i] = bool(v1[i])
+ if v1[i]:
+ constraints[i] = v2[i]
+ else:
+ constraints[i] = z3.Not(v2[i]) if not v2_i_is_bool else not v2[i]
+
+ elif v2_i_is_bool:
+ v2[i] = bool(v2[i])
+ if v2[i]:
+ constraints[i] = v1[i]
+ else:
+ constraints[i] = z3.Not(v1[i])
+ else:
+ constraints[i] = v1[i] == v2[i]
+ return z3.And(constraints)
+
+
+def _column_addition_constraint(
+ columns: npt.NDArray[z3.BoolRef | bool],
+ col_add_vars: npt.NDArray[z3.BoolRef],
+) -> z3.BoolRef:
+ assert len(columns.shape) == 3
+ max_depth = col_add_vars.shape[0] # type: ignore[unreachable]
+ n_cols = col_add_vars.shape[2]
+
+ constraints = []
+ for d in range(1, max_depth + 1):
+ for col_1 in range(n_cols):
+ for col_2 in range(col_1 + 1, n_cols):
+ col_sum = _symbolic_vector_add(columns[d - 1, :, col_1], columns[d - 1, :, col_2])
+
+ # encode col_2 += col_1
+ add_col1_to_col2 = z3.Implies(
+ col_add_vars[d - 1, col_1, col_2],
+ z3.And(
+ _symbolic_vector_eq(columns[d, :, col_2], col_sum),
+ _symbolic_vector_eq(columns[d, :, col_1], columns[d - 1, :, col_1]),
+ ),
+ )
+
+ # encode col_1 += col_2
+ add_col2_to_col1 = z3.Implies(
+ col_add_vars[d - 1, col_2, col_1],
+ z3.And(
+ _symbolic_vector_eq(columns[d, :, col_1], col_sum),
+ _symbolic_vector_eq(columns[d, :, col_2], columns[d - 1, :, col_2]),
+ ),
+ )
+
+ constraints.extend([add_col1_to_col2, add_col2_to_col1])
+
+ return z3.And(constraints)
+
+
+def _final_matrix_constraint(columns: npt.NDArray[z3.BoolRef | bool]) -> z3.BoolRef:
+ assert len(columns.shape) == 3
+ return z3.PbEq( # type: ignore[unreachable]
+ [(z3.Not(z3.Or(list(columns[-1, :, col]))), 1) for col in range(columns.shape[2])],
+ columns.shape[2] - columns.shape[1],
+ )
+
+
+def _propagate_error(dag: DagCircuit, node: DAGNode, x_errors: bool = True) -> PauliList:
+ """Propagates a Pauli error through a circuit beginning from control of node."""
+ control = node.qargs[0]._index # noqa: SLF001
+ error = np.array([0] * dag.num_qubits(), dtype=np.int8) # type: npt.NDArray[np.int8]
+ error[control] = 1
+ # propagate error through circuit via bfs
+ q = deque([node])
+ visited = set() # type: set[DAGNode]
+ while q:
+ node = q.popleft()
+ if node in visited or isinstance(node, DAGOutNode):
+ continue
+ control = node.qargs[0]._index # noqa: SLF001
+ target = node.qargs[1]._index # noqa: SLF001
+ if x_errors:
+ error[target] = (error[target] + error[control]) % 2
+ else:
+ error[control] = (error[target] + error[control]) % 2
+ for succ in dag.successors(node):
+ q.append(succ)
+ return error
+
+
+def _remove_trivial_faults(
+ faults: npt.NDArray[np.int8], stabs: npt.NDArray[np.int8], code: CSSCode, x_errors: bool, num_errors: int
+) -> npt.NDArray[np.int8]:
+ # remove trivial faults
+ faults = faults.copy()
+ logging.info("Removing trivial faults.")
+ d_error = code.x_distance if x_errors else code.z_distance
+ t_error = (d_error - 1) // 2
+ t = (code.distance - 1) // 2
+ max_w = t_error // t
+ for i, fault in enumerate(faults):
+ faults[i] = _coset_leader(fault, stabs)
+ faults = faults[np.where(np.sum(faults, axis=1) > max_w * num_errors)[0]]
+
+ # unique faults
+ return np.unique(faults, axis=0)
+
+
+def _remove_stabilizer_equivalent_faults(
+ faults: npt.NDArray[np.int8], stabilizers: npt.NDArray[np.int8]
+) -> npt.NDArray[np.int8]:
+ """Remove stabilizer equivalent faults from a list of faults."""
+ faults = faults.copy()
+ stabilizers = stabilizers.copy()
+ removed = set()
+
+ logging.debug(f"Removing stabilizer equivalent faults from {len(faults)} faults.")
+ for i, f1 in enumerate(faults):
+ if i in removed:
+ continue
+ stabs_ext1 = np.vstack((stabilizers, f1))
+ if mod2.rank(stabs_ext1) == mod2.rank(stabilizers):
+ removed.add(i)
+ continue
+
+ for j, f2 in enumerate(faults[i + 1 :]):
+ if j + i + 1 in removed:
+ continue
+ stabs_ext2 = np.vstack((stabs_ext1, f2))
+
+ if mod2.rank(stabs_ext2) == mod2.rank(stabs_ext1):
+ removed.add(j + i + 1)
+
+ logging.debug(f"Removed {len(removed)} stabilizer equivalent faults.")
+ indices = list(set(range(len(faults))) - removed)
+ if len(indices) == 0:
+ return np.array([])
+
+ return faults[indices]
+
+
+def naive_verification_circuit(sp_circ: StatePrepCircuit) -> QuantumCircuit:
+ """Naive verification circuit for a state preparation circuit."""
+ if sp_circ.code.Hx is None or sp_circ.code.Hz is None:
+ msg = "Code must have stabilizers defined."
+ raise ValueError(msg)
+
+ z_measurements = list(sp_circ.code.Hx)
+ x_measurements = list(sp_circ.code.Hz)
+ reps = (sp_circ.code.distance - 1) // 2
+ return _measure_ft_stabs(sp_circ, x_measurements * reps, z_measurements * reps)
+
+
+def w_flag_pattern(w: int) -> list[int]:
+ """Return the w-flag construction from https://arxiv.org/abs/1708.02246.
+
+ Args:
+ w: The number of w-flags to construct.
+
+ Returns:
+ The w-flag pattern.
+ """
+ s1 = [2 * j + 2 for j in reversed(range((w - 4) // 2))]
+ s2 = [w - 3, 0]
+ s3 = [2 * j + 1 for j in reversed(range((w - 4) // 2))]
+ return s1 + s2 + s3 + [w - 2]
+
+
+def _ancilla_cnot(qc: QuantumCircuit, qubit: Qubit | AncillaQubit, ancilla: AncillaQubit, z_measurement: bool) -> None:
+ if z_measurement:
+ qc.cx(qubit, ancilla)
+ else:
+ qc.cx(ancilla, qubit)
+
+
+def _flag_measure(qc: QuantumCircuit, flag: AncillaQubit, meas_bit: ClBit, z_measurement: bool) -> None:
+ if z_measurement:
+ qc.h(flag)
+ qc.measure(flag, meas_bit)
+
+
+def _flag_reset(qc: QuantumCircuit, flag: AncillaQubit, z_measurement: bool) -> None:
+ qc.reset(flag)
+ if z_measurement:
+ qc.h(flag)
+
+
+def _flag_init(qc: QuantumCircuit, flag: AncillaQubit, z_measurement: bool) -> None:
+ if z_measurement:
+ qc.h(flag)
+
+
+def _measure_stab_unflagged(
+ qc: QuantumCircuit,
+ stab: list[Qubit] | npt.NDArray[np.int_],
+ ancilla: AncillaQubit,
+ measurement_bit: ClBit,
+ z_measurement: bool = True,
+) -> None:
+ if not z_measurement:
+ qc.h(ancilla)
+ qc.cx([ancilla] * len(stab), stab)
+ qc.h(ancilla)
+ else:
+ qc.cx(stab, [ancilla] * len(stab))
+ qc.measure(ancilla, measurement_bit)
+
+
+def measure_flagged(
+ qc: QuantumCircuit,
+ stab: list[Qubit] | npt.NDArray[np.int_],
+ ancilla: AncillaQubit,
+ measurement_bit: ClBit,
+ z_measurement: bool = True,
+) -> None:
+ """Measure a w-flagged stabilizer with the general scheme.
+
+ The measurement is done in place.
+
+ Args:
+ qc: The quantum circuit to add the measurement to.
+ stab: The qubits to measure.
+ ancilla: The ancilla qubit to use for the measurement.
+ measurement_bit: The classical bit to store the measurement result of the ancilla.
+ z_measurement: Whether to measure an X (False) or Z (True) stabilizer.
+ """
+ w = len(stab)
+ if w < 3:
+ _measure_stab_unflagged(qc, stab, ancilla, measurement_bit, z_measurement)
+
+ if w == 4:
+ measure_flagged_4(qc, stab, ancilla, measurement_bit, z_measurement)
+ return
+
+ if w == 6:
+ measure_flagged_6(qc, stab, ancilla, measurement_bit, z_measurement)
+ return
+
+ if w == 8:
+ measure_flagged_8(qc, stab, ancilla, measurement_bit, z_measurement)
+ return
+
+ flag_reg = AncillaRegister(w - 1)
+ meas_reg = ClassicalRegister(w - 1)
+ qc.add_register(flag_reg)
+ qc.add_register(meas_reg)
+
+ pattern = w_flag_pattern(w)
+
+ if not z_measurement:
+ qc.h(ancilla)
+
+ for flag in flag_reg:
+ _flag_init(qc, flag, z_measurement)
+ _ancilla_cnot(qc, flag, ancilla, z_measurement)
+
+ for i, q in enumerate(stab[:-3]):
+ flag = pattern[i]
+ _ancilla_cnot(qc, q, ancilla, z_measurement)
+ _ancilla_cnot(qc, flag_reg[flag], ancilla, z_measurement)
+ _flag_measure(qc, flag_reg[flag], meas_reg[flag], z_measurement)
+ _flag_reset(qc, flag_reg[flag], z_measurement)
+ _ancilla_cnot(qc, flag_reg[flag], ancilla, z_measurement)
+
+ _ancilla_cnot(qc, stab[-2], ancilla, z_measurement)
+
+ subpattern_length = (w - 4) // 2
+ reordered_pattern = [
+ pattern[-1],
+ *list(reversed(pattern[subpattern_length + 2 : -1])),
+ *list(reversed(pattern[subpattern_length : subpattern_length + 2])),
+ *list(reversed(pattern[:subpattern_length])),
+ ]
+
+ for flag in reordered_pattern:
+ _ancilla_cnot(qc, flag_reg[flag], ancilla, z_measurement)
+ _flag_measure(qc, flag_reg[flag], meas_reg[flag], z_measurement)
+
+ _ancilla_cnot(qc, stab[-1], ancilla, z_measurement)
+
+ if not z_measurement:
+ qc.h(ancilla)
+ qc.measure(ancilla, measurement_bit)
+
+
+def measure_flagged_4(
+ qc: QuantumCircuit,
+ stab: list[Qubit] | npt.NDArray[np.int_],
+ ancilla: AncillaQubit,
+ measurement_bit: ClBit,
+ z_measurement: bool = True,
+) -> None:
+ """Measure a 4-flagged stabilizer using an optimized scheme."""
+ assert len(stab) == 4
+ flag_reg = AncillaRegister(1)
+ meas_reg = ClassicalRegister(1)
+ qc.add_register(flag_reg)
+ qc.add_register(meas_reg)
+ flag = flag_reg[0]
+ flag_meas = meas_reg[0]
+
+ if not z_measurement:
+ qc.h(ancilla)
+
+ _ancilla_cnot(qc, stab[0], ancilla, z_measurement)
+ _flag_init(qc, flag, z_measurement)
+
+ _ancilla_cnot(qc, flag, ancilla, z_measurement)
+
+ _ancilla_cnot(qc, stab[1], ancilla, z_measurement)
+ _ancilla_cnot(qc, stab[2], ancilla, z_measurement)
+
+ _ancilla_cnot(qc, flag, ancilla, z_measurement)
+ _flag_measure(qc, flag, flag_meas, z_measurement)
+
+ _ancilla_cnot(qc, stab[3], ancilla, z_measurement)
+
+ if not z_measurement:
+ qc.h(ancilla)
+ qc.measure(ancilla, measurement_bit)
+
+
+def measure_flagged_6(
+ qc: QuantumCircuit,
+ stab: list[Qubit] | npt.NDArray[np.int_],
+ ancilla: AncillaQubit,
+ measurement_bit: ClBit,
+ z_measurement: bool = True,
+) -> None:
+ """Measure a 6-flagged stabilizer using an optimized scheme."""
+ assert len(stab) == 6
+ flag = AncillaRegister(2)
+ meas = ClassicalRegister(2)
+
+ qc.add_register(flag)
+ qc.add_register(meas)
+
+ if not z_measurement:
+ qc.h(ancilla)
+
+ _ancilla_cnot(qc, stab[0], ancilla, z_measurement)
+
+ _flag_init(qc, flag[0], z_measurement)
+ _ancilla_cnot(qc, flag[0], ancilla, z_measurement)
+
+ _ancilla_cnot(qc, stab[1], ancilla, z_measurement)
+
+ _flag_init(qc, flag[1], z_measurement)
+ _ancilla_cnot(qc, flag[1], ancilla, z_measurement)
+
+ _ancilla_cnot(qc, stab[2], ancilla, z_measurement)
+ _ancilla_cnot(qc, stab[3], ancilla, z_measurement)
+
+ _ancilla_cnot(qc, flag[0], ancilla, z_measurement)
+ _flag_measure(qc, flag[0], meas[0], z_measurement)
+
+ _ancilla_cnot(qc, stab[4], ancilla, z_measurement)
+
+ _ancilla_cnot(qc, flag[1], ancilla, z_measurement)
+ _flag_measure(qc, flag[1], meas[1], z_measurement)
+
+ _ancilla_cnot(qc, stab[5], ancilla, z_measurement)
+
+ if not z_measurement:
+ qc.h(ancilla)
+ qc.measure(ancilla, measurement_bit)
+
+
+def measure_flagged_8(
+ qc: QuantumCircuit,
+ stab: list[Qubit] | npt.NDArray[np.int_],
+ ancilla: AncillaQubit,
+ measurement_bit: ClBit,
+ z_measurement: bool = True,
+) -> None:
+ """Measure an 8-flagged stabilizer using an optimized scheme."""
+ assert len(stab) == 8
+ flag = AncillaRegister(3)
+ meas = ClassicalRegister(3)
+ qc.add_register(flag)
+ qc.add_register(meas)
+
+ if not z_measurement:
+ qc.h(ancilla)
+
+ _ancilla_cnot(qc, stab[0], ancilla, z_measurement)
+
+ _flag_init(qc, flag[0], z_measurement)
+ _ancilla_cnot(qc, flag[0], ancilla, z_measurement)
+
+ _ancilla_cnot(qc, stab[1], ancilla, z_measurement)
+
+ _flag_init(qc, flag[1], z_measurement)
+ _ancilla_cnot(qc, flag[1], ancilla, z_measurement)
+
+ _ancilla_cnot(qc, stab[2], ancilla, z_measurement)
+ _ancilla_cnot(qc, stab[3], ancilla, z_measurement)
+
+ _flag_init(qc, flag[2], z_measurement)
+ _ancilla_cnot(qc, flag[2], ancilla, z_measurement)
+
+ _ancilla_cnot(qc, stab[4], ancilla, z_measurement)
+ _ancilla_cnot(qc, stab[5], ancilla, z_measurement)
+
+ _ancilla_cnot(qc, flag[0], ancilla, z_measurement)
+ _flag_measure(qc, flag[0], meas[0], z_measurement)
+
+ _ancilla_cnot(qc, stab[6], ancilla, z_measurement)
+
+ _ancilla_cnot(qc, flag[2], ancilla, z_measurement)
+ _flag_measure(qc, flag[2], meas[2], z_measurement)
+
+ _ancilla_cnot(qc, flag[1], ancilla, z_measurement)
+ _flag_measure(qc, flag[1], meas[1], z_measurement)
+
+ _ancilla_cnot(qc, stab[7], ancilla, z_measurement)
+
+ if not z_measurement:
+ qc.h(ancilla)
+ qc.measure(ancilla, measurement_bit)
+
+
+def _hook_errors(stabs: list[npt.NDArray[np.int8]]) -> npt.NDArray[np.int8]:
+ """Assuming CNOTs are executed in ascending order of qubit index, this function gives all the hook errors of the given stabilizer measurements."""
+ errors = []
+ for stab in stabs:
+ error = stab.copy()
+ for i in range(len(stab)):
+ if stab[i] == 1:
+ error[i] = 0
+ errors.append(error.copy())
+ error[i] = 1
+ return np.array(errors)
diff --git a/src/mqt/qecc/sample_codes/cc_4_8_8_d5/hx.npy b/src/mqt/qecc/sample_codes/cc_4_8_8_d5/hx.npy
new file mode 100644
index 00000000..bad45712
Binary files /dev/null and b/src/mqt/qecc/sample_codes/cc_4_8_8_d5/hx.npy differ
diff --git a/src/mqt/qecc/sample_codes/cc_4_8_8_d5/hz.npy b/src/mqt/qecc/sample_codes/cc_4_8_8_d5/hz.npy
new file mode 100644
index 00000000..bad45712
Binary files /dev/null and b/src/mqt/qecc/sample_codes/cc_4_8_8_d5/hz.npy differ
diff --git a/src/mqt/qecc/sample_codes/golay/hx.npy b/src/mqt/qecc/sample_codes/golay/hx.npy
new file mode 100644
index 00000000..645dc143
Binary files /dev/null and b/src/mqt/qecc/sample_codes/golay/hx.npy differ
diff --git a/src/mqt/qecc/sample_codes/golay/hz.npy b/src/mqt/qecc/sample_codes/golay/hz.npy
new file mode 100644
index 00000000..645dc143
Binary files /dev/null and b/src/mqt/qecc/sample_codes/golay/hz.npy differ
diff --git a/src/mqt/qecc/sample_codes/hamming/hx.npy b/src/mqt/qecc/sample_codes/hamming/hx.npy
new file mode 100644
index 00000000..eb25185f
Binary files /dev/null and b/src/mqt/qecc/sample_codes/hamming/hx.npy differ
diff --git a/src/mqt/qecc/sample_codes/hamming/hz.npy b/src/mqt/qecc/sample_codes/hamming/hz.npy
new file mode 100644
index 00000000..eb25185f
Binary files /dev/null and b/src/mqt/qecc/sample_codes/hamming/hz.npy differ
diff --git a/src/mqt/qecc/sample_codes/rotated_surface_d3/hx.npy b/src/mqt/qecc/sample_codes/rotated_surface_d3/hx.npy
new file mode 100644
index 00000000..8fddc13b
Binary files /dev/null and b/src/mqt/qecc/sample_codes/rotated_surface_d3/hx.npy differ
diff --git a/src/mqt/qecc/sample_codes/rotated_surface_d3/hz.npy b/src/mqt/qecc/sample_codes/rotated_surface_d3/hz.npy
new file mode 100644
index 00000000..14ef1817
Binary files /dev/null and b/src/mqt/qecc/sample_codes/rotated_surface_d3/hz.npy differ
diff --git a/src/mqt/qecc/sample_codes/rotated_surface_d5/hx.npy b/src/mqt/qecc/sample_codes/rotated_surface_d5/hx.npy
new file mode 100644
index 00000000..b456ca41
Binary files /dev/null and b/src/mqt/qecc/sample_codes/rotated_surface_d5/hx.npy differ
diff --git a/src/mqt/qecc/sample_codes/rotated_surface_d5/hz.npy b/src/mqt/qecc/sample_codes/rotated_surface_d5/hz.npy
new file mode 100644
index 00000000..ef8c8532
Binary files /dev/null and b/src/mqt/qecc/sample_codes/rotated_surface_d5/hz.npy differ
diff --git a/src/mqt/qecc/sample_codes/shor/hx.npy b/src/mqt/qecc/sample_codes/shor/hx.npy
new file mode 100644
index 00000000..584d47ff
Binary files /dev/null and b/src/mqt/qecc/sample_codes/shor/hx.npy differ
diff --git a/src/mqt/qecc/sample_codes/shor/hz.npy b/src/mqt/qecc/sample_codes/shor/hz.npy
new file mode 100644
index 00000000..e4cbdbd5
Binary files /dev/null and b/src/mqt/qecc/sample_codes/shor/hz.npy differ
diff --git a/src/mqt/qecc/sample_codes/steane/hx.npy b/src/mqt/qecc/sample_codes/steane/hx.npy
new file mode 100644
index 00000000..7531a5ab
Binary files /dev/null and b/src/mqt/qecc/sample_codes/steane/hx.npy differ
diff --git a/src/mqt/qecc/sample_codes/steane/hz.npy b/src/mqt/qecc/sample_codes/steane/hz.npy
new file mode 100644
index 00000000..7531a5ab
Binary files /dev/null and b/src/mqt/qecc/sample_codes/steane/hz.npy differ
diff --git a/src/mqt/qecc/sample_codes/tetrahedral/hx.npy b/src/mqt/qecc/sample_codes/tetrahedral/hx.npy
new file mode 100644
index 00000000..5d8cc634
Binary files /dev/null and b/src/mqt/qecc/sample_codes/tetrahedral/hx.npy differ
diff --git a/src/mqt/qecc/sample_codes/tetrahedral/hz.npy b/src/mqt/qecc/sample_codes/tetrahedral/hz.npy
new file mode 100644
index 00000000..1db80cdf
Binary files /dev/null and b/src/mqt/qecc/sample_codes/tetrahedral/hz.npy differ
diff --git a/test/python/ft_stateprep/test_simulation.py b/test/python/ft_stateprep/test_simulation.py
new file mode 100644
index 00000000..d053589a
--- /dev/null
+++ b/test/python/ft_stateprep/test_simulation.py
@@ -0,0 +1,141 @@
+"""Test the simulation of fault-tolerant state preparation circuits."""
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+import numpy as np
+import pytest
+
+from mqt.qecc import CSSCode
+from mqt.qecc.ft_stateprep import (
+ LutDecoder,
+ NoisyNDFTStatePrepSimulator,
+ gate_optimal_verification_circuit,
+ heuristic_prep_circuit,
+)
+
+if TYPE_CHECKING: # pragma: no cover
+ from qiskit import QuantumCircuit
+
+
+@pytest.fixture
+def steane_code() -> CSSCode:
+ """Return the Steane code."""
+ return CSSCode.from_code_name("steane")
+
+
+@pytest.fixture
+def non_ft_steane_zero(steane_code: CSSCode) -> QuantumCircuit:
+ """Return a non fault-tolerant Steane code state preparation circuit."""
+ return heuristic_prep_circuit(steane_code).circ
+
+
+@pytest.fixture
+def non_ft_steane_plus(steane_code: CSSCode) -> QuantumCircuit:
+ """Return a non fault-tolerant Steane code state preparation circuit."""
+ return heuristic_prep_circuit(steane_code, zero_state=False).circ
+
+
+@pytest.fixture
+def ft_steane_zero(steane_code: CSSCode) -> QuantumCircuit:
+ """Return a fault-tolerant Steane code state preparation circuit."""
+ circ = heuristic_prep_circuit(steane_code)
+ return gate_optimal_verification_circuit(circ, max_timeout=2)
+
+
+@pytest.fixture
+def ft_steane_plus(steane_code: CSSCode) -> QuantumCircuit:
+ """Return a fault-tolerant Steane code state preparation circuit."""
+ circ = heuristic_prep_circuit(steane_code, zero_state=False)
+ return gate_optimal_verification_circuit(circ, max_timeout=2)
+
+
+def test_lut(steane_code: CSSCode) -> None:
+ """Test the LutDecoder class."""
+ assert steane_code.Hx is not None, "Steane code does not have X stabilizers."
+ assert steane_code.Hz is not None, "Steane code does not have Z stabilizers."
+
+ lut = LutDecoder(steane_code, init_luts=False)
+
+ assert len(lut.x_lut) == 0
+ assert len(lut.z_lut) == 0
+
+ lut.generate_x_lut()
+ lut.generate_z_lut()
+
+ assert len(lut.x_lut) != 0
+ assert lut.x_lut is lut.z_lut # Code is self dual so luts should be the same
+
+ error_1 = np.zeros(steane_code.n, dtype=np.int8) # type: ignore[var-annotated]
+ error_1[0] = 1
+
+ error_w1 = (steane_code.Hx[0] + error_1) % 2
+ syndrome_1 = steane_code.get_x_syndrome(error_w1)
+ estimate_1 = lut.decode_x(syndrome_1.astype(np.int8))
+ assert steane_code.stabilizer_eq_x_error(estimate_1, error_1)
+ assert steane_code.stabilizer_eq_z_error(estimate_1, error_1)
+
+ error_2 = np.zeros(steane_code.n, dtype=np.int8) # type: ignore[var-annotated]
+ error_2[0] = 1
+ error_2[1] = 1
+ error_w2 = (steane_code.Hx[0] + error_2) % 2
+ syndrome_2 = steane_code.get_x_syndrome(error_w2)
+ estimate_2 = lut.decode_x(syndrome_2.astype(np.int8))
+
+ # Weight 2 error should have be estimated to be weight 1
+ assert not steane_code.stabilizer_eq_x_error(estimate_2, error_2)
+ assert np.sum(estimate_2) == 1
+
+ error_3 = np.ones((steane_code.n), dtype=np.int8) # type: ignore[var-annotated]
+ error_w3 = (steane_code.Hx[0] + error_3) % 2
+ syndrome_3 = steane_code.get_x_syndrome(error_w3)
+ estimate_3 = lut.decode_x(syndrome_3.astype(np.int8))
+ # Weight 3 error should have be estimated to be weight 0
+ assert not steane_code.stabilizer_eq_x_error(estimate_3, error_3)
+ assert steane_code.stabilizer_eq_x_error(estimate_3, np.zeros(steane_code.n, dtype=np.int8))
+ assert np.sum(estimate_3) == 0
+
+
+def test_non_ft_sim_zero(steane_code: CSSCode, non_ft_steane_zero: QuantumCircuit) -> None:
+ """Test the simulation of a non fault-tolerant state preparation circuit for the Steane |0>."""
+ tol = 5e-4
+ p = 1e-3
+ lower = 1e-4
+ simulator = NoisyNDFTStatePrepSimulator(non_ft_steane_zero, steane_code, p=p)
+ p_l, _, _, _ = simulator.logical_error_rate(min_errors=10)
+
+ assert p_l - tol > lower
+
+
+def test_ft_sim_zero(steane_code: CSSCode, ft_steane_zero: QuantumCircuit) -> None:
+ """Test the simulation of a fault-tolerant state preparation circuit for the Steane |0>."""
+ tol = 5e-4
+ p = 1e-3
+ lower = 1e-4
+ simulator = NoisyNDFTStatePrepSimulator(ft_steane_zero, steane_code, p=p)
+ p_l, _, _, _ = simulator.logical_error_rate(min_errors=10)
+
+ assert p_l - tol < lower
+
+
+def test_non_ft_sim_plus(steane_code: CSSCode, non_ft_steane_plus: QuantumCircuit) -> None:
+ """Test the simulation of a non fault-tolerant state preparation circuit for the Steane |0>."""
+ tol = 5e-4
+ p = 1e-3
+ lower = 1e-4
+ simulator = NoisyNDFTStatePrepSimulator(non_ft_steane_plus, steane_code, p=p, zero_state=False)
+ p_l, _, _, _ = simulator.logical_error_rate(min_errors=10)
+
+ assert p_l - tol > lower
+
+
+def test_ft_sim_plus(steane_code: CSSCode, ft_steane_plus: QuantumCircuit) -> None:
+ """Test the simulation of a fault-tolerant state preparation circuit for the Steane |0>."""
+ tol = 5e-4
+ p = 1e-3
+ lower = 1e-4
+ simulator = NoisyNDFTStatePrepSimulator(ft_steane_plus, steane_code, p=p, zero_state=False)
+ p_l, _, _, _ = simulator.logical_error_rate(min_errors=10)
+
+ assert p_l - tol < lower
diff --git a/test/python/ft_stateprep/test_stateprep.py b/test/python/ft_stateprep/test_stateprep.py
new file mode 100644
index 00000000..a6a007c1
--- /dev/null
+++ b/test/python/ft_stateprep/test_stateprep.py
@@ -0,0 +1,427 @@
+"""Test synthesis of state preparation and verification circuits for FT state preparation."""
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+import numpy as np
+import pytest
+from ldpc import mod2
+from qiskit.quantum_info import Clifford
+
+from mqt.qecc import CSSCode
+from mqt.qecc.ft_stateprep import (
+ depth_optimal_prep_circuit,
+ gate_optimal_prep_circuit,
+ gate_optimal_verification_circuit,
+ gate_optimal_verification_stabilizers,
+ heuristic_prep_circuit,
+ heuristic_verification_circuit,
+ heuristic_verification_stabilizers,
+)
+
+if TYPE_CHECKING: # pragma: no cover
+ import numpy.typing as npt
+ from qiskit import QuantumCircuit
+
+ from mqt.qecc.ft_stateprep import StatePrepCircuit
+
+
+@pytest.fixture
+def steane_code() -> CSSCode:
+ """Return the Steane code."""
+ return CSSCode.from_code_name("Steane")
+
+
+@pytest.fixture
+def surface_code() -> CSSCode:
+ """Return the distance 3 rotated Surface Code."""
+ return CSSCode.from_code_name("surface", 3)
+
+
+@pytest.fixture
+def tetrahedral_code() -> CSSCode:
+ """Return the tetrahedral code."""
+ return CSSCode.from_code_name("tetrahedral")
+
+
+@pytest.fixture
+def cc_4_8_8_code() -> CSSCode:
+ """Return the d=5 4,8,8 color code."""
+ return CSSCode.from_code_name("cc_4_8_8")
+
+
+@pytest.fixture
+def steane_code_sp(steane_code: CSSCode) -> StatePrepCircuit:
+ """Return a non-ft state preparation circuit for the Steane code."""
+ sp_circ = heuristic_prep_circuit(steane_code)
+ sp_circ.compute_fault_sets()
+ return sp_circ
+
+
+@pytest.fixture
+def tetrahedral_code_sp(tetrahedral_code: CSSCode) -> StatePrepCircuit:
+ """Return a non-ft state preparation circuit for the tetrahedral code."""
+ sp_circ = heuristic_prep_circuit(tetrahedral_code)
+ sp_circ.compute_fault_sets()
+ return sp_circ
+
+
+@pytest.fixture
+def color_code_d5_sp(cc_4_8_8_code: CSSCode) -> StatePrepCircuit:
+ """Return a non-ft state preparation circuit for the d=5 4,8,8 color code."""
+ sp_circ = heuristic_prep_circuit(cc_4_8_8_code)
+ sp_circ.compute_fault_sets()
+ return sp_circ
+
+
+def eq_span(a: npt.NDArray[np.int_], b: npt.NDArray[np.int_]) -> bool:
+ """Check if two matrices have the same row space."""
+ return a.shape == b.shape and mod2.rank(np.vstack((a, b))) == mod2.rank(a) == mod2.rank(b)
+
+
+def in_span(m: npt.NDArray[np.int_], v: npt.NDArray[np.int_]) -> bool:
+ """Check if a vector is in the row space of a matrix."""
+ return bool(mod2.rank(np.vstack((m, v))) == mod2.rank(m))
+
+
+def get_stabs(qc: QuantumCircuit) -> tuple[npt.NDArray[np.int_], npt.NDArray[np.int_]]:
+ """Return the stabilizers of a quantum circuit."""
+ cliff = Clifford(qc)
+ x = cliff.stab_x.astype(int)
+ x = x[np.where(np.logical_not(np.all(x == 0, axis=1)))[0]]
+ z = cliff.stab_z.astype(int)
+ z = z[np.where(np.logical_not(np.all(z == 0, axis=1)))[0]]
+ return x, z
+
+
+@pytest.mark.parametrize("code_name", ["steane", "tetrahedral", "surface", "cc_4_8_8"])
+def test_heuristic_prep_consistent(code_name: str) -> None:
+ """Check that heuristic_prep_circuit returns a valid circuit with the correct stabilizers."""
+ code = CSSCode.from_code_name(code_name)
+
+ sp_circ = heuristic_prep_circuit(code)
+ circ = sp_circ.circ
+ max_cnots = np.sum(code.Hx) + np.sum(code.Hz) # type: ignore[arg-type]
+
+ assert circ.num_qubits == code.n
+ assert circ.num_nonlocal_gates() <= max_cnots
+
+ x, z = get_stabs(circ)
+ assert eq_span(code.Hx, x) # type: ignore[arg-type]
+ assert eq_span(np.vstack((code.Hz, code.Lz)), z) # type: ignore[arg-type]
+
+
+@pytest.mark.parametrize("code", ["steane_code", "surface_code"])
+def test_gate_optimal_prep_consistent(code: CSSCode, request) -> None: # type: ignore[no-untyped-def]
+ """Check that gate_optimal_prep_circuit returns a valid circuit with the correct stabilizers."""
+ code = request.getfixturevalue(code)
+ sp_circ = gate_optimal_prep_circuit(code, max_timeout=6)
+ assert sp_circ is not None
+ assert sp_circ.zero_state
+
+ circ = sp_circ.circ
+ max_cnots = np.sum(code.Hx) + np.sum(code.Hz) # type: ignore[arg-type]
+
+ assert circ.num_qubits == code.n
+ assert circ.num_nonlocal_gates() <= max_cnots
+
+ x, z = get_stabs(circ)
+ assert eq_span(code.Hx, x) # type: ignore[arg-type]
+ assert eq_span(np.vstack((code.Hz, code.Lz)), z) # type: ignore[arg-type]
+
+
+@pytest.mark.parametrize("code", ["steane_code", "surface_code"])
+def test_depth_optimal_prep_consistent(code: CSSCode, request) -> None: # type: ignore[no-untyped-def]
+ """Check that depth_optimal_prep_circuit returns a valid circuit with the correct stabilizers."""
+ code = request.getfixturevalue(code)
+
+ sp_circ = depth_optimal_prep_circuit(code, max_timeout=6)
+ assert sp_circ is not None
+ circ = sp_circ.circ
+ max_cnots = np.sum(code.Hx) + np.sum(code.Hz) # type: ignore[arg-type]
+
+ assert circ.num_qubits == code.n
+ assert circ.num_nonlocal_gates() <= max_cnots
+
+ x, z = get_stabs(circ)
+ assert eq_span(code.Hx, x) # type: ignore[arg-type]
+ assert eq_span(np.vstack((code.Hz, code.Lz)), z) # type: ignore[arg-type]
+
+
+@pytest.mark.parametrize("code", ["steane_code", "surface_code"])
+def test_plus_state_gate_optimal(code: CSSCode, request) -> None: # type: ignore[no-untyped-def]
+ """Test synthesis of the plus state."""
+ code = request.getfixturevalue(code)
+ sp_circ_plus = gate_optimal_prep_circuit(code, max_timeout=5, zero_state=False)
+
+ assert sp_circ_plus is not None
+ assert not sp_circ_plus.zero_state
+
+ circ_plus = sp_circ_plus.circ
+ max_cnots = np.sum(code.Hx) + np.sum(code.Hz) # type: ignore[arg-type]
+
+ assert circ_plus.num_qubits == code.n
+ assert circ_plus.num_nonlocal_gates() <= max_cnots
+
+ x, z = get_stabs(circ_plus)
+ assert eq_span(code.Hz, z) # type: ignore[arg-type]
+ assert eq_span(np.vstack((code.Hx, code.Lx)), x) # type: ignore[arg-type]
+
+ sp_circ_zero = gate_optimal_prep_circuit(code, max_timeout=5, zero_state=True)
+
+ assert sp_circ_zero is not None
+
+ circ_zero = sp_circ_zero.circ
+ x_zero, z_zero = get_stabs(circ_zero)
+
+ if code.is_self_dual():
+ assert np.array_equal(x, z_zero)
+ assert np.array_equal(z, x_zero)
+ else:
+ assert not np.array_equal(x, z_zero)
+ assert not np.array_equal(z, x_zero)
+
+
+@pytest.mark.parametrize("code", ["steane_code", "surface_code", "tetrahedral_code"])
+def test_plus_state_heuristic(code: CSSCode, request) -> None: # type: ignore[no-untyped-def]
+ """Test synthesis of the plus state."""
+ code = request.getfixturevalue(code)
+ sp_circ_plus = heuristic_prep_circuit(code, zero_state=False)
+
+ assert sp_circ_plus is not None
+ assert not sp_circ_plus.zero_state
+
+ circ_plus = sp_circ_plus.circ
+ max_cnots = np.sum(code.Hx) + np.sum(code.Hz) # type: ignore[arg-type]
+
+ assert circ_plus.num_qubits == code.n
+ assert circ_plus.num_nonlocal_gates() <= max_cnots
+
+ x, z = get_stabs(circ_plus)
+ assert eq_span(code.Hz, z) # type: ignore[arg-type]
+ assert eq_span(np.vstack((code.Hx, code.Lx)), x) # type: ignore[arg-type]
+
+ sp_circ_zero = heuristic_prep_circuit(code, zero_state=True)
+ circ_zero = sp_circ_zero.circ
+ x_zero, z_zero = get_stabs(circ_zero)
+
+ if code.is_self_dual():
+ assert np.array_equal(x, z_zero)
+ assert np.array_equal(z, x_zero)
+ else:
+ assert not np.array_equal(x, z_zero)
+ assert not np.array_equal(z, x_zero)
+
+
+def test_optimal_steane_verification_circuit(steane_code_sp: StatePrepCircuit) -> None:
+ """Test that the optimal verification circuit for the Steane code is correct."""
+ circ = steane_code_sp
+ ver_stabs_layers = gate_optimal_verification_stabilizers(circ, x_errors=True, max_timeout=5)
+
+ assert len(ver_stabs_layers) == 1 # 1 Ancilla measurement
+
+ ver_stabs = ver_stabs_layers[0]
+
+ assert np.sum(ver_stabs) == 3 # 3 CNOTs
+ z_gens = circ.z_checks
+
+ for stab in ver_stabs:
+ assert in_span(z_gens, stab)
+
+ errors = circ.compute_fault_set(1)
+ non_detected = np.where(np.all(ver_stabs @ errors.T % 2 == 0, axis=1))[0]
+ assert len(non_detected) == 0
+
+ # Check that circuit is correct
+ circ_ver = gate_optimal_verification_circuit(circ)
+ assert circ_ver.num_qubits == circ.num_qubits + 1
+ assert circ_ver.num_nonlocal_gates() == np.sum(ver_stabs) + circ.circ.num_nonlocal_gates()
+ assert circ_ver.depth() == np.sum(ver_stabs) + circ.circ.depth() + 1 # 1 for the measurement
+
+
+def test_heuristic_steane_verification_circuit(steane_code_sp: StatePrepCircuit) -> None:
+ """Test that the optimal verification circuit for the Steane code is correct."""
+ circ = steane_code_sp
+
+ ver_stabs_layers = heuristic_verification_stabilizers(circ, x_errors=True)
+
+ assert len(ver_stabs_layers) == 1 # 1 layer of verification measurements
+
+ ver_stabs = ver_stabs_layers[0]
+ assert len(ver_stabs) == 1 # 1 Ancilla measurement
+ assert np.sum(ver_stabs[0]) == 3 # 3 CNOTs
+ z_gens = circ.z_checks
+
+ for stab in ver_stabs:
+ assert in_span(z_gens, stab)
+
+ errors = circ.compute_fault_set(1)
+ non_detected = np.where(np.all(ver_stabs @ errors.T % 2 == 0, axis=1))[0]
+ assert len(non_detected) == 0
+
+ # Check that circuit is correct
+ circ_ver = heuristic_verification_circuit(circ)
+ assert circ_ver.num_qubits == circ.num_qubits + 1
+ assert circ_ver.num_nonlocal_gates() == np.sum(ver_stabs) + circ.circ.num_nonlocal_gates()
+ assert circ_ver.depth() == np.sum(ver_stabs) + circ.circ.depth() + 1 # 1 for the measurement
+
+
+def test_optimal_tetrahedral_verification_circuit(tetrahedral_code_sp: StatePrepCircuit) -> None:
+ """Test the optimal verification circuit for the tetrahedral code is correct.
+
+ The tetrahedral code has an x-distance of 7. We expect that the verification only checks for a single propagated error since the tetrahedral code has a distance of 3.
+ """
+ circ = tetrahedral_code_sp
+
+ ver_stabs_layers = gate_optimal_verification_stabilizers(circ, x_errors=True, max_ancillas=1, max_timeout=5)
+
+ assert len(ver_stabs_layers) == 1 # 1 layer of verification measurements
+
+ ver_stabs = ver_stabs_layers[0]
+ assert len(ver_stabs) == 1 # 1 Ancilla measurement
+ assert np.sum(ver_stabs[0]) == 3 # 3 CNOTs
+ z_gens = circ.z_checks
+
+ for stab in ver_stabs:
+ assert in_span(z_gens, stab)
+
+ errors = circ.compute_fault_set(1)
+ non_detected = np.where(np.all(ver_stabs @ errors.T % 2 == 0, axis=1))[0]
+ assert len(non_detected) == 0
+
+ # Check that circuit is correct
+ circ_ver = gate_optimal_verification_circuit(circ, max_ancillas=1, max_timeout=5)
+ assert circ_ver.num_qubits == circ.num_qubits + 1
+ assert circ_ver.num_nonlocal_gates() == np.sum(ver_stabs) + circ.circ.num_nonlocal_gates()
+ assert circ_ver.depth() == np.sum(ver_stabs) + circ.circ.depth() + 1 # 1 for the measurement
+
+
+def test_heuristic_tetrahedral_verification_circuit(tetrahedral_code_sp: StatePrepCircuit) -> None:
+ """Test the optimal verification circuit for the tetrahedral code is correct.
+
+ The tetrahedral code has an x-distance of 7. We expect that the verification only checks for a single propagated error since the tetrahedral code has a distance of 3.
+ """
+ circ = tetrahedral_code_sp
+
+ ver_stabs_layers = heuristic_verification_stabilizers(circ, x_errors=True)
+
+ assert len(ver_stabs_layers) == 1 # 1 layer of verification measurements
+
+ ver_stabs = ver_stabs_layers[0]
+ assert len(ver_stabs) == 1 # 1 Ancilla measurement
+ assert np.sum(ver_stabs[0]) == 3 # 3 CNOTs
+ z_gens = circ.z_checks
+
+ for stab in ver_stabs:
+ assert in_span(z_gens, stab)
+
+ errors = circ.compute_fault_set(1)
+ non_detected = np.where(np.all(ver_stabs @ errors.T % 2 == 0, axis=1))[0]
+ assert len(non_detected) == 0
+
+ # Check that circuit is correct
+ circ_ver = heuristic_verification_circuit(circ)
+ assert circ_ver.num_qubits == circ.num_qubits + 1
+ assert circ_ver.num_nonlocal_gates() == np.sum(ver_stabs) + circ.circ.num_nonlocal_gates()
+ assert circ_ver.depth() == np.sum(ver_stabs) + circ.circ.depth() + 1 # 1 for the measurement
+
+
+def test_not_full_ft_opt_cc5(color_code_d5_sp: StatePrepCircuit) -> None:
+ """Test that the optimal verification is also correct for higher distance.
+
+ Ignore Z errors.
+ Due to time constraints, we set the timeout for each search to 2 seconds.
+ """
+ circ = color_code_d5_sp
+
+ ver_stabs_layers = gate_optimal_verification_stabilizers(circ, x_errors=True, max_ancillas=3, max_timeout=5)
+
+ assert len(ver_stabs_layers) == 2 # 2 layers of verification measurements
+
+ ver_stabs_1 = ver_stabs_layers[0]
+ assert len(ver_stabs_1) == 2 # 2 Ancilla measurements
+ assert np.sum(ver_stabs_1) == 9 # 9 CNOTs
+
+ ver_stabs_2 = ver_stabs_layers[1]
+ assert len(ver_stabs_2) == 3 # 2 Ancilla measurements
+ assert np.sum(ver_stabs_2) <= 14 # less than 14 CNOTs (sometimes 13, sometimes 14 depending on how fast the CPU is)
+
+ z_gens = circ.z_checks
+
+ for stab in np.vstack((ver_stabs_1, ver_stabs_2)):
+ assert in_span(z_gens, stab)
+
+ errors_1 = circ.compute_fault_set(1)
+ non_detected = np.where(np.all(ver_stabs_1 @ errors_1.T % 2 == 0, axis=1))[0]
+ assert len(non_detected) == 0
+
+ errors_2 = circ.compute_fault_set(2)
+ non_detected = np.where(np.all(ver_stabs_2 @ errors_2.T % 2 == 0, axis=1))[0]
+ assert len(non_detected) == 0
+
+ # Check that circuit is correct
+ n_cnots = np.sum(ver_stabs_1) + np.sum(ver_stabs_2)
+ circ_ver = gate_optimal_verification_circuit(circ, max_ancillas=3, max_timeout=5, full_fault_tolerance=True)
+ assert circ_ver.num_qubits > circ.num_qubits + 5 # overhead from the flags
+ assert circ_ver.num_nonlocal_gates() > n_cnots + circ.circ.num_nonlocal_gates() # Overhead from Flag CNOTS
+
+
+def test_not_full_ft_heuristic_cc5(color_code_d5_sp: StatePrepCircuit) -> None:
+ """Test that the optimal verification circuit for the Steane code is correct.
+
+ Ignore Z errors.
+ """
+ circ = color_code_d5_sp
+ ver_stabs_layers = heuristic_verification_stabilizers(circ, x_errors=True)
+
+ assert len(ver_stabs_layers) == 2 # 2 layers of verification measurements
+
+ ver_stabs_1 = ver_stabs_layers[0]
+ ver_stabs_2 = ver_stabs_layers[1]
+
+ z_gens = circ.z_checks
+
+ for stab in np.vstack((ver_stabs_1, ver_stabs_2)):
+ assert in_span(z_gens, stab)
+
+ errors_1 = circ.compute_fault_set(1)
+ non_detected = np.where(np.all(ver_stabs_1 @ errors_1.T % 2 == 0, axis=1))[0]
+ assert len(non_detected) == 0
+
+ errors_2 = circ.compute_fault_set(2)
+ non_detected = np.where(np.all(ver_stabs_2 @ errors_2.T % 2 == 0, axis=1))[0]
+ assert len(non_detected) == 0
+
+ # Check that circuit is correct
+ circ_ver = heuristic_verification_circuit(circ, full_fault_tolerance=False)
+ n_cnots = np.sum(ver_stabs_1) + np.sum(ver_stabs_2)
+ assert circ_ver.num_qubits == circ.num_qubits + len(ver_stabs_1) + len(ver_stabs_2)
+ assert circ_ver.num_nonlocal_gates() == n_cnots + circ.circ.num_nonlocal_gates()
+
+
+def test_full_ft_opt_cc5(color_code_d5_sp: StatePrepCircuit) -> None:
+ """Test that the optimal verification is also correct for higher distance.
+
+ Include Z errors.
+ Due to time constraints, we set the timeout for each search to 2 seconds.
+ """
+ circ = color_code_d5_sp
+
+ circ_ver_full_ft = gate_optimal_verification_circuit(circ, max_ancillas=3, max_timeout=5, full_fault_tolerance=True)
+ circ_ver_x_ft = gate_optimal_verification_circuit(circ, max_ancillas=3, max_timeout=5, full_fault_tolerance=False)
+ assert circ_ver_full_ft.num_nonlocal_gates() > circ_ver_x_ft.num_nonlocal_gates()
+ assert circ_ver_full_ft.depth() > circ_ver_x_ft.depth()
+
+
+def test_full_ft_heuristic_cc5(color_code_d5_sp: StatePrepCircuit) -> None:
+ """Test that the optimal verification is also correct for higher distance.
+
+ Include Z errors.
+ """
+ circ = color_code_d5_sp
+
+ circ_ver_full_ft = heuristic_verification_circuit(circ, full_fault_tolerance=True)
+ circ_ver_x_ft = heuristic_verification_circuit(circ, full_fault_tolerance=False)
+ assert circ_ver_full_ft.num_nonlocal_gates() > circ_ver_x_ft.num_nonlocal_gates()
+ assert circ_ver_full_ft.depth() > circ_ver_x_ft.depth()
diff --git a/test/python/test_code.py b/test/python/test_code.py
new file mode 100644
index 00000000..2026da4f
--- /dev/null
+++ b/test/python/test_code.py
@@ -0,0 +1,138 @@
+"""Test the CSSCode class."""
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+import numpy as np
+import pytest
+
+from mqt.qecc import CSSCode, InvalidCSSCodeError
+
+if TYPE_CHECKING: # pragma: no cover
+ import numpy.typing as npt
+
+
+@pytest.fixture
+def rep_code() -> tuple[npt.NDArray[np.int8] | None, npt.NDArray[np.int8] | None]:
+ """Return the parity check matrices for the repetition code."""
+ hx = np.array([[1, 1, 0], [0, 0, 1]])
+ hz = None
+ return hx, hz
+
+
+@pytest.fixture
+def steane_code() -> tuple[npt.NDArray[np.int8], npt.NDArray[np.int8]]:
+ """Return the check matrices for the Steane code."""
+ hx = np.array([[1, 1, 1, 1, 0, 0, 0], [1, 0, 1, 0, 1, 0, 1], [0, 1, 1, 0, 1, 1, 0]])
+ hz = hx
+ return hx, hz
+
+
+def test_invalid_css_codes() -> None:
+ """Test that an invalid CSS code raises an error."""
+ # Violates CSS condition
+ hx = np.array([[1, 1, 1]])
+ hz = np.array([[1, 0, 0]])
+ with pytest.raises(InvalidCSSCodeError):
+ CSSCode(distance=3, Hx=hx, Hz=hz)
+
+ # Distances don't match
+ hz = np.array([[1, 1, 0]])
+ with pytest.raises(InvalidCSSCodeError):
+ CSSCode(distance=3, Hx=hx, Hz=hz, x_distance=4, z_distance=1)
+
+ # Checks not over the same number of qubits
+ hz = np.array([[1, 1]])
+ with pytest.raises(InvalidCSSCodeError):
+ CSSCode(distance=3, Hx=hx, Hz=hz)
+
+ # Invalid distance
+ with pytest.raises(InvalidCSSCodeError):
+ CSSCode(distance=-1, Hx=hx)
+
+ # Checks not provided
+ with pytest.raises(InvalidCSSCodeError):
+ CSSCode(distance=3)
+
+
+@pytest.mark.parametrize("checks", ["steane_code", "rep_code"])
+def test_logicals(checks: tuple[npt.NDArray[np.int8] | None, npt.NDArray[np.int8] | None], request) -> None: # type: ignore[no-untyped-def]
+ """Test the logical operators of the CSSCode class."""
+ hx, hz = request.getfixturevalue(checks)
+ code = CSSCode(distance=3, Hx=hx, Hz=hz)
+ assert code.Lx is not None
+ assert code.Lz is not None
+ assert code.Lx.shape[1] == code.Lz.shape[1] == hx.shape[1]
+ assert code.Lx.shape[0] == code.Lz.shape[0]
+
+ # assert that logicals anticommute
+ assert code.Lx @ code.Lz.T % 2 != 0
+
+ # assert that logicals commute with stabilizers
+ if code.Hz is not None:
+ assert np.all(code.Lx @ code.Hz.T % 2 == 0)
+ if code.Hx is not None:
+ assert np.all(code.Lz @ code.Hx.T % 2 == 0)
+
+
+def test_errors(steane_code: tuple[npt.NDArray[np.int8], npt.NDArray[np.int8]]) -> None:
+ """Test error detection and symdromes."""
+ hx, hz = steane_code
+ code = CSSCode(distance=3, Hx=hx, Hz=hz)
+ e1 = np.array([1, 0, 0, 0, 0, 0, 0])
+ e2 = np.array([0, 1, 0, 0, 1, 0, 0])
+ e3 = np.array([0, 0, 0, 0, 0, 1, 1])
+ e4 = np.array([0, 1, 1, 1, 0, 0, 0])
+
+ assert np.array_equal(code.get_x_syndrome(e1), code.get_z_syndrome(e2))
+ assert np.array_equal(code.get_x_syndrome(e2), code.get_z_syndrome(e2))
+
+ x_syndrome_1 = code.get_x_syndrome(e1)
+ x_syndrome_2 = code.get_x_syndrome(e2)
+ x_syndrome_3 = code.get_x_syndrome(e3)
+ x_syndrome_4 = code.get_x_syndrome(e4)
+
+ assert np.array_equal(x_syndrome_1, x_syndrome_2)
+ assert not np.array_equal(x_syndrome_1, x_syndrome_3)
+ assert np.array_equal(x_syndrome_1, x_syndrome_4)
+
+ # e1 and e2 have same syndrome but if we add them we get a logical error
+ assert code.check_if_logical_x_error((e1 + e2) % 2)
+ assert code.check_if_logical_z_error((e1 + e2) % 2)
+ assert not code.stabilizer_eq_x_error(e1, e2)
+ assert not code.stabilizer_eq_z_error(e1, e2)
+
+ # e1 and e4 on the other hand do not induce a logical error because they are stabilizer equivalent
+ assert not code.check_if_logical_x_error((e1 + e4) % 2)
+ assert not code.check_if_logical_z_error((e1 + e4) % 2)
+ assert code.stabilizer_eq_x_error(e1, e4)
+ assert code.stabilizer_eq_z_error(e1, e4)
+
+
+def test_steane(steane_code: tuple[npt.NDArray[np.int8], npt.NDArray[np.int8]]) -> None:
+ """Test utility functions and correctness of the Steane code."""
+ hx, hz = steane_code
+ code = CSSCode(distance=3, Hx=hx, Hz=hz)
+ assert code.n == 7
+ assert code.k == 1
+ assert code.distance == 3
+ assert code.is_self_dual()
+
+ x_paulis, z_paulis = code.stabs_as_pauli_strings()
+ assert x_paulis is not None
+ assert z_paulis is not None
+ assert len(x_paulis) == len(z_paulis) == 3
+ assert x_paulis == ["XXXXIII", "XIXIXIX", "IXXIXXI"]
+ assert z_paulis == ["ZZZZIII", "ZIZIZIZ", "IZZIZZI"]
+
+ x_log = code.x_logicals_as_pauli_string()
+ z_log = code.z_logicals_as_pauli_string()
+ assert x_log.count("X") == 3
+ assert x_log.count("I") == 4
+ assert z_log.count("Z") == 3
+ assert z_log.count("I") == 4
+
+ hx_reordered = hx[::-1, :]
+ code_reordered = CSSCode(distance=3, Hx=hx_reordered, Hz=hz)
+ assert code == code_reordered