From 47e8099b02524a99921a4ff2b9a1a9d759dc03c0 Mon Sep 17 00:00:00 2001
From: Seneca Meeks <senecacmeeks@gmail.com>
Date: Wed, 17 Jul 2024 17:09:14 -0700
Subject: [PATCH] Adds tools for appending randomized measurement bases and
 processing renyi entropy from bitstring (#6664)

* add utilities for processing renyi entropy and appending randomized measurements + tests

* nit: update test comments

* address comments

* fix np array shape which changes test solution

* Address comments and fix bug

* address comments

* use zip and itertools and update to transformer

* rm print

* type check

* Update cirq-core/cirq/qis/entropy.py

Co-authored-by: Noureldin <noureldinyosri@gmail.com>

* Update cirq-core/cirq/transformers/randomized_measurements.py

Co-authored-by: Noureldin <noureldinyosri@gmail.com>

* comments

* line too long

---------

Co-authored-by: Noureldin <noureldinyosri@gmail.com>
---
 cirq/qis/__init__.py                          |   1 +
 cirq/qis/entropy.py                           | 115 ++++++++++++++++++
 cirq/qis/entropy_test.py                      |  42 +++++++
 cirq/transformers/__init__.py                 |   2 +
 cirq/transformers/randomized_measurements.py  | 109 +++++++++++++++++
 .../randomized_measurements_test.py           |  55 +++++++++
 6 files changed, 324 insertions(+)
 create mode 100644 cirq/qis/entropy.py
 create mode 100644 cirq/qis/entropy_test.py
 create mode 100644 cirq/transformers/randomized_measurements.py
 create mode 100644 cirq/transformers/randomized_measurements_test.py

diff --git a/cirq/qis/__init__.py b/cirq/qis/__init__.py
index 12913d6e8cc..4edceb0dc0c 100644
--- a/cirq/qis/__init__.py
+++ b/cirq/qis/__init__.py
@@ -60,3 +60,4 @@
     average_error,
     decoherence_pauli_error,
 )
