From 75b3f408e14fa0323f7b57195321787e10446d4a Mon Sep 17 00:00:00 2001 From: Seneca Meeks Date: Wed, 17 Jul 2024 17:09:14 -0700 Subject: [PATCH] Adds tools for appending randomized measurement bases and processing renyi entropy from bitstring (#6664) * add utilities for processing renyi entropy and appending randomized measurements + tests * nit: update test comments * address comments * fix np array shape which changes test solution * Address comments and fix bug * address comments * use zip and itertools and update to transformer * rm print * type check * Update cirq-core/cirq/qis/entropy.py Co-authored-by: Noureldin * Update cirq-core/cirq/transformers/randomized_measurements.py Co-authored-by: Noureldin * comments * line too long --------- Co-authored-by: Noureldin --- cirq-core/cirq/qis/__init__.py | 1 + cirq-core/cirq/qis/entropy.py | 115 ++++++++++++++++++ cirq-core/cirq/qis/entropy_test.py | 42 +++++++ cirq-core/cirq/transformers/__init__.py | 2 + .../transformers/randomized_measurements.py | 109 +++++++++++++++++ .../randomized_measurements_test.py | 55 +++++++++ 6 files changed, 324 insertions(+) create mode 100644 cirq-core/cirq/qis/entropy.py create mode 100644 cirq-core/cirq/qis/entropy_test.py create mode 100644 cirq-core/cirq/transformers/randomized_measurements.py create mode 100644 cirq-core/cirq/transformers/randomized_measurements_test.py diff --git a/cirq-core/cirq/qis/__init__.py b/cirq-core/cirq/qis/__init__.py index 12913d6e8cc..4edceb0dc0c 100644 --- a/cirq-core/cirq/qis/__init__.py +++ b/cirq-core/cirq/qis/__init__.py @@ -60,3 +60,4 @@ average_error, decoherence_pauli_error, ) +from cirq.qis.entropy import process_renyi_entropy_from_bitstrings diff --git a/cirq-core/cirq/qis/entropy.py b/cirq-core/cirq/qis/entropy.py new file mode 100644 index 00000000000..6a70ffaa428 --- /dev/null +++ b/cirq-core/cirq/qis/entropy.py @@ -0,0 +1,115 @@ +# Copyright 2024 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from concurrent.futures import ThreadPoolExecutor +from collections.abc import Sequence +from itertools import product +from typing import Any, Optional + +import numpy as np +import numpy.typing as npt + + +def _get_hamming_distance( + bitstring_1: npt.NDArray[np.int8], bitstring_2: npt.NDArray[np.int8] +) -> int: + """Calculates the Hamming distance between two bitstrings. + Args: + bitstring_1: Bitstring 1 + bitstring_2: Bitstring 2 + Returns: The Hamming distance + """ + return (bitstring_1 ^ bitstring_2).sum().item() + + +def _bitstrings_to_probs( + bitstrings: npt.NDArray[np.int8], +) -> tuple[npt.NDArray[np.int8], npt.NDArray[Any]]: + """Given a list of bitstrings from different measurements returns a probability distribution. + Args: + bitstrings: The bitstring + Returns: + A tuple of bitstrings and their corresponding probabilities. + """ + + num_shots = bitstrings.shape[0] + unique_bitstrings, counts = np.unique(bitstrings, return_counts=True, axis=0) + probs = counts / num_shots + + return (unique_bitstrings, probs) + + +def _bitstring_format_helper( + measured_bitstrings: npt.NDArray[np.int8], subsystem: Sequence[int] | None = None +) -> npt.NDArray[np.int8]: + """Formats the bitstring for analysis based on the selected subsystem. + Args: + measured_bitstrings: List of sampled measurement outcomes as a numpy array of bitstrings. + subsystem: Subsystem of interest + Returns: The bitstring string for the subsystem + """ + if subsystem is None: + return measured_bitstrings + + return measured_bitstrings[:, :, subsystem] + + +def _compute_bitstrings_contribution_to_purity(bitstrings: npt.NDArray[np.int8]) -> float: + """Computes the contribution to the purity of the bitstrings. + Args: + bitstrings: The bitstrings measured using the same unitary operators + Returns: The purity of the bitstring + """ + + bitstrings, probs = _bitstrings_to_probs(bitstrings) + purity = 0 + for (s, p), (s_prime, p_prime) in product(zip(bitstrings, probs), repeat=2): + purity += (-2.0) ** float(-_get_hamming_distance(s, s_prime)) * p * p_prime + + return purity * 2 ** (bitstrings.shape[-1]) + + +def process_renyi_entropy_from_bitstrings( + measured_bitstrings: npt.NDArray[np.int8], + subsystem: tuple[int] | None = None, + pool: Optional[ThreadPoolExecutor] = None, +) -> float: + """Compute the Rényi entropy of an array of bitstrings. + Args: + measured_bitstrings: List of sampled measurement outcomes as a numpy array of bitstrings. + subsystem: Subsystem of interest + pool: ThreadPoolExecutor used to paralelleize the computation. + + Returns: + A float indicating the computed entropy. + """ + bitstrings = _bitstring_format_helper(measured_bitstrings, subsystem) + num_shots = bitstrings.shape[1] + num_qubits = bitstrings.shape[-1] + + if num_shots == 1: + return 0 + + if pool is not None: + purities = list(pool.map(_compute_bitstrings_contribution_to_purity, list(bitstrings))) + purity = np.mean(purities) + + else: + purity = np.mean( + [_compute_bitstrings_contribution_to_purity(bitstring) for bitstring in bitstrings] + ) + + purity_unbiased = purity * num_shots / (num_shots - 1) - (2**num_qubits) / (num_shots - 1) + + return -np.log2(purity_unbiased) diff --git a/cirq-core/cirq/qis/entropy_test.py b/cirq-core/cirq/qis/entropy_test.py new file mode 100644 index 00000000000..e38d74217c6 --- /dev/null +++ b/cirq-core/cirq/qis/entropy_test.py @@ -0,0 +1,42 @@ +# Copyright 2024 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from concurrent.futures import ThreadPoolExecutor +import pytest +import numpy as np + +from cirq.qis.entropy import process_renyi_entropy_from_bitstrings + + +@pytest.mark.parametrize('pool', [None, ThreadPoolExecutor(max_workers=1)]) +def test_process_renyi_entropy_from_bitstrings(pool): + bitstrings = np.array( + [ + [[0, 1, 1, 0], [1, 0, 0, 0], [0, 0, 0, 1], [0, 1, 1, 1]], + [[0, 1, 1, 0], [0, 1, 1, 0], [1, 1, 0, 0], [1, 1, 0, 1]], + [[0, 0, 1, 1], [0, 0, 0, 0], [0, 1, 1, 1], [0, 1, 1, 1]], + [[1, 0, 1, 1], [0, 1, 0, 0], [0, 1, 0, 1], [0, 1, 0, 0]], + [[1, 0, 1, 1], [0, 0, 1, 0], [0, 0, 1, 0], [0, 0, 0, 1]], + ] + ) + substsytem = (0, 1) + entropy = process_renyi_entropy_from_bitstrings(bitstrings, substsytem, pool) + assert entropy == 0.5145731728297583 + + +def test_process_renyi_entropy_from_bitstrings_safeguards_against_divide_by_0_error(): + bitstrings = np.array([[[0, 1, 1, 0]], [[0, 1, 1, 0]], [[0, 0, 1, 1]]]) + + entropy = process_renyi_entropy_from_bitstrings(bitstrings) + assert entropy == 0 diff --git a/cirq-core/cirq/transformers/__init__.py b/cirq-core/cirq/transformers/__init__.py index 5a98df78a1a..5f3f8f2f17a 100644 --- a/cirq-core/cirq/transformers/__init__.py +++ b/cirq-core/cirq/transformers/__init__.py @@ -134,3 +134,5 @@ SqrtCZGaugeTransformer, SqrtISWAPGaugeTransformer, ) + +from cirq.transformers.randomized_measurements import RandomizedMeasurements diff --git a/cirq-core/cirq/transformers/randomized_measurements.py b/cirq-core/cirq/transformers/randomized_measurements.py new file mode 100644 index 00000000000..30b7f94b91b --- /dev/null +++ b/cirq-core/cirq/transformers/randomized_measurements.py @@ -0,0 +1,109 @@ +# Copyright 2024 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Sequence +from typing import Any, Literal + +import cirq +import numpy as np +from cirq.transformers import transformer_api + + +@transformer_api.transformer +class RandomizedMeasurements: + """A transformer that appends a moment of random rotations to map qubits to + random pauli bases.""" + + def __init__(self, subsystem: Sequence[int] | None = None): + """Class structure for performing and analyzing a general randomized measurement protocol. + For more details on the randomized measurement toolbox see https://arxiv.org/abs/2203.11374 + + Args: + subsystem: The specific subsystem (e.g qubit index) to measure in random basis + """ + self.subsystem = subsystem + + def __call__( + self, + circuit: 'cirq.AbstractCircuit', + rng: np.random.Generator | None = None, + *, + context: transformer_api.TransformerContext | None = None, + ): + """Apply the transformer to the given circuit. Given an input circuit returns + a list of circuits with the pre-measurement unitaries. If no arguments are specified, + it will default to computing the entropy of the entire circuit. + + Args: + circuit: The circuit to add randomized measurements to. + rng: Random number generator. + context: Not used; to satisfy transformer API. + + Returns: + List of circuits with pre-measurement unitaries and measurements added + """ + if rng is None: + rng = np.random.default_rng() + + qubits = sorted(circuit.all_qubits()) + num_qubits = len(qubits) + + pre_measurement_unitaries_list = self._generate_unitaries_list(rng, num_qubits) + pre_measurement_moment = self.unitaries_to_moment(pre_measurement_unitaries_list, qubits) + + return cirq.Circuit.from_moments( + *circuit.moments, pre_measurement_moment, cirq.M(*qubits, key='m') + ) + + def _generate_unitaries_list(self, rng: np.random.Generator, num_qubits: int) -> Sequence[Any]: + """Generates a list of pre-measurement unitaries.""" + + pauli_strings = rng.choice(["X", "Y", "Z"], size=num_qubits) + + if self.subsystem is not None: + for i in range(pauli_strings.shape[0]): + if i not in self.subsystem: + pauli_strings[i] = np.array("Z") + + return pauli_strings.tolist() + + def unitaries_to_moment( + self, unitaries: Sequence[Literal["X", "Y", "Z"]], qubits: Sequence[Any] + ) -> 'cirq.Moment': + """Outputs the cirq moment associated with the pre-measurement rotations. + Args: + unitaries: List of pre-measurement unitaries + qubits: List of qubits + + Returns: The cirq moment associated with the pre-measurement rotations + """ + op_list: list[cirq.Operation] = [] + for idx, pauli in enumerate(unitaries): + op_list.append(_pauli_basis_rotation(pauli).on(qubits[idx])) + + return cirq.Moment.from_ops(*op_list) + + +def _pauli_basis_rotation(basis: Literal["X", "Y", "Z"]) -> 'cirq.Gate': + """Given a measurement basis returns the associated rotation. + Args: + basis: Measurement basis + Returns: The cirq gate for associated with measurement basis + """ + if basis == "X": + return cirq.Ry(rads=-np.pi / 2) + elif basis == "Y": + return cirq.Rx(rads=np.pi / 2) + elif basis == "Z": + return cirq.I diff --git a/cirq-core/cirq/transformers/randomized_measurements_test.py b/cirq-core/cirq/transformers/randomized_measurements_test.py new file mode 100644 index 00000000000..7ff355eb461 --- /dev/null +++ b/cirq-core/cirq/transformers/randomized_measurements_test.py @@ -0,0 +1,55 @@ +# Copyright 2024 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cirq +import cirq.transformers.randomized_measurements as rand_meas + + +def test_randomized_measurements_appends_two_moments_on_returned_circuit(): + # Create a 4-qubit circuit + q0, q1, q2, q3 = cirq.LineQubit.range(4) + circuit = cirq.Circuit([cirq.H(q0), cirq.CNOT(q0, q1), cirq.CNOT(q1, q2), cirq.CNOT(q2, q3)]) + num_moments_pre = len(circuit.moments) + + # Append randomized measurements to subsystem + circuit = rand_meas.RandomizedMeasurements()(circuit) + + num_moments_post = len(circuit.moments) + assert num_moments_post == num_moments_pre + 2 + + +def test_append_randomized_measurements_leaves_qubits_not_in_specified_subsystem_unchanged(): + # Create a 4-qubit circuit + q0, q1, q2, q3 = cirq.LineQubit.range(4) + circuit = cirq.Circuit([cirq.H(q0), cirq.CNOT(q0, q1), cirq.CNOT(q1, q2), cirq.CNOT(q2, q3)]) + + # Append randomized measurements to subsystem + circuit = rand_meas.RandomizedMeasurements(subsystem=(0, 1))(circuit) + + # assert latter subsystems were not changed. + assert circuit.operation_at(q2, 4) == cirq.I(q2) + assert circuit.operation_at(q3, 4) == cirq.I(q3) + + +def test_append_randomized_measurements_leaves_qubits_not_in_noncontinuous_subsystem_unchanged(): + # Create a 4-qubit circuit + q0, q1, q2, q3 = cirq.LineQubit.range(4) + circuit = cirq.Circuit([cirq.H(q0), cirq.CNOT(q0, q1), cirq.CNOT(q1, q2), cirq.CNOT(q2, q3)]) + + # Append randomized measurements to subsystem + circuit = rand_meas.RandomizedMeasurements(subsystem=(0, 2))(circuit) + + # assert latter subsystems were not changed. + assert circuit.operation_at(q1, 4) == cirq.I(q1) + assert circuit.operation_at(q3, 4) == cirq.I(q3)