Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add add_deep_support flag to @cirq.transformer decorator #5108

Merged
merged 6 commits into from
Mar 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 69 additions & 5 deletions cirq-core/cirq/transformers/transformer_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
import functools
import textwrap
from typing import (
cast,
Any,
Callable,
Tuple,
Hashable,
List,
Expand All @@ -29,9 +31,12 @@
Type,
TYPE_CHECKING,
TypeVar,
Union,
)
from typing_extensions import Protocol

from cirq import circuits

if TYPE_CHECKING:
import cirq

Expand Down Expand Up @@ -214,10 +219,13 @@ class TransformerContext:
circuit. Transformers should not transform any operation marked with a tag that
belongs to this tuple. Note that any instance of a Hashable type (like `str`,
`cirq.VirtualTag` etc.) is a valid tag.
deep: If true, the transformer should be recursively applied to all sub-circuits wrapped
inside circuit operations.
"""

logger: TransformerLogger = NoOpTransformerLogger()
tags_to_ignore: Tuple[Hashable, ...] = ()
deep: bool = False


class TRANSFORMER(Protocol):
Expand All @@ -229,19 +237,31 @@ def __call__(

_TRANSFORMER_T = TypeVar('_TRANSFORMER_T', bound=TRANSFORMER)
_TRANSFORMER_CLS_T = TypeVar('_TRANSFORMER_CLS_T', bound=Type[TRANSFORMER])
_TRANSFORMER_OR_CLS_T = TypeVar(
'_TRANSFORMER_OR_CLS_T', bound=Union[TRANSFORMER, Type[TRANSFORMER]]
)


@overload
def transformer(
*, add_deep_support: bool = False
) -> Callable[[_TRANSFORMER_OR_CLS_T], _TRANSFORMER_OR_CLS_T]:
pass


@overload
def transformer(cls_or_func: _TRANSFORMER_T) -> _TRANSFORMER_T:
def transformer(cls_or_func: _TRANSFORMER_T, *, add_deep_support: bool = False) -> _TRANSFORMER_T:
pass


@overload
def transformer(cls_or_func: _TRANSFORMER_CLS_T) -> _TRANSFORMER_CLS_T:
def transformer(
cls_or_func: _TRANSFORMER_CLS_T, *, add_deep_support: bool = False
) -> _TRANSFORMER_CLS_T:
pass


def transformer(cls_or_func: Any) -> Any:
def transformer(cls_or_func: Any = None, *, add_deep_support: bool = False) -> Any:
"""Decorator to verify API and append logging functionality to transformer functions & classes.

A transformer is a callable that takes as inputs a `cirq.AbstractCircuit` and
Expand Down Expand Up @@ -284,10 +304,22 @@ def transformer(cls_or_func: Any) -> Any:

Args:
cls_or_func: The callable class or function to be decorated.
add_deep_support: If True, the decorator adds the logic to first apply the
decorated transformer on subcircuits wrapped inside `cirq.CircuitOperation`s
before applying it on the top-level circuit, if context.deep is True.

