Skip to content

Commit

Permalink
Add add_deep_support flag to @cirq.transformer decorator (#5108)
Browse files Browse the repository at this point in the history
* Add  flag to @cirq.transformer decorator

* Fix mypy type errors and remove typos

* Rename add_support_for_deep to add_deep_support
  • Loading branch information
tanujkhattar authored Mar 21, 2022
1 parent aed4eb8 commit cfa255a
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 8 deletions.
74 changes: 69 additions & 5 deletions cirq-core/cirq/transformers/transformer_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
import functools
import textwrap
from typing import (
cast,
Any,
Callable,
Tuple,
Hashable,
List,
Expand All @@ -29,9 +31,12 @@
Type,
TYPE_CHECKING,
TypeVar,
Union,
)
from typing_extensions import Protocol

from cirq import circuits

if TYPE_CHECKING:
import cirq

Expand Down Expand Up @@ -214,10 +219,13 @@ class TransformerContext:
circuit. Transformers should not transform any operation marked with a tag that
belongs to this tuple. Note that any instance of a Hashable type (like `str`,
`cirq.VirtualTag` etc.) is a valid tag.
deep: If true, the transformer should be recursively applied to all sub-circuits wrapped
inside circuit operations.
"""

logger: TransformerLogger = NoOpTransformerLogger()
tags_to_ignore: Tuple[Hashable, ...] = ()
deep: bool = False


class TRANSFORMER(Protocol):
Expand All @@ -229,19 +237,31 @@ def __call__(

_TRANSFORMER_T = TypeVar('_TRANSFORMER_T', bound=TRANSFORMER)
_TRANSFORMER_CLS_T = TypeVar('_TRANSFORMER_CLS_T', bound=Type[TRANSFORMER])
_TRANSFORMER_OR_CLS_T = TypeVar(
'_TRANSFORMER_OR_CLS_T', bound=Union[TRANSFORMER, Type[TRANSFORMER]]
)


@overload
def transformer(
*, add_deep_support: bool = False
) -> Callable[[_TRANSFORMER_OR_CLS_T], _TRANSFORMER_OR_CLS_T]:
pass


@overload
def transformer(cls_or_func: _TRANSFORMER_T) -> _TRANSFORMER_T:
def transformer(cls_or_func: _TRANSFORMER_T, *, add_deep_support: bool = False) -> _TRANSFORMER_T:
pass


@overload
def transformer(cls_or_func: _TRANSFORMER_CLS_T) -> _TRANSFORMER_CLS_T:
def transformer(
cls_or_func: _TRANSFORMER_CLS_T, *, add_deep_support: bool = False
) -> _TRANSFORMER_CLS_T:
pass


def transformer(cls_or_func: Any) -> Any:
def transformer(cls_or_func: Any = None, *, add_deep_support: bool = False) -> Any:
"""Decorator to verify API and append logging functionality to transformer functions & classes.
A transformer is a callable that takes as inputs a `cirq.AbstractCircuit` and
Expand Down Expand Up @@ -284,10 +304,22 @@ def transformer(cls_or_func: Any) -> Any:
Args:
cls_or_func: The callable class or function to be decorated.
add_deep_support: If True, the decorator adds the logic to first apply the
decorated transformer on subcircuits wrapped inside `cirq.CircuitOperation`s
before applying it on the top-level circuit, if context.deep is True.
Returns:
Decorated class / function which includes additional logging boilerplate.
"""

# If keyword arguments were specified, python invokes the decorator method
# without a `cls` argument, then passes `cls` into the result.
if cls_or_func is None:
return lambda deferred_cls_or_func: transformer(
deferred_cls_or_func,
add_deep_support=add_deep_support,
)

if isinstance(cls_or_func, type):
cls = cls_or_func
method = cls.__call__
Expand All @@ -298,6 +330,7 @@ def method_with_logging(
self, circuit: 'cirq.AbstractCircuit', **kwargs
) -> 'cirq.AbstractCircuit':
return _transform_and_log(
add_deep_support,
lambda circuit, **kwargs: method(self, circuit, **kwargs),
cls.__name__,
circuit,
Expand All @@ -315,6 +348,7 @@ def method_with_logging(
@functools.wraps(func)
def func_with_logging(circuit: 'cirq.AbstractCircuit', **kwargs) -> 'cirq.AbstractCircuit':
return _transform_and_log(
add_deep_support,
func,
func.__name__,
circuit,
Expand All @@ -325,7 +359,7 @@ def func_with_logging(circuit: 'cirq.AbstractCircuit', **kwargs) -> 'cirq.Abstra
return func_with_logging


def _get_default_context(func: TRANSFORMER):
def _get_default_context(func: TRANSFORMER) -> TransformerContext:
sig = inspect.signature(func)
default_context = sig.parameters["context"].default
assert (
Expand All @@ -334,7 +368,35 @@ def _get_default_context(func: TRANSFORMER):
return default_context


def _run_transformer_on_circuit(
add_deep_support: bool,
func: TRANSFORMER,
circuit: 'cirq.AbstractCircuit',
extracted_context: Optional[TransformerContext],
**kwargs,
) -> 'cirq.AbstractCircuit':
mutable_circuit = None
if extracted_context and extracted_context.deep and add_deep_support:
batch_replace = []
for i, op in circuit.findall_operations(
lambda o: isinstance(o.untagged, circuits.CircuitOperation)
):
op_untagged = cast(circuits.CircuitOperation, op.untagged)
if not set(op.tags).isdisjoint(extracted_context.tags_to_ignore):
continue
op_untagged = op_untagged.replace(
circuit=_run_transformer_on_circuit(
add_deep_support, func, op_untagged.circuit, extracted_context, **kwargs
).freeze()
)
batch_replace.append((i, op, op_untagged.with_tags(*op.tags)))
mutable_circuit = circuit.unfreeze(copy=True)
mutable_circuit.batch_replace(batch_replace)
return func(mutable_circuit if mutable_circuit else circuit, **kwargs)


def _transform_and_log(
add_deep_support: bool,
func: TRANSFORMER,
transformer_name: str,
circuit: 'cirq.AbstractCircuit',
Expand All @@ -344,7 +406,9 @@ def _transform_and_log(
"""Helper to log initial and final circuits before and after calling the transformer."""
if extracted_context:
extracted_context.logger.register_initial(circuit, transformer_name)
transformed_circuit = func(circuit, **kwargs)
transformed_circuit = _run_transformer_on_circuit(
add_deep_support, func, circuit, extracted_context, **kwargs
)
if extracted_context:
extracted_context.logger.register_final(transformed_circuit, transformer_name)
return transformed_circuit
53 changes: 50 additions & 3 deletions cirq-core/cirq/transformers/transformer_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import pytest


@cirq.transformer
@cirq.transformer()
class MockTransformerClass:
def __init__(self):
self.mock = mock.Mock()
Expand Down Expand Up @@ -59,6 +59,11 @@ def __call__(
return circuit[::-1]


@cirq.transformer(add_deep_support=True)
class MockTransformerClassSupportsDeep(MockTransformerClass):
pass


def make_transformer_func_with_defaults() -> cirq.TRANSFORMER:
my_mock = mock.Mock()

Expand All @@ -77,10 +82,10 @@ def func(
return func


def make_transformer_func() -> cirq.TRANSFORMER:
def make_transformer_func(add_deep_support: bool = False) -> cirq.TRANSFORMER:
my_mock = mock.Mock()

@cirq.transformer
@cirq.transformer(add_deep_support=add_deep_support)
def mock_tranformer_func(
circuit: cirq.AbstractCircuit, *, context: Optional[cirq.TransformerContext] = None
) -> cirq.Circuit:
Expand Down Expand Up @@ -134,6 +139,48 @@ def test_transformer_decorator_with_defaults(transformer):
transformer.mock.assert_called_with(circuit, context, 1e-2, CustomArg(12))


@pytest.mark.parametrize(
'transformer, supports_deep',
[
(MockTransformerClass(), False),
(make_transformer_func(), False),
(MockTransformerClassSupportsDeep(), True),
(make_transformer_func(add_deep_support=True), True),
],
)
def test_transformer_decorator_adds_support_for_deep(transformer, supports_deep):
q = cirq.NamedQubit("q")
c_nested_x = cirq.FrozenCircuit(cirq.X(q))
c_nested_y = cirq.FrozenCircuit(cirq.Y(q))
c_nested_xy = cirq.FrozenCircuit(
cirq.CircuitOperation(c_nested_x).repeat(5).with_tags("ignore"),
cirq.CircuitOperation(c_nested_y).repeat(7).with_tags("preserve_tag"),
)
c_nested_yx = cirq.FrozenCircuit(
cirq.CircuitOperation(c_nested_y).repeat(7).with_tags("ignore"),
cirq.CircuitOperation(c_nested_x).repeat(5).with_tags("preserve_tag"),
)
c_orig = cirq.Circuit(
cirq.CircuitOperation(c_nested_xy).repeat(4),
cirq.CircuitOperation(c_nested_x).repeat(5).with_tags("ignore"),
cirq.CircuitOperation(c_nested_y).repeat(6),
cirq.CircuitOperation(c_nested_yx).repeat(7),
)
context = cirq.TransformerContext(tags_to_ignore=["ignore"], deep=True)
transformer(c_orig, context=context)
expected_calls = [mock.call(c_orig, context)]
if supports_deep:
expected_calls = [
mock.call(c_nested_y, context), # c_orig --> xy --> y
mock.call(c_nested_xy, context), # c_orig --> xy
mock.call(c_nested_y, context), # c_orig --> y
mock.call(c_nested_x, context), # c_orig --> yx --> x
mock.call(c_nested_yx, context), # c_orig --> yx
mock.call(c_orig, context), # c_orig
]
transformer.mock.assert_has_calls(expected_calls)


@cirq.transformer
class T1:
def __call__(
Expand Down

0 comments on commit cfa255a

Please sign in to comment.