Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for deep=True flag in cg.optimized_for_sycamore and cg.SycamoreTargetGateset transformers #5126

Merged
merged 3 commits into from
Mar 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def optimized_for_sycamore(
copy = cirq.optimize_for_target_gateset(
circuit,
gateset=_TARGET_GATESETS[optimizer_type](tolerance, tabulation),
context=cirq.TransformerContext(deep=True),
)
copy = cirq.merge_single_qubit_gates_to_phxz(copy, atol=tolerance)
copy = cirq.eject_phased_paulis(copy, atol=tolerance)
Expand Down
27 changes: 27 additions & 0 deletions cirq-google/cirq_google/optimizers/optimize_for_sycamore_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,30 @@ def test_assert_new_device_deprecated():
_ = cg.optimized_for_sycamore(
circuit0, optimizer_type='sqrt_iswap', new_device=TestDevice()
)


@pytest.mark.parametrize(
'optimizer_type, two_qubit_gate_type',
[('sycamore', cg.SycamoreGate), ('sqrt_iswap', cirq.ISwapPowGate), ('xmon', cirq.CZPowGate)],
)
def test_circuit_operation_conversion(optimizer_type, two_qubit_gate_type):
q0, q1 = cirq.LineQubit.range(2)
subcircuit = cirq.FrozenCircuit(cirq.X(q0), cirq.SWAP(q0, q1))
circuit = cirq.Circuit(cirq.CircuitOperation(subcircuit))
converted_circuit = cg.optimized_for_sycamore(circuit, optimizer_type=optimizer_type)
# Verify that the CircuitOperation was preserved.
ops = list(converted_circuit.all_operations())
assert isinstance(ops[0], cirq.CircuitOperation)
# Verify that the contents of the CircuitOperation were optimized.
converted_subcircuit = cg.optimized_for_sycamore(
subcircuit.unfreeze(), optimizer_type=optimizer_type
)
assert len(
[*converted_subcircuit.findall_operations_with_gate_type(two_qubit_gate_type)]
) == len([*ops[0].circuit.findall_operations_with_gate_type(two_qubit_gate_type)])
cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(
ops[0].circuit, converted_subcircuit, atol=1e-8
)
cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(
circuit, converted_circuit, atol=1e-8
)
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Target gateset used for compiling circuits to Sycamore + 1-q rotations + measurement gates."""

import itertools
from typing import List, Optional, Sequence
from typing import cast, List, Optional, Sequence

import cirq
from cirq.protocols.decompose_protocol import DecomposeResult
Expand All @@ -33,6 +33,7 @@ def merge_swap_rzz_and_2q_unitaries(
context: Optional['cirq.TransformerContext'] = None,
merged_swap_rzz_tag: str = "_merged_swap_rzz",
merged_2q_component_tag: str = "_merged_2q_unitaries",
intermediate_result_tag: Optional[str] = None,
) -> 'cirq.Circuit':
"""Merges 2-qubit connected components and adjacent `cirq.SWAP` and `cirq.ZZPowGate` gates.

Expand All @@ -50,6 +51,8 @@ def merge_swap_rzz_and_2q_unitaries(
`cirq.SWAP` and `cirq.ZZPowGate`s.
merged_2q_component_tag: Tag to apply on newly introduced circuit operations wrapping
connected components of 1 and 2 qubit unitaries.
intermediate_result_tag: If specified, the tag is added to newly introduced both the newly
introduced circuit operations encapsulating swap_rzz or 2q connected component.

Returns:
Copy of the transformed input circuit.
Expand All @@ -71,19 +74,34 @@ def merge_func_swap_rzz(
return False

tags_to_ignore = context.tags_to_ignore if context else ()
deep = context.deep if context else False
circuit = cirq.merge_operations_to_circuit_op(
circuit,
merge_func_swap_rzz,
tags_to_ignore=tags_to_ignore,
merged_circuit_op_tag=merged_swap_rzz_tag,
deep=deep,
)

return cirq.merge_k_qubit_unitaries_to_circuit_op(
circuit = cirq.merge_k_qubit_unitaries_to_circuit_op(
circuit,
k=2,
tags_to_ignore=tags_to_ignore + (merged_swap_rzz_tag,),
tags_to_ignore=tuple(tags_to_ignore) + (merged_swap_rzz_tag,),
merged_circuit_op_tag=merged_2q_component_tag,
).unfreeze(copy=False)
deep=deep,
)

if intermediate_result_tag is not None:
merged_cop_tags = {merged_swap_rzz_tag, merged_2q_component_tag}
circuit = cirq.map_operations(
circuit,
map_func=lambda op, _: op
if merged_cop_tags.isdisjoint(op.tags)
else op.with_tags(cast(str, intermediate_result_tag)),
tags_to_ignore=tags_to_ignore,
deep=True,
)
return circuit.unfreeze(copy=False)


class SycamoreTargetGateset(cirq.TwoQubitCompilationTargetGateset):
Expand Down Expand Up @@ -122,7 +140,10 @@ def preprocess_transformers(self) -> List[cirq.TRANSFORMER]:
cirq.expand_composite,
no_decomp=lambda op: cirq.num_qubits(op) <= self.num_qubits,
),
merge_swap_rzz_and_2q_unitaries,
_create_transformer_with_kwargs(
merge_swap_rzz_and_2q_unitaries,
intermediate_result_tag=self._intermediate_result_tag,
),
]

def _decompose_two_qubit_operation(self, op: cirq.Operation, _) -> DecomposeResult:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,55 @@ def test_merge_swap_rzz_and_2q_unitaries_raises_if_tags_sames():
)