+from cirq.qis.entropy import process_renyi_entropy_from_bitstrings
diff --git a/cirq/qis/entropy.py b/cirq/qis/entropy.py
new file mode 100644
index 00000000000..6a70ffaa428
--- /dev/null
+++ b/cirq/qis/entropy.py
@@ -0,0 +1,115 @@
+# Copyright 2024 The Cirq Developers
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from concurrent.futures import ThreadPoolExecutor
+from collections.abc import Sequence
+from itertools import product
+from typing import Any, Optional
+
+import numpy as np
+import numpy.typing as npt
+
+
+def _get_hamming_distance(
+    bitstring_1: npt.NDArray[np.int8], bitstring_2: npt.NDArray[np.int8]
+) -> int:
+    """Calculates the Hamming distance between two bitstrings.
+    Args:
+        bitstring_1: Bitstring 1
+        bitstring_2: Bitstring 2
+    Returns: The Hamming distance
+    """
+    return (bitstring_1 ^ bitstring_2).sum().item()
+
+
+def _bitstrings_to_probs(
+    bitstrings: npt.NDArray[np.int8],
+) -> tuple[npt.NDArray[np.int8], npt.NDArray[Any]]:
+    """Given a list of bitstrings from different measurements returns a probability distribution.
+    Args:
+        bitstrings: The bitstring
+    Returns:
+        A tuple of bitstrings and their corresponding probabilities.
+    """
+
+    num_shots = bitstrings.shape[0]
+    unique_bitstrings, counts = np.unique(bitstrings, return_counts=True, axis=0)
+    probs = counts / num_shots
+
+    return (unique_bitstrings, probs)
+
+
+def _bitstring_format_helper(
+    measured_bitstrings: npt.NDArray[np.int8], subsystem: Sequence[int] | None = None
+) -> npt.NDArray[np.int8]:
+    """Formats the bitstring for analysis based on the selected subsystem.
+    Args:
+        measured_bitstrings: List of sampled measurement outcomes as a numpy array of bitstrings.
+        subsystem: Subsystem of interest
+    Returns: The bitstring string for the subsystem
+    """
+    if subsystem is None:
+        return measured_bitstrings
+
+    return measured_bitstrings[:, :, subsystem]
+
+
+def _compute_bitstrings_contribution_to_purity(bitstrings: npt.NDArray[np.int8]) -> float:
+    """Computes the contribution to the purity of the bitstrings.
+    Args:
+        bitstrings: The bitstrings measured using the same unitary operators
+    Returns: The purity of the bitstring
+    """
+
+    bitstrings, probs = _bitstrings_to_probs(bitstrings)
+    purity = 0
+    for (s, p), (s_prime, p_prime) in product(zip(bitstrings, probs), repeat=2):
+        purity += (-2.0) ** float(-_get_hamming_distance(s, s_prime)) * p * p_prime
+
+    return purity * 2 ** (bitstrings.shape[-1])
+
+
+def process_renyi_entropy_from_bitstrings(
+    measured_bitstrings: npt.NDArray[np.int8],
+    subsystem: tuple[int] | None = None,
+    pool: Optional[ThreadPoolExecutor] = None,
+) -> float:
+    """Compute the Rényi entropy of an array of bitstrings.
+    Args:
+        measured_bitstrings: List of sampled measurement outcomes as a numpy array of bitstrings.
+        subsystem: Subsystem of interest
+        pool: ThreadPoolExecutor used to paralelleize the computation.
+
+    Returns:
+        A float indicating the computed entropy.
+    """
+    bitstrings = _bitstring_format_helper(measured_bitstrings, subsystem)
+    num_shots = bitstrings.shape[1]
+    num_qubits = bitstrings.shape[-1]
+
+    if num_shots == 1:
+        return 0
+
+    if pool is not None:
+        purities = list(pool.map(_compute_bitstrings_contribution_to_purity, list(bitstrings)))
+        purity = np.mean(purities)
+
+    else:
+        purity = np.mean(
+            [_compute_bitstrings_contribution_to_purity(bitstring) for bitstring in bitstrings]
+        )
+
+    purity_unbiased = purity * num_shots / (num_shots - 1) - (2**num_qubits) / (num_shots - 1)
+
+    return -np.log2(purity_unbiased)
diff --git a/cirq/qis/entropy_test.py b/cirq/qis/entropy_test.py
new file mode 100644
index 00000000000..e38d74217c6
--- /dev/null
+++ b/cirq/qis/entropy_test.py
@@ -0,0 +1,42 @@
+# Copyright 2024 The Cirq Developers
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from concurrent.futures import ThreadPoolExecutor
+import pytest
+import numpy as np
+
+from cirq.qis.entropy import process_renyi_entropy_from_bitstrings
+
+
+@pytest.mark.parametrize('pool', [None, ThreadPoolExecutor(max_workers=1)])
+def test_process_renyi_entropy_from_bitstrings(pool):
+    bitstrings = np.array(
+        [
+            [[0, 1, 1, 0], [1, 0, 0, 0], [0, 0, 0, 1], [0, 1, 1, 1]],
+            [[0, 1, 1, 0], [0, 1, 1, 0], [1, 1, 0, 0], [1, 1, 0, 1]],
+            [[0, 0, 1, 1], [0, 0, 0, 0], [0, 1, 1, 1], [0, 1, 1, 1]],
+            [[1, 0, 1, 1], [0, 1, 0, 0], [0, 1, 0, 1], [0, 1, 0, 0]],
+            [[1, 0, 1, 1], [0, 0, 1, 0], [0, 0, 1, 0], [0, 0, 0, 1]],
+        ]
+    )
+    substsytem = (0, 1)
+    entropy = process_renyi_entropy_from_bitstrings(bitstrings, substsytem, pool)
+    assert entropy == 0.5145731728297583
+
+
+def test_process_renyi_entropy_from_bitstrings_safeguards_against_divide_by_0_error():
+    bitstrings = np.array([[[0, 1, 1, 0]], [[0, 1, 1, 0]], [[0, 0, 1, 1]]])
+
+    entropy = process_renyi_entropy_from_bitstrings(bitstrings)
+    assert entropy == 0
diff --git a/cirq/transformers/__init__.py b/cirq/transformers/__init__.py
index 5a98df78a1a..5f3f8f2f17a 100644
--- a/cirq/transformers/__init__.py
+++ b/cirq/transformers/__init__.py
@@ -134,3 +134,5 @@
     SqrtCZGaugeTransformer,
     SqrtISWAPGaugeTransformer,
 )
