Skip to content

Commit

Permalink
Allow sympy expressions as classical controls (#4740)
Browse files Browse the repository at this point in the history
Part 14 of https://tinyurl.com/cirq-feedforward.

Adds the ability to create classical control conditions based on sympy expressions. 

To account for the fact that measurement key strings can contain characters not allowed in sympy variables, the measurement keys in a sympy condition string must be wrapped in curly braces to denote them. For example, to create an expression that checks if measurement A was greater than measurement B, the proper syntax is  `cirq.parse_sympy_condition('{A} > {B}')`.

This PR does not yet handle qudits completely, as multi-qubit measurements are interpreted as base-2 when converting to integer. A subsequent PR (https://github.com/daxfohl/Cirq/compare/sympy3...daxfohl:qudits?expand=1) will allow this functionality.
  • Loading branch information
daxfohl authored Dec 23, 2021
1 parent 65d783e commit ff671ae
Show file tree
Hide file tree
Showing 16 changed files with 545 additions and 59 deletions.
3 changes: 3 additions & 0 deletions cirq-core/cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,15 +483,18 @@
canonicalize_half_turns,
chosen_angle_to_canonical_half_turns,
chosen_angle_to_half_turns,
Condition,
Duration,
DURATION_LIKE,
GenericMetaImplementAnyOneOf,
KeyCondition,
LinearDict,
MEASUREMENT_KEY_SEPARATOR,
MeasurementKey,
PeriodicValue,
RANDOM_STATE_OR_SEED_LIKE,
state_vector_to_probabilities,
SympyCondition,
Timestamp,
TParamKey,
TParamVal,
Expand Down
1 change: 1 addition & 0 deletions cirq-core/cirq/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import numpy as np
import pandas as pd
import sympy
import sympy.printing.repr


def proper_repr(value: Any) -> str:
Expand Down
2 changes: 2 additions & 0 deletions cirq-core/cirq/json_resolver_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def _parallel_gate_op(gate, qubits):
'ISwapPowGate': cirq.ISwapPowGate,
'IdentityGate': cirq.IdentityGate,
'InitObsSetting': cirq.work.InitObsSetting,
'KeyCondition': cirq.KeyCondition,
'KrausChannel': cirq.KrausChannel,
'LinearDict': cirq.LinearDict,
'LineQubit': cirq.LineQubit,
Expand Down Expand Up @@ -150,6 +151,7 @@ def _parallel_gate_op(gate, qubits):
'StatePreparationChannel': cirq.StatePreparationChannel,
'SwapPowGate': cirq.SwapPowGate,
'SymmetricalQidPair': cirq.SymmetricalQidPair,
'SympyCondition': cirq.SympyCondition,
'TaggedOperation': cirq.TaggedOperation,
'TiltedSquareLattice': cirq.TiltedSquareLattice,
'TrialResult': cirq.Result, # keep support for Cirq < 0.11.
Expand Down
95 changes: 53 additions & 42 deletions cirq-core/cirq/ops/classically_controlled_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@
Any,
Dict,
FrozenSet,
List,
Optional,
Sequence,
TYPE_CHECKING,
Tuple,
Union,
)

import sympy

from cirq import protocols, value
from cirq.ops import raw_types

Expand All @@ -46,7 +49,7 @@ class ClassicallyControlledOperation(raw_types.Operation):
def __init__(
self,
sub_operation: 'cirq.Operation',
conditions: Sequence[Union[str, 'cirq.MeasurementKey']],
conditions: Sequence[Union[str, 'cirq.MeasurementKey', 'cirq.Condition', sympy.Basic]],
):
"""Initializes a `ClassicallyControlledOperation`.
Expand All @@ -68,13 +71,26 @@ def __init__(
raise ValueError(
f'Cannot conditionally run operations with measurements: {sub_operation}'
)
keys = tuple(value.MeasurementKey(c) if isinstance(c, str) else c for c in conditions)
conditions = tuple(conditions)
if isinstance(sub_operation, ClassicallyControlledOperation):
keys += sub_operation._control_keys
conditions += sub_operation._conditions
sub_operation = sub_operation._sub_operation
self._control_keys: Tuple['cirq.MeasurementKey', ...] = keys
conds: List['cirq.Condition'] = []
for c in conditions:
if isinstance(c, str):
c = value.MeasurementKey.parse_serialized(c)
if isinstance(c, value.MeasurementKey):
c = value.KeyCondition(c)
if isinstance(c, sympy.Basic):
c = value.SympyCondition(c)
conds.append(c)
self._conditions: Tuple['cirq.Condition', ...] = tuple(conds)
self._sub_operation: 'cirq.Operation' = sub_operation

@property
def classical_controls(self) -> FrozenSet['cirq.Condition']:
return frozenset(self._conditions).union(self._sub_operation.classical_controls)

def without_classical_controls(self) -> 'cirq.Operation':
return self._sub_operation.without_classical_controls()

Expand All @@ -84,27 +100,27 @@ def qubits(self):

def with_qubits(self, *new_qubits):
return self._sub_operation.with_qubits(*new_qubits).with_classical_controls(
*self._control_keys
*self._conditions
)

def _decompose_(self):
result = protocols.decompose_once(self._sub_operation, NotImplemented)
if result is NotImplemented:
return NotImplemented

return [ClassicallyControlledOperation(op, self._control_keys) for op in result]
return [ClassicallyControlledOperation(op, self._conditions) for op in result]

def _value_equality_values_(self):
return (frozenset(self._control_keys), self._sub_operation)
return (frozenset(self._conditions), self._sub_operation)

def __str__(self) -> str:
keys = ', '.join(map(str, self._control_keys))
keys = ', '.join(map(str, self._conditions))
return f'{self._sub_operation}.with_classical_controls({keys})'

def __repr__(self):
return (
f'cirq.ClassicallyControlledOperation('
f'{self._sub_operation!r}, {list(self._control_keys)!r})'
f'{self._sub_operation!r}, {list(self._conditions)!r})'
)

def _is_parameterized_(self) -> bool:
Expand All @@ -117,7 +133,7 @@ def _resolve_parameters_(
self, resolver: 'cirq.ParamResolver', recursive: bool
) -> 'ClassicallyControlledOperation':
new_sub_op = protocols.resolve_parameters(self._sub_operation, resolver, recursive)
return new_sub_op.with_classical_controls(*self._control_keys)
return new_sub_op.with_classical_controls(*self._conditions)

def _circuit_diagram_info_(
self, args: 'cirq.CircuitDiagramInfoArgs'
Expand All @@ -133,12 +149,20 @@ def _circuit_diagram_info_(
if sub_info is None:
return NotImplemented # coverage: ignore

wire_symbols = sub_info.wire_symbols + ('^',) * len(self._control_keys)
control_count = len({k for c in self._conditions for k in c.keys})
wire_symbols = sub_info.wire_symbols + ('^',) * control_count
if any(not isinstance(c, value.KeyCondition) for c in self._conditions):
wire_symbols = (
wire_symbols[0]
+ '(conditions=['
+ ', '.join(str(c) for c in self._conditions)
+ '])',
) + wire_symbols[1:]
exponent_qubit_index = None
if sub_info.exponent_qubit_index is not None:
exponent_qubit_index = sub_info.exponent_qubit_index + len(self._control_keys)
exponent_qubit_index = sub_info.exponent_qubit_index + control_count
elif sub_info.exponent is not None:
exponent_qubit_index = len(self._control_keys)
exponent_qubit_index = control_count
return protocols.CircuitDiagramInfo(
wire_symbols=wire_symbols,
exponent=sub_info.exponent,
Expand All @@ -148,58 +172,45 @@ def _circuit_diagram_info_(
def _json_dict_(self) -> Dict[str, Any]:
return {
'cirq_type': self.__class__.__name__,
'conditions': self._control_keys,
'conditions': self._conditions,
'sub_operation': self._sub_operation,
}

def _act_on_(self, args: 'cirq.ActOnArgs') -> bool:
def not_zero(measurement):
return any(i != 0 for i in measurement)

measurements = [
args.log_of_measurement_results.get(str(key), str(key)) for key in self._control_keys
]
missing = [m for m in measurements if isinstance(m, str)]
if missing:
raise ValueError(f'Measurement keys {missing} missing when performing {self}')
if all(not_zero(measurement) for measurement in measurements):
if all(c.resolve(args.log_of_measurement_results) for c in self._conditions):
protocols.act_on(self._sub_operation, args)
return True

def _with_measurement_key_mapping_(
self, key_map: Dict[str, str]
) -> 'ClassicallyControlledOperation':
conditions = [protocols.with_measurement_key_mapping(c, key_map) for c in self._conditions]
sub_operation = protocols.with_measurement_key_mapping(self._sub_operation, key_map)
sub_operation = self._sub_operation if sub_operation is NotImplemented else sub_operation
return sub_operation.with_classical_controls(
*[protocols.with_measurement_key_mapping(k, key_map) for k in self._control_keys]
)
return sub_operation.with_classical_controls(*conditions)

def _with_key_path_prefix_(self, path: Tuple[str, ...]) -> 'ClassicallyControlledOperation':
keys = [protocols.with_key_path_prefix(k, path) for k in self._control_keys]
return self._sub_operation.with_classical_controls(*keys)
def _with_key_path_prefix_(self, prefix: Tuple[str, ...]) -> 'ClassicallyControlledOperation':
conditions = [protocols.with_key_path_prefix(c, prefix) for c in self._conditions]
sub_operation = protocols.with_key_path_prefix(self._sub_operation, prefix)
sub_operation = self._sub_operation if sub_operation is NotImplemented else sub_operation
return sub_operation.with_classical_controls(*conditions)

def _with_rescoped_keys_(
self,
path: Tuple[str, ...],
bindable_keys: FrozenSet['cirq.MeasurementKey'],
) -> 'ClassicallyControlledOperation':
def map_key(key: 'cirq.MeasurementKey') -> 'cirq.MeasurementKey':
for i in range(len(path) + 1):
back_path = path[: len(path) - i]
new_key = key.with_key_path_prefix(*back_path)
if new_key in bindable_keys:
return new_key
return key

conds = [protocols.with_rescoped_keys(c, path, bindable_keys) for c in self._conditions]
sub_operation = protocols.with_rescoped_keys(self._sub_operation, path, bindable_keys)
return sub_operation.with_classical_controls(*[map_key(k) for k in self._control_keys])
return sub_operation.with_classical_controls(*conds)

def _control_keys_(self) -> FrozenSet['cirq.MeasurementKey']:
return frozenset(self._control_keys).union(protocols.control_keys(self._sub_operation))
local_keys: FrozenSet['cirq.MeasurementKey'] = frozenset(
k for condition in self._conditions for k in condition.keys
)
return local_keys.union(protocols.control_keys(self._sub_operation))

def _qasm_(self, args: 'cirq.QasmArgs') -> Optional[str]:
args.validate_version('2.0')
keys = [f'm_{key}!=0' for key in self._control_keys]
all_keys = " && ".join(keys)
all_keys = " && ".join(c.qasm for c in self._conditions)
return args.format('if ({0}) {1}', all_keys, protocols.qasm(self._sub_operation, args=args))
Loading

0 comments on commit ff671ae

Please sign in to comment.