Skip to content

Commit

Permalink
Add measurement_key_obj and measurement_key_objs protocols. (quan…
Browse files Browse the repository at this point in the history
…tumlib#4497)

* _measurement_key_obj[s]_ methods to measurement gate and operation

* Add protocols and magic methods. Also tests.

* Fix channels

* Add more caching

* Fix all types

* Renamed cached circuitop mkeys

* remove unused value import

Co-authored-by: Orion Martin <[email protected]>
  • Loading branch information
smitsanghavi and 95-martin-orion authored Oct 8, 2021
1 parent f11846e commit 978dbac
Show file tree
Hide file tree
Showing 18 changed files with 343 additions and 96 deletions.
2 changes: 2 additions & 0 deletions cirq-core/cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,8 +530,10 @@
kraus,
measurement_key,
measurement_key_name,
measurement_key_obj,
measurement_keys,
measurement_key_names,
measurement_key_objs,
mixture,
mul,
num_qubits,
Expand Down
18 changes: 12 additions & 6 deletions cirq-core/cirq/circuits/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@

import cirq._version
from cirq._compat import deprecated
from cirq import devices, ops, protocols, qis
from cirq import devices, ops, protocols, value, qis
from cirq.circuits._bucket_priority_queue import BucketPriorityQueue
from cirq.circuits.circuit_operation import CircuitOperation
from cirq.circuits.insert_strategy import InsertStrategy
Expand Down Expand Up @@ -909,6 +909,12 @@ def qid_shape(
qids = ops.QubitOrder.as_qubit_order(qubit_order).order_for(self.all_qubits())
return protocols.qid_shape(qids)

def all_measurement_key_objs(self) -> AbstractSet[value.MeasurementKey]:
return {key for op in self.all_operations() for key in protocols.measurement_key_objs(op)}

def _measurement_key_objs_(self) -> AbstractSet[value.MeasurementKey]:
return self.all_measurement_key_objs()

@deprecated(deadline='v0.13', fix='use all_measurement_key_names instead')
def all_measurement_keys(self) -> AbstractSet[str]:
return self.all_measurement_key_names()
Expand Down Expand Up @@ -2233,7 +2239,7 @@ def batch_insert(self, insertions: Iterable[Tuple[int, 'cirq.OP_TREE']]) -> None
shift = 0
# Note: python `sorted` is guaranteed to be stable. This matters.
insertions = sorted(insertions, key=lambda e: e[0])
groups = _group_until_different(insertions, key=lambda e: e[0], value=lambda e: e[1])
groups = _group_until_different(insertions, key=lambda e: e[0], val=lambda e: e[1])
for i, group in groups:
insert_index = i + shift
next_index = copy.insert(insert_index, reversed(group), InsertStrategy.EARLIEST)
Expand Down Expand Up @@ -2617,19 +2623,19 @@ def _group_until_different(

@overload
def _group_until_different(
items: Iterable[TIn], key: Callable[[TIn], TKey], value: Callable[[TIn], TOut]
items: Iterable[TIn], key: Callable[[TIn], TKey], val: Callable[[TIn], TOut]
) -> Iterable[Tuple[TKey, List[TOut]]]:
pass


def _group_until_different(items: Iterable[TIn], key: Callable[[TIn], TKey], value=lambda e: e):
def _group_until_different(items: Iterable[TIn], key: Callable[[TIn], TKey], val=lambda e: e):
"""Groups runs of items that are identical according to a keying function.
Args:
items: The items to group.
key: If two adjacent items produce the same output from this function,
they will be grouped.
value: Maps each item into a value to put in the group. Defaults to the
val: Maps each item into a value to put in the group. Defaults to the
item itself.
Examples:
Expand All @@ -2645,4 +2651,4 @@ def _group_until_different(items: Iterable[TIn], key: Callable[[TIn], TKey], val
Yields:
Tuples containing the group key and item values.
"""
return ((k, [value(i) for i in v]) for (k, v) in groupby(items, key))
return ((k, [val(i) for i in v]) for (k, v) in groupby(items, key))
43 changes: 26 additions & 17 deletions cirq-core/cirq/circuits/circuit_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ class CircuitOperation(ops.Operation):
"""

_hash: Optional[int] = dataclasses.field(default=None, init=False)
_cached_measurement_key_objs: Optional[AbstractSet[value.MeasurementKey]] = dataclasses.field(
default=None, init=False
)

circuit: 'cirq.FrozenCircuit'
repetitions: int = 1
Expand Down Expand Up @@ -172,21 +175,27 @@ def _qid_shape_(self) -> Tuple[int, ...]:
def _is_measurement_(self) -> bool:
return self.circuit._is_measurement_()

def _measurement_key_objs_(self) -> AbstractSet[value.MeasurementKey]:
if self._cached_measurement_key_objs is None:
circuit_keys = protocols.measurement_key_objs(self.circuit)
if self.repetition_ids is not None:
circuit_keys = {
key.with_key_path_prefix(repetition_id)
for repetition_id in self.repetition_ids
for key in circuit_keys
}
object.__setattr__(
self,
'_cached_measurement_key_objs',
{
protocols.with_measurement_key_mapping(key, self.measurement_key_map)
for key in circuit_keys
},
)
return self._cached_measurement_key_objs # type: ignore

def _measurement_key_names_(self) -> AbstractSet[str]:
circuit_keys = [
value.MeasurementKey.parse_serialized(key_str)
for key_str in self.circuit.all_measurement_key_names()
]
if self.repetition_ids is not None:
circuit_keys = [
key.with_key_path_prefix(repetition_id)
for repetition_id in self.repetition_ids
for key in circuit_keys
]
return {
str(protocols.with_measurement_key_mapping(key, self.measurement_key_map))
for key in circuit_keys
}
return {str(key) for key in self._measurement_key_objs_()}

def _parameter_names_(self) -> AbstractSet[str]:
return {
Expand Down Expand Up @@ -523,14 +532,14 @@ def with_measurement_key_mapping(self, key_map: Dict[str, str]) -> 'CircuitOpera
keys than this operation.
"""
new_map = {}
for k in self.circuit.all_measurement_key_names():
k = value.MeasurementKey.parse_serialized(k).name
for k_obj in self.circuit.all_measurement_key_objs():
k = k_obj.name
k_new = self.measurement_key_map.get(k, k)
k_new = key_map.get(k_new, k_new)
if k_new != k:
new_map[k] = k_new
new_op = self.replace(measurement_key_map=new_map)
if len(new_op._measurement_key_names_()) != len(self._measurement_key_names_()):
if len(new_op._measurement_key_objs_()) != len(self._measurement_key_objs_()):
raise ValueError(
f'Collision in measurement key map composition. Original map:\n'
f'{self.measurement_key_map}\nApplied changes: {key_map}'
Expand Down
3 changes: 3 additions & 0 deletions cirq-core/cirq/circuits/circuit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4309,6 +4309,9 @@ def _measurement_key_name_(self):

# Big case.
assert c.all_measurement_key_names() == {'x', 'y', 'xy', 'test'}
assert c.all_measurement_key_names() == cirq.measurement_key_names(c)
assert c.all_measurement_key_names() == c.all_measurement_key_objs()

with cirq.testing.assert_deprecated(deadline="v0.13"):
assert c.all_measurement_key_names() == c.all_measurement_keys()

Expand Down
21 changes: 15 additions & 6 deletions cirq-core/cirq/circuits/frozen_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

import numpy as np

from cirq import devices, ops, protocols
from cirq import devices, ops, protocols, value
from cirq.circuits import AbstractCircuit, Alignment, Circuit
from cirq.circuits.insert_strategy import InsertStrategy
from cirq.type_workarounds import NotImplementedType
Expand Down Expand Up @@ -74,7 +74,7 @@ def __init__(
self._all_qubits: Optional[FrozenSet['cirq.Qid']] = None
self._all_operations: Optional[Tuple[ops.Operation, ...]] = None
self._has_measurements: Optional[bool] = None
self._all_measurement_key_names: Optional[AbstractSet[str]] = None
self._all_measurement_key_objs: Optional[AbstractSet[value.MeasurementKey]] = None
self._are_all_measurements_terminal: Optional[bool] = None

@property
Expand Down Expand Up @@ -130,10 +130,13 @@ def has_measurements(self) -> bool:
self._has_measurements = super().has_measurements()
return self._has_measurements

def all_measurement_key_names(self) -> AbstractSet[str]:
if self._all_measurement_key_names is None:
self._all_measurement_key_names = super().all_measurement_key_names()
return self._all_measurement_key_names
def all_measurement_key_objs(self) -> AbstractSet[value.MeasurementKey]:
if self._all_measurement_key_objs is None:
self._all_measurement_key_objs = super().all_measurement_key_objs()
return self._all_measurement_key_objs

def _measurement_key_objs_(self) -> AbstractSet[value.MeasurementKey]:
return self.all_measurement_key_objs()

def are_all_measurements_terminal(self) -> bool:
if self._are_all_measurements_terminal is None:
Expand All @@ -142,6 +145,12 @@ def are_all_measurements_terminal(self) -> bool:

# End of memoized methods.

def all_measurement_key_names(self) -> AbstractSet[str]:
return {str(key) for key in self.all_measurement_key_objs()}

def _measurement_key_names_(self) -> AbstractSet[str]:
return self.all_measurement_key_names()

def __add__(self, other) -> 'FrozenCircuit':
return (self.unfreeze() + other).freeze()

Expand Down
15 changes: 13 additions & 2 deletions cirq-core/cirq/ops/gate_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
Collection,
Dict,
FrozenSet,
Iterable,
List,
Optional,
Sequence,
Expand Down Expand Up @@ -245,7 +244,7 @@ def _measurement_key_name_(self) -> Optional[str]:
return getter()
return NotImplemented

def _measurement_key_names_(self) -> Optional[Iterable[str]]:
def _measurement_key_names_(self) -> Optional[AbstractSet[str]]:
getter = getattr(self.gate, '_measurement_key_names_', None)
if getter is not None:
return getter()
Expand All @@ -259,6 +258,18 @@ def _measurement_key_names_(self) -> Optional[Iterable[str]]:
return getter()
return NotImplemented

def _measurement_key_obj_(self) -> Optional[value.MeasurementKey]:
getter = getattr(self.gate, '_measurement_key_obj_', None)
if getter is not None:
return getter()
return NotImplemented

def _measurement_key_objs_(self) -> Optional[AbstractSet[value.MeasurementKey]]:
getter = getattr(self.gate, '_measurement_key_objs_', None)
if getter is not None:
return getter()
return NotImplemented

def _act_on_(self, args: 'cirq.ActOnArgs'):
getter = getattr(self.gate, '_act_on_', None)
if getter is not None:
Expand Down
7 changes: 6 additions & 1 deletion cirq-core/cirq/ops/kraus_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,12 @@ def num_qubits(self) -> int:
def _kraus_(self):
return self._kraus_ops

def _measurement_key_name_(self):
def _measurement_key_name_(self) -> str:
if self._key is None:
return NotImplemented
return str(self._key)

def _measurement_key_obj_(self) -> value.MeasurementKey:
if self._key is None:
return NotImplemented
return self._key
Expand Down
5 changes: 4 additions & 1 deletion cirq-core/cirq/ops/measurement_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,12 @@ def full_invert_mask(self):
def _is_measurement_(self) -> bool:
return True

def _measurement_key_name_(self):
def _measurement_key_name_(self) -> str:
return self.key

def _measurement_key_obj_(self) -> value.MeasurementKey:
return self.mkey

def _kraus_(self):
size = np.prod(self._qid_shape, dtype=np.int64)

Expand Down
7 changes: 6 additions & 1 deletion cirq-core/cirq/ops/mixed_unitary_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,12 @@ def num_qubits(self) -> int:
def _mixture_(self):
return self._mixture

def _measurement_key_name_(self):
def _measurement_key_name_(self) -> str:
if self._key is None:
return NotImplemented
return str(self._key)

def _measurement_key_obj_(self) -> value.MeasurementKey:
if self._key is None:
return NotImplemented
return self._key
Expand Down
15 changes: 9 additions & 6 deletions cirq-core/cirq/ops/moment.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
Union,
)

from cirq import protocols, ops
from cirq import protocols, ops, value
from cirq.ops import raw_types
from cirq.protocols import circuit_diagram_info_protocol
from cirq.type_workarounds import NotImplementedType
Expand Down Expand Up @@ -92,7 +92,7 @@ def __init__(self, *contents: 'cirq.OP_TREE') -> None:
self._qubit_to_op[q] = op

self._qubits = frozenset(self._qubit_to_op.keys())
self._measurement_key_names: Optional[AbstractSet[str]] = None
self._measurement_key_objs: Optional[AbstractSet[value.MeasurementKey]] = None

@property
def operations(self) -> Tuple['cirq.Operation', ...]:
Expand Down Expand Up @@ -220,11 +220,14 @@ def _with_measurement_key_mapping_(self, key_map: Dict[str, str]):
)

def _measurement_key_names_(self) -> AbstractSet[str]:
if self._measurement_key_names is None:
self._measurement_key_names = {
key for op in self.operations for key in protocols.measurement_key_names(op)
return {str(key) for key in self._measurement_key_objs_()}

def _measurement_key_objs_(self) -> AbstractSet[value.MeasurementKey]:
if self._measurement_key_objs is None:
self._measurement_key_objs = {
key for op in self.operations for key in protocols.measurement_key_objs(op)
}
return self._measurement_key_names
return self._measurement_key_objs

def _with_key_path_(self, path: Tuple[str, ...]):
return Moment(
Expand Down
1 change: 1 addition & 0 deletions cirq-core/cirq/ops/moment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ def test_measurement_keys():
assert not cirq.is_measurement(m)

m2 = cirq.Moment(cirq.measure(a, b, key='foo'))
assert cirq.measurement_key_objs(m2) == {cirq.MeasurementKey('foo')}
assert cirq.measurement_key_names(m2) == {'foo'}
assert cirq.is_measurement(m2)

Expand Down
3 changes: 3 additions & 0 deletions cirq-core/cirq/ops/pauli_measurement_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ def _is_measurement_(self) -> bool:
def _measurement_key_name_(self) -> str:
return self.key

def _measurement_key_obj_(self) -> value.MeasurementKey:
return self.mkey

def observable(self) -> 'cirq.DensePauliString':
"""Pauli observable which should be measured by the gate."""
return dense_pauli_string.DensePauliString(self._observable)
Expand Down
7 changes: 5 additions & 2 deletions cirq-core/cirq/ops/raw_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,8 +719,11 @@ def _has_kraus_(self) -> bool:
def _kraus_(self) -> Union[Tuple[np.ndarray], NotImplementedType]:
return protocols.kraus(self.sub_operation, NotImplemented)

def _measurement_key_name_(self) -> str:
return protocols.measurement_key_name(self.sub_operation, NotImplemented)
def _measurement_key_names_(self) -> AbstractSet[str]:
return protocols.measurement_key_names(self.sub_operation)

def _measurement_key_objs_(self) -> AbstractSet[value.MeasurementKey]:
return protocols.measurement_key_objs(self.sub_operation)

def _is_measurement_(self) -> bool:
sub = getattr(self.sub_operation, "_is_measurement_", None)
Expand Down
3 changes: 3 additions & 0 deletions cirq-core/cirq/ops/raw_types_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,9 @@ def test_tagged_operation_forwards_protocols():
assert np.isclose(cirq.kraus(h), cirq.kraus(tagged_h)).all()

assert cirq.measurement_key_name(cirq.measure(q1, key='blah').with_tags(tag)) == 'blah'
assert cirq.measurement_key_obj(
cirq.measure(q1, key='blah').with_tags(tag)
) == cirq.MeasurementKey('blah')

parameterized_op = cirq.XPowGate(exponent=sympy.Symbol('t'))(q1).with_tags(tag)
assert cirq.is_parameterized(parameterized_op)
Expand Down
2 changes: 2 additions & 0 deletions cirq-core/cirq/protocols/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,10 @@
is_measurement,
measurement_key,
measurement_key_name,
measurement_key_obj,
measurement_keys,
measurement_key_names,
measurement_key_objs,
with_key_path,
with_measurement_key_mapping,
SupportsMeasurementKey,
Expand Down
Loading

0 comments on commit 978dbac

Please sign in to comment.