+
+from cirq.transformers.randomized_measurements import RandomizedMeasurements
diff --git a/cirq/transformers/randomized_measurements.py b/cirq/transformers/randomized_measurements.py
new file mode 100644
index 00000000000..30b7f94b91b
--- /dev/null
+++ b/cirq/transformers/randomized_measurements.py
@@ -0,0 +1,109 @@
+# Copyright 2024 The Cirq Developers
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from collections.abc import Sequence
+from typing import Any, Literal
+
+import cirq
+import numpy as np
+from cirq.transformers import transformer_api
+
+
+@transformer_api.transformer
+class RandomizedMeasurements:
+    """A transformer that appends a moment of random rotations to map qubits to
+    random pauli bases."""
+
+    def __init__(self, subsystem: Sequence[int] | None = None):
+        """Class structure for performing and analyzing a general randomized measurement protocol.
+        For more details on the randomized measurement toolbox see https://arxiv.org/abs/2203.11374
+
+        Args:
+            subsystem: The specific subsystem (e.g qubit index) to measure in random basis
+        """
+        self.subsystem = subsystem
+
+    def __call__(
+        self,
+        circuit: 'cirq.AbstractCircuit',
+        rng: np.random.Generator | None = None,
+        *,
+        context: transformer_api.TransformerContext | None = None,
+    ):
+        """Apply the transformer to the given circuit. Given an input circuit returns
+        a list of circuits with the pre-measurement unitaries. If no arguments are specified,
+        it will default to computing the entropy of the entire circuit.
+
+        Args:
+            circuit: The circuit to add randomized measurements to.
+            rng: Random number generator.
+            context: Not used; to satisfy transformer API.
+
+        Returns:
+            List of circuits with pre-measurement unitaries and measurements added
+        """
+        if rng is None:
+            rng = np.random.default_rng()
+
+        qubits = sorted(circuit.all_qubits())
+        num_qubits = len(qubits)
+
+        pre_measurement_unitaries_list = self._generate_unitaries_list(rng, num_qubits)
+        pre_measurement_moment = self.unitaries_to_moment(pre_measurement_unitaries_list, qubits)
+
+        return cirq.Circuit.from_moments(
+            *circuit.moments, pre_measurement_moment, cirq.M(*qubits, key='m')
+        )
+
+    def _generate_unitaries_list(self, rng: np.random.Generator, num_qubits: int) -> Sequence[Any]:
+        """Generates a list of pre-measurement unitaries."""
+
+        pauli_strings = rng.choice(["X", "Y", "Z"], size=num_qubits)
+
+        if self.subsystem is not None:
+            for i in range(pauli_strings.shape[0]):
+                if i not in self.subsystem:
+                    pauli_strings[i] = np.array("Z")
+
+        return pauli_strings.tolist()
+
+    def unitaries_to_moment(
+        self, unitaries: Sequence[Literal["X", "Y", "Z"]], qubits: Sequence[Any]
+    ) -> 'cirq.Moment':
+        """Outputs the cirq moment associated with the pre-measurement rotations.
+        Args:
+            unitaries: List of pre-measurement unitaries
+            qubits: List of qubits
+
+        Returns: The cirq moment associated with the pre-measurement rotations
+        """
+        op_list: list[cirq.Operation] = []
+        for idx, pauli in enumerate(unitaries):
+            op_list.append(_pauli_basis_rotation(pauli).on(qubits[idx]))
+
+        return cirq.Moment.from_ops(*op_list)
+
+
+def _pauli_basis_rotation(basis: Literal["X", "Y", "Z"]) -> 'cirq.Gate':
+    """Given a measurement basis returns the associated rotation.
+    Args:
+        basis: Measurement basis
+    Returns: The cirq gate for associated with measurement basis
+    """
+    if basis == "X":
+        return cirq.Ry(rads=-np.pi / 2)
+    elif basis == "Y":
+        return cirq.Rx(rads=np.pi / 2)
+    elif basis == "Z":
+        return cirq.I
diff --git a/cirq/transformers/randomized_measurements_test.py b/cirq/transformers/randomized_measurements_test.py
new file mode 100644
index 00000000000..7ff355eb461
--- /dev/null
+++ b/cirq/transformers/randomized_measurements_test.py
@@ -0,0 +1,55 @@
+# Copyright 2024 The Cirq Developers
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import cirq
+import cirq.transformers.randomized_measurements as rand_meas
+
+
+def test_randomized_measurements_appends_two_moments_on_returned_circuit():
+    # Create a 4-qubit circuit
+    q0, q1, q2, q3 = cirq.LineQubit.range(4)
+    circuit = cirq.Circuit([cirq.H(q0), cirq.CNOT(q0, q1), cirq.CNOT(q1, q2), cirq.CNOT(q2, q3)])
+    num_moments_pre = len(circuit.moments)
+
+    # Append randomized measurements to subsystem
+    circuit = rand_meas.RandomizedMeasurements()(circuit)
+
+    num_moments_post = len(circuit.moments)
+    assert num_moments_post == num_moments_pre + 2
+
+
+def test_append_randomized_measurements_leaves_qubits_not_in_specified_subsystem_unchanged():
+    # Create a 4-qubit circuit
+    q0, q1, q2, q3 = cirq.LineQubit.range(4)
+    circuit = cirq.Circuit([cirq.H(q0), cirq.CNOT(q0, q1), cirq.CNOT(q1, q2), cirq.CNOT(q2, q3)])
+
+    # Append randomized measurements to subsystem
+    circuit = rand_meas.RandomizedMeasurements(subsystem=(0, 1))(circuit)
+
+    # assert latter subsystems were not changed.
+    assert circuit.operation_at(q2, 4) == cirq.I(q2)
+    assert circuit.operation_at(q3, 4) == cirq.I(q3)
+
+
+def test_append_randomized_measurements_leaves_qubits_not_in_noncontinuous_subsystem_unchanged():
+    # Create a 4-qubit circuit
+    q0, q1, q2, q3 = cirq.LineQubit.range(4)
+    circuit = cirq.Circuit([cirq.H(q0), cirq.CNOT(q0, q1), cirq.CNOT(q1, q2), cirq.CNOT(q2, q3)])
+
+    # Append randomized measurements to subsystem
+    circuit = rand_meas.RandomizedMeasurements(subsystem=(0, 2))(circuit)
+
+    # assert latter subsystems were not changed.
+    assert circuit.operation_at(q1, 4) == cirq.I(q1)
+    assert circuit.operation_at(q3, 4) == cirq.I(q3)