diff --git a/cirq-core/cirq/protocols/decompose_protocol.py b/cirq-core/cirq/protocols/decompose_protocol.py index 870678cf943..6aed78b23f4 100644 --- a/cirq-core/cirq/protocols/decompose_protocol.py +++ b/cirq-core/cirq/protocols/decompose_protocol.py @@ -11,13 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import dataclasses from typing import ( TYPE_CHECKING, Any, Callable, Dict, Iterable, + Iterator, List, Optional, overload, @@ -128,6 +129,60 @@ def _decompose_(self, qubits: Tuple['cirq.Qid', ...]) -> DecomposeResult: pass +def _try_op_decomposer(val: Any, decomposer: Optional[OpDecomposer]) -> DecomposeResult: + if decomposer is None or not isinstance(val, ops.Operation): + return None + return decomposer(val) + + +@dataclasses.dataclass(frozen=True) +class _DecomposeArgs: + intercepting_decomposer: Optional[OpDecomposer] + fallback_decomposer: Optional[OpDecomposer] + keep: Optional[Callable[['cirq.Operation'], bool]] + on_stuck_raise: Union[None, Exception, Callable[['cirq.Operation'], Optional[Exception]]] + preserve_structure: bool + + +def _decompose_dfs(item: Any, args: _DecomposeArgs) -> Iterator['cirq.Operation']: + from cirq.circuits import CircuitOperation, FrozenCircuit + + if isinstance(item, ops.Operation): + item_untagged = item.untagged + if args.preserve_structure and isinstance(item_untagged, CircuitOperation): + new_fc = FrozenCircuit(_decompose_dfs(item_untagged.circuit, args)) + yield item_untagged.replace(circuit=new_fc).with_tags(*item.tags) + return + if args.keep is not None and args.keep(item): + yield item + return + + decomposed = _try_op_decomposer(item, args.intercepting_decomposer) + + if decomposed is NotImplemented or decomposed is None: + decomposed = decompose_once(item, default=None) + + if decomposed is NotImplemented or decomposed is None: + decomposed = _try_op_decomposer(item, args.fallback_decomposer) + + if decomposed is NotImplemented or decomposed is None: + if not isinstance(item, ops.Operation) and isinstance(item, Iterable): + decomposed = item + + if decomposed is NotImplemented or decomposed is None: + if args.keep is not None and args.on_stuck_raise is not None: + if isinstance(args.on_stuck_raise, Exception): + raise args.on_stuck_raise + elif callable(args.on_stuck_raise): + error = args.on_stuck_raise(item) + if error is not None: + raise error + yield item + else: + for val in ops.flatten_to_ops(decomposed): + yield from _decompose_dfs(val, args) + + def decompose( val: Any, *, @@ -200,55 +255,14 @@ def decompose( "acceptable to keep." ) - if preserve_structure: - return _decompose_preserving_structure( - val, - intercepting_decomposer=intercepting_decomposer, - fallback_decomposer=fallback_decomposer, - keep=keep, - on_stuck_raise=on_stuck_raise, - ) - - def try_op_decomposer(val: Any, decomposer: Optional[OpDecomposer]) -> DecomposeResult: - if decomposer is None or not isinstance(val, ops.Operation): - return None - return decomposer(val) - - output = [] - queue: List[Any] = [val] - while queue: - item = queue.pop(0) - if isinstance(item, ops.Operation) and keep is not None and keep(item): - output.append(item) - continue - - decomposed = try_op_decomposer(item, intercepting_decomposer) - - if decomposed is NotImplemented or decomposed is None: - decomposed = decompose_once(item, default=None) - - if decomposed is NotImplemented or decomposed is None: - decomposed = try_op_decomposer(item, fallback_decomposer) - - if decomposed is not NotImplemented and decomposed is not None: - queue[:0] = ops.flatten_to_ops(decomposed) - continue - - if not isinstance(item, ops.Operation) and isinstance(item, Iterable): - queue[:0] = ops.flatten_to_ops(item) - continue - - if keep is not None and on_stuck_raise is not None: - if isinstance(on_stuck_raise, Exception): - raise on_stuck_raise - elif callable(on_stuck_raise): - error = on_stuck_raise(item) - if error is not None: - raise error - - output.append(item) - - return output + args = _DecomposeArgs( + intercepting_decomposer=intercepting_decomposer, + fallback_decomposer=fallback_decomposer, + keep=keep, + on_stuck_raise=on_stuck_raise, + preserve_structure=preserve_structure, + ) + return [*_decompose_dfs(val, args)] # pylint: disable=function-redefined @@ -383,65 +397,4 @@ def _try_decompose_into_operations_and_qubits( qid_shape_dict[q] = max(qid_shape_dict[q], level) qubits = sorted(qubit_set) return result, qubits, tuple(qid_shape_dict[q] for q in qubits) - return None, (), () - - -def _decompose_preserving_structure( - val: Any, - *, - intercepting_decomposer: Optional[OpDecomposer] = None, - fallback_decomposer: Optional[OpDecomposer] = None, - keep: Optional[Callable[['cirq.Operation'], bool]] = None, - on_stuck_raise: Union[ - None, Exception, Callable[['cirq.Operation'], Optional[Exception]] - ] = _value_error_describing_bad_operation, -) -> List['cirq.Operation']: - """Preserves structure (e.g. subcircuits) while decomposing ops. - - This can be used to reduce a circuit to a particular gateset without - increasing its serialization size. See tests for examples. - """ - - # This method provides a generated 'keep' to its decompose() calls. - # If the user-provided keep is not set, on_stuck_raise must be unset to - # ensure that failure to decompose does not generate errors. - on_stuck_raise = on_stuck_raise if keep is not None else None - - from cirq.circuits import CircuitOperation, FrozenCircuit - - visited_fcs = set() - - def keep_structure(op: 'cirq.Operation'): - circuit = getattr(op.untagged, 'circuit', None) - if circuit is not None: - return circuit in visited_fcs - if keep is not None and keep(op): - return True - - def dps_interceptor(op: 'cirq.Operation'): - if not isinstance(op.untagged, CircuitOperation): - if intercepting_decomposer is None: - return NotImplemented - return intercepting_decomposer(op) - - new_fc = FrozenCircuit( - decompose( - op.untagged.circuit, - intercepting_decomposer=dps_interceptor, - fallback_decomposer=fallback_decomposer, - keep=keep_structure, - on_stuck_raise=on_stuck_raise, - ) - ) - visited_fcs.add(new_fc) - new_co = op.untagged.replace(circuit=new_fc) - return new_co if not op.tags else new_co.with_tags(*op.tags) - - return decompose( - val, - intercepting_decomposer=dps_interceptor, - fallback_decomposer=fallback_decomposer, - keep=keep_structure, - on_stuck_raise=on_stuck_raise, - )