-
Notifications
You must be signed in to change notification settings - Fork 1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adds tools for appending randomized measurement bases and processing …
…renyi entropy from bitstring (#6664) * add utilities for processing renyi entropy and appending randomized measurements + tests * nit: update test comments * address comments * fix np array shape which changes test solution * Address comments and fix bug * address comments * use zip and itertools and update to transformer * rm print * type check * Update cirq-core/cirq/qis/entropy.py Co-authored-by: Noureldin <[email protected]> * Update cirq-core/cirq/transformers/randomized_measurements.py Co-authored-by: Noureldin <[email protected]> * comments * line too long --------- Co-authored-by: Noureldin <[email protected]>
- Loading branch information
1 parent
3922a63
commit 75b3f40
Showing
6 changed files
with
324 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -60,3 +60,4 @@ | |
average_error, | ||
decoherence_pauli_error, | ||
) | ||
from cirq.qis.entropy import process_renyi_entropy_from_bitstrings |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
# Copyright 2024 The Cirq Developers | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from concurrent.futures import ThreadPoolExecutor | ||
from collections.abc import Sequence | ||
from itertools import product | ||
from typing import Any, Optional | ||
|
||
import numpy as np | ||
import numpy.typing as npt | ||
|
||
|
||
def _get_hamming_distance( | ||
bitstring_1: npt.NDArray[np.int8], bitstring_2: npt.NDArray[np.int8] | ||
) -> int: | ||
"""Calculates the Hamming distance between two bitstrings. | ||
Args: | ||
bitstring_1: Bitstring 1 | ||
bitstring_2: Bitstring 2 | ||
Returns: The Hamming distance | ||
""" | ||
return (bitstring_1 ^ bitstring_2).sum().item() | ||
|
||
|
||
def _bitstrings_to_probs( | ||
bitstrings: npt.NDArray[np.int8], | ||
) -> tuple[npt.NDArray[np.int8], npt.NDArray[Any]]: | ||
"""Given a list of bitstrings from different measurements returns a probability distribution. | ||
Args: | ||
bitstrings: The bitstring | ||
Returns: | ||
A tuple of bitstrings and their corresponding probabilities. | ||
""" | ||
|
||
num_shots = bitstrings.shape[0] | ||
unique_bitstrings, counts = np.unique(bitstrings, return_counts=True, axis=0) | ||
probs = counts / num_shots | ||
|
||
return (unique_bitstrings, probs) | ||
|
||
|
||
def _bitstring_format_helper( | ||
measured_bitstrings: npt.NDArray[np.int8], subsystem: Sequence[int] | None = None | ||
) -> npt.NDArray[np.int8]: | ||
"""Formats the bitstring for analysis based on the selected subsystem. | ||
Args: | ||
measured_bitstrings: List of sampled measurement outcomes as a numpy array of bitstrings. | ||
subsystem: Subsystem of interest | ||
Returns: The bitstring string for the subsystem | ||
""" | ||
if subsystem is None: | ||
return measured_bitstrings | ||
|
||
return measured_bitstrings[:, :, subsystem] | ||
|
||
|
||
def _compute_bitstrings_contribution_to_purity(bitstrings: npt.NDArray[np.int8]) -> float: | ||
"""Computes the contribution to the purity of the bitstrings. | ||
Args: | ||
bitstrings: The bitstrings measured using the same unitary operators | ||
Returns: The purity of the bitstring | ||
""" | ||
|
||
bitstrings, probs = _bitstrings_to_probs(bitstrings) | ||
purity = 0 | ||
for (s, p), (s_prime, p_prime) in product(zip(bitstrings, probs), repeat=2): | ||
purity += (-2.0) ** float(-_get_hamming_distance(s, s_prime)) * p * p_prime | ||
|
||
return purity * 2 ** (bitstrings.shape[-1]) | ||
|
||
|
||
def process_renyi_entropy_from_bitstrings( | ||
measured_bitstrings: npt.NDArray[np.int8], | ||
subsystem: tuple[int] | None = None, | ||
pool: Optional[ThreadPoolExecutor] = None, | ||
) -> float: | ||
"""Compute the Rényi entropy of an array of bitstrings. | ||
Args: | ||
measured_bitstrings: List of sampled measurement outcomes as a numpy array of bitstrings. | ||
subsystem: Subsystem of interest | ||
pool: ThreadPoolExecutor used to paralelleize the computation. | ||
Returns: | ||
A float indicating the computed entropy. | ||
""" | ||
bitstrings = _bitstring_format_helper(measured_bitstrings, subsystem) | ||
num_shots = bitstrings.shape[1] | ||
num_qubits = bitstrings.shape[-1] | ||
|
||
if num_shots == 1: | ||
return 0 | ||
|
||
if pool is not None: | ||
purities = list(pool.map(_compute_bitstrings_contribution_to_purity, list(bitstrings))) | ||
purity = np.mean(purities) | ||
|
||
else: | ||
purity = np.mean( | ||
[_compute_bitstrings_contribution_to_purity(bitstring) for bitstring in bitstrings] | ||
) | ||
|
||
purity_unbiased = purity * num_shots / (num_shots - 1) - (2**num_qubits) / (num_shots - 1) | ||
|
||
return -np.log2(purity_unbiased) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# Copyright 2024 The Cirq Developers | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from concurrent.futures import ThreadPoolExecutor | ||
import pytest | ||
import numpy as np | ||
|
||
from cirq.qis.entropy import process_renyi_entropy_from_bitstrings | ||
|
||
|
||
@pytest.mark.parametrize('pool', [None, ThreadPoolExecutor(max_workers=1)]) | ||
def test_process_renyi_entropy_from_bitstrings(pool): | ||
bitstrings = np.array( | ||
[ | ||
[[0, 1, 1, 0], [1, 0, 0, 0], [0, 0, 0, 1], [0, 1, 1, 1]], | ||
[[0, 1, 1, 0], [0, 1, 1, 0], [1, 1, 0, 0], [1, 1, 0, 1]], | ||
[[0, 0, 1, 1], [0, 0, 0, 0], [0, 1, 1, 1], [0, 1, 1, 1]], | ||
[[1, 0, 1, 1], [0, 1, 0, 0], [0, 1, 0, 1], [0, 1, 0, 0]], | ||
[[1, 0, 1, 1], [0, 0, 1, 0], [0, 0, 1, 0], [0, 0, 0, 1]], | ||
] | ||
) | ||
substsytem = (0, 1) | ||
entropy = process_renyi_entropy_from_bitstrings(bitstrings, substsytem, pool) | ||
assert entropy == 0.5145731728297583 | ||
|
||
|
||
def test_process_renyi_entropy_from_bitstrings_safeguards_against_divide_by_0_error(): | ||
bitstrings = np.array([[[0, 1, 1, 0]], [[0, 1, 1, 0]], [[0, 0, 1, 1]]]) | ||
|
||
entropy = process_renyi_entropy_from_bitstrings(bitstrings) | ||
assert entropy == 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
# Copyright 2024 The Cirq Developers | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from collections.abc import Sequence | ||
from typing import Any, Literal | ||
|
||
import cirq | ||
import numpy as np | ||
from cirq.transformers import transformer_api | ||
|
||
|
||
@transformer_api.transformer | ||
class RandomizedMeasurements: | ||
"""A transformer that appends a moment of random rotations to map qubits to | ||
random pauli bases.""" | ||
|
||
def __init__(self, subsystem: Sequence[int] | None = None): | ||
"""Class structure for performing and analyzing a general randomized measurement protocol. | ||
For more details on the randomized measurement toolbox see https://arxiv.org/abs/2203.11374 | ||
Args: | ||
subsystem: The specific subsystem (e.g qubit index) to measure in random basis | ||
""" | ||
self.subsystem = subsystem | ||
|
||
def __call__( | ||
self, | ||
circuit: 'cirq.AbstractCircuit', | ||
rng: np.random.Generator | None = None, | ||
*, | ||
context: transformer_api.TransformerContext | None = None, | ||
): | ||
"""Apply the transformer to the given circuit. Given an input circuit returns | ||
a list of circuits with the pre-measurement unitaries. If no arguments are specified, | ||
it will default to computing the entropy of the entire circuit. | ||
Args: | ||
circuit: The circuit to add randomized measurements to. | ||
rng: Random number generator. | ||
context: Not used; to satisfy transformer API. | ||
Returns: | ||
List of circuits with pre-measurement unitaries and measurements added | ||
""" | ||
if rng is None: | ||
rng = np.random.default_rng() | ||
|
||
qubits = sorted(circuit.all_qubits()) | ||
num_qubits = len(qubits) | ||
|
||
pre_measurement_unitaries_list = self._generate_unitaries_list(rng, num_qubits) | ||
pre_measurement_moment = self.unitaries_to_moment(pre_measurement_unitaries_list, qubits) | ||
|
||
return cirq.Circuit.from_moments( | ||
*circuit.moments, pre_measurement_moment, cirq.M(*qubits, key='m') | ||
) | ||
|
||
def _generate_unitaries_list(self, rng: np.random.Generator, num_qubits: int) -> Sequence[Any]: | ||
"""Generates a list of pre-measurement unitaries.""" | ||
|
||
pauli_strings = rng.choice(["X", "Y", "Z"], size=num_qubits) | ||
|
||
if self.subsystem is not None: | ||
for i in range(pauli_strings.shape[0]): | ||
if i not in self.subsystem: | ||
pauli_strings[i] = np.array("Z") | ||
|
||
return pauli_strings.tolist() | ||
|
||
def unitaries_to_moment( | ||
self, unitaries: Sequence[Literal["X", "Y", "Z"]], qubits: Sequence[Any] | ||
) -> 'cirq.Moment': | ||
"""Outputs the cirq moment associated with the pre-measurement rotations. | ||
Args: | ||
unitaries: List of pre-measurement unitaries | ||
qubits: List of qubits | ||
Returns: The cirq moment associated with the pre-measurement rotations | ||
""" | ||
op_list: list[cirq.Operation] = [] | ||
for idx, pauli in enumerate(unitaries): | ||
op_list.append(_pauli_basis_rotation(pauli).on(qubits[idx])) | ||
|
||
return cirq.Moment.from_ops(*op_list) | ||
|
||
|
||
def _pauli_basis_rotation(basis: Literal["X", "Y", "Z"]) -> 'cirq.Gate': | ||
"""Given a measurement basis returns the associated rotation. | ||
Args: | ||
basis: Measurement basis | ||
Returns: The cirq gate for associated with measurement basis | ||
""" | ||
if basis == "X": | ||
return cirq.Ry(rads=-np.pi / 2) | ||
elif basis == "Y": | ||
return cirq.Rx(rads=np.pi / 2) | ||
elif basis == "Z": | ||
return cirq.I |
55 changes: 55 additions & 0 deletions
55
cirq-core/cirq/transformers/randomized_measurements_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
# Copyright 2024 The Cirq Developers | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import cirq | ||
import cirq.transformers.randomized_measurements as rand_meas | ||
|
||
|
||
def test_randomized_measurements_appends_two_moments_on_returned_circuit(): | ||
# Create a 4-qubit circuit | ||
q0, q1, q2, q3 = cirq.LineQubit.range(4) | ||
circuit = cirq.Circuit([cirq.H(q0), cirq.CNOT(q0, q1), cirq.CNOT(q1, q2), cirq.CNOT(q2, q3)]) | ||
num_moments_pre = len(circuit.moments) | ||
|
||
# Append randomized measurements to subsystem | ||
circuit = rand_meas.RandomizedMeasurements()(circuit) | ||
|
||
num_moments_post = len(circuit.moments) | ||
assert num_moments_post == num_moments_pre + 2 | ||
|
||
|
||
def test_append_randomized_measurements_leaves_qubits_not_in_specified_subsystem_unchanged(): | ||
# Create a 4-qubit circuit | ||
q0, q1, q2, q3 = cirq.LineQubit.range(4) | ||
circuit = cirq.Circuit([cirq.H(q0), cirq.CNOT(q0, q1), cirq.CNOT(q1, q2), cirq.CNOT(q2, q3)]) | ||
|
||
# Append randomized measurements to subsystem | ||
circuit = rand_meas.RandomizedMeasurements(subsystem=(0, 1))(circuit) | ||
|
||
# assert latter subsystems were not changed. | ||
assert circuit.operation_at(q2, 4) == cirq.I(q2) | ||
assert circuit.operation_at(q3, 4) == cirq.I(q3) | ||
|
||
|
||
def test_append_randomized_measurements_leaves_qubits_not_in_noncontinuous_subsystem_unchanged(): | ||
# Create a 4-qubit circuit | ||
q0, q1, q2, q3 = cirq.LineQubit.range(4) | ||
circuit = cirq.Circuit([cirq.H(q0), cirq.CNOT(q0, q1), cirq.CNOT(q1, q2), cirq.CNOT(q2, q3)]) | ||
|
||
# Append randomized measurements to subsystem | ||
circuit = rand_meas.RandomizedMeasurements(subsystem=(0, 2))(circuit) | ||
|
||
# assert latter subsystems were not changed. | ||
assert circuit.operation_at(q1, 4) == cirq.I(q1) | ||
assert circuit.operation_at(q3, 4) == cirq.I(q3) |