diff --git a/cirq-core/cirq/__init__.py b/cirq-core/cirq/__init__.py index b6814bdb156..8b18b63daa9 100644 --- a/cirq-core/cirq/__init__.py +++ b/cirq-core/cirq/__init__.py @@ -372,8 +372,10 @@ map_moments, map_operations, map_operations_and_unroll, + merge_k_qubit_unitaries_to_circuit_op, merge_moments, merge_operations, + merge_operations_to_circuit_op, prepare_two_qubit_state_using_cz, prepare_two_qubit_state_using_sqrt_iswap, single_qubit_matrix_to_gates, diff --git a/cirq-core/cirq/transformers/__init__.py b/cirq-core/cirq/transformers/__init__.py index 1ee9de1f165..77009de16eb 100644 --- a/cirq-core/cirq/transformers/__init__.py +++ b/cirq-core/cirq/transformers/__init__.py @@ -74,8 +74,10 @@ map_moments, map_operations, map_operations_and_unroll, + merge_k_qubit_unitaries_to_circuit_op, merge_moments, merge_operations, + merge_operations_to_circuit_op, toggle_tags, unroll_circuit_op, unroll_circuit_op_greedy_earliest, diff --git a/cirq-core/cirq/transformers/transformer_primitives.py b/cirq-core/cirq/transformers/transformer_primitives.py index 6015072ccf3..671384771ec 100644 --- a/cirq-core/cirq/transformers/transformer_primitives.py +++ b/cirq-core/cirq/transformers/transformer_primitives.py @@ -27,7 +27,7 @@ TYPE_CHECKING, ) -from cirq import circuits, ops +from cirq import circuits, ops, protocols from cirq.circuits.circuit import CIRCUIT_TYPE if TYPE_CHECKING: @@ -281,6 +281,97 @@ def apply_merge_func(op1: ops.Operation, op2: ops.Operation) -> Optional[ops.Ope return _to_target_circuit_type(ret_circuit, circuit) +def merge_operations_to_circuit_op( + circuit: CIRCUIT_TYPE, + can_merge: Callable[[Sequence['cirq.Operation'], Sequence['cirq.Operation']], bool], + *, + tags_to_ignore: Sequence[Hashable] = (), + merged_circuit_op_tag: str = "Merged connected component", +) -> CIRCUIT_TYPE: + """Merges connected components of operations and wraps each component into a circuit operation. + + Uses `cirq.merge_operations` to identify connected components of operations. Moment structure + is preserved for operations that do not participate in merging. For merged operations, the + newly created circuit operations are constructed by inserting operations using EARLIEST + strategy. + If you need more control on moment structure of newly created circuit operations, consider + using `cirq.merge_operations` directly with a custom `merge_func`. + + Args: + circuit: Input circuit to apply the transformations on. The input circuit is not mutated. + can_merge: Callable to determine whether a new operation `right_op` can be merged into an + existing connected component of operations `left_ops` based on boolen returned by + `can_merge(left_ops, right_op)`. + tags_to_ignore: Tagged operations marked any of `tags_to_ignore` will not be considered as + potential candidates for any connected component. + merged_circuit_op_tag: Tag to be applied on circuit operations wrapping valid connected + components. + + Returns: + Copy of input circuit with valid connected components wrapped in tagged circuit operations. + """ + + def merge_func(op1: 'cirq.Operation', op2: 'cirq.Operation') -> Optional['cirq.Operation']: + def get_ops(op: 'cirq.Operation'): + op_untagged = op.untagged + return ( + [*op_untagged.circuit.all_operations()] + if isinstance(op_untagged, circuits.CircuitOperation) + and merged_circuit_op_tag in op.tags + else [op] + ) + + left_ops, right_ops = get_ops(op1), get_ops(op2) + if not can_merge(left_ops, right_ops): + return None + return circuits.CircuitOperation(circuits.FrozenCircuit(left_ops, right_ops)).with_tags( + merged_circuit_op_tag + ) + + return merge_operations(circuit, merge_func, tags_to_ignore=tags_to_ignore) + + +def merge_k_qubit_unitaries_to_circuit_op( + circuit: CIRCUIT_TYPE, + k: int, + *, + tags_to_ignore: Sequence[Hashable] = (), + merged_circuit_op_tag: Optional[str] = None, +) -> CIRCUIT_TYPE: + """Merges connected components of operations, acting on <= k qubits, into circuit operations. + + Uses `cirq.merge_operations_to_circuit_op` to identify and merge connected components of + unitary operations acting on at-most k-qubits. Moment structure is preserved for operations + that do not participate in merging. For merged operations, the newly created circuit operations + are constructed by inserting operations using EARLIEST strategy. + + Args: + circuit: Input circuit to apply the transformations on. The input circuit is not mutated. + k: Merge-able operations acting on <= k qubits are merged into a connected component. + tags_to_ignore: Tagged operations marked any of `tags_to_ignore` will not be considered as + potential candidates for any connected component. + merged_circuit_op_tag: Tag to be applied on circuit operations wrapping valid connected + components. A default tag is applied if left None. + + Returns: + Copy of input circuit with valid connected components wrapped in tagged circuit operations. + """ + + def can_merge(ops1: Sequence['cirq.Operation'], ops2: Sequence['cirq.Operation']) -> bool: + return all( + protocols.has_unitary(op) and protocols.num_qubits(op) <= k + for op_list in [ops1, ops2] + for op in op_list + ) + + return merge_operations_to_circuit_op( + circuit, + can_merge, + tags_to_ignore=tags_to_ignore, + merged_circuit_op_tag=merged_circuit_op_tag or f"Merged {k}q unitary connected component.", + ) + + def merge_moments( circuit: CIRCUIT_TYPE, merge_func: Callable[[circuits.Moment, circuits.Moment], Optional[circuits.Moment]], diff --git a/cirq-core/cirq/transformers/transformer_primitives_test.py b/cirq-core/cirq/transformers/transformer_primitives_test.py index 6e56cc5f734..8dd71426286 100644 --- a/cirq-core/cirq/transformers/transformer_primitives_test.py +++ b/cirq-core/cirq/transformers/transformer_primitives_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Optional, List import pytest import cirq @@ -396,9 +396,9 @@ def fail_if_called_func(*_): assert cirq.merge_operations(c, fail_if_called_func, tags_to_ignore=["ignore"]) == c -def test_merge_operations_merges_connected_component(): +def _create_circuit_to_merge(): q = cirq.LineQubit.range(3) - c_orig = cirq.Circuit( + return cirq.Circuit( cirq.Moment(cirq.H.on_each(*q)), cirq.CNOT(q[0], q[2]), cirq.CNOT(*q[0:2]), @@ -409,18 +409,22 @@ def test_merge_operations_merges_connected_component(): cirq.CNOT(*q[0:2]), cirq.CNOT(*q[1:3]), cirq.X(q[0]), - cirq.Y(q[1]), + cirq.Moment(cirq.X(q[0]).with_tags("ignore"), cirq.Y(q[1])), cirq.CNOT(*q[:2]), strategy=cirq.InsertStrategy.NEW, ) + + +def test_merge_operations_merges_connected_component(): + c_orig = _create_circuit_to_merge() cirq.testing.assert_has_diagram( c_orig, ''' -0: ───H───@───@───H───@───X───────@───────X───────@─── - │ │ │ │ │ -1: ───H───┼───X───────@───────Y───X───@───────Y───X─── +0: ───H───@───@───H───@───X───────@───────X───X['ignore']───@─── + │ │ │ │ │ +1: ───H───┼───X───────@───────Y───X───@───────Y─────────────X─── │ │ -2: ───H───X───────────────────────────X─────────────── +2: ───H───X───────────────────────────X───────────────────────── ''', ) @@ -443,6 +447,76 @@ def merge_func(op1, op2): ) +# pylint: disable=line-too-long +def test_merge_operations_to_circuit_op_merges_connected_component(): + c_orig = _create_circuit_to_merge() + cirq.testing.assert_has_diagram( + c_orig, + ''' +0: ───H───@───@───H───@───X───────@───────X───X['ignore']───@─── + │ │ │ │ │ +1: ───H───┼───X───────@───────Y───X───@───────Y─────────────X─── + │ │ +2: ───H───X───────────────────────────X───────────────────────── +''', + ) + + def can_merge(ops1: List['cirq.Operation'], ops2: List['cirq.Operation']) -> bool: + """Artificial example where a CZ will absorb any merge-able operation.""" + return any(o.gate == cirq.CZ for op_list in [ops1, ops2] for o in op_list) + + c_new = cirq.merge_operations_to_circuit_op( + c_orig, can_merge, merged_circuit_op_tag="merged", tags_to_ignore=["ignore"] + ) + cirq.testing.assert_has_diagram( + c_new, + ''' + [ 0: ───────@───H───@───X───@───X─── ] +0: ───H───@───────────[ │ │ │ ]─────────────────────────────────X['ignore']───@─── + │ [ 1: ───H───X───────@───Y───X─────── ]['merged'] │ + │ │ │ +1: ───────┼───────────#2─────────────────────────────────────────────────────────────@───────Y─────────────X─── + │ │ +2: ───H───X──────────────────────────────────────────────────────────────────────────X───────────────────────── +''', + ) + + +def test_merge_2q_unitaries_to_circuit_op(): + c_orig = _create_circuit_to_merge() + c_orig[-1] = c_orig[-1].with_operations(cirq.measure(cirq.LineQubit(2))) + cirq.testing.assert_has_diagram( + c_orig, + ''' +0: ───H───@───@───H───@───X───────@───────X───X['ignore']───@─── + │ │ │ │ │ +1: ───H───┼───X───────@───────Y───X───@───────Y─────────────X─── + │ │ +2: ───H───X───────────────────────────X─────────────────────M─── +''', + ) + + c_new = cirq.merge_k_qubit_unitaries_to_circuit_op( + c_orig, k=2, merged_circuit_op_tag="merged", tags_to_ignore=["ignore"] + ) + cirq.testing.assert_has_diagram( + cirq.drop_empty_moments(c_new), + ''' + [ 0: ───H───@─── ] [ 0: ───────@───H───@───X───@───X─── ] +0: ───[ │ ]─────────────[ │ │ │ ]────────────────────────────────────────────X['ignore']───@─── + [ 2: ───H───X─── ]['merged'] [ 1: ───H───X───────@───Y───X─────── ]['merged'] │ + │ │ │ + │ │ [ 1: ───@───Y─── ] │ +1: ───┼──────────────────────────────#2─────────────────────────────────────────────────[ │ ]───────────────────────────X─── + │ [ 2: ───X─────── ]['merged'] + │ │ +2: ───#2────────────────────────────────────────────────────────────────────────────────#2───────────────────────────────────────────M───''', + ) + + +# pylint: enable=line-too-long + + def test_merge_operations_respects_tags_to_ignore(): q = cirq.LineQubit.range(2) c = cirq.Circuit(