From 775bd1443701ce295d2428f0535d1d98ba75e53b Mon Sep 17 00:00:00 2001 From: Matthew Neeley Date: Mon, 14 Feb 2022 15:45:29 -0800 Subject: [PATCH] Support repeated measurement keys in cirq.Result (#4555) Part of #4274 --- cirq/study/result.py | 59 +++++++++++++++++++++++++++++++-- cirq/study/result_test.py | 48 +++++++++++++++++++++++++++ cirq/work/sampler.py | 34 ++++++++++++++++++- cirq/work/zeros_sampler.py | 22 ++++++------ cirq/work/zeros_sampler_test.py | 22 ++++++++++++ 5 files changed, 171 insertions(+), 14 deletions(-) diff --git a/cirq/study/result.py b/cirq/study/result.py index 067bad47bf4..54f918b3b7f 100644 --- a/cirq/study/result.py +++ b/cirq/study/result.py @@ -109,6 +109,17 @@ def measurements(self) -> Mapping[str, np.ndarray]: qubits for the corresponding measurements. """ + @property + @abc.abstractmethod + def records(self) -> Mapping[str, np.ndarray]: + """A mapping from measurement key to measurement records. + + The value for each key is a 3-D array of booleans, with the first index + running over circuit repetitions, the second index running over instances + of the measurement key in the circuit, and the third index running over + the qubits for the corresponding measurements. + """ + @property @abc.abstractmethod def data(self) -> pd.DataFrame: @@ -315,7 +326,8 @@ def __init__( self, *, # Forces keyword args. params: resolver.ParamResolver, - measurements: Mapping[str, np.ndarray], + measurements: Optional[Mapping[str, np.ndarray]] = None, + records: Optional[Mapping[str, np.ndarray]] = None, ) -> None: """Inits Result. @@ -326,9 +338,20 @@ def __init__( with the first index running over the repetitions, and the second index running over the qubits for the corresponding measurements. + records: A dictionary from measurement gate key to measurement + results. The value for each key is a 3D array of booleans, + with the first index running over the repetitions, the second + index running over "instances" of that key in the circuit, and + the last index running over the qubits for the corresponding + measurements. """ + if measurements is None and records is None: + # For backwards compatibility, allow constructing with None. + measurements = {} + records = {} self._params = params self._measurements = measurements + self._records = records self._data: Optional[pd.DataFrame] = None @property @@ -337,12 +360,42 @@ def params(self) -> 'cirq.ParamResolver': @property def measurements(self) -> Mapping[str, np.ndarray]: + if self._measurements is None: + assert self._records is not None + self._measurements = {} + for key, data in self._records.items(): + reps, instances, qubits = data.shape + if instances != 1: + raise ValueError('Cannot extract 2D measurements for repeated keys') + self._measurements[key] = data.reshape((reps, qubits)) return self._measurements + @property + def records(self) -> Mapping[str, np.ndarray]: + if self._records is None: + assert self._measurements is not None + self._records = { + key: data[:, np.newaxis, :] for key, data in self._measurements.items() + } + return self._records + + @property + def repetitions(self) -> int: + if self._records is not None: + if not self._records: + return 0 + # Get the length quickly from one of the keyed results. + return len(next(iter(self._records.values()))) + else: + if not self._measurements: + return 0 + # Get the length quickly from one of the keyed results. + return len(next(iter(self._measurements.values()))) + @property def data(self) -> pd.DataFrame: if self._data is None: - self._data = self.dataframe_from_measurements(self._measurements) + self._data = self.dataframe_from_measurements(self.measurements) return self._data def __repr__(self) -> str: @@ -404,7 +457,7 @@ def _pack_digits(digits: np.ndarray, pack_bits: str = 'auto') -> Tuple[str, bool if pack_bits == 'force': return _pack_bits(digits), True if pack_bits not in ['auto', 'never']: - raise ValueError("Please set `pack_bits` to 'auto', " "'force', or 'never'.") + raise ValueError("Please set `pack_bits` to 'auto', 'force', or 'never'.") # Do error checking here, otherwise the following logic will work # for both "auto" and "never". diff --git a/cirq/study/result_test.py b/cirq/study/result_test.py index 4d0ee24077e..9a15ebb10b0 100644 --- a/cirq/study/result_test.py +++ b/cirq/study/result_test.py @@ -49,6 +49,54 @@ def test_from_single_parameter_set_deprecation(): assert result.repetitions == 0 +def test_construct_from_measurements(): + r = cirq.ResultDict( + params=None, + measurements={ + 'a': np.array([[0, 0], [1, 1]]), + 'b': np.array([[0, 0, 0], [1, 1, 1]]), + }, + ) + assert np.all(r.measurements['a'] == np.array([[0, 0], [1, 1]])) + assert np.all(r.measurements['b'] == np.array([[0, 0, 0], [1, 1, 1]])) + assert np.all(r.records['a'] == np.array([[[0, 0]], [[1, 1]]])) + assert np.all(r.records['b'] == np.array([[[0, 0, 0]], [[1, 1, 1]]])) + + +def test_construct_from_repeated_measurements(): + r = cirq.ResultDict( + params=None, + records={ + 'a': np.array([[[0, 0], [0, 1]], [[1, 0], [1, 1]]]), + 'b': np.array([[[0, 0, 0]], [[1, 1, 1]]]), + }, + ) + with pytest.raises(ValueError): + _ = r.measurements + assert np.all(r.records['a'] == np.array([[[0, 0], [0, 1]], [[1, 0], [1, 1]]])) + assert np.all(r.records['b'] == np.array([[[0, 0, 0]], [[1, 1, 1]]])) + assert r.repetitions == 2 + + r2 = cirq.ResultDict( + params=None, + records={ + 'a': np.array([[[0, 0]], [[1, 1]]]), + 'b': np.array([[[0, 0, 0]], [[1, 1, 1]]]), + }, + ) + assert np.all(r2.measurements['a'] == np.array([[0, 0], [1, 1]])) + assert np.all(r2.measurements['b'] == np.array([[0, 0, 0], [1, 1, 1]])) + assert np.all(r2.records['a'] == np.array([[[0, 0]], [[1, 1]]])) + assert np.all(r2.records['b'] == np.array([[[0, 0, 0]], [[1, 1, 1]]])) + assert r2.repetitions == 2 + + +def test_empty_measurements(): + assert cirq.ResultDict(params=None).repetitions == 0 + assert cirq.ResultDict(params=None, measurements={}).repetitions == 0 + assert cirq.ResultDict(params=None, records={}).repetitions == 0 + + def test_str(): result = cirq.ResultDict( params=cirq.ParamResolver({}), diff --git a/cirq/work/sampler.py b/cirq/work/sampler.py index 42412539847..5e2fcc8581b 100644 --- a/cirq/work/sampler.py +++ b/cirq/work/sampler.py @@ -14,11 +14,12 @@ """Abstract base class for things sampling quantum circuits.""" import abc +import collections from typing import Dict, FrozenSet, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union import pandas as pd -from cirq import study, ops +from cirq import ops, protocols, study from cirq.work.observable_measurement import ( measure_observables, RepetitionsStoppingCriteria, @@ -365,3 +366,34 @@ def sample_expectation_values( nested_results[param_i][psum_i] += res.mean return nested_results + + @staticmethod + def _get_measurement_shapes( + circuit: 'cirq.AbstractCircuit', + ) -> Dict[str, Tuple[int, Tuple[int, ...]]]: + """Gets the shapes of measurements in the given circuit. + + Returns: + A mapping from measurement key name to a tuple of (num_instances, qid_shape), + where num_instances is the number of times that key appears in the circuit and + qid_shape is the shape of measured qubits for the key, as determined by the + `cirq.qid_shape` protocol. + + Raises: + ValueError: if the qid_shape of different instances of the same measurement + key disagree. + """ + qid_shapes: Dict[str, Tuple[int, ...]] = {} + num_instances: Dict[str, int] = collections.Counter() + for op in circuit.all_operations(): + key = protocols.measurement_key_name(op, default=None) + if key is not None: + qid_shape = protocols.qid_shape(op) + prev_qid_shape = qid_shapes.setdefault(key, qid_shape) + if qid_shape != prev_qid_shape: + raise ValueError( + "Different qid shapes for repeated measurement: " + f"key={key!r}, prev_qid_shape={prev_qid_shape}, qid_shape={qid_shape}" + ) + num_instances[key] += 1 + return {k: (num_instances[k], qid_shape) for k, qid_shape in qid_shapes.items()} diff --git a/cirq/work/zeros_sampler.py b/cirq/work/zeros_sampler.py index 9234a59cfbc..1f8dc7555c9 100644 --- a/cirq/work/zeros_sampler.py +++ b/cirq/work/zeros_sampler.py @@ -13,11 +13,11 @@ # limitations under the License. import abc -from typing import Dict, List, TYPE_CHECKING +from typing import List, TYPE_CHECKING import numpy as np -from cirq import devices, work, study, protocols +from cirq import devices, work, study if TYPE_CHECKING: import cirq @@ -53,17 +53,19 @@ def run_sweep( resolver. Raises: - ValueError if this sampler has a device and the circuit is not - valid for the device. + ValueError: circuit is not valid for the sampler, due to invalid + repeated keys or incompatibility with the sampler's device. """ if self.device: self.device.validate_circuit(program) - measurements: Dict[str, np.ndarray] = {} - for op in program.all_operations(): - key = protocols.measurement_key_name(op, default=None) - if key is not None: - measurements[key] = np.zeros((repetitions, len(op.qubits)), dtype=int) + shapes = self._get_measurement_shapes(program) return [ - study.ResultDict(params=param_resolver, measurements=measurements) + study.ResultDict( + params=param_resolver, + records={ + k: np.zeros((repetitions, num_instances, len(qid_shape)), dtype=int) + for k, (num_instances, qid_shape) in shapes.items() + }, + ) for param_resolver in study.to_resolvers(params) ] diff --git a/cirq/work/zeros_sampler_test.py b/cirq/work/zeros_sampler_test.py index e02f730a116..0d9f0e80993 100644 --- a/cirq/work/zeros_sampler_test.py +++ b/cirq/work/zeros_sampler_test.py @@ -52,6 +52,28 @@ def test_sample(): assert np.all(result1 == result2) +def test_repeated_keys(): + q0, q1, q2 = cirq.LineQubit.range(3) + + c = cirq.Circuit( + cirq.measure(q0, key='a'), + cirq.measure(q1, q2, key='b'), + cirq.measure(q0, key='a'), + cirq.measure(q1, q2, key='b'), + cirq.measure(q1, q2, key='b'), + ) + result = cirq.ZerosSampler().run(c, repetitions=10) + assert result.records['a'].shape == (10, 2, 1) + assert result.records['b'].shape == (10, 3, 2) + + c2 = cirq.Circuit( + cirq.measure(q0, key='a'), + cirq.measure(q1, q2, key='a'), + ) + with pytest.raises(ValueError, match="Different qid shapes for repeated measurement"): + cirq.ZerosSampler().run(c2, repetitions=10) + + class OnlyMeasurementsDevice(cirq.Device): def validate_operation(self, operation: 'cirq.Operation') -> None: if not cirq.is_measurement(operation):