diff --git a/cirq-core/cirq/study/result.py b/cirq-core/cirq/study/result.py index 90683cfaaff..1afff592342 100644 --- a/cirq-core/cirq/study/result.py +++ b/cirq-core/cirq/study/result.py @@ -109,6 +109,17 @@ def measurements(self) -> Mapping[str, np.ndarray]: qubits for the corresponding measurements. """ + @property + @abc.abstractmethod + def records(self) -> Mapping[str, np.ndarray]: + """A mapping from measurement key to measurement records. + + The value for each key is a 3-D array of booleans, with the first index + running over circuit repetitions, the second index running over instances + of the measurement key in the circuit, and the third index running over + the qubits for the corresponding measurements. + """ + @property @abc.abstractmethod def data(self) -> pd.DataFrame: @@ -296,7 +307,8 @@ def __init__( self, *, # Forces keyword args. params: resolver.ParamResolver, - measurements: Mapping[str, np.ndarray], + measurements: Optional[Mapping[str, np.ndarray]] = None, + records: Optional[Mapping[str, np.ndarray]] = None, ) -> None: """Inits Result. @@ -307,9 +319,20 @@ def __init__( with the first index running over the repetitions, and the second index running over the qubits for the corresponding measurements. + records: A dictionary from measurement gate key to measurement + results. The value for each key is a 3D array of booleans, + with the first index running over the repetitions, the second + index running over "instances" of that key in the circuit, and + the last index running over the qubits for the corresponding + measurements. """ + if measurements is None and records is None: + # For backwards compatibility, allow constructing with None. + measurements = {} + records = {} self._params = params self._measurements = measurements + self._records = records self._data: Optional[pd.DataFrame] = None @property @@ -318,15 +341,45 @@ def params(self) -> 'cirq.ParamResolver': @property def measurements(self) -> Mapping[str, np.ndarray]: + if self._measurements is None: + assert self._records is not None + self._measurements = {} + for key, data in self._records.items(): + reps, instances, qubits = data.shape + if instances != 1: + raise ValueError('Cannot extract 2D measurements for repeated keys') + self._measurements[key] = data.reshape((reps, qubits)) return self._measurements + @property + def records(self) -> Mapping[str, np.ndarray]: + if self._records is None: + assert self._measurements is not None + self._records = { + key: data[:, np.newaxis, :] for key, data in self._measurements.items() + } + return self._records + + @property + def repetitions(self) -> int: + if self._records is not None: + if not self._records: + return 0 + # Get the length quickly from one of the keyed results. + return len(next(iter(self._records.values()))) + else: + if not self._measurements: + return 0 + # Get the length quickly from one of the keyed results. + return len(next(iter(self._measurements.values()))) + @property def data(self) -> pd.DataFrame: if self._data is None: # Convert to a DataFrame with columns as measurement keys, rows as # repetitions and a big endian integer for individual measurements. converted_dict = {} - for key, val in self._measurements.items(): + for key, val in self.measurements.items(): converted_dict[key] = [value.big_endian_bits_to_int(m_vals) for m_vals in val] # Note that when a numpy array is produced from this data frame, # Pandas will try to use np.int64 as dtype, but will upgrade to @@ -393,7 +446,7 @@ def _pack_digits(digits: np.ndarray, pack_bits: str = 'auto') -> Tuple[str, bool if pack_bits == 'force': return _pack_bits(digits), True if pack_bits not in ['auto', 'never']: - raise ValueError("Please set `pack_bits` to 'auto', " "'force', or 'never'.") + raise ValueError("Please set `pack_bits` to 'auto', 'force', or 'never'.") # Do error checking here, otherwise the following logic will work # for both "auto" and "never". diff --git a/cirq-core/cirq/study/result_test.py b/cirq-core/cirq/study/result_test.py index 4d0ee24077e..9a15ebb10b0 100644 --- a/cirq-core/cirq/study/result_test.py +++ b/cirq-core/cirq/study/result_test.py @@ -49,6 +49,54 @@ def test_from_single_parameter_set_deprecation(): assert result.repetitions == 0 +def test_construct_from_measurements(): + r = cirq.ResultDict( + params=None, + measurements={ + 'a': np.array([[0, 0], [1, 1]]), + 'b': np.array([[0, 0, 0], [1, 1, 1]]), + }, + ) + assert np.all(r.measurements['a'] == np.array([[0, 0], [1, 1]])) + assert np.all(r.measurements['b'] == np.array([[0, 0, 0], [1, 1, 1]])) + assert np.all(r.records['a'] == np.array([[[0, 0]], [[1, 1]]])) + assert np.all(r.records['b'] == np.array([[[0, 0, 0]], [[1, 1, 1]]])) + + +def test_construct_from_repeated_measurements(): + r = cirq.ResultDict( + params=None, + records={ + 'a': np.array([[[0, 0], [0, 1]], [[1, 0], [1, 1]]]), + 'b': np.array([[[0, 0, 0]], [[1, 1, 1]]]), + }, + ) + with pytest.raises(ValueError): + _ = r.measurements + assert np.all(r.records['a'] == np.array([[[0, 0], [0, 1]], [[1, 0], [1, 1]]])) + assert np.all(r.records['b'] == np.array([[[0, 0, 0]], [[1, 1, 1]]])) + assert r.repetitions == 2 + + r2 = cirq.ResultDict( + params=None, + records={ + 'a': np.array([[[0, 0]], [[1, 1]]]), + 'b': np.array([[[0, 0, 0]], [[1, 1, 1]]]), + }, + ) + assert np.all(r2.measurements['a'] == np.array([[0, 0], [1, 1]])) + assert np.all(r2.measurements['b'] == np.array([[0, 0, 0], [1, 1, 1]])) + assert np.all(r2.records['a'] == np.array([[[0, 0]], [[1, 1]]])) + assert np.all(r2.records['b'] == np.array([[[0, 0, 0]], [[1, 1, 1]]])) + assert r2.repetitions == 2 + + +def test_empty_measurements(): + assert cirq.ResultDict(params=None).repetitions == 0 + assert cirq.ResultDict(params=None, measurements={}).repetitions == 0 + assert cirq.ResultDict(params=None, records={}).repetitions == 0 + + def test_str(): result = cirq.ResultDict( params=cirq.ParamResolver({}), diff --git a/cirq-core/cirq/work/zeros_sampler.py b/cirq-core/cirq/work/zeros_sampler.py index 9234a59cfbc..d430bfa08fe 100644 --- a/cirq-core/cirq/work/zeros_sampler.py +++ b/cirq-core/cirq/work/zeros_sampler.py @@ -13,6 +13,7 @@ # limitations under the License. import abc +import collections from typing import Dict, List, TYPE_CHECKING import numpy as np @@ -53,17 +54,31 @@ def run_sweep( resolver. Raises: - ValueError if this sampler has a device and the circuit is not - valid for the device. + ValueError: circuit is not valid for the sampler, due to invalid + repeated keys or incompatibility with the sampler's device. """ if self.device: self.device.validate_circuit(program) - measurements: Dict[str, np.ndarray] = {} + num_qubits: Dict[str, int] = {} + num_instances: Dict[str, int] = collections.Counter() for op in program.all_operations(): key = protocols.measurement_key_name(op, default=None) if key is not None: - measurements[key] = np.zeros((repetitions, len(op.qubits)), dtype=int) + n = len(op.qubits) + prev_n = num_qubits.setdefault(key, n) + if n != prev_n: + raise ValueError( + "Different num qubits for repeated measurement: " + f"key={key!r}, prev_n={prev_n}, n={n}" + ) + num_instances[key] += 1 return [ - study.ResultDict(params=param_resolver, measurements=measurements) + study.ResultDict( + params=param_resolver, + records={ + k: np.zeros((repetitions, num_instances[k], n), dtype=int) + for k, n in num_qubits.items() + }, + ) for param_resolver in study.to_resolvers(params) ] diff --git a/cirq-core/cirq/work/zeros_sampler_test.py b/cirq-core/cirq/work/zeros_sampler_test.py index e02f730a116..9aec2218a0c 100644 --- a/cirq-core/cirq/work/zeros_sampler_test.py +++ b/cirq-core/cirq/work/zeros_sampler_test.py @@ -52,6 +52,28 @@ def test_sample(): assert np.all(result1 == result2) +def test_repeated_keys(): + q0, q1, q2 = cirq.LineQubit.range(3) + + c = cirq.Circuit( + cirq.measure(q0, key='a'), + cirq.measure(q1, q2, key='b'), + cirq.measure(q0, key='a'), + cirq.measure(q1, q2, key='b'), + cirq.measure(q1, q2, key='b'), + ) + result = cirq.ZerosSampler().run(c, repetitions=10) + assert result.records['a'].shape == (10, 2, 1) + assert result.records['b'].shape == (10, 3, 2) + + c2 = cirq.Circuit( + cirq.measure(q0, key='a'), + cirq.measure(q1, q2, key='a'), + ) + with pytest.raises(ValueError, match="Different num qubits for repeated measurement"): + cirq.ZerosSampler().run(c2, repetitions=10) + + class OnlyMeasurementsDevice(cirq.Device): def validate_operation(self, operation: 'cirq.Operation') -> None: if not cirq.is_measurement(operation):