Skip to content

Commit

Permalink
Add support for deep=True flag in cg.optimized_for_sycamore and `cg…
Browse files Browse the repository at this point in the history
….SycamoreTargetGateset` transformers (#5126)

- Adds support for deep=True flag in `sycamore_gateset.merge_swap_rzz_and_2q_unitaries` transformer
- Updates `cg.optimized_for_sycamore` to call `cirq.optimize_for_target_gateset` with `deep=True` by default, such that the method preserves circuit structure by default (which corresponds to its old behavior). 
- Fixes #5039
  • Loading branch information
tanujkhattar authored Mar 22, 2022
1 parent cdd3f8c commit 45624ff
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def optimized_for_sycamore(
copy = cirq.optimize_for_target_gateset(
circuit,
gateset=_TARGET_GATESETS[optimizer_type](tolerance, tabulation),
context=cirq.TransformerContext(deep=True),
)
copy = cirq.merge_single_qubit_gates_to_phxz(copy, atol=tolerance)
copy = cirq.eject_phased_paulis(copy, atol=tolerance)
Expand Down
27 changes: 27 additions & 0 deletions cirq-google/cirq_google/optimizers/optimize_for_sycamore_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,30 @@ def test_assert_new_device_deprecated():
_ = cg.optimized_for_sycamore(
circuit0, optimizer_type='sqrt_iswap', new_device=TestDevice()
)


@pytest.mark.parametrize(
'optimizer_type, two_qubit_gate_type',
[('sycamore', cg.SycamoreGate), ('sqrt_iswap', cirq.ISwapPowGate), ('xmon', cirq.CZPowGate)],
)
def test_circuit_operation_conversion(optimizer_type, two_qubit_gate_type):
q0, q1 = cirq.LineQubit.range(2)
subcircuit = cirq.FrozenCircuit(cirq.X(q0), cirq.SWAP(q0, q1))
circuit = cirq.Circuit(cirq.CircuitOperation(subcircuit))
converted_circuit = cg.optimized_for_sycamore(circuit, optimizer_type=optimizer_type)
# Verify that the CircuitOperation was preserved.
ops = list(converted_circuit.all_operations())
assert isinstance(ops[0], cirq.CircuitOperation)
# Verify that the contents of the CircuitOperation were optimized.
converted_subcircuit = cg.optimized_for_sycamore(
subcircuit.unfreeze(), optimizer_type=optimizer_type
)
assert len(
[*converted_subcircuit.findall_operations_with_gate_type(two_qubit_gate_type)]
) == len([*ops[0].circuit.findall_operations_with_gate_type(two_qubit_gate_type)])
cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(
ops[0].circuit, converted_subcircuit, atol=1e-8
)
cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(
circuit, converted_circuit, atol=1e-8
)
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Target gateset used for compiling circuits to Sycamore + 1-q rotations + measurement gates."""

import itertools
from typing import List, Optional, Sequence
from typing import cast, List, Optional, Sequence

import cirq
from cirq.protocols.decompose_protocol import DecomposeResult
Expand All @@ -33,6 +33,7 @@ def merge_swap_rzz_and_2q_unitaries(
context: Optional['cirq.TransformerContext'] = None,
merged_swap_rzz_tag: str = "_merged_swap_rzz",
merged_2q_component_tag: str = "_merged_2q_unitaries",
intermediate_result_tag: Optional[str] = None,
) -> 'cirq.Circuit':
"""Merges 2-qubit connected components and adjacent `cirq.SWAP` and `cirq.ZZPowGate` gates.
Expand All @@ -50,6 +51,8 @@ def merge_swap_rzz_and_2q_unitaries(
`cirq.SWAP` and `cirq.ZZPowGate`s.
merged_2q_component_tag: Tag to apply on newly introduced circuit operations wrapping
connected components of 1 and 2 qubit unitaries.
intermediate_result_tag: If specified, the tag is added to newly introduced both the newly
introduced circuit operations encapsulating swap_rzz or 2q connected component.
Returns:
Copy of the transformed input circuit.
Expand All @@ -71,19 +74,34 @@ def merge_func_swap_rzz(
return False

tags_to_ignore = context.tags_to_ignore if context else ()
deep = context.deep if context else False
circuit = cirq.merge_operations_to_circuit_op(
circuit,
merge_func_swap_rzz,
tags_to_ignore=tags_to_ignore,
merged_circuit_op_tag=merged_swap_rzz_tag,
deep=deep,
)

return cirq.merge_k_qubit_unitaries_to_circuit_op(
circuit = cirq.merge_k_qubit_unitaries_to_circuit_op(
circuit,
k=2,
tags_to_ignore=tags_to_ignore + (merged_swap_rzz_tag,),
tags_to_ignore=tuple(tags_to_ignore) + (merged_swap_rzz_tag,),
merged_circuit_op_tag=merged_2q_component_tag,
).unfreeze(copy=False)
deep=deep,
)

if intermediate_result_tag is not None:
merged_cop_tags = {merged_swap_rzz_tag, merged_2q_component_tag}
circuit = cirq.map_operations(
circuit,
map_func=lambda op, _: op
if merged_cop_tags.isdisjoint(op.tags)
else op.with_tags(cast(str, intermediate_result_tag)),
tags_to_ignore=tags_to_ignore,
deep=True,
)
return circuit.unfreeze(copy=False)


class SycamoreTargetGateset(cirq.TwoQubitCompilationTargetGateset):
Expand Down Expand Up @@ -122,7 +140,10 @@ def preprocess_transformers(self) -> List[cirq.TRANSFORMER]:
cirq.expand_composite,
no_decomp=lambda op: cirq.num_qubits(op) <= self.num_qubits,
),
merge_swap_rzz_and_2q_unitaries,
_create_transformer_with_kwargs(
merge_swap_rzz_and_2q_unitaries,
intermediate_result_tag=self._intermediate_result_tag,
),
]

def _decompose_two_qubit_operation(self, op: cirq.Operation, _) -> DecomposeResult:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,55 @@ def test_merge_swap_rzz_and_2q_unitaries_raises_if_tags_sames():
)


def test_merge_swap_rzz_and_2q_unitaries_deep():
q = cirq.LineQubit.range(3)
swap_rzz = cirq.FrozenCircuit(cirq.SWAP(*q[:2]), cirq.ZZ(*q[:2]) ** 0.5)
rzz_swap = cirq.FrozenCircuit(cirq.ZZ(*q[1:]) ** 0.25, cirq.SWAP(*q[1:]))
x_cnot_x = cirq.FrozenCircuit(cirq.X(q[0]), cirq.CNOT(*q[:2]), cirq.X(q[0]))
x_cz_x = cirq.FrozenCircuit(cirq.X(q[2]), cirq.CZ(*q[1:]), cirq.X(q[2]))
c_orig = cirq.Circuit(
cirq.CircuitOperation(swap_rzz).repeat(3).with_tags("ignore"),
cirq.CircuitOperation(rzz_swap).repeat(5).with_tags("preserve_tag"),
cirq.CircuitOperation(x_cnot_x).repeat(7).with_tags("ignore"),
cirq.CircuitOperation(x_cz_x).repeat(9).with_tags("preserve_tag"),
cirq.CircuitOperation(
cirq.FrozenCircuit(
[swap_rzz, rzz_swap, x_cnot_x, x_cz_x],
cirq.Moment(cirq.Y(qq).with_tags("ignore") for qq in q),
)
),
)
t_swap_rzz = "_merged_swap_rzz_tag"
t_2q_cmp = "_merged_2q_unitaries_component"
t_all = "_intermediate_result_tag_apply_to_all"

def _wrap_cop(c: cirq.FrozenCircuit, *tags) -> cirq.FrozenCircuit:
return cirq.FrozenCircuit(cirq.CircuitOperation(c).with_tags(*tags, t_all))

c_expected = cirq.Circuit(
cirq.CircuitOperation(swap_rzz).repeat(3).with_tags("ignore"),
cirq.CircuitOperation(_wrap_cop(rzz_swap, t_swap_rzz)).repeat(5).with_tags("preserve_tag"),
cirq.CircuitOperation(x_cnot_x).repeat(7).with_tags("ignore"),
cirq.CircuitOperation(_wrap_cop(x_cz_x, t_2q_cmp)).repeat(9).with_tags("preserve_tag"),
cirq.CircuitOperation(
cirq.FrozenCircuit(
[_wrap_cop(swap_rzz, t_swap_rzz), _wrap_cop(rzz_swap, t_swap_rzz)],
[_wrap_cop(x_cnot_x, t_2q_cmp), _wrap_cop(x_cz_x, t_2q_cmp)],
cirq.Moment(cirq.Y(qq).with_tags("ignore") for qq in q),
)
),
)
context = cirq.TransformerContext(tags_to_ignore=["ignore"], deep=True)
c_new = sycamore_gateset.merge_swap_rzz_and_2q_unitaries(
c_orig,
context=context,
merged_swap_rzz_tag=t_swap_rzz,
merged_2q_component_tag=t_2q_cmp,
intermediate_result_tag=t_all,
)
cirq.testing.assert_same_circuits(cirq.drop_empty_moments(c_new, context=context), c_expected)


def test_sycamore_gateset_compiles_swap_zz():
qubits = cirq.LineQubit.range(3)

Expand Down

0 comments on commit 45624ff

Please sign in to comment.