Skip to content

Commit

Permalink
Make _create_transformer_with_kwargs a public method (#5492)
Browse files Browse the repository at this point in the history
Fixes #5491
  • Loading branch information
tanujkhattar authored Jun 14, 2022
1 parent d318432 commit 6e0e164
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 16 deletions.
1 change: 1 addition & 0 deletions cirq-core/cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,7 @@
CompilationTargetGateset,
CZTargetGateset,
compute_cphase_exponents_for_fsim_decomposition,
create_transformer_with_kwargs,
decompose_clifford_tableau_to_operations,
decompose_cphase_into_two_fsim,
decompose_multi_controlled_x,
Expand Down
1 change: 1 addition & 0 deletions cirq-core/cirq/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
)

from cirq.transformers.target_gatesets import (
create_transformer_with_kwargs,
CompilationTargetGateset,
CZTargetGateset,
SqrtIswapTargetGateset,
Expand Down
1 change: 1 addition & 0 deletions cirq-core/cirq/transformers/target_gatesets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Gatesets which can act as compilation targets in Cirq."""

from cirq.transformers.target_gatesets.compilation_target_gateset import (
create_transformer_with_kwargs,
CompilationTargetGateset,
TwoQubitCompilationTargetGateset,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,53 @@
import cirq


def _create_transformer_with_kwargs(func: 'cirq.TRANSFORMER', **kwargs) -> 'cirq.TRANSFORMER':
"""Hack to capture additional keyword arguments to transformers while preserving mypy type."""
def create_transformer_with_kwargs(transformer: 'cirq.TRANSFORMER', **kwargs) -> 'cirq.TRANSFORMER':
"""Method to capture additional keyword arguments to transformers while preserving mypy type.
Returns a `cirq.TRANSFORMER` which, when called with a circuit and transformer context, is
equivalent to calling `transformer(circuit, context=context, **kwargs)`. It is often useful to
capture keyword arguments of a transformer before passing them as an argument to an API that
expects `cirq.TRANSFORMER`. For example:
>>> def run_transformers(transformers: List[cirq.TRANSFORMER]):
>>> for transformer in transformers:
>>> transformer(circuit, context=context)
>>>
>>> transformers: List[cirq.TRANSFORMER] = []
>>> transformers.append(
>>> cirq.create_transformer_with_kwargs(
>>> cirq.expand_composite, no_decomp=lambda op: cirq.num_qubits(op) <= 2
>>> )
>>> )
>>> transformers.append(cirq.create_transformer_with_kwargs(cirq.merge_k_qubit_unitaries, k=2))
>>> run_transformers(transformers)
Args:
transformer: A `cirq.TRANSFORMER` for which additional kwargs should be captured.
**kwargs: The keyword arguments which should be captured and passed to `transformer`.
Returns:
A `cirq.TRANSFORMER` method `transformer_with_kwargs`, s.t. executing
`transformer_with_kwargs(circuit, context=context)` is equivalent to executing
`transformer(circuit, context=context, **kwargs)`.
Raises:
SyntaxError: if **kwargs contain a 'context'.
"""
if 'context' in kwargs:
raise SyntaxError('**kwargs to be captured must not contain `context`.')

def transformer(
def transformer_with_kwargs(
circuit: 'cirq.AbstractCircuit', *, context: Optional['cirq.TransformerContext'] = None
) -> 'cirq.AbstractCircuit':
return func(circuit, context=context, **kwargs) # type: ignore
# Need to ignore mypy type because `cirq.TRANSFORMER` is a callable protocol which only
# accepts circuit and context; and doesn't expect additional keyword arguments. Note
# that transformers with additional keyword arguments with a default value do satisfy the
# `cirq.TRANSFORMER` API.
return transformer(circuit, context=context, **kwargs) # type: ignore

return transformer
return transformer_with_kwargs


class CompilationTargetGateset(ops.Gateset, metaclass=abc.ABCMeta):
Expand Down Expand Up @@ -93,11 +131,11 @@ def _intermediate_result_tag(self) -> Hashable:
def preprocess_transformers(self) -> List['cirq.TRANSFORMER']:
"""List of transformers which should be run before decomposing individual operations."""
return [
_create_transformer_with_kwargs(
create_transformer_with_kwargs(
expand_composite.expand_composite,
no_decomp=lambda op: protocols.num_qubits(op) <= self.num_qubits,
),
_create_transformer_with_kwargs(
create_transformer_with_kwargs(
merge_k_qubit_gates.merge_k_qubit_unitaries,
k=self.num_qubits,
rewriter=lambda op: op.with_tags(self._intermediate_result_tag),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from typing import List
import pytest
import cirq
from cirq.protocols.decompose_protocol import DecomposeResult

Expand Down Expand Up @@ -219,3 +220,10 @@ def _decompose_single_qubit_operation(self, op: 'cirq.Operation', _) -> Decompos
c_expected = cirq.Circuit(cirq.X.on_each(*q), ops[-2:])
c_new = cirq.optimize_for_target_gateset(c_orig, gateset=DummyTargetGateset())
cirq.testing.assert_same_circuits(c_new, c_expected)


def test_create_transformer_with_kwargs_raises():
with pytest.raises(SyntaxError, match="must not contain `context`"):
cirq.create_transformer_with_kwargs(
cirq.merge_k_qubit_unitaries, k=2, context=cirq.TransformerContext()
)
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@

import cirq
from cirq.protocols.decompose_protocol import DecomposeResult
from cirq.transformers.target_gatesets.compilation_target_gateset import (
_create_transformer_with_kwargs,
)
from cirq_google import ops
from cirq_google.transformers.analytical_decompositions import two_qubit_to_sycamore

Expand Down Expand Up @@ -137,10 +134,10 @@ def __init__(
@property
def preprocess_transformers(self) -> List[cirq.TRANSFORMER]:
return [
_create_transformer_with_kwargs(
cirq.create_transformer_with_kwargs(
cirq.expand_composite, no_decomp=lambda op: cirq.num_qubits(op) <= self.num_qubits
),
_create_transformer_with_kwargs(
cirq.create_transformer_with_kwargs(
merge_swap_rzz_and_2q_unitaries,
intermediate_result_tag=self._intermediate_result_tag,
),
Expand Down
5 changes: 1 addition & 4 deletions cirq-ionq/cirq_ionq/ionq_gateset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@
from typing import List

import cirq
from cirq.transformers.target_gatesets.compilation_target_gateset import (
_create_transformer_with_kwargs,
)


class IonQTargetGateset(cirq.TwoQubitCompilationTargetGateset):
Expand Down Expand Up @@ -85,7 +82,7 @@ def _decompose_two_qubit_operation(self, op: cirq.Operation, _) -> cirq.OP_TREE:
def preprocess_transformers(self) -> List['cirq.TRANSFORMER']:
"""List of transformers which should be run before decomposing individual operations."""
return [
_create_transformer_with_kwargs(
cirq.create_transformer_with_kwargs(
cirq.expand_composite, no_decomp=lambda op: cirq.num_qubits(op) <= self.num_qubits
)
]
Expand Down

0 comments on commit 6e0e164

Please sign in to comment.