Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use a Protocol for TRANSFORMER to ensure common arg names #4871

Merged
merged 6 commits into from
Jan 24, 2022
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 58 additions & 63 deletions cirq-core/cirq/transformers/transformer_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,21 @@

"""Defines the API for circuit transformers in Cirq."""

import textwrap
import dataclasses
import enum
import functools
import textwrap
from typing import (
Any,
Callable,
Tuple,
Hashable,
List,
Type,
overload,
Type,
TYPE_CHECKING,
TypeVar,
)
import dataclasses
import enum
from cirq.circuits.circuit import CIRCUIT_TYPE
from typing_extensions import Protocol

if TYPE_CHECKING:
import cirq
Expand Down Expand Up @@ -218,77 +218,41 @@ class TransformerContext:
ignore_tags: Tuple[Hashable, ...] = ()


TRANSFORMER = Callable[['cirq.AbstractCircuit', TransformerContext], 'cirq.AbstractCircuit']
_TRANSFORMER_TYPE = Callable[['cirq.AbstractCircuit', TransformerContext], CIRCUIT_TYPE]


def _transform_and_log(
func: _TRANSFORMER_TYPE[CIRCUIT_TYPE],
transformer_name: str,
circuit: 'cirq.AbstractCircuit',
context: TransformerContext,
) -> CIRCUIT_TYPE:
"""Helper to log initial and final circuits before and after calling the transformer."""

context.logger.register_initial(circuit, transformer_name)
transformed_circuit = func(circuit, context)
context.logger.register_final(transformed_circuit, transformer_name)
return transformed_circuit

class TRANSFORMER(Protocol):
def __call__(
self, circuit: 'cirq.AbstractCircuit', context: TransformerContext
) -> 'cirq.AbstractCircuit':
...

def _transformer_class(
cls: Type[_TRANSFORMER_TYPE[CIRCUIT_TYPE]],
) -> Type[_TRANSFORMER_TYPE[CIRCUIT_TYPE]]:
old_func = cls.__call__

def transformer_with_logging_cls(
self: Type[_TRANSFORMER_TYPE[CIRCUIT_TYPE]],
circuit: 'cirq.AbstractCircuit',
context: TransformerContext,
) -> CIRCUIT_TYPE:
def call_old_func(c: 'cirq.AbstractCircuit', ct: TransformerContext) -> CIRCUIT_TYPE:
return old_func(self, c, ct)

return _transform_and_log(call_old_func, cls.__name__, circuit, context)

setattr(cls, '__call__', transformer_with_logging_cls)
return cls


def _transformer_func(func: _TRANSFORMER_TYPE[CIRCUIT_TYPE]) -> _TRANSFORMER_TYPE[CIRCUIT_TYPE]:
@functools.wraps(func)
def transformer_with_logging_func(
circuit: 'cirq.AbstractCircuit',
context: TransformerContext,
) -> CIRCUIT_TYPE:
return _transform_and_log(func, func.__name__, circuit, context)

return transformer_with_logging_func
_TRANSFORMER_T = TypeVar('_TRANSFORMER_T', bound=TRANSFORMER)
_TRANSFORMER_CLS_T = TypeVar('_TRANSFORMER_CLS_T', bound=Type[TRANSFORMER])


@overload
def transformer(cls_or_func: _TRANSFORMER_TYPE[CIRCUIT_TYPE]) -> _TRANSFORMER_TYPE[CIRCUIT_TYPE]:
def transformer(cls_or_func: _TRANSFORMER_T) -> _TRANSFORMER_T:
pass


@overload
def transformer(
cls_or_func: Type[_TRANSFORMER_TYPE[CIRCUIT_TYPE]],
) -> Type[_TRANSFORMER_TYPE[CIRCUIT_TYPE]]:
def transformer(cls_or_func: _TRANSFORMER_CLS_T) -> _TRANSFORMER_CLS_T:
pass


def transformer(cls_or_func: Any) -> Any:
"""Decorator to verify API and append logging functionality to transformer functions & classes.

The decorated function or class must satisfy
`Callable[[cirq.Circuit, cirq.TransformerContext], cirq.Circuit]` API. For Example:
A transformer is a callable that takes as inputs a cirq.AbstractCircuit and
maffoo marked this conversation as resolved.
Show resolved Hide resolved
cirq.TransformerContext, and returns another cirq.AbstractCircuit without
maffoo marked this conversation as resolved.
Show resolved Hide resolved
modifying the input circuit. A transformer could be a function, for example:

>>> @cirq.transformer
>>> def convert_to_cz(circuit: cirq.Circuit, context: cirq.TransformerContext) -> cirq.Circuit:
>>> def convert_to_cz(
>>> circuit: cirq.AbstractCircuit, context: cirq.TransformerContext
>>> ) -> cirq.Circuit:
>>> ...

The decorated class must implement the `__call__` method to satisfy the above API.
Or it could be a class that implements `__call__` with the same API, for example:

>>> @cirq.transformer
>>> class ConvertToSqrtISwaps:
maffoo marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -300,14 +264,45 @@ def transformer(cls_or_func: Any) -> Any:
>>> ...

Args:
cls_or_func: The callable class or method to be decorated.
cls_or_func: The callable class or function to be decorated.

Returns:
Decorated class / method which includes additional logging boilerplate. The decorated
callable always receives a copy of the input circuit so that the input is never mutated.
Decorated class / function which includes additional logging boilerplate.
"""
if isinstance(cls_or_func, type):
return _transformer_class(cls_or_func)
cls = cls_or_func
method = cls.__call__

@functools.wraps(method)
def method_with_logging(self, circuit, context):
maffoo marked this conversation as resolved.
Show resolved Hide resolved
return _transform_and_log(
maffoo marked this conversation as resolved.
Show resolved Hide resolved
lambda circuit, context: method(self, circuit, context),
cls.__name__,
circuit,
context,
)

setattr(cls, '__call__', method_with_logging)
return cls
else:
assert callable(cls_or_func)
return _transformer_func(cls_or_func)
func = cls_or_func

@functools.wraps(func)
def func_with_logging(circuit, context):
return _transform_and_log(func, func.__name__, circuit, context)

return func_with_logging


def _transform_and_log(
func: TRANSFORMER,
transformer_name: str,
circuit: 'cirq.AbstractCircuit',
context: TransformerContext,
) -> 'cirq.AbstractCircuit':
"""Helper to log initial and final circuits before and after calling the transformer."""
context.logger.register_initial(circuit, transformer_name)
transformed_circuit = func(circuit, context)
context.logger.register_final(transformed_circuit, transformer_name)
return transformed_circuit