Returns:
Decorated class / function which includes additional logging boilerplate.
"""

# If keyword arguments were specified, python invokes the decorator method
# without a `cls` argument, then passes `cls` into the result.
if cls_or_func is None:
return lambda deferred_cls_or_func: transformer(
deferred_cls_or_func,
add_deep_support=add_deep_support,
)

if isinstance(cls_or_func, type):
cls = cls_or_func
method = cls.__call__
Expand All @@ -298,6 +330,7 @@ def method_with_logging(
self, circuit: 'cirq.AbstractCircuit', **kwargs
) -> 'cirq.AbstractCircuit':
return _transform_and_log(
add_deep_support,
lambda circuit, **kwargs: method(self, circuit, **kwargs),
cls.__name__,
circuit,
Expand All @@ -315,6 +348,7 @@ def method_with_logging(
@functools.wraps(func)
def func_with_logging(circuit: 'cirq.AbstractCircuit', **kwargs) -> 'cirq.AbstractCircuit':
return _transform_and_log(
add_deep_support,
func,
func.__name__,
circuit,
Expand All @@ -325,7 +359,7 @@ def func_with_logging(circuit: 'cirq.AbstractCircuit', **kwargs) -> 'cirq.Abstra
return func_with_logging


def _get_default_context(func: TRANSFORMER):
def _get_default_context(func: TRANSFORMER) -> TransformerContext:
sig = inspect.signature(func)
default_context = sig.parameters["context"].default
assert (
Expand All @@ -334,7 +368,35 @@ def _get_default_context(func: TRANSFORMER):
return default_context


def _run_transformer_on_circuit(
add_deep_support: bool,
func: TRANSFORMER,
circuit: 'cirq.AbstractCircuit',
extracted_context: Optional[TransformerContext],
**kwargs,
) -> 'cirq.AbstractCircuit':
mutable_circuit = None
if extracted_context and extracted_context.deep and add_deep_support:
batch_replace = []
for i, op in circuit.findall_operations(
lambda o: isinstance(o.untagged, circuits.CircuitOperation)
):
op_untagged = cast(circuits.CircuitOperation, op.untagged)
if not set(op.tags).isdisjoint(extracted_context.tags_to_ignore):
continue
op_untagged = op_untagged.replace(
circuit=_run_transformer_on_circuit(
add_deep_support, func, op_untagged.circuit, extracted_context, **kwargs
).freeze()
)
batch_replace.append((i, op, op_untagged.with_tags(*op.tags)))
mutable_circuit = circuit.unfreeze(copy=True)
mutable_circuit.batch_replace(batch_replace)
return func(mutable_circuit if mutable_circuit else circuit, **kwargs)


def _transform_and_log(
add_deep_support: bool,
func: TRANSFORMER,
transformer_name: str,
circuit: 'cirq.AbstractCircuit',
Expand All @@ -344,7 +406,9 @@ def _transform_and_log(
"""Helper to log initial and final circuits before and after calling the transformer."""
if extracted_context:
extracted_context.logger.register_initial(circuit, transformer_name)
transformed_circuit = func(circuit, **kwargs)
transformed_circuit = _run_transformer_on_circuit(
add_deep_support, func, circuit, extracted_context, **kwargs
)
if extracted_context:
extracted_context.logger.register_final(transformed_circuit, transformer_name)
return transformed_circuit
53 changes: 50 additions & 3 deletions cirq-core/cirq/transformers/transformer_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import pytest


@cirq.transformer
@cirq.transformer()
class MockTransformerClass:
def __init__(self):
self.mock = mock.Mock()
Expand Down Expand Up @@ -59,6 +59,11 @@ def __call__(
return circuit[::-1]


@cirq.transformer(add_deep_support=True)
class MockTransformerClassSupportsDeep(MockTransformerClass):
pass


def make_transformer_func_with_defaults() -> cirq.TRANSFORMER:
my_mock = mock.Mock()

Expand All @@ -77,10 +82,10 @@ def func(
return func


def make_transformer_func() -> cirq.TRANSFORMER:
def make_transformer_func(add_deep_support: bool = False) -> cirq.TRANSFORMER:
my_mock = mock.Mock()

@cirq.transformer
@cirq.transformer(add_deep_support=add_deep_support)
def mock_tranformer_func(
circuit: cirq.AbstractCircuit, *, context: Optional[cirq.TransformerContext] = None
) -> cirq.Circuit:
Expand Down Expand Up @@ -134,6 +139,48 @@ def test_transformer_decorator_with_defaults(transformer):
transformer.mock.assert_called_with(circuit, context, 1e-2, CustomArg(12))


@pytest.mark.parametrize(
'transformer, supports_deep',
[
(MockTransformerClass(), False),
(make_transformer_func(), False),
(MockTransformerClassSupportsDeep(), True),
(make_transformer_func(add_deep_support=True), True),
],
)
def test_transformer_decorator_adds_support_for_deep(transformer, supports_deep):
q = cirq.NamedQubit("q")
c_nested_x = cirq.FrozenCircuit(cirq.X(q))
c_nested_y = cirq.FrozenCircuit(cirq.Y(q))
c_nested_xy = cirq.FrozenCircuit(
cirq.CircuitOperation(c_nested_x).repeat(5).with_tags("ignore"),
cirq.CircuitOperation(c_nested_y).repeat(7).with_tags("preserve_tag"),
)
c_nested_yx = cirq.FrozenCircuit(
cirq.CircuitOperation(c_nested_y).repeat(7).with_tags("ignore"),
cirq.CircuitOperation(c_nested_x).repeat(5).with_tags("preserve_tag"),
)
c_orig = cirq.Circuit(
cirq.CircuitOperation(c_nested_xy).repeat(4),
cirq.CircuitOperation(c_nested_x).repeat(5).with_tags("ignore"),
cirq.CircuitOperation(c_nested_y).repeat(6),
cirq.CircuitOperation(c_nested_yx).repeat(7),
)
context = cirq.TransformerContext(tags_to_ignore=["ignore"], deep=True)
transformer(c_orig, context=context)
expected_calls = [mock.call(c_orig, context)]
if supports_deep:
expected_calls = [
mock.call(c_nested_y, context), # c_orig --> xy --> y
mock.call(c_nested_xy, context), # c_orig --> xy
mock.call(c_nested_y, context), # c_orig --> y
mock.call(c_nested_x, context), # c_orig --> yx --> x
mock.call(c_nested_yx, context), # c_orig --> yx
mock.call(c_orig, context), # c_orig
]
transformer.mock.assert_has_calls(expected_calls)


@cirq.transformer
class T1:
def __call__(
Expand Down