Skip to content

Commit

Permalink
Add synchronize_terminal_measurements transformer to replace `Synch…
Browse files Browse the repository at this point in the history
…ronizeTerminalMeasurements` (quantumlib#4911)

- Part of quantumlib#4722
- Follows the new Transformer API quantumlib#4483
- Supports no compile tags NoCompile Tag for optimizers quantumlib#4253
- Fixes quantumlib#4907
  • Loading branch information
tanujkhattar authored and rht committed May 1, 2023
1 parent 4bb20a5 commit 03ddbd8
Show file tree
Hide file tree
Showing 7 changed files with 328 additions and 7 deletions.
1 change: 1 addition & 0 deletions cirq-core/cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@
single_qubit_matrix_to_phased_x_z,
single_qubit_matrix_to_phxz,
single_qubit_op_to_framed_phase_form,
synchronize_terminal_measurements,
TRANSFORMER,
TransformerContext,
TransformerLogger,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@

from typing import List, Set, Tuple, cast
from cirq import circuits, ops, protocols
from cirq._compat import deprecated_class


@deprecated_class(deadline='v1.0', fix='Use cirq.synchronize_terminal_measurements instead.')
class SynchronizeTerminalMeasurements:
"""Move measurements to the end of the circuit.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@


def assert_optimizes(before, after, measure_only_moment=True):
opt = cirq.SynchronizeTerminalMeasurements(measure_only_moment)
opt(before)
assert before == after
with cirq.testing.assert_deprecated(
"Use cirq.synchronize_terminal_measurements", deadline='v1.0'
):
opt = cirq.SynchronizeTerminalMeasurements(measure_only_moment)
opt(before)
assert before == after


def test_no_move():
Expand Down
2 changes: 2 additions & 0 deletions cirq-core/cirq/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@

from cirq.transformers.align import align_left, align_right

from cirq.transformers.synchronize_terminal_measurements import synchronize_terminal_measurements

from cirq.transformers.transformer_api import (
LogLevel,
TRANSFORMER,
Expand Down
99 changes: 99 additions & 0 deletions cirq-core/cirq/transformers/synchronize_terminal_measurements.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright 2022 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Transformer pass to move terminal measurements to the end of circuit."""

from typing import List, Optional, Set, Tuple, TYPE_CHECKING
from cirq import protocols, ops
from cirq.transformers import transformer_api

if TYPE_CHECKING:
import cirq


def find_terminal_measurements(
circuit: 'cirq.AbstractCircuit',
) -> List[Tuple[int, 'cirq.Operation']]:
"""Finds all terminal measurements in the given circuit.
A measurement is terminal if there are no other operations acting on the measured qubits
after the measurement operation occurs in the circuit.
Args:
circuit: The circuit to find terminal measurements in.
Returns:
List of terminal measurements, each specified as (moment_index, measurement_operation).
"""

open_qubits: Set['cirq.Qid'] = set(circuit.all_qubits())
seen_control_keys: Set['cirq.MeasurementKey'] = set()
terminal_measurements: List[Tuple[int, 'cirq.Operation']] = []
for i in range(len(circuit) - 1, -1, -1):
moment = circuit[i]
for q in open_qubits:
op = moment.operation_at(q)
seen_control_keys |= protocols.control_keys(op)
if (
op is not None
and open_qubits.issuperset(op.qubits)
and protocols.is_measurement(op)
and not (seen_control_keys & protocols.measurement_key_objs(op))
):
terminal_measurements.append((i, op))
open_qubits -= moment.qubits
if not open_qubits:
break
return terminal_measurements


@transformer_api.transformer
def synchronize_terminal_measurements(
circuit: 'cirq.AbstractCircuit',
*,
context: Optional['cirq.TransformerContext'] = None,
after_other_operations: bool = True,
) -> 'cirq.Circuit':
"""Move measurements to the end of the circuit.
Move all measurements in a circuit to the final moment, if it can accommodate them (without
overlapping with other operations). If `after_other_operations` is true, then a new moment will
be added to the end of the circuit containing all the measurements that should be brought
forward.
Args:
circuit: Input circuit to transform.
context: `cirq.TransformerContext` storing common configurable options for transformers.
after_other_operations: Set by default. If the circuit's final moment contains
non-measurement operations and this is set then a new empty moment is appended to
the circuit before pushing measurements to the end.
Returns:
Copy of the transformed input circuit.
"""
if context is None:
context = transformer_api.TransformerContext()
terminal_measurements = [
(i, op)
for i, op in find_terminal_measurements(circuit)
if set(op.tags).isdisjoint(context.ignore_tags)
]
ret = circuit.unfreeze(copy=True)
if not terminal_measurements:
return ret

ret.batch_remove(terminal_measurements)
if ret[-1] and after_other_operations:
ret.append(ops.Moment())
ret[-1] = ret[-1].with_operations(op for _, op in terminal_measurements)
return ret
214 changes: 214 additions & 0 deletions cirq-core/cirq/transformers/synchronize_terminal_measurements_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
# Copyright 2022 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import cirq

NO_COMPILE_TAG = "no_compile_tag"


def assert_optimizes(before, after, measure_only_moment=True, with_context=False):
transformed_circuit = (
cirq.synchronize_terminal_measurements(before, after_other_operations=measure_only_moment)
if not with_context
else cirq.synchronize_terminal_measurements(
before,
context=cirq.TransformerContext(ignore_tags=(NO_COMPILE_TAG,)),
after_other_operations=measure_only_moment,
)
)
cirq.testing.assert_same_circuits(transformed_circuit, after)


def test_no_move():
q1 = cirq.NamedQubit('q1')
before = cirq.Circuit([cirq.Moment([cirq.H(q1)])])
after = before
assert_optimizes(before=before, after=after)


def test_simple_align():
q1 = cirq.NamedQubit('q1')
q2 = cirq.NamedQubit('q2')
before = cirq.Circuit(
[
cirq.Moment([cirq.H(q1), cirq.H(q2)]),
cirq.Moment([cirq.measure(q1).with_tags(NO_COMPILE_TAG), cirq.Z(q2)]),
cirq.Moment([cirq.measure(q2)]),
]
)
after = cirq.Circuit(
[
cirq.Moment([cirq.H(q1), cirq.H(q2)]),
cirq.Moment([cirq.Z(q2)]),
cirq.Moment([cirq.measure(q1).with_tags(NO_COMPILE_TAG), cirq.measure(q2)]),
]
)
assert_optimizes(before=before, after=after)
assert_optimizes(before=before, after=before, with_context=True)


def test_simple_partial_align():
q1 = cirq.NamedQubit('q1')
q2 = cirq.NamedQubit('q2')
before = cirq.Circuit(
[
cirq.Moment([cirq.measure(q1), cirq.Z(q2)]),
cirq.Moment([cirq.Z(q1), cirq.measure(q2).with_tags(NO_COMPILE_TAG)]),
]
)
after = cirq.Circuit(
[
cirq.Moment([cirq.measure(q1), cirq.Z(q2)]),
cirq.Moment([cirq.Z(q1)]),
cirq.Moment([cirq.measure(q2).with_tags(NO_COMPILE_TAG)]),
]
)
assert_optimizes(before=before, after=after)
assert_optimizes(before=before, after=before, with_context=True)


def test_slide_forward_one():
q1 = cirq.NamedQubit('q1')
q2 = cirq.NamedQubit('q2')
q3 = cirq.NamedQubit('q3')
before = cirq.Circuit(
[
cirq.Moment([cirq.H(q1), cirq.measure(q2).with_tags(NO_COMPILE_TAG), cirq.measure(q3)]),
]
)
after = cirq.Circuit(
[
cirq.Moment([cirq.H(q1)]),
cirq.Moment([cirq.measure(q2).with_tags(NO_COMPILE_TAG), cirq.measure(q3)]),
]
)
after_no_compile = cirq.Circuit(
[
cirq.Moment([cirq.H(q1), cirq.measure(q2).with_tags(NO_COMPILE_TAG)]),
cirq.Moment([cirq.measure(q3)]),
]
)
assert_optimizes(before=before, after=after)
assert_optimizes(before=before, after=after_no_compile, with_context=True)


def test_no_slide_forward_one():
q1 = cirq.NamedQubit('q1')
q2 = cirq.NamedQubit('q2')
q3 = cirq.NamedQubit('q3')
before = cirq.Circuit(
[
cirq.Moment([cirq.H(q1), cirq.measure(q2), cirq.measure(q3)]),
]
)
after = cirq.Circuit(
[
cirq.Moment([cirq.H(q1), cirq.measure(q2), cirq.measure(q3)]),
]
)
assert_optimizes(before=before, after=after, measure_only_moment=False)


def test_blocked_shift_one():
q1 = cirq.NamedQubit('q1')
q2 = cirq.NamedQubit('q2')
before = cirq.Circuit(
[
cirq.Moment([cirq.H(q1), cirq.H(q2)]),
cirq.Moment([cirq.measure(q1), cirq.Z(q2)]),
cirq.Moment([cirq.H(q1), cirq.measure(q2).with_tags(NO_COMPILE_TAG)]),
]
)
after = cirq.Circuit(
[
cirq.Moment([cirq.H(q1), cirq.H(q2)]),
cirq.Moment([cirq.measure(q1), cirq.Z(q2)]),
cirq.Moment([cirq.H(q1)]),
cirq.Moment([cirq.measure(q2).with_tags(NO_COMPILE_TAG)]),
]
)
assert_optimizes(before=before, after=after)
assert_optimizes(before=before, after=before, with_context=True)


def test_complex_move():
q1 = cirq.NamedQubit('q1')
q2 = cirq.NamedQubit('q2')
q3 = cirq.NamedQubit('q3')
before = cirq.Circuit(
[
cirq.Moment([cirq.H(q1), cirq.H(q2)]),
cirq.Moment([cirq.measure(q1), cirq.Z(q2)]),
cirq.Moment([cirq.H(q1), cirq.measure(q2).with_tags(NO_COMPILE_TAG)]),
cirq.Moment([cirq.H(q3)]),
cirq.Moment([cirq.X(q1), cirq.measure(q3).with_tags(NO_COMPILE_TAG)]),
]
)
after = cirq.Circuit(
[
cirq.Moment([cirq.H(q1), cirq.H(q2)]),
cirq.Moment([cirq.measure(q1), cirq.Z(q2)]),
cirq.Moment([cirq.H(q1)]),
cirq.Moment([cirq.H(q3)]),
cirq.Moment([cirq.X(q1)]),
cirq.Moment(
[
cirq.measure(q2).with_tags(NO_COMPILE_TAG),
cirq.measure(q3).with_tags(NO_COMPILE_TAG),
]
),
]
)
assert_optimizes(before=before, after=after)
assert_optimizes(before=before, after=before, with_context=True)


def test_complex_move_no_slide():
q1 = cirq.NamedQubit('q1')
q2 = cirq.NamedQubit('q2')
q3 = cirq.NamedQubit('q3')
before = cirq.Circuit(
[
cirq.Moment([cirq.H(q1), cirq.H(q2)]),
cirq.Moment([cirq.measure(q1), cirq.Z(q2)]),
cirq.Moment([cirq.H(q1), cirq.measure(q2).with_tags(NO_COMPILE_TAG)]),
cirq.Moment([cirq.H(q3)]),
cirq.Moment([cirq.X(q1), cirq.measure(q3)]),
]
)
after = cirq.Circuit(
[
cirq.Moment(cirq.H(q1), cirq.H(q2)),
cirq.Moment(cirq.measure(q1), cirq.Z(q2)),
cirq.Moment(cirq.H(q1)),
cirq.Moment(cirq.H(q3)),
cirq.Moment(cirq.X(q1), cirq.measure(q2).with_tags(NO_COMPILE_TAG), cirq.measure(q3)),
]
)
assert_optimizes(before=before, after=after, measure_only_moment=False)
assert_optimizes(before=before, after=before, measure_only_moment=False, with_context=True)


def test_multi_qubit():
q0, q1 = cirq.LineQubit.range(2)
circuit = cirq.Circuit(cirq.measure(q0, q1, key='m'), cirq.H(q1))
assert_optimizes(before=circuit, after=circuit)


def test_classically_controlled_op():
q0, q1 = cirq.LineQubit.range(2)
circuit = cirq.Circuit(
cirq.H(q0), cirq.measure(q0, key='m'), cirq.X(q1).with_classical_controls('m')
)
assert_optimizes(before=circuit, after=circuit)
8 changes: 4 additions & 4 deletions docs/tutorials/google/spin_echoes.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@
" # Alignment.\n",
" if with_alignment:\n",
" circuit = cirq.align_right(circuit)\n",
" cirq.SynchronizeTerminalMeasurements().optimize_circuit(circuit)\n",
" circuit = synchronize_terminal_measurements(circuit)\n",
"\n",
" return circuit\n",
"\n",
Expand Down Expand Up @@ -819,7 +819,7 @@
"id": "J537zBXPHny4"
},
"source": [
"Note: Optimizers can cause terminal measurements to become misaligned, but this can be fixed with `cirq.SynchronizeTerminalMeasurements` as discussed below."
"Note: Optimizers can cause terminal measurements to become misaligned, but this can be fixed with `cirq.synchronize_terminal_measurements` as discussed below."
]
},
{
Expand Down Expand Up @@ -1065,7 +1065,7 @@
"id": "q0J0zRqjKKhP"
},
"source": [
"You can use the `cirq.SynchronizeTerminalMeasurements` to move all measurements to the final moment if it can accommodate them (without overlapping with other operations)."
"You can use the `cirq.synchronize_terminal_measurements` to move all measurements to the final moment if it can accommodate them (without overlapping with other operations)."
]
},
{
Expand Down Expand Up @@ -1114,7 +1114,7 @@
}
],
"source": [
"cirq.SynchronizeTerminalMeasurements().optimize_circuit(circuit)\n",
"circuit = cirq.synchronize_terminal_measurements(circuit)\n",
"circuit"
]
},
Expand Down

0 comments on commit 03ddbd8

Please sign in to comment.