Skip to content

Commit

Permalink
Preserve circuit tags in transformer_primitives.map_operations (#6505)
Browse files Browse the repository at this point in the history
  • Loading branch information
maffoo authored Mar 18, 2024
1 parent dadfdcb commit 7780c01
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 6 deletions.
9 changes: 3 additions & 6 deletions cirq-core/cirq/transformers/transformer_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,9 @@ def _to_target_circuit_type(


def _create_target_circuit_type(ops: ops.OP_TREE, target_circuit: CIRCUIT_TYPE) -> CIRCUIT_TYPE:
return cast(
CIRCUIT_TYPE,
circuits.Circuit(ops)
if isinstance(target_circuit, circuits.Circuit)
else circuits.FrozenCircuit(ops),
)
if isinstance(target_circuit, circuits.FrozenCircuit):
return cast(CIRCUIT_TYPE, circuits.FrozenCircuit(ops).with_tags(*target_circuit.tags))
return cast(CIRCUIT_TYPE, circuits.Circuit(ops))


def map_moments(
Expand Down
27 changes: 27 additions & 0 deletions cirq-core/cirq/transformers/transformer_primitives_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,33 @@ def map_func(op: cirq.Operation, _: int) -> cirq.OP_TREE:
# pylint: enable=line-too-long


@pytest.mark.parametrize("deep", [False, True])
def test_map_operations_preserves_circuit_tags(deep: bool) -> None:
tag = "should be preserved"

def func(op: cirq.Operation, idx: int) -> cirq.Operation:
return cirq.Y(op.qubits[0]) if op.gate == cirq.X else op

x = cirq.X(cirq.q(0))
circuit = cirq.FrozenCircuit.from_moments(x, cirq.FrozenCircuit(x)).with_tags(tag)
mapped = cirq.map_operations(circuit, func, deep=deep)

assert mapped.tags == (tag,)


def test_map_operations_deep_preserves_subcircuit_tags():
tag = "should be preserved"

def func(op: cirq.Operation, idx: int) -> cirq.Operation:
return cirq.Y(op.qubits[0]) if op.gate == cirq.X else op

x = cirq.X(cirq.q(0))
circuit = cirq.FrozenCircuit.from_moments(x, cirq.FrozenCircuit(x).with_tags(tag))
mapped = cirq.map_operations(circuit, func, deep=True)

assert mapped[1].operations[0].circuit.tags == (tag,)


def test_map_operations_deep_respects_tags_to_ignore():
q = cirq.LineQubit.range(2)
c_nested = cirq.FrozenCircuit(cirq.CX(*q), cirq.CX(*q).with_tags("ignore"), cirq.CX(*q))
Expand Down

0 comments on commit 7780c01

Please sign in to comment.