Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add transformer primitive to merge connected component of operations in a circuit op #4974

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cirq-core/cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,8 +372,10 @@
map_moments,
map_operations,
map_operations_and_unroll,
merge_k_qubit_unitaries_to_circuit_op,
merge_moments,
merge_operations,
merge_operations_to_circuit_op,
prepare_two_qubit_state_using_cz,
prepare_two_qubit_state_using_sqrt_iswap,
single_qubit_matrix_to_gates,
Expand Down
2 changes: 2 additions & 0 deletions cirq-core/cirq/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,10 @@
map_moments,
map_operations,
map_operations_and_unroll,
merge_k_qubit_unitaries_to_circuit_op,
merge_moments,
merge_operations,
merge_operations_to_circuit_op,
toggle_tags,
unroll_circuit_op,
unroll_circuit_op_greedy_earliest,
Expand Down
93 changes: 92 additions & 1 deletion cirq-core/cirq/transformers/transformer_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
TYPE_CHECKING,
)

from cirq import circuits, ops
from cirq import circuits, ops, protocols
from cirq.circuits.circuit import CIRCUIT_TYPE

if TYPE_CHECKING:
Expand Down Expand Up @@ -281,6 +281,97 @@ def apply_merge_func(op1: ops.Operation, op2: ops.Operation) -> Optional[ops.Ope
return _to_target_circuit_type(ret_circuit, circuit)


def merge_operations_to_circuit_op(
circuit: CIRCUIT_TYPE,
can_merge: Callable[[Sequence['cirq.Operation'], Sequence['cirq.Operation']], bool],
*,
tags_to_ignore: Sequence[Hashable] = (),
merged_circuit_op_tag: str = "Merged connected component",
) -> CIRCUIT_TYPE:
"""Merges connected components of operations and wraps each component into a circuit operation.

Uses `cirq.merge_operations` to identify connected components of operations. Moment structure
is preserved for operations that do not participate in merging. For merged operations, the
newly created circuit operations are constructed by inserting operations using EARLIEST
strategy.
If you need more control on moment structure of newly created circuit operations, consider
using `cirq.merge_operations` directly with a custom `merge_func`.

Args:
circuit: Input circuit to apply the transformations on. The input circuit is not mutated.
can_merge: Callable to determine whether a new operation `right_op` can be merged into an
existing connected component of operations `left_ops` based on boolen returned by
`can_merge(left_ops, right_op)`.
tags_to_ignore: Tagged operations marked any of `tags_to_ignore` will not be considered as
potential candidates for any connected component.
merged_circuit_op_tag: Tag to be applied on circuit operations wrapping valid connected
components.

Returns:
Copy of input circuit with valid connected components wrapped in tagged circuit operations.
"""

def merge_func(op1: 'cirq.Operation', op2: 'cirq.Operation') -> Optional['cirq.Operation']:
def get_ops(op: 'cirq.Operation'):
op_untagged = op.untagged
return (
[*op_untagged.circuit.all_operations()]
if isinstance(op_untagged, circuits.CircuitOperation)
and merged_circuit_op_tag in op.tags
else [op]
)

left_ops, right_ops = get_ops(op1), get_ops(op2)
if not can_merge(left_ops, right_ops):
return None
return circuits.CircuitOperation(circuits.FrozenCircuit(left_ops, right_ops)).with_tags(
merged_circuit_op_tag
)

return merge_operations(circuit, merge_func, tags_to_ignore=tags_to_ignore)


def merge_k_qubit_unitaries_to_circuit_op(
circuit: CIRCUIT_TYPE,
k: int,
*,
tags_to_ignore: Sequence[Hashable] = (),
merged_circuit_op_tag: Optional[str] = None,
) -> CIRCUIT_TYPE:
tanujkhattar marked this conversation as resolved.
Show resolved Hide resolved
"""Merges connected components of operations, acting on <= k qubits, into circuit operations.

Uses `cirq.merge_operations_to_circuit_op` to identify and merge connected components of
unitary operations acting on at-most k-qubits. Moment structure is preserved for operations
that do not participate in merging. For merged operations, the newly created circuit operations
are constructed by inserting operations using EARLIEST strategy.

Args:
circuit: Input circuit to apply the transformations on. The input circuit is not mutated.
k: Merge-able operations acting on <= k qubits are merged into a connected component.
tags_to_ignore: Tagged operations marked any of `tags_to_ignore` will not be considered as
potential candidates for any connected component.
merged_circuit_op_tag: Tag to be applied on circuit operations wrapping valid connected
components. A default tag is applied if left None.

Returns:
Copy of input circuit with valid connected components wrapped in tagged circuit operations.
"""

def can_merge(ops1: Sequence['cirq.Operation'], ops2: Sequence['cirq.Operation']) -> bool:
return all(
protocols.has_unitary(op) and protocols.num_qubits(op) <= k
for op_list in [ops1, ops2]
for op in op_list
)

return merge_operations_to_circuit_op(
circuit,
can_merge,
tags_to_ignore=tags_to_ignore,
merged_circuit_op_tag=merged_circuit_op_tag or f"Merged {k}q unitary connected component.",
)


def merge_moments(
circuit: CIRCUIT_TYPE,
merge_func: Callable[[circuits.Moment, circuits.Moment], Optional[circuits.Moment]],
Expand Down
90 changes: 82 additions & 8 deletions cirq-core/cirq/transformers/transformer_primitives_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional
from typing import Optional, List
import pytest

import cirq
Expand Down Expand Up @@ -396,9 +396,9 @@ def fail_if_called_func(*_):
assert cirq.merge_operations(c, fail_if_called_func, tags_to_ignore=["ignore"]) == c


def test_merge_operations_merges_connected_component():
def _create_circuit_to_merge():
q = cirq.LineQubit.range(3)
c_orig = cirq.Circuit(
return cirq.Circuit(
cirq.Moment(cirq.H.on_each(*q)),
cirq.CNOT(q[0], q[2]),
cirq.CNOT(*q[0:2]),
Expand All @@ -409,18 +409,22 @@ def test_merge_operations_merges_connected_component():
cirq.CNOT(*q[0:2]),
cirq.CNOT(*q[1:3]),
cirq.X(q[0]),
cirq.Y(q[1]),
cirq.Moment(cirq.X(q[0]).with_tags("ignore"), cirq.Y(q[1])),
cirq.CNOT(*q[:2]),
strategy=cirq.InsertStrategy.NEW,
)


def test_merge_operations_merges_connected_component():
c_orig = _create_circuit_to_merge()
cirq.testing.assert_has_diagram(
c_orig,
'''
0: ───H───@───@───H───@───X───────@───────X──────@───
│ │ │ │ │
1: ───H───┼───X───────@───────Y───X───@───────Y───X───
0: ───H───@───@───H───@───X───────@───────X───X['ignore']───@───
│ │ │ │
1: ───H───┼───X───────@───────Y───X───@───────Y─────────────X───
│ │
2: ───H───X───────────────────────────X───────────────
2: ───H───X───────────────────────────X─────────────────────────
''',
)

Expand All @@ -443,6 +447,76 @@ def merge_func(op1, op2):
)


# pylint: disable=line-too-long
def test_merge_operations_to_circuit_op_merges_connected_component():
c_orig = _create_circuit_to_merge()
cirq.testing.assert_has_diagram(
c_orig,
'''
0: ───H───@───@───H───@───X───────@───────X───X['ignore']───@───
│ │ │ │ │
1: ───H───┼───X───────@───────Y───X───@───────Y─────────────X───
│ │
2: ───H───X───────────────────────────X─────────────────────────
''',
)

def can_merge(ops1: List['cirq.Operation'], ops2: List['cirq.Operation']) -> bool:
"""Artificial example where a CZ will absorb any merge-able operation."""
return any(o.gate == cirq.CZ for op_list in [ops1, ops2] for o in op_list)

c_new = cirq.merge_operations_to_circuit_op(
c_orig, can_merge, merged_circuit_op_tag="merged", tags_to_ignore=["ignore"]
)
cirq.testing.assert_has_diagram(
c_new,
'''
[ 0: ───────@───H───@───X───@───X─── ]
0: ───H───@───────────[ │ │ │ ]─────────────────────────────────X['ignore']───@───
│ [ 1: ───H───X───────@───Y───X─────── ]['merged'] │
│ │ │
1: ───────┼───────────#2─────────────────────────────────────────────────────────────@───────Y─────────────X───
│ │
2: ───H───X──────────────────────────────────────────────────────────────────────────X─────────────────────────
''',
)


def test_merge_2q_unitaries_to_circuit_op():
c_orig = _create_circuit_to_merge()
c_orig[-1] = c_orig[-1].with_operations(cirq.measure(cirq.LineQubit(2)))
cirq.testing.assert_has_diagram(
c_orig,
'''
0: ───H───@───@───H───@───X───────@───────X───X['ignore']───@───
│ │ │ │ │
1: ───H───┼───X───────@───────Y───X───@───────Y─────────────X───
│ │
2: ───H───X───────────────────────────X─────────────────────M───
''',
)

c_new = cirq.merge_k_qubit_unitaries_to_circuit_op(
c_orig, k=2, merged_circuit_op_tag="merged", tags_to_ignore=["ignore"]
)
cirq.testing.assert_has_diagram(
cirq.drop_empty_moments(c_new),
'''
[ 0: ───H───@─── ] [ 0: ───────@───H───@───X───@───X─── ]
0: ───[ │ ]─────────────[ │ │ │ ]────────────────────────────────────────────X['ignore']───@───
[ 2: ───H───X─── ]['merged'] [ 1: ───H───X───────@───Y───X─────── ]['merged'] │
│ │ │
│ │ [ 1: ───@───Y─── ] │
1: ───┼──────────────────────────────#2─────────────────────────────────────────────────[ │ ]───────────────────────────X───
│ [ 2: ───X─────── ]['merged']
│ │
2: ───#2────────────────────────────────────────────────────────────────────────────────#2───────────────────────────────────────────M───''',
)


# pylint: enable=line-too-long


def test_merge_operations_respects_tags_to_ignore():
q = cirq.LineQubit.range(2)
c = cirq.Circuit(
Expand Down