From f6b5ed8066c64ffcbd917816b9613638fd22fa1c Mon Sep 17 00:00:00 2001 From: Tanuj Khattar Date: Sat, 19 Feb 2022 02:28:31 +0530 Subject: [PATCH] Add `cirq.convert_to_target_gateset` transformer and `cirq.CompilationTargetGateset` interface (#5005) * Add convert_to_target_gateset transformer and CompilationTargetGateset interface * Override validation_operation to not accept intermediate results --- cirq/__init__.py | 2 + cirq/protocols/decompose_protocol.py | 6 +- cirq/protocols/decompose_protocol_test.py | 3 +- cirq/protocols/json_test_data/spec.py | 2 + cirq/transformers/__init__.py | 6 + .../optimize_for_target_gateset.py | 130 ++++++++++++ .../optimize_for_target_gateset_test.py | 198 ++++++++++++++++++ cirq/transformers/target_gatesets/__init__.py | 17 ++ .../compilation_target_gateset.py | 117 +++++++++++ .../compilation_target_gateset_test.py | 56 +++++ 10 files changed, 534 insertions(+), 3 deletions(-) create mode 100644 cirq/transformers/optimize_for_target_gateset.py create mode 100644 cirq/transformers/optimize_for_target_gateset_test.py create mode 100644 cirq/transformers/target_gatesets/__init__.py create mode 100644 cirq/transformers/target_gatesets/compilation_target_gateset.py create mode 100644 cirq/transformers/target_gatesets/compilation_target_gateset_test.py diff --git a/cirq/__init__.py b/cirq/__init__.py index b656f1c48b8..db8d46dbaea 100644 --- a/cirq/__init__.py +++ b/cirq/__init__.py @@ -355,6 +355,7 @@ from cirq.transformers import ( align_left, align_right, + CompilationTargetGateset, compute_cphase_exponents_for_fsim_decomposition, decompose_clifford_tableau_to_operations, decompose_cphase_into_two_fsim, @@ -380,6 +381,7 @@ merge_single_qubit_gates_to_phased_x_and_z, merge_single_qubit_gates_to_phxz, merge_single_qubit_moments_to_phxz, + optimize_for_target_gateset, prepare_two_qubit_state_using_cz, prepare_two_qubit_state_using_sqrt_iswap, single_qubit_matrix_to_gates, diff --git a/cirq/protocols/decompose_protocol.py b/cirq/protocols/decompose_protocol.py index a9e3ca93307..dda931d1772 100644 --- a/cirq/protocols/decompose_protocol.py +++ b/cirq/protocols/decompose_protocol.py @@ -180,7 +180,11 @@ def decompose( that doesn't satisfy the given `keep` predicate. """ - if on_stuck_raise is not _value_error_describing_bad_operation and keep is None: + if ( + on_stuck_raise is not _value_error_describing_bad_operation + and on_stuck_raise is not None + and keep is None + ): raise ValueError( "Must specify 'keep' if specifying 'on_stuck_raise', because it's " "not possible to get stuck if you don't have a criteria on what's " diff --git a/cirq/protocols/decompose_protocol_test.py b/cirq/protocols/decompose_protocol_test.py index aca655f828e..c4c9bce3631 100644 --- a/cirq/protocols/decompose_protocol_test.py +++ b/cirq/protocols/decompose_protocol_test.py @@ -182,6 +182,7 @@ def test_decompose_on_stuck_raise(): _ = cirq.decompose(NoMethod(), keep=lambda _: False) # Unless there's no operations to be unhappy about. assert cirq.decompose([], keep=lambda _: False) == [] + assert cirq.decompose([], on_stuck_raise=None) == [] # Or you say you're fine. assert cirq.decompose(no_method, keep=lambda _: False, on_stuck_raise=None) == [no_method] assert cirq.decompose(no_method, keep=lambda _: False, on_stuck_raise=lambda _: None) == [ @@ -198,8 +199,6 @@ def test_decompose_on_stuck_raise(): ) # There's a nice warning if you specify `on_stuck_raise` but not `keep`. - with pytest.raises(ValueError, match='on_stuck_raise'): - assert cirq.decompose([], on_stuck_raise=None) with pytest.raises(ValueError, match='on_stuck_raise'): assert cirq.decompose([], on_stuck_raise=TypeError('x')) diff --git a/cirq/protocols/json_test_data/spec.py b/cirq/protocols/json_test_data/spec.py index 9df089bc156..f18095db8db 100644 --- a/cirq/protocols/json_test_data/spec.py +++ b/cirq/protocols/json_test_data/spec.py @@ -92,6 +92,8 @@ 'ApplyMixtureArgs', 'ApplyUnitaryArgs', 'OperationTarget', + # Abstract base class for creating compilation targets. + 'CompilationTargetGateset', # Circuit optimizers are function-like. Only attributes # are ignore_failures, tolerance, and other feature flags 'AlignLeft', diff --git a/cirq/transformers/__init__.py b/cirq/transformers/__init__.py index 8033c6315c2..959a1a7fa1a 100644 --- a/cirq/transformers/__init__.py +++ b/cirq/transformers/__init__.py @@ -41,6 +41,10 @@ two_qubit_gate_product_tabulation, ) +from cirq.transformers.target_gatesets import ( + CompilationTargetGateset, +) + from cirq.transformers.align import align_left, align_right from cirq.transformers.stratify import stratified_circuit @@ -49,6 +53,8 @@ from cirq.transformers.eject_phased_paulis import eject_phased_paulis +from cirq.transformers.optimize_for_target_gateset import optimize_for_target_gateset + from cirq.transformers.drop_empty_moments import drop_empty_moments from cirq.transformers.drop_negligible_operations import drop_negligible_operations diff --git a/cirq/transformers/optimize_for_target_gateset.py b/cirq/transformers/optimize_for_target_gateset.py new file mode 100644 index 00000000000..d028366db14 --- /dev/null +++ b/cirq/transformers/optimize_for_target_gateset.py @@ -0,0 +1,130 @@ +# Copyright 2022 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Transformers to rewrite a circuit using gates from a given target gateset.""" + +from typing import Optional, Callable, TYPE_CHECKING + +from cirq.protocols import decompose_protocol as dp +from cirq.transformers import transformer_api, transformer_primitives + +if TYPE_CHECKING: + import cirq + + +def _create_on_stuck_raise_error(gateset: 'cirq.Gateset'): + def _value_error_describing_bad_operation(op: 'cirq.Operation') -> ValueError: + return ValueError(f"Unable to convert {op} to target gateset {gateset!r}") + + return _value_error_describing_bad_operation + + +@transformer_api.transformer +def _decompose_operations_to_target_gateset( + circuit: 'cirq.AbstractCircuit', + *, + context: Optional['cirq.TransformerContext'] = None, + gateset: Optional['cirq.Gateset'] = None, + decomposer: Callable[['cirq.Operation', int], dp.DecomposeResult] = lambda *_: NotImplemented, + ignore_failures: bool = True, +) -> 'cirq.Circuit': + """Decomposes every operation to `gateset` using `cirq.decompose` and `decomposer`. + + This transformer attempts to decompose every operation `op` in the given circuit to `gateset` + using `cirq.decompose` protocol with `decomposer` used as an intercepting decomposer. This + ensures that `op` is recursively decomposed using implicitly defined known decompositions + (eg: in `_decompose_` magic method on the gaet class) till either `decomposer` knows how to + decompose the given operation or the given operation belongs to `gateset`. + + Args: + circuit: Input circuit to transform. It will not be modified. + context: `cirq.TransformerContext` storing common configurable options for transformers. + gateset: Target gateset, which the decomposed operations should belong to. + decomposer: A callable type which accepts an (operation, moment_index) and returns + - An equivalent `cirq.OP_TREE` implementing `op` using gates from `gateset`. + - `None` or `NotImplemented` if does not know how to decompose a given `op`. + ignore_failures: If set, operations that fail to convert are left unchanged. If not set, + conversion failures raise a ValueError. + + Returns: + An equivalent circuit containing gates accepted by `gateset`. + + Raises: + ValueError: If any input operation fails to convert and `ignore_failures` is False. + """ + + def map_func(op: 'cirq.Operation', moment_index: int): + return dp.decompose( + op, + intercepting_decomposer=lambda o: decomposer(o, moment_index), + keep=gateset.validate if gateset else None, + on_stuck_raise=( + None + if ignore_failures or gateset is None + else _create_on_stuck_raise_error(gateset) + ), + ) + + return transformer_primitives.map_operations_and_unroll( + circuit, map_func, tags_to_ignore=context.tags_to_ignore if context else () + ).unfreeze(copy=False) + + +@transformer_api.transformer +def optimize_for_target_gateset( + circuit: 'cirq.AbstractCircuit', + *, + context: Optional['cirq.TransformerContext'] = None, + gateset: Optional['cirq.CompilationTargetGateset'] = None, + ignore_failures: bool = True, +) -> 'cirq.Circuit': + """Transforms the given circuit into an equivalent circuit using gates accepted by `gateset`. + + 1. Run all `gateset.preprocess_transformers` + 2. Convert operations using built-in cirq decompose + `gateset.decompose_to_target_gateset`. + 3. Run all `gateset.postprocess_transformers` + + Args: + circuit: Input circuit to transform. It will not be modified. + context: `cirq.TransformerContext` storing common configurable options for transformers. + gateset: Target gateset, which should be an instance of `cirq.CompilationTargetGateset`. + ignore_failures: If set, operations that fail to convert are left unchanged. If not set, + conversion failures raise a ValueError. + + Returns: + An equivalent circuit containing gates accepted by `gateset`. + + Raises: + ValueError: If any input operation fails to convert and `ignore_failures` is False. + """ + if gateset is None: + return _decompose_operations_to_target_gateset( + circuit, context=context, ignore_failures=ignore_failures + ) + + for transformer in gateset.preprocess_transformers: + circuit = transformer(circuit, context=context) + + circuit = _decompose_operations_to_target_gateset( + circuit, + context=context, + gateset=gateset, + decomposer=gateset.decompose_to_target_gateset, + ignore_failures=ignore_failures, + ) + + for transformer in gateset.postprocess_transformers: + circuit = transformer(circuit, context=context) + + return circuit.unfreeze(copy=False) diff --git a/cirq/transformers/optimize_for_target_gateset_test.py b/cirq/transformers/optimize_for_target_gateset_test.py new file mode 100644 index 00000000000..e923ceb60db --- /dev/null +++ b/cirq/transformers/optimize_for_target_gateset_test.py @@ -0,0 +1,198 @@ +# Copyright 2022 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cirq +from cirq.protocols.decompose_protocol import DecomposeResult +from cirq.transformers.optimize_for_target_gateset import _decompose_operations_to_target_gateset +import pytest + + +def test_decompose_operations_raises_on_stuck(): + c_orig = cirq.Circuit(cirq.X(cirq.NamedQubit("q")).with_tags("ignore")) + gateset = cirq.Gateset(cirq.Y) + with pytest.raises(ValueError, match="Unable to convert"): + _ = _decompose_operations_to_target_gateset(c_orig, gateset=gateset, ignore_failures=False) + + # Gates marked with a no-compile tag are completely ignored. + c_new = _decompose_operations_to_target_gateset( + c_orig, + context=cirq.TransformerContext(tags_to_ignore=("ignore",)), + gateset=gateset, + ignore_failures=False, + ) + cirq.testing.assert_same_circuits(c_orig, c_new) + + +# pylint: disable=line-too-long +def test_decompose_operations_to_target_gateset_default(): + q = cirq.LineQubit.range(2) + c_orig = cirq.Circuit( + cirq.T(q[0]), + cirq.SWAP(*q), + cirq.T(q[0]), + cirq.SWAP(*q).with_tags("ignore"), + cirq.measure(q[0], key="m"), + cirq.X(q[1]).with_classical_controls("m"), + cirq.Moment(cirq.T.on_each(*q)), + cirq.SWAP(*q), + cirq.T.on_each(*q), + ) + cirq.testing.assert_has_diagram( + c_orig, + ''' +0: ───T───×───T───×['ignore']───M───────T───×───T─── + │ │ ║ │ +1: ───────×───────×─────────────╫───X───T───×───T─── + ║ ║ +m: ═════════════════════════════@═══^═══════════════''', + ) + context = cirq.TransformerContext(tags_to_ignore=("ignore",)) + c_new = _decompose_operations_to_target_gateset(c_orig, context=context) + cirq.testing.assert_has_diagram( + c_new, + ''' +0: ───T────────────@───Y^-0.5───@───Y^0.5────@───────────T───×['ignore']───M───────T────────────@───Y^-0.5───@───Y^0.5────@───────────T─── + │ │ │ │ ║ │ │ │ +1: ───────Y^-0.5───@───Y^0.5────@───Y^-0.5───@───Y^0.5───────×─────────────╫───X───T───Y^-0.5───@───Y^0.5────@───Y^-0.5───@───Y^0.5───T─── + ║ ║ +m: ════════════════════════════════════════════════════════════════════════@═══^══════════════════════════════════════════════════════════ +''', + ) + + +def test_decompose_operations_to_target_gateset(): + q = cirq.LineQubit.range(2) + c_orig = cirq.Circuit( + cirq.T(q[0]), + cirq.SWAP(*q), + cirq.T(q[0]), + cirq.SWAP(*q).with_tags("ignore"), + cirq.measure(q[0], key="m"), + cirq.X(q[1]).with_classical_controls("m"), + cirq.Moment(cirq.T.on_each(*q)), + cirq.SWAP(*q), + cirq.T.on_each(*q), + ) + gateset = cirq.Gateset(cirq.H, cirq.CNOT) + decomposer = ( + lambda op, _: cirq.H(op.qubits[0]) + if cirq.has_unitary(op) and cirq.num_qubits(op) == 1 + else NotImplemented + ) + context = cirq.TransformerContext(tags_to_ignore=("ignore",)) + c_new = _decompose_operations_to_target_gateset( + c_orig, gateset=gateset, decomposer=decomposer, context=context + ) + cirq.testing.assert_has_diagram( + c_new, + ''' +0: ───H───@───X───@───H───×['ignore']───M───────H───@───X───@───H─── + │ │ │ │ ║ │ │ │ +1: ───────X───@───X───────×─────────────╫───X───H───X───@───X───H─── + ║ ║ +m: ═════════════════════════════════════@═══^═══════════════════════''', + ) + + with pytest.raises(ValueError, match="Unable to convert"): + _ = _decompose_operations_to_target_gateset( + c_orig, gateset=gateset, decomposer=decomposer, context=context, ignore_failures=False + ) + + +class MatrixGateTargetGateset(cirq.CompilationTargetGateset): + def __init__(self): + super().__init__(cirq.MatrixGate) + + @property + def num_qubits(self) -> int: + return 2 + + def decompose_to_target_gateset(self, op: 'cirq.Operation', _) -> DecomposeResult: + if cirq.num_qubits(op) != 2 or not cirq.has_unitary(op): + return NotImplemented + return cirq.MatrixGate(cirq.unitary(op), name="M").on(*op.qubits) + + +def test_optimize_for_target_gateset_default(): + q = cirq.LineQubit.range(2) + c_orig = cirq.Circuit( + cirq.T(q[0]), + cirq.SWAP(*q), + cirq.T(q[0]), + cirq.SWAP(*q).with_tags("ignore"), + ) + context = cirq.TransformerContext(tags_to_ignore=("ignore",)) + c_new = cirq.optimize_for_target_gateset(c_orig, context=context) + cirq.testing.assert_has_diagram( + c_new, + ''' +0: ───T────────────@───Y^-0.5───@───Y^0.5────@───────────T───×['ignore']─── + │ │ │ │ +1: ───────Y^-0.5───@───Y^0.5────@───Y^-0.5───@───Y^0.5───────×───────────── +''', + ) + cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(c_orig, c_new, atol=1e-6) + + +def test_optimize_for_target_gateset(): + q = cirq.LineQubit.range(4) + c_orig = cirq.Circuit( + cirq.QuantumFourierTransformGate(4).on(*q), + cirq.Y(q[0]).with_tags("ignore"), + cirq.Y(q[1]).with_tags("ignore"), + cirq.CNOT(*q[2:]).with_tags("ignore"), + cirq.measure(*q[:2], key="m"), + cirq.CZ(*q[2:]).with_classical_controls("m"), + cirq.inverse(cirq.QuantumFourierTransformGate(4).on(*q)), + ) + + cirq.testing.assert_has_diagram( + c_orig, + ''' +0: ───qft───Y['ignore']───M───────qft^-1─── + │ ║ │ +1: ───#2────Y['ignore']───M───────#2─────── + │ ║ │ +2: ───#3────@['ignore']───╫───@───#3─────── + │ │ ║ ║ │ +3: ───#4────X─────────────╫───@───#4─────── + ║ ║ +m: ═══════════════════════@═══^════════════ +''', + ) + gateset = MatrixGateTargetGateset() + context = cirq.TransformerContext(tags_to_ignore=("ignore",)) + c_new = cirq.optimize_for_target_gateset(c_orig, gateset=gateset, context=context) + cirq.testing.assert_has_diagram( + c_new, + ''' + ┌────────┐ ┌────────┐ ┌────────┐ +0: ───M[1]──────────M[1]──────────────────────M[1]────Y['ignore']───M────────M[1]───────────────────────────M[1]────M[1]───M[1]─── + │ │ │ ║ │ │ │ │ +1: ───M[2]───M[1]───┼─────────────M[1]────M[1]┼───────Y['ignore']───M────────┼───M[1]───────────M[1]────M[1]┼───────┼──────M[2]─── + │ │ │ │ │ ║ │ │ │ │ │ │ +2: ──────────M[2]───M[2]───M[1]───┼───────M[2]┼───────@['ignore']───╫───@────┼───M[2]────M[1]───┼───────M[2]┼───────M[2]────────── + │ │ │ │ ║ ║ │ │ │ │ +3: ────────────────────────M[2]───M[2]────────M[2]────X─────────────╫───@────M[2]────────M[2]───M[2]────────M[2]────────────────── + ║ ║ +m: ═════════════════════════════════════════════════════════════════@═══^═════════════════════════════════════════════════════════ + └────────┘ └────────┘ └────────┘ + ''', + ) + + with pytest.raises(ValueError, match="Unable to convert"): + # Raises an error due to CCO and Measurement gate, which are not part of the gateset. + _ = cirq.optimize_for_target_gateset( + c_orig, gateset=gateset, context=context, ignore_failures=False + ) diff --git a/cirq/transformers/target_gatesets/__init__.py b/cirq/transformers/target_gatesets/__init__.py new file mode 100644 index 00000000000..567c2d7c2f5 --- /dev/null +++ b/cirq/transformers/target_gatesets/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2022 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Gatesets which can act as compilation targets in Cirq.""" + +from cirq.transformers.target_gatesets.compilation_target_gateset import CompilationTargetGateset diff --git a/cirq/transformers/target_gatesets/compilation_target_gateset.py b/cirq/transformers/target_gatesets/compilation_target_gateset.py new file mode 100644 index 00000000000..88edbb20cfd --- /dev/null +++ b/cirq/transformers/target_gatesets/compilation_target_gateset.py @@ -0,0 +1,117 @@ +# Copyright 2022 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Base class for creating custom target gatesets which can be used for compilation.""" + +from typing import Optional, List, Hashable, TYPE_CHECKING +import abc + +from cirq import ops, protocols, _import +from cirq.protocols.decompose_protocol import DecomposeResult +from cirq.transformers import ( + merge_k_qubit_gates, + merge_single_qubit_gates, +) + +drop_empty_moments = _import.LazyLoader('drop_empty_moments', globals(), 'cirq.transformers') +drop_negligible = _import.LazyLoader('drop_negligible_operations', globals(), 'cirq.transformers') +expand_composite = _import.LazyLoader('expand_composite', globals(), 'cirq.transformers') + +if TYPE_CHECKING: + import cirq + + +def _create_transformer_with_kwargs(func: 'cirq.TRANSFORMER', **kwargs) -> 'cirq.TRANSFORMER': + """Hack to capture additional keyword arguments to transformers while preserving mypy type.""" + + def transformer( + circuit: 'cirq.AbstractCircuit', *, context: Optional['cirq.TransformerContext'] = None + ) -> 'cirq.AbstractCircuit': + return func(circuit, context=context, **kwargs) # type: ignore + + return transformer + + +class CompilationTargetGateset(ops.Gateset, metaclass=abc.ABCMeta): + """Abstract base class to create gatesets that can be used as targets for compilation. + + An instance of this type can be passed to transformers like `cirq.convert_to_target_gateset`, + which can transform any given circuit to contain gates accepted by this gateset. + """ + + @property + @abc.abstractmethod + def num_qubits(self) -> int: + """Maximum number of qubits on which a gate from this gateset can act upon.""" + + @abc.abstractmethod + def decompose_to_target_gateset(self, op: 'cirq.Operation', moment_idx: int) -> DecomposeResult: + """Method to rewrite the given operation using gates from this gateset. + + Args: + op: `cirq.Operation` to be rewritten using gates from this gateset. + moment_idx: Moment index where the given operation `op` occurs in a circuit. + + Returns: + - An equivalent `cirq.OP_TREE` implementing `op` using gates from this gateset. + - `None` or `NotImplemented` if does not know how to decompose `op`. + """ + + def _validate_operation(self, op: 'cirq.Operation') -> bool: + """Validates whether the given `cirq.Operation` is contained in this Gateset. + + Overrides the method on the base gateset class to ensure that operations which created + as intermediate compilation results are not accepted. + For example, if a preprocessing `merge_k_qubit_unitaries` transformer merges connected + component of 2q unitaries, it should not be accepted in the gateset so that so we can + use `decompose_to_target_gateset` to determine how to expand this component. + + Args: + op: The `cirq.Operation` instance to check containment for. + + Returns: + Whether the given operation is contained in the gateset. + """ + if self._intermediate_result_tag in op.tags: + return False + return super()._validate_operation(op) + + @property + def _intermediate_result_tag(self) -> Hashable: + """A tag used to identify intermediate compilation results.""" + return "_default_merged_k_qubit_unitaries" + + @property + def preprocess_transformers(self) -> List['cirq.TRANSFORMER']: + """List of transformers which should be run before decomposing individual operations.""" + return [ + _create_transformer_with_kwargs( + expand_composite.expand_composite, + no_decomp=lambda op: protocols.num_qubits(op) <= self.num_qubits, + ), + _create_transformer_with_kwargs( + merge_k_qubit_gates.merge_k_qubit_unitaries, + k=self.num_qubits, + rewriter=lambda op: op.with_tags(self._intermediate_result_tag), + ), + ] + + @property + def postprocess_transformers(self) -> List['cirq.TRANSFORMER']: + """List of transformers which should be run after decomposing individual operations.""" + return [ + merge_single_qubit_gates.merge_single_qubit_moments_to_phxz, + drop_negligible.drop_negligible_operations, + drop_empty_moments.drop_empty_moments, + ] diff --git a/cirq/transformers/target_gatesets/compilation_target_gateset_test.py b/cirq/transformers/target_gatesets/compilation_target_gateset_test.py new file mode 100644 index 00000000000..e2c56bb9ba2 --- /dev/null +++ b/cirq/transformers/target_gatesets/compilation_target_gateset_test.py @@ -0,0 +1,56 @@ +# Copyright 2022 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List +import cirq +from cirq.protocols.decompose_protocol import DecomposeResult + + +def test_compilation_target_gateset(): + class DummyTargetGateset(cirq.CompilationTargetGateset): + def __init__(self): + super().__init__(cirq.AnyUnitaryGateFamily(2)) + + @property + def num_qubits(self) -> int: + return 2 + + def decompose_to_target_gateset(self, op: 'cirq.Operation', _) -> DecomposeResult: + return op if cirq.num_qubits(op) == 2 and cirq.has_unitary(op) else NotImplemented + + @property + def preprocess_transformers(self) -> List[cirq.TRANSFORMER]: + return [] + + gateset = DummyTargetGateset() + + q = cirq.LineQubit.range(2) + assert cirq.X(q[0]) not in gateset + assert cirq.CNOT(*q) in gateset + assert cirq.measure(*q) not in gateset + circuit_op = cirq.CircuitOperation(cirq.FrozenCircuit(cirq.CZ(*q), cirq.CNOT(*q), cirq.CZ(*q))) + assert circuit_op in gateset + assert circuit_op.with_tags(gateset._intermediate_result_tag) not in gateset + + assert gateset.num_qubits == 2 + assert gateset.decompose_to_target_gateset(cirq.X(q[0]), 1) is NotImplemented + assert gateset.decompose_to_target_gateset(cirq.CNOT(*q), 2) == cirq.CNOT(*q) + assert gateset.decompose_to_target_gateset(cirq.measure(*q), 3) is NotImplemented + + assert gateset.preprocess_transformers == [] + assert gateset.postprocess_transformers == [ + cirq.merge_single_qubit_moments_to_phxz, + cirq.drop_negligible_operations, + cirq.drop_empty_moments, + ]