Skip to content

Commit

Permalink
Add cirq.toggle_tags helper to apply transformers on specific subse…
Browse files Browse the repository at this point in the history
…ts of operations in a circuit (quantumlib#4973)

* Add  helper to apply transformers on specific subsets of operations in a circuit

* Rename to toggle_tags and address feedback
  • Loading branch information
tanujkhattar authored and rht committed May 1, 2023
1 parent 568e0e4 commit 89dfc3f
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 0 deletions.
1 change: 1 addition & 0 deletions cirq-core/cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@
two_qubit_gate_product_tabulation,
TwoQubitGateTabulation,
TwoQubitGateTabulationResult,
toggle_tags,
unroll_circuit_op,
unroll_circuit_op_greedy_earliest,
unroll_circuit_op_greedy_frontier,
Expand Down
1 change: 1 addition & 0 deletions cirq-core/cirq/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
map_operations_and_unroll,
merge_moments,
merge_operations,
toggle_tags,
unroll_circuit_op,
unroll_circuit_op_greedy_earliest,
unroll_circuit_op_greedy_frontier,
Expand Down
34 changes: 34 additions & 0 deletions cirq-core/cirq/transformers/align_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,40 @@ def test_align_left_no_compile_context():
)


def test_align_left_subset_of_operations():
q1 = cirq.NamedQubit('q1')
q2 = cirq.NamedQubit('q2')
tag = "op_to_align"
c_orig = cirq.Circuit(
[
cirq.Moment([cirq.Y(q1)]),
cirq.Moment([cirq.X(q2)]),
cirq.Moment([cirq.X(q1).with_tags(tag)]),
cirq.Moment([cirq.Y(q2)]),
cirq.measure(*[q1, q2], key='a'),
]
)
c_exp = cirq.Circuit(
[
cirq.Moment([cirq.Y(q1)]),
cirq.Moment([cirq.X(q1).with_tags(tag), cirq.X(q2)]),
cirq.Moment(),
cirq.Moment([cirq.Y(q2)]),
cirq.measure(*[q1, q2], key='a'),
]
)
cirq.testing.assert_same_circuits(
cirq.toggle_tags(
cirq.align_left(
cirq.toggle_tags(c_orig, [tag]),
context=cirq.TransformerContext(tags_to_ignore=[tag]),
),
[tag],
),
c_exp,
)


def test_align_right_no_compile_context():
q1 = cirq.NamedQubit('q1')
q2 = cirq.NamedQubit('q2')
Expand Down
31 changes: 31 additions & 0 deletions cirq-core/cirq/transformers/transformer_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,3 +426,34 @@ def unroll_circuit_op_greedy_frontier(
)
frontier = unrolled_circuit.insert_at_frontier(sub_circuit.all_operations(), idx, frontier)
return _to_target_circuit_type(unrolled_circuit, circuit)


def toggle_tags(circuit: CIRCUIT_TYPE, tags: Sequence[Hashable], *, deep: bool = False):
"""Toggles tags applied on each operation in the circuit, via `op.tags ^= tags`
For every operations `op` in the input circuit, the tags on `op` are replaced by a symmetric
difference of `op.tags` and `tags` -- this is useful in scenarios where you mark a small subset
of operations with a specific tag and then toggle the set of marked operations s.t. every
marked operation is now unmarked and vice versa.
Often used in transformer workflows to apply a transformer on a small subset of operations.
Args:
circuit: Input circuit to apply the transformations on. The input circuit is not mutated.
tags: Sequence of tags s.t. `op.tags ^= tags` is done for every operation `op` in circuit.
deep: If true, tags will be recursively toggled for operations in circuits wrapped inside
any circuit operations contained within `circuit`.
Returns:
Copy of transformed input circuit with operation sets marked with `tags` toggled.
"""
tags_to_xor = set(tags)

def map_func(op: 'cirq.Operation', _) -> 'cirq.Operation':
return (
op
if deep and isinstance(op, circuits.CircuitOperation)
else op.untagged.with_tags(*(set(op.tags) ^ tags_to_xor))
)

return map_operations(circuit, map_func, deep=deep)
20 changes: 20 additions & 0 deletions cirq-core/cirq/transformers/transformer_primitives_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,26 @@ def test_map_operations_respects_tags_to_ignore():
)


def test_apply_tag_to_inverted_op_set():
q = cirq.LineQubit.range(2)
op = cirq.CNOT(*q)
tag = "tag_to_flip"
c_orig = cirq.Circuit(op, op.with_tags(tag), cirq.CircuitOperation(cirq.FrozenCircuit(op)))
# Toggle with deep = True.
c_toggled = cirq.Circuit(
op.with_tags(tag), op, cirq.CircuitOperation(cirq.FrozenCircuit(op.with_tags(tag)))
)
cirq.testing.assert_same_circuits(cirq.toggle_tags(c_orig, [tag], deep=True), c_toggled)
cirq.testing.assert_same_circuits(cirq.toggle_tags(c_toggled, [tag], deep=True), c_orig)

# Toggle with deep = False
c_toggled = cirq.Circuit(
op.with_tags(tag), op, cirq.CircuitOperation(cirq.FrozenCircuit(op)).with_tags(tag)
)
cirq.testing.assert_same_circuits(cirq.toggle_tags(c_orig, [tag], deep=False), c_toggled)
cirq.testing.assert_same_circuits(cirq.toggle_tags(c_toggled, [tag], deep=False), c_orig)


def test_unroll_circuit_op_and_variants():
q = cirq.LineQubit.range(2)
c = cirq.Circuit(cirq.X(q[0]), cirq.CNOT(q[0], q[1]), cirq.X(q[0]))
Expand Down

0 comments on commit 89dfc3f

Please sign in to comment.