diff --git a/cirq-core/cirq/transformers/transformer_api.py b/cirq-core/cirq/transformers/transformer_api.py index 6983312ab02..ee4f6118bf6 100644 --- a/cirq-core/cirq/transformers/transformer_api.py +++ b/cirq-core/cirq/transformers/transformer_api.py @@ -20,7 +20,9 @@ import functools import textwrap from typing import ( + cast, Any, + Callable, Tuple, Hashable, List, @@ -29,9 +31,12 @@ Type, TYPE_CHECKING, TypeVar, + Union, ) from typing_extensions import Protocol +from cirq import circuits + if TYPE_CHECKING: import cirq @@ -214,10 +219,13 @@ class TransformerContext: circuit. Transformers should not transform any operation marked with a tag that belongs to this tuple. Note that any instance of a Hashable type (like `str`, `cirq.VirtualTag` etc.) is a valid tag. + deep: If true, the transformer should be recursively applied to all sub-circuits wrapped + inside circuit operations. """ logger: TransformerLogger = NoOpTransformerLogger() tags_to_ignore: Tuple[Hashable, ...] = () + deep: bool = False class TRANSFORMER(Protocol): @@ -229,19 +237,31 @@ def __call__( _TRANSFORMER_T = TypeVar('_TRANSFORMER_T', bound=TRANSFORMER) _TRANSFORMER_CLS_T = TypeVar('_TRANSFORMER_CLS_T', bound=Type[TRANSFORMER]) +_TRANSFORMER_OR_CLS_T = TypeVar( + '_TRANSFORMER_OR_CLS_T', bound=Union[TRANSFORMER, Type[TRANSFORMER]] +) + + +@overload +def transformer( + *, add_deep_support: bool = False +) -> Callable[[_TRANSFORMER_OR_CLS_T], _TRANSFORMER_OR_CLS_T]: + pass @overload -def transformer(cls_or_func: _TRANSFORMER_T) -> _TRANSFORMER_T: +def transformer(cls_or_func: _TRANSFORMER_T, *, add_deep_support: bool = False) -> _TRANSFORMER_T: pass @overload -def transformer(cls_or_func: _TRANSFORMER_CLS_T) -> _TRANSFORMER_CLS_T: +def transformer( + cls_or_func: _TRANSFORMER_CLS_T, *, add_deep_support: bool = False +) -> _TRANSFORMER_CLS_T: pass -def transformer(cls_or_func: Any) -> Any: +def transformer(cls_or_func: Any = None, *, add_deep_support: bool = False) -> Any: """Decorator to verify API and append logging functionality to transformer functions & classes. A transformer is a callable that takes as inputs a `cirq.AbstractCircuit` and @@ -284,10 +304,22 @@ def transformer(cls_or_func: Any) -> Any: Args: cls_or_func: The callable class or function to be decorated. + add_deep_support: If True, the decorator adds the logic to first apply the + decorated transformer on subcircuits wrapped inside `cirq.CircuitOperation`s + before applying it on the top-level circuit, if context.deep is True. Returns: Decorated class / function which includes additional logging boilerplate. """ + + # If keyword arguments were specified, python invokes the decorator method + # without a `cls` argument, then passes `cls` into the result. + if cls_or_func is None: + return lambda deferred_cls_or_func: transformer( + deferred_cls_or_func, + add_deep_support=add_deep_support, + ) + if isinstance(cls_or_func, type): cls = cls_or_func method = cls.__call__ @@ -298,6 +330,7 @@ def method_with_logging( self, circuit: 'cirq.AbstractCircuit', **kwargs ) -> 'cirq.AbstractCircuit': return _transform_and_log( + add_deep_support, lambda circuit, **kwargs: method(self, circuit, **kwargs), cls.__name__, circuit, @@ -315,6 +348,7 @@ def method_with_logging( @functools.wraps(func) def func_with_logging(circuit: 'cirq.AbstractCircuit', **kwargs) -> 'cirq.AbstractCircuit': return _transform_and_log( + add_deep_support, func, func.__name__, circuit, @@ -325,7 +359,7 @@ def func_with_logging(circuit: 'cirq.AbstractCircuit', **kwargs) -> 'cirq.Abstra return func_with_logging -def _get_default_context(func: TRANSFORMER): +def _get_default_context(func: TRANSFORMER) -> TransformerContext: sig = inspect.signature(func) default_context = sig.parameters["context"].default assert ( @@ -334,7 +368,35 @@ def _get_default_context(func: TRANSFORMER): return default_context +def _run_transformer_on_circuit( + add_deep_support: bool, + func: TRANSFORMER, + circuit: 'cirq.AbstractCircuit', + extracted_context: Optional[TransformerContext], + **kwargs, +) -> 'cirq.AbstractCircuit': + mutable_circuit = None + if extracted_context and extracted_context.deep and add_deep_support: + batch_replace = [] + for i, op in circuit.findall_operations( + lambda o: isinstance(o.untagged, circuits.CircuitOperation) + ): + op_untagged = cast(circuits.CircuitOperation, op.untagged) + if not set(op.tags).isdisjoint(extracted_context.tags_to_ignore): + continue + op_untagged = op_untagged.replace( + circuit=_run_transformer_on_circuit( + add_deep_support, func, op_untagged.circuit, extracted_context, **kwargs + ).freeze() + ) + batch_replace.append((i, op, op_untagged.with_tags(*op.tags))) + mutable_circuit = circuit.unfreeze(copy=True) + mutable_circuit.batch_replace(batch_replace) + return func(mutable_circuit if mutable_circuit else circuit, **kwargs) + + def _transform_and_log( + add_deep_support: bool, func: TRANSFORMER, transformer_name: str, circuit: 'cirq.AbstractCircuit', @@ -344,7 +406,9 @@ def _transform_and_log( """Helper to log initial and final circuits before and after calling the transformer.""" if extracted_context: extracted_context.logger.register_initial(circuit, transformer_name) - transformed_circuit = func(circuit, **kwargs) + transformed_circuit = _run_transformer_on_circuit( + add_deep_support, func, circuit, extracted_context, **kwargs + ) if extracted_context: extracted_context.logger.register_final(transformed_circuit, transformer_name) return transformed_circuit diff --git a/cirq-core/cirq/transformers/transformer_api_test.py b/cirq-core/cirq/transformers/transformer_api_test.py index 8b35656fca8..9ec7d506661 100644 --- a/cirq-core/cirq/transformers/transformer_api_test.py +++ b/cirq-core/cirq/transformers/transformer_api_test.py @@ -21,7 +21,7 @@ import pytest -@cirq.transformer +@cirq.transformer() class MockTransformerClass: def __init__(self): self.mock = mock.Mock() @@ -59,6 +59,11 @@ def __call__( return circuit[::-1] +@cirq.transformer(add_deep_support=True) +class MockTransformerClassSupportsDeep(MockTransformerClass): + pass + + def make_transformer_func_with_defaults() -> cirq.TRANSFORMER: my_mock = mock.Mock() @@ -77,10 +82,10 @@ def func( return func -def make_transformer_func() -> cirq.TRANSFORMER: +def make_transformer_func(add_deep_support: bool = False) -> cirq.TRANSFORMER: my_mock = mock.Mock() - @cirq.transformer + @cirq.transformer(add_deep_support=add_deep_support) def mock_tranformer_func( circuit: cirq.AbstractCircuit, *, context: Optional[cirq.TransformerContext] = None ) -> cirq.Circuit: @@ -134,6 +139,48 @@ def test_transformer_decorator_with_defaults(transformer): transformer.mock.assert_called_with(circuit, context, 1e-2, CustomArg(12)) +@pytest.mark.parametrize( + 'transformer, supports_deep', + [ + (MockTransformerClass(), False), + (make_transformer_func(), False), + (MockTransformerClassSupportsDeep(), True), + (make_transformer_func(add_deep_support=True), True), + ], +) +def test_transformer_decorator_adds_support_for_deep(transformer, supports_deep): + q = cirq.NamedQubit("q") + c_nested_x = cirq.FrozenCircuit(cirq.X(q)) + c_nested_y = cirq.FrozenCircuit(cirq.Y(q)) + c_nested_xy = cirq.FrozenCircuit( + cirq.CircuitOperation(c_nested_x).repeat(5).with_tags("ignore"), + cirq.CircuitOperation(c_nested_y).repeat(7).with_tags("preserve_tag"), + ) + c_nested_yx = cirq.FrozenCircuit( + cirq.CircuitOperation(c_nested_y).repeat(7).with_tags("ignore"), + cirq.CircuitOperation(c_nested_x).repeat(5).with_tags("preserve_tag"), + ) + c_orig = cirq.Circuit( + cirq.CircuitOperation(c_nested_xy).repeat(4), + cirq.CircuitOperation(c_nested_x).repeat(5).with_tags("ignore"), + cirq.CircuitOperation(c_nested_y).repeat(6), + cirq.CircuitOperation(c_nested_yx).repeat(7), + ) + context = cirq.TransformerContext(tags_to_ignore=["ignore"], deep=True) + transformer(c_orig, context=context) + expected_calls = [mock.call(c_orig, context)] + if supports_deep: + expected_calls = [ + mock.call(c_nested_y, context), # c_orig --> xy --> y + mock.call(c_nested_xy, context), # c_orig --> xy + mock.call(c_nested_y, context), # c_orig --> y + mock.call(c_nested_x, context), # c_orig --> yx --> x + mock.call(c_nested_yx, context), # c_orig --> yx + mock.call(c_orig, context), # c_orig + ] + transformer.mock.assert_has_calls(expected_calls) + + @cirq.transformer class T1: def __call__(