Skip to content

Commit

Permalink
Add control_keys protocol (simplified) (#4610)
Browse files Browse the repository at this point in the history
Use control keys in circuit construction to block controlled gates "earliest" append method from falling back to before the measurement.

Parts 5&7 of https://tinyurl.com/cirq-feedforward. (I grouped them together in a single PR so that circuit.append could be a POC that control_keys protocol design works as intended).

Replaces/simplifies #4490 since we will handle "extern" control keys in subcircuits as a separate PR.
  • Loading branch information
daxfohl authored Nov 4, 2021
1 parent d137a7d commit 00230f0
Show file tree
Hide file tree
Showing 7 changed files with 132 additions and 4 deletions.
2 changes: 2 additions & 0 deletions cirq-core/cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,7 @@
CircuitDiagramInfo,
CircuitDiagramInfoArgs,
commutes,
control_keys,
decompose,
decompose_once,
decompose_once_with_qubits,
Expand Down Expand Up @@ -550,6 +551,7 @@
SupportsConsistentApplyUnitary,
SupportsCircuitDiagramInfo,
SupportsCommutes,
SupportsControlKey,
SupportsDecompose,
SupportsDecomposeWithQubits,
SupportsEqualUpToGlobalPhase,
Expand Down
10 changes: 6 additions & 4 deletions cirq-core/cirq/circuits/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1868,9 +1868,14 @@ def transform_qubits(
def _prev_moment_available(self, op: 'cirq.Operation', end_moment_index: int) -> Optional[int]:
last_available = end_moment_index
k = end_moment_index
op_control_keys = protocols.control_keys(op)
op_qubits = op.qubits
while k > 0:
k -= 1
if not self._can_commute_past(k, op):
moment = self._moments[k]
if moment.operates_on(op_qubits) or (
op_control_keys & protocols.measurement_key_objs(moment)
):
return last_available
if self._can_add_op_at(k, op):
last_available = k
Expand Down Expand Up @@ -1924,9 +1929,6 @@ def _can_add_op_at(self, moment_index: int, operation: 'cirq.Operation') -> bool
return True
return self._device.can_add_operation_into_moment(operation, self._moments[moment_index])

def _can_commute_past(self, moment_index: int, operation: 'cirq.Operation') -> bool:
return not self._moments[moment_index].operates_on(operation.qubits)

def insert(
self,
index: int,
Expand Down
30 changes: 30 additions & 0 deletions cirq-core/cirq/circuits/circuit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,21 @@ def validate_moment(self, moment):
moment_and_op_type_validating_device = _MomentAndOpTypeValidatingDeviceType()


class ControlOp(cirq.Operation):
def __init__(self, keys):
self._keys = keys

def with_qubits(self, *new_qids):
pass # coverage: ignore

@property
def qubits(self):
return [] # coverage: ignore

def _control_keys_(self):
return self._keys


def test_alignment():
assert repr(cirq.Alignment.LEFT) == 'cirq.Alignment.LEFT'
assert repr(cirq.Alignment.RIGHT) == 'cirq.Alignment.RIGHT'
Expand Down Expand Up @@ -224,6 +239,21 @@ def test_append_single():
assert c == cirq.Circuit([cirq.Moment([cirq.X(a)])])


def test_append_control_key():
q = cirq.LineQubit(0)

c = cirq.Circuit()
c.append(cirq.measure(q, key='a'))
c.append(ControlOp([cirq.MeasurementKey('a')]))
assert len(c) == 2

c = cirq.Circuit()
c.append(cirq.measure(q, key='a'))
c.append(ControlOp([cirq.MeasurementKey('b')]))
c.append(ControlOp([cirq.MeasurementKey('b')]))
assert len(c) == 1


def test_append_multiple():
a = cirq.NamedQubit('a')
b = cirq.NamedQubit('b')
Expand Down
4 changes: 4 additions & 0 deletions cirq-core/cirq/protocols/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@
definitely_commutes,
SupportsCommutes,
)
from cirq.protocols.control_key_protocol import (
control_keys,
SupportsControlKey,
)
from cirq.protocols.circuit_diagram_info_protocol import (
circuit_diagram_info,
CircuitDiagramInfo,
Expand Down
60 changes: 60 additions & 0 deletions cirq-core/cirq/protocols/control_key_protocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright 2021 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Protocol for object that have control keys."""

from typing import AbstractSet, Any, Iterable, TYPE_CHECKING

from typing_extensions import Protocol

from cirq._doc import doc_private

if TYPE_CHECKING:
import cirq


class SupportsControlKey(Protocol):
"""An object that is a has a classical control key or keys.
Control keys are used in referencing the results of a measurement.
Users should implement `_control_keys_` returning an iterable of
`MeasurementKey`.
"""

@doc_private
def _control_keys_(self) -> Iterable['cirq.MeasurementKey']:
"""Return the keys for controls referenced by the receiving object.
Returns:
The measurement keys the value is controlled by. If the value is not
classically controlled, the result is the empty tuple.
"""


def control_keys(val: Any) -> AbstractSet['cirq.MeasurementKey']:
"""Gets the keys that the value is classically controlled by.
Args:
val: The object that may be classically controlled.
Returns:
The measurement keys the value is controlled by. If the value is not
classically controlled, the result is the empty tuple.
"""
getter = getattr(val, '_control_keys_', None)
result = NotImplemented if getter is None else getter()
if result is not NotImplemented and result is not None:
return set(result)

return set()
29 changes: 29 additions & 0 deletions cirq-core/cirq/protocols/control_key_protocol_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright 2021 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import cirq


def test_control_key():
class Named:
def _control_keys_(self):
return [cirq.MeasurementKey('key')]

class NoImpl:
def _control_keys_(self):
return NotImplemented

assert cirq.control_keys(Named()) == {cirq.MeasurementKey('key')}
assert not cirq.control_keys(NoImpl())
assert not cirq.control_keys(5)
1 change: 1 addition & 0 deletions cirq-core/cirq/protocols/json_test_data/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@
'SupportsCircuitDiagramInfo',
'SupportsCommutes',
'SupportsConsistentApplyUnitary',
'SupportsControlKey',
'SupportsDecompose',
'SupportsDecomposeWithQubits',
'SupportsEqualUpToGlobalPhase',
Expand Down

0 comments on commit 00230f0

Please sign in to comment.