diff --git a/cirq-google/cirq_google/optimizers/optimize_for_sycamore.py b/cirq-google/cirq_google/optimizers/optimize_for_sycamore.py index cff37ae0843..b641e74930f 100644 --- a/cirq-google/cirq_google/optimizers/optimize_for_sycamore.py +++ b/cirq-google/cirq_google/optimizers/optimize_for_sycamore.py @@ -105,6 +105,7 @@ def optimized_for_sycamore( copy = cirq.optimize_for_target_gateset( circuit, gateset=_TARGET_GATESETS[optimizer_type](tolerance, tabulation), + context=cirq.TransformerContext(deep=True), ) copy = cirq.merge_single_qubit_gates_to_phxz(copy, atol=tolerance) copy = cirq.eject_phased_paulis(copy, atol=tolerance) diff --git a/cirq-google/cirq_google/optimizers/optimize_for_sycamore_test.py b/cirq-google/cirq_google/optimizers/optimize_for_sycamore_test.py index 2422d0acc1f..1c4484f9a58 100644 --- a/cirq-google/cirq_google/optimizers/optimize_for_sycamore_test.py +++ b/cirq-google/cirq_google/optimizers/optimize_for_sycamore_test.py @@ -134,3 +134,30 @@ def test_assert_new_device_deprecated(): _ = cg.optimized_for_sycamore( circuit0, optimizer_type='sqrt_iswap', new_device=TestDevice() ) + + +@pytest.mark.parametrize( + 'optimizer_type, two_qubit_gate_type', + [('sycamore', cg.SycamoreGate), ('sqrt_iswap', cirq.ISwapPowGate), ('xmon', cirq.CZPowGate)], +) +def test_circuit_operation_conversion(optimizer_type, two_qubit_gate_type): + q0, q1 = cirq.LineQubit.range(2) + subcircuit = cirq.FrozenCircuit(cirq.X(q0), cirq.SWAP(q0, q1)) + circuit = cirq.Circuit(cirq.CircuitOperation(subcircuit)) + converted_circuit = cg.optimized_for_sycamore(circuit, optimizer_type=optimizer_type) + # Verify that the CircuitOperation was preserved. + ops = list(converted_circuit.all_operations()) + assert isinstance(ops[0], cirq.CircuitOperation) + # Verify that the contents of the CircuitOperation were optimized. + converted_subcircuit = cg.optimized_for_sycamore( + subcircuit.unfreeze(), optimizer_type=optimizer_type + ) + assert len( + [*converted_subcircuit.findall_operations_with_gate_type(two_qubit_gate_type)] + ) == len([*ops[0].circuit.findall_operations_with_gate_type(two_qubit_gate_type)]) + cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent( + ops[0].circuit, converted_subcircuit, atol=1e-8 + ) + cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent( + circuit, converted_circuit, atol=1e-8 + ) diff --git a/cirq-google/cirq_google/transformers/target_gatesets/sycamore_gateset.py b/cirq-google/cirq_google/transformers/target_gatesets/sycamore_gateset.py index eb2fb8b5a94..eb262b702ea 100644 --- a/cirq-google/cirq_google/transformers/target_gatesets/sycamore_gateset.py +++ b/cirq-google/cirq_google/transformers/target_gatesets/sycamore_gateset.py @@ -15,7 +15,7 @@ """Target gateset used for compiling circuits to Sycamore + 1-q rotations + measurement gates.""" import itertools -from typing import List, Optional, Sequence +from typing import cast, List, Optional, Sequence import cirq from cirq.protocols.decompose_protocol import DecomposeResult @@ -33,6 +33,7 @@ def merge_swap_rzz_and_2q_unitaries( context: Optional['cirq.TransformerContext'] = None, merged_swap_rzz_tag: str = "_merged_swap_rzz", merged_2q_component_tag: str = "_merged_2q_unitaries", + intermediate_result_tag: Optional[str] = None, ) -> 'cirq.Circuit': """Merges 2-qubit connected components and adjacent `cirq.SWAP` and `cirq.ZZPowGate` gates. @@ -50,6 +51,8 @@ def merge_swap_rzz_and_2q_unitaries( `cirq.SWAP` and `cirq.ZZPowGate`s. merged_2q_component_tag: Tag to apply on newly introduced circuit operations wrapping connected components of 1 and 2 qubit unitaries. + intermediate_result_tag: If specified, the tag is added to newly introduced both the newly + introduced circuit operations encapsulating swap_rzz or 2q connected component. Returns: Copy of the transformed input circuit. @@ -71,19 +74,34 @@ def merge_func_swap_rzz( return False tags_to_ignore = context.tags_to_ignore if context else () + deep = context.deep if context else False circuit = cirq.merge_operations_to_circuit_op( circuit, merge_func_swap_rzz, tags_to_ignore=tags_to_ignore, merged_circuit_op_tag=merged_swap_rzz_tag, + deep=deep, ) - return cirq.merge_k_qubit_unitaries_to_circuit_op( + circuit = cirq.merge_k_qubit_unitaries_to_circuit_op( circuit, k=2, - tags_to_ignore=tags_to_ignore + (merged_swap_rzz_tag,), + tags_to_ignore=tuple(tags_to_ignore) + (merged_swap_rzz_tag,), merged_circuit_op_tag=merged_2q_component_tag, - ).unfreeze(copy=False) + deep=deep, + ) + + if intermediate_result_tag is not None: + merged_cop_tags = {merged_swap_rzz_tag, merged_2q_component_tag} + circuit = cirq.map_operations( + circuit, + map_func=lambda op, _: op + if merged_cop_tags.isdisjoint(op.tags) + else op.with_tags(cast(str, intermediate_result_tag)), + tags_to_ignore=tags_to_ignore, + deep=True, + ) + return circuit.unfreeze(copy=False) class SycamoreTargetGateset(cirq.TwoQubitCompilationTargetGateset): @@ -122,7 +140,10 @@ def preprocess_transformers(self) -> List[cirq.TRANSFORMER]: cirq.expand_composite, no_decomp=lambda op: cirq.num_qubits(op) <= self.num_qubits, ), - merge_swap_rzz_and_2q_unitaries, + _create_transformer_with_kwargs( + merge_swap_rzz_and_2q_unitaries, + intermediate_result_tag=self._intermediate_result_tag, + ), ] def _decompose_two_qubit_operation(self, op: cirq.Operation, _) -> DecomposeResult: diff --git a/cirq-google/cirq_google/transformers/target_gatesets/sycamore_gateset_test.py b/cirq-google/cirq_google/transformers/target_gatesets/sycamore_gateset_test.py index 4e57b7f24ea..c9d65d5e289 100644 --- a/cirq-google/cirq_google/transformers/target_gatesets/sycamore_gateset_test.py +++ b/cirq-google/cirq_google/transformers/target_gatesets/sycamore_gateset_test.py @@ -97,6 +97,55 @@ def test_merge_swap_rzz_and_2q_unitaries_raises_if_tags_sames(): ) +def test_merge_swap_rzz_and_2q_unitaries_deep(): + q = cirq.LineQubit.range(3) + swap_rzz = cirq.FrozenCircuit(cirq.SWAP(*q[:2]), cirq.ZZ(*q[:2]) ** 0.5) + rzz_swap = cirq.FrozenCircuit(cirq.ZZ(*q[1:]) ** 0.25, cirq.SWAP(*q[1:])) + x_cnot_x = cirq.FrozenCircuit(cirq.X(q[0]), cirq.CNOT(*q[:2]), cirq.X(q[0])) + x_cz_x = cirq.FrozenCircuit(cirq.X(q[2]), cirq.CZ(*q[1:]), cirq.X(q[2])) + c_orig = cirq.Circuit( + cirq.CircuitOperation(swap_rzz).repeat(3).with_tags("ignore"), + cirq.CircuitOperation(rzz_swap).repeat(5).with_tags("preserve_tag"), + cirq.CircuitOperation(x_cnot_x).repeat(7).with_tags("ignore"), + cirq.CircuitOperation(x_cz_x).repeat(9).with_tags("preserve_tag"), + cirq.CircuitOperation( + cirq.FrozenCircuit( + [swap_rzz, rzz_swap, x_cnot_x, x_cz_x], + cirq.Moment(cirq.Y(qq).with_tags("ignore") for qq in q), + ) + ), + ) + t_swap_rzz = "_merged_swap_rzz_tag" + t_2q_cmp = "_merged_2q_unitaries_component" + t_all = "_intermediate_result_tag_apply_to_all" + + def _wrap_cop(c: cirq.FrozenCircuit, *tags) -> cirq.FrozenCircuit: + return cirq.FrozenCircuit(cirq.CircuitOperation(c).with_tags(*tags, t_all)) + + c_expected = cirq.Circuit( + cirq.CircuitOperation(swap_rzz).repeat(3).with_tags("ignore"), + cirq.CircuitOperation(_wrap_cop(rzz_swap, t_swap_rzz)).repeat(5).with_tags("preserve_tag"), + cirq.CircuitOperation(x_cnot_x).repeat(7).with_tags("ignore"), + cirq.CircuitOperation(_wrap_cop(x_cz_x, t_2q_cmp)).repeat(9).with_tags("preserve_tag"), + cirq.CircuitOperation( + cirq.FrozenCircuit( + [_wrap_cop(swap_rzz, t_swap_rzz), _wrap_cop(rzz_swap, t_swap_rzz)], + [_wrap_cop(x_cnot_x, t_2q_cmp), _wrap_cop(x_cz_x, t_2q_cmp)], + cirq.Moment(cirq.Y(qq).with_tags("ignore") for qq in q), + ) + ), + ) + context = cirq.TransformerContext(tags_to_ignore=["ignore"], deep=True) + c_new = sycamore_gateset.merge_swap_rzz_and_2q_unitaries( + c_orig, + context=context, + merged_swap_rzz_tag=t_swap_rzz, + merged_2q_component_tag=t_2q_cmp, + intermediate_result_tag=t_all, + ) + cirq.testing.assert_same_circuits(cirq.drop_empty_moments(c_new, context=context), c_expected) + + def test_sycamore_gateset_compiles_swap_zz(): qubits = cirq.LineQubit.range(3)