From ff671ae211eb37af52c81e551098ada0220c2bad Mon Sep 17 00:00:00 2001 From: Dax Fohl Date: Thu, 23 Dec 2021 12:47:12 -0800 Subject: [PATCH] Allow sympy expressions as classical controls (#4740) Part 14 of https://tinyurl.com/cirq-feedforward. Adds the ability to create classical control conditions based on sympy expressions. To account for the fact that measurement key strings can contain characters not allowed in sympy variables, the measurement keys in a sympy condition string must be wrapped in curly braces to denote them. For example, to create an expression that checks if measurement A was greater than measurement B, the proper syntax is `cirq.parse_sympy_condition('{A} > {B}')`. This PR does not yet handle qudits completely, as multi-qubit measurements are interpreted as base-2 when converting to integer. A subsequent PR (https://github.com/daxfohl/Cirq/compare/sympy3...daxfohl:qudits?expand=1) will allow this functionality. --- cirq-core/cirq/__init__.py | 3 + cirq-core/cirq/_compat.py | 1 + cirq-core/cirq/json_resolver_cache.py | 2 + .../ops/classically_controlled_operation.py | 95 +++++----- .../classically_controlled_operation_test.py | 158 ++++++++++++++++- cirq-core/cirq/ops/raw_types.py | 18 +- .../ClassicallyControlledOperation.json | 18 +- .../ClassicallyControlledOperation.repr | 2 +- .../json_test_data/KeyCondition.json | 8 + .../json_test_data/KeyCondition.repr | 1 + .../json_test_data/SympyCondition.json | 17 ++ .../json_test_data/SympyCondition.repr | 1 + .../protocols/measurement_key_protocol.py | 3 + cirq-core/cirq/value/__init__.py | 6 + cirq-core/cirq/value/condition.py | 166 ++++++++++++++++++ cirq-core/cirq/value/condition_test.py | 105 +++++++++++ 16 files changed, 545 insertions(+), 59 deletions(-) create mode 100644 cirq-core/cirq/protocols/json_test_data/KeyCondition.json create mode 100644 cirq-core/cirq/protocols/json_test_data/KeyCondition.repr create mode 100644 cirq-core/cirq/protocols/json_test_data/SympyCondition.json create mode 100644 cirq-core/cirq/protocols/json_test_data/SympyCondition.repr create mode 100644 cirq-core/cirq/value/condition.py create mode 100644 cirq-core/cirq/value/condition_test.py diff --git a/cirq-core/cirq/__init__.py b/cirq-core/cirq/__init__.py index f45004b9347..ac57b1b07dd 100644 --- a/cirq-core/cirq/__init__.py +++ b/cirq-core/cirq/__init__.py @@ -483,15 +483,18 @@ canonicalize_half_turns, chosen_angle_to_canonical_half_turns, chosen_angle_to_half_turns, + Condition, Duration, DURATION_LIKE, GenericMetaImplementAnyOneOf, + KeyCondition, LinearDict, MEASUREMENT_KEY_SEPARATOR, MeasurementKey, PeriodicValue, RANDOM_STATE_OR_SEED_LIKE, state_vector_to_probabilities, + SympyCondition, Timestamp, TParamKey, TParamVal, diff --git a/cirq-core/cirq/_compat.py b/cirq-core/cirq/_compat.py index 6dc089e0f6b..facfdd8e45a 100644 --- a/cirq-core/cirq/_compat.py +++ b/cirq-core/cirq/_compat.py @@ -27,6 +27,7 @@ import numpy as np import pandas as pd import sympy +import sympy.printing.repr def proper_repr(value: Any) -> str: diff --git a/cirq-core/cirq/json_resolver_cache.py b/cirq-core/cirq/json_resolver_cache.py index 8213f007868..6a017320dad 100644 --- a/cirq-core/cirq/json_resolver_cache.py +++ b/cirq-core/cirq/json_resolver_cache.py @@ -94,6 +94,7 @@ def _parallel_gate_op(gate, qubits): 'ISwapPowGate': cirq.ISwapPowGate, 'IdentityGate': cirq.IdentityGate, 'InitObsSetting': cirq.work.InitObsSetting, + 'KeyCondition': cirq.KeyCondition, 'KrausChannel': cirq.KrausChannel, 'LinearDict': cirq.LinearDict, 'LineQubit': cirq.LineQubit, @@ -150,6 +151,7 @@ def _parallel_gate_op(gate, qubits): 'StatePreparationChannel': cirq.StatePreparationChannel, 'SwapPowGate': cirq.SwapPowGate, 'SymmetricalQidPair': cirq.SymmetricalQidPair, + 'SympyCondition': cirq.SympyCondition, 'TaggedOperation': cirq.TaggedOperation, 'TiltedSquareLattice': cirq.TiltedSquareLattice, 'TrialResult': cirq.Result, # keep support for Cirq < 0.11. diff --git a/cirq-core/cirq/ops/classically_controlled_operation.py b/cirq-core/cirq/ops/classically_controlled_operation.py index fe3093a3d4e..74a4c3dbb53 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation.py +++ b/cirq-core/cirq/ops/classically_controlled_operation.py @@ -16,6 +16,7 @@ Any, Dict, FrozenSet, + List, Optional, Sequence, TYPE_CHECKING, @@ -23,6 +24,8 @@ Union, ) +import sympy + from cirq import protocols, value from cirq.ops import raw_types @@ -46,7 +49,7 @@ class ClassicallyControlledOperation(raw_types.Operation): def __init__( self, sub_operation: 'cirq.Operation', - conditions: Sequence[Union[str, 'cirq.MeasurementKey']], + conditions: Sequence[Union[str, 'cirq.MeasurementKey', 'cirq.Condition', sympy.Basic]], ): """Initializes a `ClassicallyControlledOperation`. @@ -68,13 +71,26 @@ def __init__( raise ValueError( f'Cannot conditionally run operations with measurements: {sub_operation}' ) - keys = tuple(value.MeasurementKey(c) if isinstance(c, str) else c for c in conditions) + conditions = tuple(conditions) if isinstance(sub_operation, ClassicallyControlledOperation): - keys += sub_operation._control_keys + conditions += sub_operation._conditions sub_operation = sub_operation._sub_operation - self._control_keys: Tuple['cirq.MeasurementKey', ...] = keys + conds: List['cirq.Condition'] = [] + for c in conditions: + if isinstance(c, str): + c = value.MeasurementKey.parse_serialized(c) + if isinstance(c, value.MeasurementKey): + c = value.KeyCondition(c) + if isinstance(c, sympy.Basic): + c = value.SympyCondition(c) + conds.append(c) + self._conditions: Tuple['cirq.Condition', ...] = tuple(conds) self._sub_operation: 'cirq.Operation' = sub_operation + @property + def classical_controls(self) -> FrozenSet['cirq.Condition']: + return frozenset(self._conditions).union(self._sub_operation.classical_controls) + def without_classical_controls(self) -> 'cirq.Operation': return self._sub_operation.without_classical_controls() @@ -84,7 +100,7 @@ def qubits(self): def with_qubits(self, *new_qubits): return self._sub_operation.with_qubits(*new_qubits).with_classical_controls( - *self._control_keys + *self._conditions ) def _decompose_(self): @@ -92,19 +108,19 @@ def _decompose_(self): if result is NotImplemented: return NotImplemented - return [ClassicallyControlledOperation(op, self._control_keys) for op in result] + return [ClassicallyControlledOperation(op, self._conditions) for op in result] def _value_equality_values_(self): - return (frozenset(self._control_keys), self._sub_operation) + return (frozenset(self._conditions), self._sub_operation) def __str__(self) -> str: - keys = ', '.join(map(str, self._control_keys)) + keys = ', '.join(map(str, self._conditions)) return f'{self._sub_operation}.with_classical_controls({keys})' def __repr__(self): return ( f'cirq.ClassicallyControlledOperation(' - f'{self._sub_operation!r}, {list(self._control_keys)!r})' + f'{self._sub_operation!r}, {list(self._conditions)!r})' ) def _is_parameterized_(self) -> bool: @@ -117,7 +133,7 @@ def _resolve_parameters_( self, resolver: 'cirq.ParamResolver', recursive: bool ) -> 'ClassicallyControlledOperation': new_sub_op = protocols.resolve_parameters(self._sub_operation, resolver, recursive) - return new_sub_op.with_classical_controls(*self._control_keys) + return new_sub_op.with_classical_controls(*self._conditions) def _circuit_diagram_info_( self, args: 'cirq.CircuitDiagramInfoArgs' @@ -133,12 +149,20 @@ def _circuit_diagram_info_( if sub_info is None: return NotImplemented # coverage: ignore - wire_symbols = sub_info.wire_symbols + ('^',) * len(self._control_keys) + control_count = len({k for c in self._conditions for k in c.keys}) + wire_symbols = sub_info.wire_symbols + ('^',) * control_count + if any(not isinstance(c, value.KeyCondition) for c in self._conditions): + wire_symbols = ( + wire_symbols[0] + + '(conditions=[' + + ', '.join(str(c) for c in self._conditions) + + '])', + ) + wire_symbols[1:] exponent_qubit_index = None if sub_info.exponent_qubit_index is not None: - exponent_qubit_index = sub_info.exponent_qubit_index + len(self._control_keys) + exponent_qubit_index = sub_info.exponent_qubit_index + control_count elif sub_info.exponent is not None: - exponent_qubit_index = len(self._control_keys) + exponent_qubit_index = control_count return protocols.CircuitDiagramInfo( wire_symbols=wire_symbols, exponent=sub_info.exponent, @@ -148,58 +172,45 @@ def _circuit_diagram_info_( def _json_dict_(self) -> Dict[str, Any]: return { 'cirq_type': self.__class__.__name__, - 'conditions': self._control_keys, + 'conditions': self._conditions, 'sub_operation': self._sub_operation, } def _act_on_(self, args: 'cirq.ActOnArgs') -> bool: - def not_zero(measurement): - return any(i != 0 for i in measurement) - - measurements = [ - args.log_of_measurement_results.get(str(key), str(key)) for key in self._control_keys - ] - missing = [m for m in measurements if isinstance(m, str)] - if missing: - raise ValueError(f'Measurement keys {missing} missing when performing {self}') - if all(not_zero(measurement) for measurement in measurements): + if all(c.resolve(args.log_of_measurement_results) for c in self._conditions): protocols.act_on(self._sub_operation, args) return True def _with_measurement_key_mapping_( self, key_map: Dict[str, str] ) -> 'ClassicallyControlledOperation': + conditions = [protocols.with_measurement_key_mapping(c, key_map) for c in self._conditions] sub_operation = protocols.with_measurement_key_mapping(self._sub_operation, key_map) sub_operation = self._sub_operation if sub_operation is NotImplemented else sub_operation - return sub_operation.with_classical_controls( - *[protocols.with_measurement_key_mapping(k, key_map) for k in self._control_keys] - ) + return sub_operation.with_classical_controls(*conditions) - def _with_key_path_prefix_(self, path: Tuple[str, ...]) -> 'ClassicallyControlledOperation': - keys = [protocols.with_key_path_prefix(k, path) for k in self._control_keys] - return self._sub_operation.with_classical_controls(*keys) + def _with_key_path_prefix_(self, prefix: Tuple[str, ...]) -> 'ClassicallyControlledOperation': + conditions = [protocols.with_key_path_prefix(c, prefix) for c in self._conditions] + sub_operation = protocols.with_key_path_prefix(self._sub_operation, prefix) + sub_operation = self._sub_operation if sub_operation is NotImplemented else sub_operation + return sub_operation.with_classical_controls(*conditions) def _with_rescoped_keys_( self, path: Tuple[str, ...], bindable_keys: FrozenSet['cirq.MeasurementKey'], ) -> 'ClassicallyControlledOperation': - def map_key(key: 'cirq.MeasurementKey') -> 'cirq.MeasurementKey': - for i in range(len(path) + 1): - back_path = path[: len(path) - i] - new_key = key.with_key_path_prefix(*back_path) - if new_key in bindable_keys: - return new_key - return key - + conds = [protocols.with_rescoped_keys(c, path, bindable_keys) for c in self._conditions] sub_operation = protocols.with_rescoped_keys(self._sub_operation, path, bindable_keys) - return sub_operation.with_classical_controls(*[map_key(k) for k in self._control_keys]) + return sub_operation.with_classical_controls(*conds) def _control_keys_(self) -> FrozenSet['cirq.MeasurementKey']: - return frozenset(self._control_keys).union(protocols.control_keys(self._sub_operation)) + local_keys: FrozenSet['cirq.MeasurementKey'] = frozenset( + k for condition in self._conditions for k in condition.keys + ) + return local_keys.union(protocols.control_keys(self._sub_operation)) def _qasm_(self, args: 'cirq.QasmArgs') -> Optional[str]: args.validate_version('2.0') - keys = [f'm_{key}!=0' for key in self._control_keys] - all_keys = " && ".join(keys) + all_keys = " && ".join(c.qasm for c in self._conditions) return args.format('if ({0}) {1}', all_keys, protocols.qasm(self._sub_operation, args=args)) diff --git a/cirq-core/cirq/ops/classically_controlled_operation_test.py b/cirq-core/cirq/ops/classically_controlled_operation_test.py index a9896dbed0d..1baf7612f37 100644 --- a/cirq-core/cirq/ops/classically_controlled_operation_test.py +++ b/cirq-core/cirq/ops/classically_controlled_operation_test.py @@ -11,9 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import re import pytest import sympy +from sympy.parsing import sympy_parser import cirq @@ -331,10 +331,17 @@ def test_key_set_in_subcircuit_outer_scope(): assert result.measurements['b'] == 1 +def test_condition_types(): + q0 = cirq.LineQubit(0) + sympy_cond = sympy_parser.parse_expr('a >= 2') + op = cirq.X(q0).with_classical_controls(cirq.MeasurementKey('a'), 'b', 'a > b', sympy_cond) + assert set(map(str, op.classical_controls)) == {'a', 'b', 'a > b', 'a >= 2'} + + def test_condition_flattening(): q0 = cirq.LineQubit(0) op = cirq.X(q0).with_classical_controls('a').with_classical_controls('b') - assert set(map(str, op._control_keys)) == {'a', 'b'} + assert set(map(str, op.classical_controls)) == {'a', 'b'} assert isinstance(op._sub_operation, cirq.GateOperation) @@ -342,6 +349,7 @@ def test_condition_stacking(): q0 = cirq.LineQubit(0) op = cirq.X(q0).with_classical_controls('a').with_tags('t').with_classical_controls('b') assert set(map(str, cirq.control_keys(op))) == {'a', 'b'} + assert set(map(str, op.classical_controls)) == {'a', 'b'} assert not op.tags @@ -356,6 +364,7 @@ def test_condition_removal(): ) op = op.without_classical_controls() assert not cirq.control_keys(op) + assert not op.classical_controls assert set(map(str, op.tags)) == {'t1'} @@ -604,7 +613,7 @@ def test_repr(): op = cirq.X(q0).with_classical_controls('a') assert repr(op) == ( "cirq.ClassicallyControlledOperation(" - "cirq.X(cirq.LineQubit(0)), [cirq.MeasurementKey(name='a')]" + "cirq.X(cirq.LineQubit(0)), [cirq.KeyCondition(cirq.MeasurementKey(name='a'))]" ")" ) @@ -619,10 +628,7 @@ def test_unmeasured_condition(): q0 = cirq.LineQubit(0) bad_circuit = cirq.Circuit(cirq.X(q0).with_classical_controls('a')) with pytest.raises( - ValueError, - match=re.escape( - "Measurement keys ['a'] missing when performing X(0).with_classical_controls(a)" - ), + ValueError, match='Measurement key a missing when testing classical control' ): _ = cirq.Simulator().simulate(bad_circuit) @@ -669,3 +675,141 @@ def test_layered_circuit_operations_with_controls_in_between(): """, use_unicode_characters=True, ) + + +def test_sympy(): + q0, q1, q2, q3, q_result = cirq.LineQubit.range(5) + for i in range(4): + for j in range(4): + # Put first two qubits into a state representing bitstring(i), next two qubits into a + # state representing bitstring(j) and measure those into m_i and m_j respectively. Then + # add a conditional X(q_result) based on m_i > m_j and measure that. + bitstring_i = cirq.big_endian_int_to_bits(i, bit_count=2) + bitstring_j = cirq.big_endian_int_to_bits(j, bit_count=2) + circuit = cirq.Circuit( + cirq.X(q0) ** bitstring_i[0], + cirq.X(q1) ** bitstring_i[1], + cirq.X(q2) ** bitstring_j[0], + cirq.X(q3) ** bitstring_j[1], + cirq.measure(q0, q1, key='m_i'), + cirq.measure(q2, q3, key='m_j'), + cirq.X(q_result).with_classical_controls(sympy_parser.parse_expr('m_j > m_i')), + cirq.measure(q_result, key='m_result'), + ) + + # m_result should now be set iff j > i. + result = cirq.Simulator().run(circuit) + assert result.measurements['m_result'][0][0] == (j > i) + + +def test_sympy_path_prefix(): + q = cirq.LineQubit(0) + op = cirq.X(q).with_classical_controls(sympy.Symbol('b')) + prefixed = cirq.with_key_path_prefix(op, ('0',)) + assert cirq.control_keys(prefixed) == {'0:b'} + + +def test_sympy_scope(): + q = cirq.LineQubit(0) + a, b, c, d = sympy.symbols('a b c d') + inner = cirq.Circuit( + cirq.measure(q, key='a'), + cirq.X(q).with_classical_controls(a & b).with_classical_controls(c | d), + ) + middle = cirq.Circuit( + cirq.measure(q, key='b'), + cirq.measure(q, key=cirq.MeasurementKey('c', ('0',))), + cirq.CircuitOperation(inner.freeze(), repetitions=2), + ) + outer_subcircuit = cirq.CircuitOperation(middle.freeze(), repetitions=2) + circuit = outer_subcircuit.mapped_circuit(deep=True) + internal_controls = [str(k) for op in circuit.all_operations() for k in cirq.control_keys(op)] + assert set(internal_controls) == {'0:0:a', '0:1:a', '1:0:a', '1:1:a', '0:b', '1:b', 'c', 'd'} + assert cirq.control_keys(outer_subcircuit) == {'c', 'd'} + assert cirq.control_keys(circuit) == {'c', 'd'} + assert circuit == cirq.Circuit(cirq.decompose(outer_subcircuit)) + cirq.testing.assert_has_diagram( + cirq.Circuit(outer_subcircuit), + """ + [ [ 0: ───M───X(conditions=[c | d, a & b])─── ] ] + [ [ ║ ║ ] ] + [ [ a: ═══@═══^══════════════════════════════ ] ] + [ [ ║ ] ] + [ 0: ───M───M('0:c')───[ b: ═══════^══════════════════════════════ ]──────────── ] + [ ║ [ ║ ] ] + [ ║ [ c: ═══════^══════════════════════════════ ] ] +0: ───[ ║ [ ║ ] ]──────────── + [ ║ [ d: ═══════^══════════════════════════════ ](loops=2) ] + [ ║ ║ ] + [ b: ═══@══════════════╬════════════════════════════════════════════════════════ ] + [ ║ ] + [ c: ══════════════════╬════════════════════════════════════════════════════════ ] + [ ║ ] + [ d: ══════════════════╩════════════════════════════════════════════════════════ ](loops=2) + ║ +c: ═══╬═════════════════════════════════════════════════════════════════════════════════════════════ + ║ +d: ═══╩═════════════════════════════════════════════════════════════════════════════════════════════ +""", + use_unicode_characters=True, + ) + + # pylint: disable=line-too-long + cirq.testing.assert_has_diagram( + circuit, + """ +0: ───────M───M('0:0:c')───M───X(conditions=[c | d, 0:0:a & 0:b])───M───X(conditions=[c | d, 0:1:a & 0:b])───M───M('1:0:c')───M───X(conditions=[c | d, 1:0:a & 1:b])───M───X(conditions=[c | d, 1:1:a & 1:b])─── + ║ ║ ║ ║ ║ ║ ║ ║ ║ ║ +0:0:a: ═══╬════════════════@═══^════════════════════════════════════╬═══╬════════════════════════════════════╬════════════════╬═══╬════════════════════════════════════╬═══╬════════════════════════════════════ + ║ ║ ║ ║ ║ ║ ║ ║ ║ +0:1:a: ═══╬════════════════════╬════════════════════════════════════@═══^════════════════════════════════════╬════════════════╬═══╬════════════════════════════════════╬═══╬════════════════════════════════════ + ║ ║ ║ ║ ║ ║ ║ ║ +0:b: ═════@════════════════════^════════════════════════════════════════^════════════════════════════════════╬════════════════╬═══╬════════════════════════════════════╬═══╬════════════════════════════════════ + ║ ║ ║ ║ ║ ║ ║ +1:0:a: ════════════════════════╬════════════════════════════════════════╬════════════════════════════════════╬════════════════@═══^════════════════════════════════════╬═══╬════════════════════════════════════ + ║ ║ ║ ║ ║ ║ +1:1:a: ════════════════════════╬════════════════════════════════════════╬════════════════════════════════════╬════════════════════╬════════════════════════════════════@═══^════════════════════════════════════ + ║ ║ ║ ║ ║ +1:b: ══════════════════════════╬════════════════════════════════════════╬════════════════════════════════════@════════════════════^════════════════════════════════════════^════════════════════════════════════ + ║ ║ ║ ║ +c: ════════════════════════════^════════════════════════════════════════^═════════════════════════════════════════════════════════^════════════════════════════════════════^════════════════════════════════════ + ║ ║ ║ ║ +d: ════════════════════════════^════════════════════════════════════════^═════════════════════════════════════════════════════════^════════════════════════════════════════^════════════════════════════════════ +""", + use_unicode_characters=True, + ) + # pylint: enable=line-too-long + + +def test_sympy_scope_simulation(): + q0, q1, q2, q3, q_ignored, q_result = cirq.LineQubit.range(6) + condition = sympy_parser.parse_expr('a & b | c & d') + # We set up condition (a & b | c & d) plus an ignored measurement key, and run through the + # combinations of possible values of those (by doing X(q_i)**bits[i] on each), then verify + # that the final measurement into m_result is True iff that condition was met. + for i in range(32): + bits = cirq.big_endian_int_to_bits(i, bit_count=5) + inner = cirq.Circuit( + cirq.X(q0) ** bits[0], + cirq.measure(q0, key='a'), + cirq.X(q_result).with_classical_controls(condition), + cirq.measure(q_result, key='m_result'), + ) + middle = cirq.Circuit( + cirq.X(q1) ** bits[1], + cirq.measure(q1, key='b'), + cirq.X(q_ignored) ** bits[4], + cirq.measure(q_ignored, key=cirq.MeasurementKey('c', ('0',))), + cirq.CircuitOperation(inner.freeze(), repetition_ids=['0']), + ) + circuit = cirq.Circuit( + cirq.X(q2) ** bits[2], + cirq.measure(q2, key='c'), + cirq.X(q3) ** bits[3], + cirq.measure(q3, key='d'), + cirq.CircuitOperation(middle.freeze(), repetition_ids=['0']), + ) + result = cirq.CliffordSimulator().run(circuit) + assert result.measurements['0:0:m_result'][0][0] == ( + bits[0] and bits[1] or bits[2] and bits[3] # bits[4] irrelevant + ) diff --git a/cirq-core/cirq/ops/raw_types.py b/cirq-core/cirq/ops/raw_types.py index c88ba65eeb8..ca665008655 100644 --- a/cirq-core/cirq/ops/raw_types.py +++ b/cirq-core/cirq/ops/raw_types.py @@ -22,6 +22,7 @@ Callable, Collection, Dict, + FrozenSet, Hashable, Iterable, List, @@ -34,6 +35,7 @@ ) import numpy as np +import sympy from cirq import protocols, value from cirq._import import LazyLoader @@ -590,8 +592,13 @@ def _commutes_( return np.allclose(m12, m21, atol=atol) + @property + def classical_controls(self) -> FrozenSet['cirq.Condition']: + """The classical controls gating this operation.""" + return frozenset() + def with_classical_controls( - self, *conditions: Union[str, 'cirq.MeasurementKey'] + self, *conditions: Union[str, 'cirq.MeasurementKey', 'cirq.Condition', sympy.Expr] ) -> 'cirq.ClassicallyControlledOperation': """Returns a classically controlled version of this operation. @@ -604,8 +611,9 @@ def with_classical_controls( since tags are considered a local attribute. Args: - conditions: A list of measurement keys, or strings that can be - parsed into measurement keys. + conditions: A list of measurement keys, strings that can be parsed + into measurement keys, or sympy expressions where the free + symbols are measurement key strings. Returns: A `ClassicallyControlledOperation` wrapping the operation. @@ -821,6 +829,10 @@ def _equal_up_to_global_phase_( ) -> Union[NotImplementedType, bool]: return protocols.equal_up_to_global_phase(self.sub_operation, other, atol=atol) + @property + def classical_controls(self) -> FrozenSet['cirq.Condition']: + return self.sub_operation.classical_controls + def without_classical_controls(self) -> 'cirq.Operation': new_sub_operation = self.sub_operation.without_classical_controls() return self if new_sub_operation is self.sub_operation else new_sub_operation diff --git a/cirq-core/cirq/protocols/json_test_data/ClassicallyControlledOperation.json b/cirq-core/cirq/protocols/json_test_data/ClassicallyControlledOperation.json index a22c2720095..8fbae9b27c7 100644 --- a/cirq-core/cirq/protocols/json_test_data/ClassicallyControlledOperation.json +++ b/cirq-core/cirq/protocols/json_test_data/ClassicallyControlledOperation.json @@ -2,14 +2,20 @@ "cirq_type": "ClassicallyControlledOperation", "conditions": [ { - "cirq_type": "MeasurementKey", - "name": "a", - "path": [] + "cirq_type": "KeyCondition", + "key": { + "cirq_type": "MeasurementKey", + "name": "a", + "path": [] + } }, { - "cirq_type": "MeasurementKey", - "name": "b", - "path": [] + "cirq_type": "KeyCondition", + "key": { + "cirq_type": "MeasurementKey", + "name": "b", + "path": [] + } } ], "sub_operation": { diff --git a/cirq-core/cirq/protocols/json_test_data/ClassicallyControlledOperation.repr b/cirq-core/cirq/protocols/json_test_data/ClassicallyControlledOperation.repr index bbc3a1dc22b..423551a8501 100644 --- a/cirq-core/cirq/protocols/json_test_data/ClassicallyControlledOperation.repr +++ b/cirq-core/cirq/protocols/json_test_data/ClassicallyControlledOperation.repr @@ -1 +1 @@ -cirq.ClassicallyControlledOperation(cirq.Y.on(cirq.NamedQubit('target')), [cirq.MeasurementKey('a'), cirq.MeasurementKey('b')]) +cirq.ClassicallyControlledOperation(cirq.Y.on(cirq.NamedQubit('target')), [cirq.KeyCondition(key=cirq.MeasurementKey('a')), cirq.KeyCondition(key=cirq.MeasurementKey('b'))]) \ No newline at end of file diff --git a/cirq-core/cirq/protocols/json_test_data/KeyCondition.json b/cirq-core/cirq/protocols/json_test_data/KeyCondition.json new file mode 100644 index 00000000000..f5b81ba63dc --- /dev/null +++ b/cirq-core/cirq/protocols/json_test_data/KeyCondition.json @@ -0,0 +1,8 @@ +{ + "cirq_type": "KeyCondition", + "key": { + "cirq_type": "MeasurementKey", + "name": "a", + "path": [] + } +} \ No newline at end of file diff --git a/cirq-core/cirq/protocols/json_test_data/KeyCondition.repr b/cirq-core/cirq/protocols/json_test_data/KeyCondition.repr new file mode 100644 index 00000000000..fb9fa3232ec --- /dev/null +++ b/cirq-core/cirq/protocols/json_test_data/KeyCondition.repr @@ -0,0 +1 @@ +cirq.KeyCondition(key=cirq.MeasurementKey('a')) \ No newline at end of file diff --git a/cirq-core/cirq/protocols/json_test_data/SympyCondition.json b/cirq-core/cirq/protocols/json_test_data/SympyCondition.json new file mode 100644 index 00000000000..1dc17ec7710 --- /dev/null +++ b/cirq-core/cirq/protocols/json_test_data/SympyCondition.json @@ -0,0 +1,17 @@ +{ + "cirq_type": "SympyCondition", + "expr": + { + "cirq_type": "sympy.GreaterThan", + "args": [ + { + "cirq_type": "sympy.Symbol", + "name": "a" + }, + { + "cirq_type": "sympy.Symbol", + "name": "b" + } + ] + } +} \ No newline at end of file diff --git a/cirq-core/cirq/protocols/json_test_data/SympyCondition.repr b/cirq-core/cirq/protocols/json_test_data/SympyCondition.repr new file mode 100644 index 00000000000..6c961a2a1f6 --- /dev/null +++ b/cirq-core/cirq/protocols/json_test_data/SympyCondition.repr @@ -0,0 +1 @@ +cirq.SympyCondition(sympy.GreaterThan(sympy.Symbol('a'), sympy.Symbol('b'))) \ No newline at end of file diff --git a/cirq-core/cirq/protocols/measurement_key_protocol.py b/cirq-core/cirq/protocols/measurement_key_protocol.py index 639df1aa180..f94ca105148 100644 --- a/cirq-core/cirq/protocols/measurement_key_protocol.py +++ b/cirq-core/cirq/protocols/measurement_key_protocol.py @@ -20,6 +20,9 @@ from cirq import value from cirq._doc import doc_private +if TYPE_CHECKING: + import cirq + if TYPE_CHECKING: import cirq diff --git a/cirq-core/cirq/value/__init__.py b/cirq-core/cirq/value/__init__.py index 390db1e4a11..da6cfc2b058 100644 --- a/cirq-core/cirq/value/__init__.py +++ b/cirq-core/cirq/value/__init__.py @@ -25,6 +25,12 @@ chosen_angle_to_half_turns, ) +from cirq.value.condition import ( + Condition, + KeyCondition, + SympyCondition, +) + from cirq.value.digits import ( big_endian_bits_to_int, big_endian_digits_to_int, diff --git a/cirq-core/cirq/value/condition.py b/cirq-core/cirq/value/condition.py new file mode 100644 index 00000000000..ef432b7506f --- /dev/null +++ b/cirq-core/cirq/value/condition.py @@ -0,0 +1,166 @@ +# Copyright 2021 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +import dataclasses +from typing import Dict, Mapping, Sequence, Tuple, TYPE_CHECKING, FrozenSet + +import sympy + +from cirq._compat import proper_repr +from cirq.protocols import json_serialization, measurement_key_protocol as mkp +from cirq.value import digits, measurement_key + +if TYPE_CHECKING: + import cirq + + +class Condition(abc.ABC): + """A classical control condition that can gate an operation.""" + + @property + @abc.abstractmethod + def keys(self) -> Tuple['cirq.MeasurementKey', ...]: + """Gets the control keys.""" + + @abc.abstractmethod + def replace_key(self, current: 'cirq.MeasurementKey', replacement: 'cirq.MeasurementKey'): + """Replaces the control keys.""" + + @abc.abstractmethod + def resolve(self, measurements: Mapping[str, Sequence[int]]) -> bool: + """Resolves the condition based on the measurements.""" + + @property + @abc.abstractmethod + def qasm(self): + """Returns the qasm of this condition.""" + + def _with_measurement_key_mapping_(self, key_map: Dict[str, str]) -> 'cirq.Condition': + condition = self + for k in self.keys: + condition = condition.replace_key(k, mkp.with_measurement_key_mapping(k, key_map)) + return condition + + def _with_key_path_prefix_(self, path: Tuple[str, ...]) -> 'cirq.Condition': + condition = self + for k in self.keys: + condition = condition.replace_key(k, mkp.with_key_path_prefix(k, path)) + return condition + + def _with_rescoped_keys_( + self, + path: Tuple[str, ...], + bindable_keys: FrozenSet['cirq.MeasurementKey'], + ) -> 'cirq.Condition': + condition = self + for key in self.keys: + for i in range(len(path) + 1): + back_path = path[: len(path) - i] + new_key = key.with_key_path_prefix(*back_path) + if new_key in bindable_keys: + condition = condition.replace_key(key, new_key) + break + return condition + + +@dataclasses.dataclass(frozen=True) +class KeyCondition(Condition): + """A classical control condition based on a single measurement key. + + This condition resolves to True iff the measurement key is non-zero at the + time of resolution. + """ + + key: 'cirq.MeasurementKey' + + @property + def keys(self): + return (self.key,) + + def replace_key(self, current: 'cirq.MeasurementKey', replacement: 'cirq.MeasurementKey'): + return KeyCondition(replacement) if self.key == current else self + + def __str__(self): + return str(self.key) + + def __repr__(self): + return f'cirq.KeyCondition({self.key!r})' + + def resolve(self, measurements: Mapping[str, Sequence[int]]) -> bool: + key = str(self.key) + if key not in measurements: + raise ValueError(f'Measurement key {key} missing when testing classical control') + return any(measurements[key]) + + def _json_dict_(self): + return json_serialization.dataclass_json_dict(self) + + @classmethod + def _from_json_dict_(cls, key, **kwargs): + return cls(key=key) + + @property + def qasm(self): + return f'm_{self.key}!=0' + + +@dataclasses.dataclass(frozen=True) +class SympyCondition(Condition): + """A classical control condition based on a sympy expression. + + This condition resolves to True iff the sympy expression resolves to a + truthy value (i.e. `bool(x) == True`) when the measurement keys are + substituted in as the free variables. + """ + + expr: sympy.Basic + + @property + def keys(self): + return tuple( + measurement_key.MeasurementKey.parse_serialized(symbol.name) + for symbol in self.expr.free_symbols + ) + + def replace_key(self, current: 'cirq.MeasurementKey', replacement: 'cirq.MeasurementKey'): + return SympyCondition(self.expr.subs({str(current): sympy.Symbol(str(replacement))})) + + def __str__(self): + return str(self.expr) + + def __repr__(self): + return f'cirq.SympyCondition({proper_repr(self.expr)})' + + def resolve(self, measurements: Mapping[str, Sequence[int]]) -> bool: + missing = [str(k) for k in self.keys if str(k) not in measurements] + if missing: + raise ValueError(f'Measurement keys {missing} missing when testing classical control') + + def value(k): + return digits.big_endian_bits_to_int(measurements[str(k)]) + + replacements = {str(k): value(k) for k in self.keys} + return bool(self.expr.subs(replacements)) + + def _json_dict_(self): + return json_serialization.dataclass_json_dict(self) + + @classmethod + def _from_json_dict_(cls, expr, **kwargs): + return cls(expr=expr) + + @property + def qasm(self): + raise NotImplementedError() diff --git a/cirq-core/cirq/value/condition_test.py b/cirq-core/cirq/value/condition_test.py new file mode 100644 index 00000000000..fd80033a29a --- /dev/null +++ b/cirq-core/cirq/value/condition_test.py @@ -0,0 +1,105 @@ +# Copyright 2021 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re + +import pytest +import sympy + +import cirq + +key_a = cirq.MeasurementKey.parse_serialized('0:a') +key_b = cirq.MeasurementKey.parse_serialized('0:b') +key_c = cirq.MeasurementKey.parse_serialized('0:c') +init_key_condition = cirq.KeyCondition(key_a) +init_sympy_condition = cirq.SympyCondition(sympy.Symbol('0:a') >= 1) + + +def test_key_condition_with_keys(): + c = init_key_condition.replace_key(key_a, key_b) + assert c.key is key_b + c = init_key_condition.replace_key(key_b, key_c) + assert c.key is key_a + + +def test_key_condition_str(): + assert str(init_key_condition) == '0:a' + + +def test_key_condition_repr(): + cirq.testing.assert_equivalent_repr(init_key_condition) + + +def test_key_condition_resolve(): + assert init_key_condition.resolve({'0:a': [1]}) + assert init_key_condition.resolve({'0:a': [2]}) + assert init_key_condition.resolve({'0:a': [0, 1]}) + assert init_key_condition.resolve({'0:a': [1, 0]}) + assert not init_key_condition.resolve({'0:a': [0]}) + assert not init_key_condition.resolve({'0:a': [0, 0]}) + assert not init_key_condition.resolve({'0:a': []}) + assert not init_key_condition.resolve({'0:a': [0], 'b': [1]}) + with pytest.raises( + ValueError, match='Measurement key 0:a missing when testing classical control' + ): + _ = init_key_condition.resolve({}) + with pytest.raises( + ValueError, match='Measurement key 0:a missing when testing classical control' + ): + _ = init_key_condition.resolve({'0:b': [1]}) + + +def test_key_condition_qasm(): + assert cirq.KeyCondition(cirq.MeasurementKey('a')).qasm == 'm_a!=0' + + +def test_sympy_condition_with_keys(): + c = init_sympy_condition.replace_key(key_a, key_b) + assert c.keys == (key_b,) + c = init_sympy_condition.replace_key(key_b, key_c) + assert c.keys == (key_a,) + + +def test_sympy_condition_str(): + assert str(init_sympy_condition) == '0:a >= 1' + + +def test_sympy_condition_repr(): + cirq.testing.assert_equivalent_repr(init_sympy_condition) + + +def test_sympy_condition_resolve(): + assert init_sympy_condition.resolve({'0:a': [1]}) + assert init_sympy_condition.resolve({'0:a': [2]}) + assert init_sympy_condition.resolve({'0:a': [0, 1]}) + assert init_sympy_condition.resolve({'0:a': [1, 0]}) + assert not init_sympy_condition.resolve({'0:a': [0]}) + assert not init_sympy_condition.resolve({'0:a': [0, 0]}) + assert not init_sympy_condition.resolve({'0:a': []}) + assert not init_sympy_condition.resolve({'0:a': [0], 'b': [1]}) + with pytest.raises( + ValueError, + match=re.escape("Measurement keys ['0:a'] missing when testing classical control"), + ): + _ = init_sympy_condition.resolve({}) + with pytest.raises( + ValueError, + match=re.escape("Measurement keys ['0:a'] missing when testing classical control"), + ): + _ = init_sympy_condition.resolve({'0:b': [1]}) + + +def test_sympy_condition_qasm(): + with pytest.raises(NotImplementedError): + _ = init_sympy_condition.qasm