From 5e9bef6d8dc5fd47d225256661ab14999ee9ae15 Mon Sep 17 00:00:00 2001 From: Tanuj Khattar Date: Mon, 22 Nov 2021 18:26:06 -0800 Subject: [PATCH] Add `map_operations` and `map_moments` transformer primitives (#4692) * Add map_operations and map_moments transformer primitives * Improve circuit type conversion efficiency * Add return entries to docstrings --- cirq-core/cirq/__init__.py | 6 + cirq-core/cirq/optimizers/__init__.py | 10 + .../cirq/optimizers/transformer_primitives.py | 215 ++++++++++++++++++ .../optimizers/transformer_primitives_test.py | 200 ++++++++++++++++ 4 files changed, 431 insertions(+) create mode 100644 cirq-core/cirq/optimizers/transformer_primitives.py create mode 100644 cirq-core/cirq/optimizers/transformer_primitives_test.py diff --git a/cirq-core/cirq/__init__.py b/cirq-core/cirq/__init__.py index a0f5f54b19a..062423bf431 100644 --- a/cirq-core/cirq/__init__.py +++ b/cirq-core/cirq/__init__.py @@ -336,6 +336,9 @@ EjectZ, ExpandComposite, is_negligible_turn, + map_moments, + map_operations, + map_operations_and_unroll, merge_single_qubit_gates_into_phased_x_z, merge_single_qubit_gates_into_phxz, MergeInteractions, @@ -352,6 +355,9 @@ two_qubit_matrix_to_diagonal_and_operations, two_qubit_matrix_to_sqrt_iswap_operations, three_qubit_matrix_to_operations, + unroll_circuit_op, + unroll_circuit_op_greedy_earliest, + unroll_circuit_op_greedy_frontier, ) from cirq.qis import ( diff --git a/cirq-core/cirq/optimizers/__init__.py b/cirq-core/cirq/optimizers/__init__.py index 9a1ab286dca..a25d7c29734 100644 --- a/cirq-core/cirq/optimizers/__init__.py +++ b/cirq-core/cirq/optimizers/__init__.py @@ -92,6 +92,16 @@ three_qubit_matrix_to_operations, ) +from cirq.optimizers.transformer_primitives import ( + map_moments, + map_operations, + map_operations_and_unroll, + unroll_circuit_op, + unroll_circuit_op_greedy_earliest, + unroll_circuit_op_greedy_frontier, +) + + from cirq.optimizers.two_qubit_decompositions import ( two_qubit_matrix_to_operations, two_qubit_matrix_to_diagonal_and_operations, diff --git a/cirq-core/cirq/optimizers/transformer_primitives.py b/cirq-core/cirq/optimizers/transformer_primitives.py new file mode 100644 index 00000000000..56bc2fe960c --- /dev/null +++ b/cirq-core/cirq/optimizers/transformer_primitives.py @@ -0,0 +1,215 @@ +# Copyright 2021 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Defines primitives for common transformer patterns.""" + +from collections import defaultdict +from typing import ( + cast, + Callable, + Dict, + Hashable, + Optional, + Sequence, + TYPE_CHECKING, +) + +from cirq import circuits, ops, protocols +from cirq.circuits.circuit import CIRCUIT_TYPE + +if TYPE_CHECKING: + import cirq + +MAPPED_CIRCUIT_OP_TAG = '' + + +def _to_target_circuit_type( + circuit: circuits.AbstractCircuit, target_circuit: CIRCUIT_TYPE +) -> CIRCUIT_TYPE: + return cast( + CIRCUIT_TYPE, + circuit.unfreeze(copy=False) + if isinstance(target_circuit, circuits.Circuit) + else circuit.freeze(), + ) + + +def _create_target_circuit_type(ops: ops.OP_TREE, target_circuit: CIRCUIT_TYPE) -> CIRCUIT_TYPE: + return cast( + CIRCUIT_TYPE, + circuits.Circuit(ops) + if isinstance(target_circuit, circuits.Circuit) + else circuits.FrozenCircuit(ops), + ) + + +def map_moments( + circuit: CIRCUIT_TYPE, + map_func: Callable[[ops.Moment, int], Sequence[ops.Moment]], +) -> CIRCUIT_TYPE: + """Applies local transformation on moments, by calling `map_func(moment)` for each moment. + + Args: + circuit: Input circuit to apply the transformations on. The input circuit is not mutated. + map_func: Mapping function from (cirq.Moment, moment_index) to a sequence of moments. + + Returns: + Copy of input circuit with mapped moments. + """ + return _create_target_circuit_type( + (map_func(circuit[i], i) for i in range(len(circuit))), circuit + ) + + +def map_operations( + circuit: CIRCUIT_TYPE, + map_func: Callable[[ops.Operation, int], ops.OP_TREE], +) -> CIRCUIT_TYPE: + """Applies local transformations on operations, by calling `map_func(op)` for each op. + + Note that the function assumes `issubset(qubit_set(map_func(op)), op.qubits)` is True. + + Args: + circuit: Input circuit to apply the transformations on. The input circuit is not mutated. + map_func: Mapping function from (cirq.Operation, moment_index) to a cirq.OP_TREE. If the + resulting optree spans more than 1 moment, it's inserted in-place in the same moment as + `cirq.CircuitOperation(cirq.FrozenCircuit(op_tree)).with_tags(MAPPED_CIRCUIT_OP_TAG)` + to preserve moment structure. Utility methods like `cirq.unroll_circuit_op` can + subsequently be used to unroll the mapped circuit operation. + + Raises: + ValueError if `issubset(qubit_set(map_func(op)), op.qubits) is False`. + + Returns: + Copy of input circuit with mapped operations (wrapped in a tagged CircuitOperation). + """ + + def apply_map(op: ops.Operation, idx: int) -> ops.OP_TREE: + c = circuits.FrozenCircuit(map_func(op, idx)) + if not c.all_qubits().issubset(op.qubits): + raise ValueError( + f"Mapped operations {c.all_operations()} should act on a subset " + f"of qubits of the original operation {op}" + ) + if len(c) == 1: + # All operations act in the same moment; so we don't need to wrap them in a circuit_op. + return c[0].operations + circuit_op = circuits.CircuitOperation(c).with_tags(MAPPED_CIRCUIT_OP_TAG) + return circuit_op + + return map_moments(circuit, lambda m, i: [ops.Moment(apply_map(op, i) for op in m.operations)]) + + +def map_operations_and_unroll( + circuit: CIRCUIT_TYPE, + map_func: Callable[[ops.Operation, int], ops.OP_TREE], +) -> CIRCUIT_TYPE: + """Applies local transformations via `cirq.map_operations` & unrolls intermediate circuit ops. + + See `cirq.map_operations` and `cirq.unroll_circuit_op` for more details. + + Args: + circuit: Input circuit to apply the transformations on. The input circuit is not mutated. + map_func: Mapping function from (cirq.Operation, moment_index) to a cirq.OP_TREE. + + Returns: + Copy of input circuit with mapped operations, unrolled in a moment preserving way. + """ + return unroll_circuit_op(map_operations(circuit, map_func)) + + +def _check_circuit_op(op, tags_to_check: Optional[Sequence[Hashable]]): + return isinstance(op.untagged, circuits.CircuitOperation) and ( + tags_to_check is None or any(tag in op.tags for tag in tags_to_check) + ) + + +def unroll_circuit_op( + circuit: CIRCUIT_TYPE, *, tags_to_check: Optional[Sequence[Hashable]] = (MAPPED_CIRCUIT_OP_TAG,) +) -> CIRCUIT_TYPE: + """Unrolls (tagged) `cirq.CircuitOperation`s while preserving the moment structure. + + Each moment containing a matching circuit operation is expanded into a list of moments with the + unrolled operations, hence preserving the original moment structure. + + Args: + circuit: Input circuit to apply the transformations on. The input circuit is not mutated. + tags_to_check: If specified, only circuit operations tagged with one of the `tags_to_check` + are unrolled. + + Returns: + Copy of input circuit with (Tagged) CircuitOperation's expanded in a moment preserving way. + """ + + def map_func(m: ops.Moment, _: int): + to_zip = [ + cast(circuits.CircuitOperation, op.untagged).mapped_circuit() + if _check_circuit_op(op, tags_to_check) + else circuits.Circuit(op) + for op in m + ] + return circuits.Circuit.zip(*to_zip).moments + + return map_moments(circuit, map_func) + + +def unroll_circuit_op_greedy_earliest( + circuit: CIRCUIT_TYPE, *, tags_to_check=(MAPPED_CIRCUIT_OP_TAG,) +) -> CIRCUIT_TYPE: + """Unrolls (tagged) `cirq.CircuitOperation`s by inserting operations using EARLIEST strategy. + + Each matching `cirq.CircuitOperation` is replaced by inserting underlying operations using the + `cirq.InsertStrategy.EARLIEST` strategy. The greedy approach attempts to minimize circuit depth + of the resulting circuit. + + Args: + circuit: Input circuit to apply the transformations on. The input circuit is not mutated. + tags_to_check: If specified, only circuit operations tagged with one of the `tags_to_check` + are unrolled. + + Returns: + Copy of input circuit with (Tagged) CircuitOperation's expanded using EARLIEST strategy. + """ + batch_removals = [*circuit.findall_operations(lambda op: _check_circuit_op(op, tags_to_check))] + batch_inserts = [(i, protocols.decompose_once(op)) for i, op in batch_removals] + unrolled_circuit = circuit.unfreeze(copy=True) + unrolled_circuit.batch_remove(batch_removals) + unrolled_circuit.batch_insert(batch_inserts) + return _to_target_circuit_type(unrolled_circuit, circuit) + + +def unroll_circuit_op_greedy_frontier( + circuit: CIRCUIT_TYPE, *, tags_to_check=(MAPPED_CIRCUIT_OP_TAG,) +) -> CIRCUIT_TYPE: + """Unrolls (tagged) `cirq.CircuitOperation`s by inserting operations inline at qubit frontier. + + Each matching `cirq.CircuitOperation` is replaced by inserting underlying operations using the + `circuit.insert_at_frontier` method. The greedy approach attempts to reuse any available space + in existing moments on the right of circuit_op before inserting new moments. + + Args: + circuit: Input circuit to apply the transformations on. The input circuit is not mutated. + tags_to_check: If specified, only circuit operations tagged with one of the `tags_to_check` + are unrolled. + + Returns: + Copy of input circuit with (Tagged) CircuitOperation's expanded inline at qubit frontier. + """ + unrolled_circuit = circuit.unfreeze(copy=True) + frontier: Dict['cirq.Qid', int] = defaultdict(lambda: 0) + for idx, op in circuit.findall_operations(lambda op: _check_circuit_op(op, tags_to_check)): + idx = max(idx, max(frontier[q] for q in op.qubits)) + unrolled_circuit.clear_operations_touching(op.qubits, [idx]) + frontier = unrolled_circuit.insert_at_frontier(protocols.decompose_once(op), idx, frontier) + return _to_target_circuit_type(unrolled_circuit, circuit) diff --git a/cirq-core/cirq/optimizers/transformer_primitives_test.py b/cirq-core/cirq/optimizers/transformer_primitives_test.py new file mode 100644 index 00000000000..11f260c5c2b --- /dev/null +++ b/cirq-core/cirq/optimizers/transformer_primitives_test.py @@ -0,0 +1,200 @@ +# Copyright 2021 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import cirq +from cirq.optimizers.transformer_primitives import MAPPED_CIRCUIT_OP_TAG + + +def test_map_operations_can_write_new_gates_inline(): + x = cirq.NamedQubit('x') + y = cirq.NamedQubit('y') + z = cirq.NamedQubit('z') + c = cirq.Circuit( + cirq.CZ(x, y), + cirq.Y(x), + cirq.Z(x), + cirq.X(y), + cirq.CNOT(y, z), + cirq.Z(y), + cirq.Z(x), + cirq.CNOT(y, z), + cirq.CNOT(z, y), + ) + cirq.testing.assert_has_diagram( + c, + ''' +x: ───@───Y───Z───Z─────────── + │ +y: ───@───X───@───Z───@───X─── + │ │ │ +z: ───────────X───────X───@─── +''', + ) + expected_diagram = ''' +x: ───X───X───X───X─────────── + +y: ───X───X───X───X───X───X─── + +z: ───────────X───────X───X─── +''' + cirq.testing.assert_has_diagram( + cirq.map_operations(c, lambda op, _: cirq.X.on_each(*op.qubits)), expected_diagram + ) + cirq.testing.assert_has_diagram( + cirq.map_operations_and_unroll(c, lambda op, _: cirq.X.on_each(*op.qubits)), + expected_diagram, + ) + + +def test_map_operations_does_not_insert_too_many_moments(): + q = cirq.LineQubit.range(5) + c_orig = cirq.Circuit( + cirq.CX(q[0], q[1]), + cirq.CX(q[3], q[2]), + cirq.CX(q[3], q[4]), + ) + + def map_func(op: cirq.Operation, _: int) -> cirq.OP_TREE: + if op.gate == cirq.CX: + yield cirq.Z.on_each(*op.qubits) + yield cirq.CX(*op.qubits) + yield cirq.Z.on_each(*op.qubits) + return op + + cirq.testing.assert_has_diagram( + c_orig, + ''' +0: ───@─────── + │ +1: ───X─────── + +2: ───X─────── + │ +3: ───@───@─── + │ +4: ───────X─── +''', + ) + + c_mapped = cirq.map_operations(c_orig, map_func) + circuit_op = cirq.CircuitOperation( + cirq.FrozenCircuit( + cirq.Z.on_each(q[0], q[1]), cirq.CNOT(q[0], q[1]), cirq.Z.on_each(q[0], q[1]) + ) + ) + c_expected = cirq.Circuit( + circuit_op.with_qubits(q[0], q[1]).mapped_op().with_tags(''), + circuit_op.with_qubits(q[3], q[2]).mapped_op().with_tags(''), + circuit_op.with_qubits(q[3], q[4]).mapped_op().with_tags(''), + ) + cirq.testing.assert_same_circuits(c_mapped, c_expected) + + cirq.testing.assert_has_diagram( + cirq.map_operations_and_unroll(c_orig, map_func), + ''' +0: ───Z───@───Z─────────────── + │ +1: ───Z───X───Z─────────────── + +2: ───Z───X───Z─────────────── + │ +3: ───Z───@───Z───Z───@───Z─── + │ +4: ───────────────Z───X───Z─── +''', + ) + + +def test_unroll_circuit_op_and_variants(): + q = cirq.LineQubit.range(2) + c = cirq.Circuit(cirq.X(q[0]), cirq.CNOT(q[0], q[1]), cirq.X(q[0])) + cirq.testing.assert_has_diagram( + c, + ''' +0: ───X───@───X─── + │ +1: ───────X─────── +''', + ) + mapped_circuit = cirq.map_operations( + c, lambda op, i: [cirq.Z(q[1])] * 2 if op.gate == cirq.CNOT else op + ) + cirq.testing.assert_has_diagram( + cirq.unroll_circuit_op(mapped_circuit), + ''' +0: ───X───────────X─── + +1: ───────Z───Z─────── +''', + ) + cirq.testing.assert_has_diagram( + cirq.unroll_circuit_op_greedy_earliest(mapped_circuit), + ''' +0: ───X───────X─── + +1: ───Z───Z─────── +''', + ) + cirq.testing.assert_has_diagram( + cirq.unroll_circuit_op_greedy_frontier(mapped_circuit), + ''' +0: ───X───────X─── + +1: ───────Z───Z─── +''', + ) + + +def test_unroll_circuit_op_no_tags(): + q = cirq.LineQubit.range(2) + op_list = [cirq.X(q[0]), cirq.Y(q[1])] + op1 = cirq.CircuitOperation(cirq.FrozenCircuit(op_list)) + op2 = op1.with_tags("custom tag") + op3 = op1.with_tags(MAPPED_CIRCUIT_OP_TAG) + c = cirq.Circuit(op1, op2, op3) + for unroller in [ + cirq.unroll_circuit_op, + cirq.unroll_circuit_op_greedy_earliest, + cirq.unroll_circuit_op_greedy_frontier, + ]: + cirq.testing.assert_same_circuits( + unroller(c, tags_to_check=None), cirq.Circuit([op_list] * 3) + ) + cirq.testing.assert_same_circuits(unroller(c), cirq.Circuit([op1, op2, op_list])) + cirq.testing.assert_same_circuits( + unroller(c, tags_to_check=("custom tag",)), cirq.Circuit([op1, op_list, op3]) + ) + cirq.testing.assert_same_circuits( + unroller( + c, + tags_to_check=("custom tag", MAPPED_CIRCUIT_OP_TAG), + ), + cirq.Circuit([op1, op_list, op_list]), + ) + + +def test_map_operations_raises_qubits_not_subset(): + q = cirq.LineQubit.range(3) + with pytest.raises(ValueError, match='should act on a subset'): + _ = cirq.map_operations( + cirq.Circuit(cirq.CNOT(q[0], q[1])), lambda op, i: cirq.CNOT(q[1], q[2]) + ) + + +def test_map_moments_drop_empty_moments(): + op = cirq.X(cirq.NamedQubit("x")) + c = cirq.Circuit(cirq.Moment(op), cirq.Moment(), cirq.Moment(op)) + c_mapped = cirq.map_moments(c, lambda m, i: [] if len(m) == 0 else [m]) + cirq.testing.assert_same_circuits(c_mapped, cirq.Circuit(c[0], c[0]))