def test_merge_swap_rzz_and_2q_unitaries_deep():
q = cirq.LineQubit.range(3)
swap_rzz = cirq.FrozenCircuit(cirq.SWAP(*q[:2]), cirq.ZZ(*q[:2]) ** 0.5)
rzz_swap = cirq.FrozenCircuit(cirq.ZZ(*q[1:]) ** 0.25, cirq.SWAP(*q[1:]))
x_cnot_x = cirq.FrozenCircuit(cirq.X(q[0]), cirq.CNOT(*q[:2]), cirq.X(q[0]))
x_cz_x = cirq.FrozenCircuit(cirq.X(q[2]), cirq.CZ(*q[1:]), cirq.X(q[2]))
c_orig = cirq.Circuit(
cirq.CircuitOperation(swap_rzz).repeat(3).with_tags("ignore"),
cirq.CircuitOperation(rzz_swap).repeat(5).with_tags("preserve_tag"),
cirq.CircuitOperation(x_cnot_x).repeat(7).with_tags("ignore"),
cirq.CircuitOperation(x_cz_x).repeat(9).with_tags("preserve_tag"),
cirq.CircuitOperation(
cirq.FrozenCircuit(
[swap_rzz, rzz_swap, x_cnot_x, x_cz_x],
cirq.Moment(cirq.Y(qq).with_tags("ignore") for qq in q),
)
),
)
t_swap_rzz = "_merged_swap_rzz_tag"
t_2q_cmp = "_merged_2q_unitaries_component"
t_all = "_intermediate_result_tag_apply_to_all"

def _wrap_cop(c: cirq.FrozenCircuit, *tags) -> cirq.FrozenCircuit:
return cirq.FrozenCircuit(cirq.CircuitOperation(c).with_tags(*tags, t_all))

c_expected = cirq.Circuit(
cirq.CircuitOperation(swap_rzz).repeat(3).with_tags("ignore"),
cirq.CircuitOperation(_wrap_cop(rzz_swap, t_swap_rzz)).repeat(5).with_tags("preserve_tag"),
cirq.CircuitOperation(x_cnot_x).repeat(7).with_tags("ignore"),
cirq.CircuitOperation(_wrap_cop(x_cz_x, t_2q_cmp)).repeat(9).with_tags("preserve_tag"),
cirq.CircuitOperation(
cirq.FrozenCircuit(
[_wrap_cop(swap_rzz, t_swap_rzz), _wrap_cop(rzz_swap, t_swap_rzz)],
[_wrap_cop(x_cnot_x, t_2q_cmp), _wrap_cop(x_cz_x, t_2q_cmp)],
cirq.Moment(cirq.Y(qq).with_tags("ignore") for qq in q),
)
),
)
context = cirq.TransformerContext(tags_to_ignore=["ignore"], deep=True)
c_new = sycamore_gateset.merge_swap_rzz_and_2q_unitaries(
c_orig,
context=context,
merged_swap_rzz_tag=t_swap_rzz,
merged_2q_component_tag=t_2q_cmp,
intermediate_result_tag=t_all,
)
cirq.testing.assert_same_circuits(cirq.drop_empty_moments(c_new, context=context), c_expected)


def test_sycamore_gateset_compiles_swap_zz():
qubits = cirq.LineQubit.range(3)

Expand Down