Skip to content

Commit

Permalink
Support repeated measurement keys in cirq.Result (#4555)
Browse files Browse the repository at this point in the history
Part of #4274
  • Loading branch information
maffoo authored Feb 14, 2022
1 parent 872b22a commit dd55a86
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 14 deletions.
59 changes: 56 additions & 3 deletions cirq-core/cirq/study/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,17 @@ def measurements(self) -> Mapping[str, np.ndarray]:
qubits for the corresponding measurements.
"""

@property
@abc.abstractmethod
def records(self) -> Mapping[str, np.ndarray]:
"""A mapping from measurement key to measurement records.
The value for each key is a 3-D array of booleans, with the first index
running over circuit repetitions, the second index running over instances
of the measurement key in the circuit, and the third index running over
the qubits for the corresponding measurements.
"""

@property
@abc.abstractmethod
def data(self) -> pd.DataFrame:
Expand Down Expand Up @@ -315,7 +326,8 @@ def __init__(
self,
*, # Forces keyword args.
params: resolver.ParamResolver,
measurements: Mapping[str, np.ndarray],
measurements: Optional[Mapping[str, np.ndarray]] = None,
records: Optional[Mapping[str, np.ndarray]] = None,
) -> None:
"""Inits Result.
Expand All @@ -326,9 +338,20 @@ def __init__(
with the first index running over the repetitions, and the
second index running over the qubits for the corresponding
measurements.
records: A dictionary from measurement gate key to measurement
results. The value for each key is a 3D array of booleans,
with the first index running over the repetitions, the second
index running over "instances" of that key in the circuit, and
the last index running over the qubits for the corresponding
measurements.
"""
if measurements is None and records is None:
# For backwards compatibility, allow constructing with None.
measurements = {}
records = {}
self._params = params
self._measurements = measurements
self._records = records
self._data: Optional[pd.DataFrame] = None

@property
Expand All @@ -337,12 +360,42 @@ def params(self) -> 'cirq.ParamResolver':

@property
def measurements(self) -> Mapping[str, np.ndarray]:
if self._measurements is None:
assert self._records is not None
self._measurements = {}
for key, data in self._records.items():
reps, instances, qubits = data.shape
if instances != 1:
raise ValueError('Cannot extract 2D measurements for repeated keys')
self._measurements[key] = data.reshape((reps, qubits))
return self._measurements

@property
def records(self) -> Mapping[str, np.ndarray]:
if self._records is None:
assert self._measurements is not None
self._records = {
key: data[:, np.newaxis, :] for key, data in self._measurements.items()
}
return self._records

@property
def repetitions(self) -> int:
if self._records is not None:
if not self._records:
return 0
# Get the length quickly from one of the keyed results.
return len(next(iter(self._records.values())))
else:
if not self._measurements:
return 0
# Get the length quickly from one of the keyed results.
return len(next(iter(self._measurements.values())))

@property
def data(self) -> pd.DataFrame:
if self._data is None:
self._data = self.dataframe_from_measurements(self._measurements)
self._data = self.dataframe_from_measurements(self.measurements)
return self._data

def __repr__(self) -> str:
Expand Down Expand Up @@ -404,7 +457,7 @@ def _pack_digits(digits: np.ndarray, pack_bits: str = 'auto') -> Tuple[str, bool
if pack_bits == 'force':
return _pack_bits(digits), True
if pack_bits not in ['auto', 'never']:
raise ValueError("Please set `pack_bits` to 'auto', " "'force', or 'never'.")
raise ValueError("Please set `pack_bits` to 'auto', 'force', or 'never'.")
# Do error checking here, otherwise the following logic will work
# for both "auto" and "never".

Expand Down
48 changes: 48 additions & 0 deletions cirq-core/cirq/study/result_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,54 @@ def test_from_single_parameter_set_deprecation():
assert result.repetitions == 0


def test_construct_from_measurements():
r = cirq.ResultDict(
params=None,
measurements={
'a': np.array([[0, 0], [1, 1]]),
'b': np.array([[0, 0, 0], [1, 1, 1]]),
},
)
assert np.all(r.measurements['a'] == np.array([[0, 0], [1, 1]]))
assert np.all(r.measurements['b'] == np.array([[0, 0, 0], [1, 1, 1]]))
assert np.all(r.records['a'] == np.array([[[0, 0]], [[1, 1]]]))
assert np.all(r.records['b'] == np.array([[[0, 0, 0]], [[1, 1, 1]]]))


def test_construct_from_repeated_measurements():
r = cirq.ResultDict(
params=None,
records={
'a': np.array([[[0, 0], [0, 1]], [[1, 0], [1, 1]]]),
'b': np.array([[[0, 0, 0]], [[1, 1, 1]]]),
},
)
with pytest.raises(ValueError):
_ = r.measurements
assert np.all(r.records['a'] == np.array([[[0, 0], [0, 1]], [[1, 0], [1, 1]]]))
assert np.all(r.records['b'] == np.array([[[0, 0, 0]], [[1, 1, 1]]]))
assert r.repetitions == 2

r2 = cirq.ResultDict(
params=None,
records={
'a': np.array([[[0, 0]], [[1, 1]]]),
'b': np.array([[[0, 0, 0]], [[1, 1, 1]]]),
},
)
assert np.all(r2.measurements['a'] == np.array([[0, 0], [1, 1]]))
assert np.all(r2.measurements['b'] == np.array([[0, 0, 0], [1, 1, 1]]))
assert np.all(r2.records['a'] == np.array([[[0, 0]], [[1, 1]]]))
assert np.all(r2.records['b'] == np.array([[[0, 0, 0]], [[1, 1, 1]]]))
assert r2.repetitions == 2


def test_empty_measurements():
assert cirq.ResultDict(params=None).repetitions == 0
assert cirq.ResultDict(params=None, measurements={}).repetitions == 0
assert cirq.ResultDict(params=None, records={}).repetitions == 0


def test_str():
result = cirq.ResultDict(
params=cirq.ParamResolver({}),
Expand Down
34 changes: 33 additions & 1 deletion cirq-core/cirq/work/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@
"""Abstract base class for things sampling quantum circuits."""

import abc
import collections
from typing import Dict, FrozenSet, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union

import pandas as pd

from cirq import study, ops
from cirq import ops, protocols, study
from cirq.work.observable_measurement import (
measure_observables,
RepetitionsStoppingCriteria,
Expand Down Expand Up @@ -365,3 +366,34 @@ def sample_expectation_values(
nested_results[param_i][psum_i] += res.mean

return nested_results

@staticmethod
def _get_measurement_shapes(
circuit: 'cirq.AbstractCircuit',
) -> Dict[str, Tuple[int, Tuple[int, ...]]]:
"""Gets the shapes of measurements in the given circuit.
Returns:
A mapping from measurement key name to a tuple of (num_instances, qid_shape),
where num_instances is the number of times that key appears in the circuit and
qid_shape is the shape of measured qubits for the key, as determined by the
`cirq.qid_shape` protocol.
Raises:
ValueError: if the qid_shape of different instances of the same measurement
key disagree.
"""
qid_shapes: Dict[str, Tuple[int, ...]] = {}
num_instances: Dict[str, int] = collections.Counter()
for op in circuit.all_operations():
key = protocols.measurement_key_name(op, default=None)
if key is not None:
qid_shape = protocols.qid_shape(op)
prev_qid_shape = qid_shapes.setdefault(key, qid_shape)
if qid_shape != prev_qid_shape:
raise ValueError(
"Different qid shapes for repeated measurement: "
f"key={key!r}, prev_qid_shape={prev_qid_shape}, qid_shape={qid_shape}"
)
num_instances[key] += 1
return {k: (num_instances[k], qid_shape) for k, qid_shape in qid_shapes.items()}
22 changes: 12 additions & 10 deletions cirq-core/cirq/work/zeros_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
# limitations under the License.

import abc
from typing import Dict, List, TYPE_CHECKING
from typing import List, TYPE_CHECKING

import numpy as np

from cirq import devices, work, study, protocols
from cirq import devices, work, study

if TYPE_CHECKING:
import cirq
Expand Down Expand Up @@ -53,17 +53,19 @@ def run_sweep(
resolver.
Raises:
ValueError if this sampler has a device and the circuit is not
valid for the device.
ValueError: circuit is not valid for the sampler, due to invalid
repeated keys or incompatibility with the sampler's device.
"""
if self.device:
self.device.validate_circuit(program)
measurements: Dict[str, np.ndarray] = {}
for op in program.all_operations():
key = protocols.measurement_key_name(op, default=None)
if key is not None:
measurements[key] = np.zeros((repetitions, len(op.qubits)), dtype=int)
shapes = self._get_measurement_shapes(program)
return [
study.ResultDict(params=param_resolver, measurements=measurements)
study.ResultDict(
params=param_resolver,
records={
k: np.zeros((repetitions, num_instances, len(qid_shape)), dtype=int)
for k, (num_instances, qid_shape) in shapes.items()
},
)
for param_resolver in study.to_resolvers(params)
]
22 changes: 22 additions & 0 deletions cirq-core/cirq/work/zeros_sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,28 @@ def test_sample():
assert np.all(result1 == result2)


def test_repeated_keys():
q0, q1, q2 = cirq.LineQubit.range(3)

c = cirq.Circuit(
cirq.measure(q0, key='a'),
cirq.measure(q1, q2, key='b'),
cirq.measure(q0, key='a'),
cirq.measure(q1, q2, key='b'),
cirq.measure(q1, q2, key='b'),
)
result = cirq.ZerosSampler().run(c, repetitions=10)
assert result.records['a'].shape == (10, 2, 1)
assert result.records['b'].shape == (10, 3, 2)

c2 = cirq.Circuit(
cirq.measure(q0, key='a'),
cirq.measure(q1, q2, key='a'),
)
with pytest.raises(ValueError, match="Different qid shapes for repeated measurement"):
cirq.ZerosSampler().run(c2, repetitions=10)


class OnlyMeasurementsDevice(cirq.Device):
def validate_operation(self, operation: 'cirq.Operation') -> None:
if not cirq.is_measurement(operation):
Expand Down

0 comments on commit dd55a86

Please sign in to comment.