diff --git a/cirq-core/cirq/__init__.py b/cirq-core/cirq/__init__.py index 67e8942da6f..31e2b474521 100644 --- a/cirq-core/cirq/__init__.py +++ b/cirq-core/cirq/__init__.py @@ -358,6 +358,7 @@ CompilationTargetGateset, CZTargetGateset, compute_cphase_exponents_for_fsim_decomposition, + create_transformer_with_kwargs, decompose_clifford_tableau_to_operations, decompose_cphase_into_two_fsim, decompose_multi_controlled_x, diff --git a/cirq-core/cirq/transformers/__init__.py b/cirq-core/cirq/transformers/__init__.py index d6cbb0639ff..a1bc125ac3f 100644 --- a/cirq-core/cirq/transformers/__init__.py +++ b/cirq-core/cirq/transformers/__init__.py @@ -43,6 +43,7 @@ ) from cirq.transformers.target_gatesets import ( + create_transformer_with_kwargs, CompilationTargetGateset, CZTargetGateset, SqrtIswapTargetGateset, diff --git a/cirq-core/cirq/transformers/target_gatesets/__init__.py b/cirq-core/cirq/transformers/target_gatesets/__init__.py index 222e58ef46d..9c5369d6120 100644 --- a/cirq-core/cirq/transformers/target_gatesets/__init__.py +++ b/cirq-core/cirq/transformers/target_gatesets/__init__.py @@ -15,6 +15,7 @@ """Gatesets which can act as compilation targets in Cirq.""" from cirq.transformers.target_gatesets.compilation_target_gateset import ( + create_transformer_with_kwargs, CompilationTargetGateset, TwoQubitCompilationTargetGateset, ) diff --git a/cirq-core/cirq/transformers/target_gatesets/compilation_target_gateset.py b/cirq-core/cirq/transformers/target_gatesets/compilation_target_gateset.py index 45801676282..743ba3c303b 100644 --- a/cirq-core/cirq/transformers/target_gatesets/compilation_target_gateset.py +++ b/cirq-core/cirq/transformers/target_gatesets/compilation_target_gateset.py @@ -29,15 +29,53 @@ import cirq -def _create_transformer_with_kwargs(func: 'cirq.TRANSFORMER', **kwargs) -> 'cirq.TRANSFORMER': - """Hack to capture additional keyword arguments to transformers while preserving mypy type.""" +def create_transformer_with_kwargs(transformer: 'cirq.TRANSFORMER', **kwargs) -> 'cirq.TRANSFORMER': + """Method to capture additional keyword arguments to transformers while preserving mypy type. + + Returns a `cirq.TRANSFORMER` which, when called with a circuit and transformer context, is + equivalent to calling `transformer(circuit, context=context, **kwargs)`. It is often useful to + capture keyword arguments of a transformer before passing them as an argument to an API that + expects `cirq.TRANSFORMER`. For example: + + >>> def run_transformers(transformers: List[cirq.TRANSFORMER]): + >>> for transformer in transformers: + >>> transformer(circuit, context=context) + >>> + >>> transformers: List[cirq.TRANSFORMER] = [] + >>> transformers.append( + >>> cirq.create_transformer_with_kwargs( + >>> cirq.expand_composite, no_decomp=lambda op: cirq.num_qubits(op) <= 2 + >>> ) + >>> ) + >>> transformers.append(cirq.create_transformer_with_kwargs(cirq.merge_k_qubit_unitaries, k=2)) + >>> run_transformers(transformers) + + + Args: + transformer: A `cirq.TRANSFORMER` for which additional kwargs should be captured. + **kwargs: The keyword arguments which should be captured and passed to `transformer`. + + Returns: + A `cirq.TRANSFORMER` method `transformer_with_kwargs`, s.t. executing + `transformer_with_kwargs(circuit, context=context)` is equivalent to executing + `transformer(circuit, context=context, **kwargs)`. + + Raises: + SyntaxError: if **kwargs contain a 'context'. + """ + if 'context' in kwargs: + raise SyntaxError('**kwargs to be captured must not contain `context`.') - def transformer( + def transformer_with_kwargs( circuit: 'cirq.AbstractCircuit', *, context: Optional['cirq.TransformerContext'] = None ) -> 'cirq.AbstractCircuit': - return func(circuit, context=context, **kwargs) # type: ignore + # Need to ignore mypy type because `cirq.TRANSFORMER` is a callable protocol which only + # accepts circuit and context; and doesn't expect additional keyword arguments. Note + # that transformers with additional keyword arguments with a default value do satisfy the + # `cirq.TRANSFORMER` API. + return transformer(circuit, context=context, **kwargs) # type: ignore - return transformer + return transformer_with_kwargs class CompilationTargetGateset(ops.Gateset, metaclass=abc.ABCMeta): @@ -93,11 +131,11 @@ def _intermediate_result_tag(self) -> Hashable: def preprocess_transformers(self) -> List['cirq.TRANSFORMER']: """List of transformers which should be run before decomposing individual operations.""" return [ - _create_transformer_with_kwargs( + create_transformer_with_kwargs( expand_composite.expand_composite, no_decomp=lambda op: protocols.num_qubits(op) <= self.num_qubits, ), - _create_transformer_with_kwargs( + create_transformer_with_kwargs( merge_k_qubit_gates.merge_k_qubit_unitaries, k=self.num_qubits, rewriter=lambda op: op.with_tags(self._intermediate_result_tag), diff --git a/cirq-core/cirq/transformers/target_gatesets/compilation_target_gateset_test.py b/cirq-core/cirq/transformers/target_gatesets/compilation_target_gateset_test.py index c3af371e90f..80bea204e0d 100644 --- a/cirq-core/cirq/transformers/target_gatesets/compilation_target_gateset_test.py +++ b/cirq-core/cirq/transformers/target_gatesets/compilation_target_gateset_test.py @@ -13,6 +13,7 @@ # limitations under the License. from typing import List +import pytest import cirq from cirq.protocols.decompose_protocol import DecomposeResult @@ -219,3 +220,10 @@ def _decompose_single_qubit_operation(self, op: 'cirq.Operation', _) -> Decompos c_expected = cirq.Circuit(cirq.X.on_each(*q), ops[-2:]) c_new = cirq.optimize_for_target_gateset(c_orig, gateset=DummyTargetGateset()) cirq.testing.assert_same_circuits(c_new, c_expected) + + +def test_create_transformer_with_kwargs_raises(): + with pytest.raises(SyntaxError, match="must not contain `context`"): + cirq.create_transformer_with_kwargs( + cirq.merge_k_qubit_unitaries, k=2, context=cirq.TransformerContext() + ) diff --git a/cirq-google/cirq_google/transformers/target_gatesets/sycamore_gateset.py b/cirq-google/cirq_google/transformers/target_gatesets/sycamore_gateset.py index 7545ddf4537..10605ef6395 100644 --- a/cirq-google/cirq_google/transformers/target_gatesets/sycamore_gateset.py +++ b/cirq-google/cirq_google/transformers/target_gatesets/sycamore_gateset.py @@ -19,9 +19,6 @@ import cirq from cirq.protocols.decompose_protocol import DecomposeResult -from cirq.transformers.target_gatesets.compilation_target_gateset import ( - _create_transformer_with_kwargs, -) from cirq_google import ops from cirq_google.transformers.analytical_decompositions import two_qubit_to_sycamore @@ -137,10 +134,10 @@ def __init__( @property def preprocess_transformers(self) -> List[cirq.TRANSFORMER]: return [ - _create_transformer_with_kwargs( + cirq.create_transformer_with_kwargs( cirq.expand_composite, no_decomp=lambda op: cirq.num_qubits(op) <= self.num_qubits ), - _create_transformer_with_kwargs( + cirq.create_transformer_with_kwargs( merge_swap_rzz_and_2q_unitaries, intermediate_result_tag=self._intermediate_result_tag, ), diff --git a/cirq-ionq/cirq_ionq/ionq_gateset.py b/cirq-ionq/cirq_ionq/ionq_gateset.py index a4d75e118f6..e2ea3b4f294 100644 --- a/cirq-ionq/cirq_ionq/ionq_gateset.py +++ b/cirq-ionq/cirq_ionq/ionq_gateset.py @@ -18,9 +18,6 @@ from typing import List import cirq -from cirq.transformers.target_gatesets.compilation_target_gateset import ( - _create_transformer_with_kwargs, -) class IonQTargetGateset(cirq.TwoQubitCompilationTargetGateset): @@ -85,7 +82,7 @@ def _decompose_two_qubit_operation(self, op: cirq.Operation, _) -> cirq.OP_TREE: def preprocess_transformers(self) -> List['cirq.TRANSFORMER']: """List of transformers which should be run before decomposing individual operations.""" return [ - _create_transformer_with_kwargs( + cirq.create_transformer_with_kwargs( cirq.expand_composite, no_decomp=lambda op: cirq.num_qubits(op) <= self.num_qubits ) ]