Skip to content

Commit

Permalink
wip - Support repeated measurement keys in cirq.Result
Browse files Browse the repository at this point in the history
  • Loading branch information
maffoo committed Jan 31, 2022
1 parent b0519b3 commit faa7446
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 8 deletions.
59 changes: 56 additions & 3 deletions cirq-core/cirq/study/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,17 @@ def measurements(self) -> Mapping[str, np.ndarray]:
qubits for the corresponding measurements.
"""

@property
@abc.abstractmethod
def records(self) -> Mapping[str, np.ndarray]:
"""A mapping from measurement key to measurement records.
The value for each key is a 3-D array of booleans, with the first index
running over circuit repetitions, the second index running over instances
of the measurement key in the circuit, and the third index running over
the qubits for the corresponding measurements.
"""

@property
@abc.abstractmethod
def data(self) -> pd.DataFrame:
Expand Down Expand Up @@ -296,7 +307,8 @@ def __init__(
self,
*, # Forces keyword args.
params: resolver.ParamResolver,
measurements: Mapping[str, np.ndarray],
measurements: Optional[Mapping[str, np.ndarray]] = None,
records: Optional[Mapping[str, np.ndarray]] = None,
) -> None:
"""Inits Result.
Expand All @@ -307,9 +319,20 @@ def __init__(
with the first index running over the repetitions, and the
second index running over the qubits for the corresponding
measurements.
records: A dictionary from measurement gate key to measurement
results. The value for each key is a 3D array of booleans,
with the first index running over the repetitions, the second
index running over "instances" of that key in the circuit, and
the last index running over the qubits for the corresponding
measurements.
"""
if measurements is None and records is None:
# For backwards compatibility, allow constructing with None.
measurements = {}
records = {}
self._params = params
self._measurements = measurements
self._records = records
self._data: Optional[pd.DataFrame] = None

@property
Expand All @@ -318,15 +341,45 @@ def params(self) -> 'cirq.ParamResolver':

@property
def measurements(self) -> Mapping[str, np.ndarray]:
if self._measurements is None:
assert self._records is not None
self._measurements = {}
for key, data in self._records.items():
reps, instances, qubits = data.shape
if instances != 1:
raise ValueError('Cannot extract 2D measurements for repeated keys')
self._measurements[key] = data.reshape((reps, qubits))
return self._measurements

@property
def records(self) -> Mapping[str, np.ndarray]:
if self._records is None:
assert self._measurements is not None
self._records = {
key: data[:, np.newaxis, :] for key, data in self._measurements.items()
}
return self._records

@property
def repetitions(self) -> int:
if self._records is not None:
if not self._records:
return 0
# Get the length quickly from one of the keyed results.
return len(next(iter(self._records.values())))
else:
if not self._measurements:
return 0
# Get the length quickly from one of the keyed results.
return len(next(iter(self._measurements.values())))

@property
def data(self) -> pd.DataFrame:
if self._data is None:
# Convert to a DataFrame with columns as measurement keys, rows as
# repetitions and a big endian integer for individual measurements.
converted_dict = {}
for key, val in self._measurements.items():
for key, val in self.measurements.items():
converted_dict[key] = [value.big_endian_bits_to_int(m_vals) for m_vals in val]
# Note that when a numpy array is produced from this data frame,
# Pandas will try to use np.int64 as dtype, but will upgrade to
Expand Down Expand Up @@ -393,7 +446,7 @@ def _pack_digits(digits: np.ndarray, pack_bits: str = 'auto') -> Tuple[str, bool
if pack_bits == 'force':
return _pack_bits(digits), True
if pack_bits not in ['auto', 'never']:
raise ValueError("Please set `pack_bits` to 'auto', " "'force', or 'never'.")
raise ValueError("Please set `pack_bits` to 'auto', 'force', or 'never'.")
# Do error checking here, otherwise the following logic will work
# for both "auto" and "never".

Expand Down
48 changes: 48 additions & 0 deletions cirq-core/cirq/study/result_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,54 @@ def test_from_single_parameter_set_deprecation():
assert result.repetitions == 0


def test_construct_from_measurements():
r = cirq.ResultDict(
params=None,
measurements={
'a': np.array([[0, 0], [1, 1]]),
'b': np.array([[0, 0, 0], [1, 1, 1]]),
},
)
assert np.all(r.measurements['a'] == np.array([[0, 0], [1, 1]]))
assert np.all(r.measurements['b'] == np.array([[0, 0, 0], [1, 1, 1]]))
assert np.all(r.records['a'] == np.array([[[0, 0]], [[1, 1]]]))
assert np.all(r.records['b'] == np.array([[[0, 0, 0]], [[1, 1, 1]]]))


def test_construct_from_repeated_measurements():
r = cirq.ResultDict(
params=None,
records={
'a': np.array([[[0, 0], [0, 1]], [[1, 0], [1, 1]]]),
'b': np.array([[[0, 0, 0]], [[1, 1, 1]]]),
},
)
with pytest.raises(ValueError):
_ = r.measurements
assert np.all(r.records['a'] == np.array([[[0, 0], [0, 1]], [[1, 0], [1, 1]]]))
assert np.all(r.records['b'] == np.array([[[0, 0, 0]], [[1, 1, 1]]]))
assert r.repetitions == 2

r2 = cirq.ResultDict(
params=None,
records={
'a': np.array([[[0, 0]], [[1, 1]]]),
'b': np.array([[[0, 0, 0]], [[1, 1, 1]]]),
},
)
assert np.all(r2.measurements['a'] == np.array([[0, 0], [1, 1]]))
assert np.all(r2.measurements['b'] == np.array([[0, 0, 0], [1, 1, 1]]))
assert np.all(r2.records['a'] == np.array([[[0, 0]], [[1, 1]]]))
assert np.all(r2.records['b'] == np.array([[[0, 0, 0]], [[1, 1, 1]]]))
assert r2.repetitions == 2


def test_empty_measurements():
assert cirq.ResultDict(params=None).repetitions == 0
assert cirq.ResultDict(params=None, measurements={}).repetitions == 0
assert cirq.ResultDict(params=None, records={}).repetitions == 0


def test_str():
result = cirq.ResultDict(
params=cirq.ParamResolver({}),
Expand Down
25 changes: 20 additions & 5 deletions cirq-core/cirq/work/zeros_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import abc
import collections
from typing import Dict, List, TYPE_CHECKING

import numpy as np
Expand Down Expand Up @@ -53,17 +54,31 @@ def run_sweep(
resolver.
Raises:
ValueError if this sampler has a device and the circuit is not
valid for the device.
ValueError: circuit is not valid for the sampler, due to invalid
repeated keys or incompatibility with the sampler's device.
"""
if self.device:
self.device.validate_circuit(program)
measurements: Dict[str, np.ndarray] = {}
num_qubits: Dict[str, int] = {}
num_instances: Dict[str, int] = collections.Counter()
for op in program.all_operations():
key = protocols.measurement_key_name(op, default=None)
if key is not None:
measurements[key] = np.zeros((repetitions, len(op.qubits)), dtype=int)
n = len(op.qubits)
prev_n = num_qubits.setdefault(key, n)
if n != prev_n:
raise ValueError(
"Different num qubits for repeated measurement: "
f"key={key!r}, prev_n={prev_n}, n={n}"
)
num_instances[key] += 1
return [
study.ResultDict(params=param_resolver, measurements=measurements)
study.ResultDict(
params=param_resolver,
records={
k: np.zeros((repetitions, num_instances[k], n), dtype=int)
for k, n in num_qubits.items()
},
)
for param_resolver in study.to_resolvers(params)
]
22 changes: 22 additions & 0 deletions cirq-core/cirq/work/zeros_sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,28 @@ def test_sample():
assert np.all(result1 == result2)


def test_repeated_keys():
q0, q1, q2 = cirq.LineQubit.range(3)

c = cirq.Circuit(
cirq.measure(q0, key='a'),
cirq.measure(q1, q2, key='b'),
cirq.measure(q0, key='a'),
cirq.measure(q1, q2, key='b'),
cirq.measure(q1, q2, key='b'),
)
result = cirq.ZerosSampler().run(c, repetitions=10)
assert result.records['a'].shape == (10, 2, 1)
assert result.records['b'].shape == (10, 3, 2)

c2 = cirq.Circuit(
cirq.measure(q0, key='a'),
cirq.measure(q1, q2, key='a'),
)
with pytest.raises(ValueError, match="Different num qubits for repeated measurement"):
cirq.ZerosSampler().run(c2, repetitions=10)


class OnlyMeasurementsDevice(cirq.Device):
def validate_operation(self, operation: 'cirq.Operation') -> None:
if not cirq.is_measurement(operation):
Expand Down

0 comments on commit faa7446

Please sign in to comment.