Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add synchronize_terminal_measurements transformer to replace SynchronizeTerminalMeasurements #4911

Merged
merged 7 commits into from
Jan 31, 2022
1 change: 1 addition & 0 deletions cirq-core/cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@
single_qubit_matrix_to_phased_x_z,
single_qubit_matrix_to_phxz,
single_qubit_op_to_framed_phase_form,
synchronize_terminal_measurements,
TRANSFORMER,
TransformerContext,
TransformerLogger,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@

from typing import List, Set, Tuple, cast
from cirq import circuits, ops, protocols
from cirq._compat import deprecated_class


@deprecated_class(deadline='v1.0', fix='Use cirq.synchronize_terminal_measurements instead.')
class SynchronizeTerminalMeasurements:
"""Move measurements to the end of the circuit.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@


def assert_optimizes(before, after, measure_only_moment=True):
opt = cirq.SynchronizeTerminalMeasurements(measure_only_moment)
opt(before)
assert before == after
with cirq.testing.assert_deprecated(
"Use cirq.synchronize_terminal_measurements", deadline='v1.0'
):
opt = cirq.SynchronizeTerminalMeasurements(measure_only_moment)
opt(before)
assert before == after


def test_no_move():
Expand Down
2 changes: 2 additions & 0 deletions cirq-core/cirq/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@

from cirq.transformers.align import align_left, align_right

from cirq.transformers.synchronize_terminal_measurements import synchronize_terminal_measurements

from cirq.transformers.transformer_api import (
LogLevel,
TRANSFORMER,
Expand Down
99 changes: 99 additions & 0 deletions cirq-core/cirq/transformers/synchronize_terminal_measurements.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright 2022 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Transformer pass to move terminal measurements to the end of circuit."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this defined as the last moment in the circuit ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but the precise moment to which it get's added depends upon the value of after_other_operations flag.


from typing import List, Optional, Set, Tuple, TYPE_CHECKING
from cirq import protocols, ops
from cirq.transformers import transformer_api

if TYPE_CHECKING:
import cirq


def find_terminal_measurements(
circuit: 'cirq.AbstractCircuit',
) -> List[Tuple[int, 'cirq.Operation']]:
"""Finds all terminal measurements in the given circuit.

A measurement is terminal if there are no other operations acting on the measured qubits
after the measurement operation occurs in the circuit.

Args:
circuit: The circuit to find terminal measurements in.

Returns:
List of terminal measurements, each specified as (moment_index, measurement_operation).
"""

open_qubits: Set['cirq.Qid'] = set(circuit.all_qubits())
seen_control_keys: Set['cirq.MeasurementKey'] = set()
terminal_measurements: List[Tuple[int, 'cirq.Operation']] = []
for i in range(len(circuit) - 1, -1, -1):
moment = circuit[i]
for q in open_qubits:
op = moment.operation_at(q)
seen_control_keys |= protocols.control_keys(op)
if (
op is not None
and open_qubits.issuperset(op.qubits)
and protocols.is_measurement(op)
and not (seen_control_keys & protocols.measurement_key_objs(op))
):
terminal_measurements.append((i, op))
open_qubits -= moment.qubits
if not open_qubits:
break
return terminal_measurements


@transformer_api.transformer
def synchronize_terminal_measurements(
circuit: 'cirq.AbstractCircuit',
*,
context: Optional['cirq.TransformerContext'] = None,
after_other_operations: bool = True,
) -> 'cirq.Circuit':
"""Move measurements to the end of the circuit.

Move all measurements in a circuit to the final moment, if it can accommodate them (without
overlapping with other operations). If `after_other_operations` is true, then a new moment will
be added to the end of the circuit containing all the measurements that should be brought
forward.

Args:
circuit: Input circuit to transform.
context: `cirq.TransformerContext` storing common configurable options for transformers.
after_other_operations: Set by default. If the circuit's final moment contains
non-measurement operations and this is set then a new empty moment is appended to
the circuit before pushing measurements to the end.
Returns:
Copy of the transformed input circuit.
"""
if context is None:
context = transformer_api.TransformerContext()
terminal_measurements = [
(i, op)
for i, op in find_terminal_measurements(circuit)
if set(op.tags).isdisjoint(context.ignore_tags)
]
ret = circuit.unfreeze(copy=True)
if not terminal_measurements:
return ret

ret.batch_remove(terminal_measurements)
if ret[-1] and after_other_operations:
ret.append(ops.Moment())
ret[-1] = ret[-1].with_operations(op for _, op in terminal_measurements)
return ret
206 changes: 206 additions & 0 deletions cirq-core/cirq/transformers/synchronize_terminal_measurements_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
# Copyright 2022 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import cirq

NO_COMPILE_TAG = "no_compile_tag"


def assert_optimizes(before, after, measure_only_moment=True, with_context=False):
transformed_circuit = (
cirq.synchronize_terminal_measurements(before, after_other_operations=measure_only_moment)
if not with_context
else cirq.synchronize_terminal_measurements(
before,
context=cirq.TransformerContext(ignore_tags=(NO_COMPILE_TAG,)),
after_other_operations=measure_only_moment,
)
)
cirq.testing.assert_same_circuits(transformed_circuit, after)


def test_no_move():
q1 = cirq.NamedQubit('q1')
before = cirq.Circuit([cirq.Moment([cirq.H(q1)])])
after = before
assert_optimizes(before=before, after=after)


def test_simple_align():
q1 = cirq.NamedQubit('q1')
q2 = cirq.NamedQubit('q2')
before = cirq.Circuit(
[
cirq.Moment([cirq.H(q1), cirq.H(q2)]),
cirq.Moment([cirq.measure(q1).with_tags(NO_COMPILE_TAG), cirq.Z(q2)]),
cirq.Moment([cirq.measure(q2)]),
]
)
after = cirq.Circuit(
[
cirq.Moment([cirq.H(q1), cirq.H(q2)]),
cirq.Moment([cirq.Z(q2)]),
cirq.Moment([cirq.measure(q1).with_tags(NO_COMPILE_TAG), cirq.measure(q2)]),
]
)
assert_optimizes(before=before, after=after)
assert_optimizes(before=before, after=before, with_context=True)


def test_simple_partial_align():
q1 = cirq.NamedQubit('q1')
q2 = cirq.NamedQubit('q2')
before = cirq.Circuit(
[
cirq.Moment([cirq.measure(q1), cirq.Z(q2)]),
cirq.Moment([cirq.Z(q1), cirq.measure(q2).with_tags(NO_COMPILE_TAG)]),
]
)
after = cirq.Circuit(
[
cirq.Moment([cirq.measure(q1), cirq.Z(q2)]),
cirq.Moment([cirq.Z(q1)]),
cirq.Moment([cirq.measure(q2).with_tags(NO_COMPILE_TAG)]),
]
)
assert_optimizes(before=before, after=after)
assert_optimizes(before=before, after=before, with_context=True)


def test_slide_forward_one():
q1 = cirq.NamedQubit('q1')
q2 = cirq.NamedQubit('q2')
q3 = cirq.NamedQubit('q3')
before = cirq.Circuit(
[
cirq.Moment([cirq.H(q1), cirq.measure(q2).with_tags(NO_COMPILE_TAG), cirq.measure(q3)]),
]
)
after = cirq.Circuit(
[
cirq.Moment([cirq.H(q1)]),
cirq.Moment([cirq.measure(q2).with_tags(NO_COMPILE_TAG), cirq.measure(q3)]),
]
)
after_no_compile = cirq.Circuit(
[
cirq.Moment([cirq.H(q1), cirq.measure(q2).with_tags(NO_COMPILE_TAG)]),
cirq.Moment([cirq.measure(q3)]),
]
)
assert_optimizes(before=before, after=after)
assert_optimizes(before=before, after=after_no_compile, with_context=True)


def test_no_slide_forward_one():
q1 = cirq.NamedQubit('q1')
q2 = cirq.NamedQubit('q2')
q3 = cirq.NamedQubit('q3')
before = cirq.Circuit(
[
cirq.Moment([cirq.H(q1), cirq.measure(q2), cirq.measure(q3)]),
]
)
after = cirq.Circuit(
[
cirq.Moment([cirq.H(q1), cirq.measure(q2), cirq.measure(q3)]),
]
)
assert_optimizes(before=before, after=after, measure_only_moment=False)


def test_blocked_shift_one():
q1 = cirq.NamedQubit('q1')
q2 = cirq.NamedQubit('q2')
before = cirq.Circuit(
[
cirq.Moment([cirq.H(q1), cirq.H(q2)]),
cirq.Moment([cirq.measure(q1), cirq.Z(q2)]),
cirq.Moment([cirq.H(q1), cirq.measure(q2).with_tags(NO_COMPILE_TAG)]),
]
)
after = cirq.Circuit(
[
cirq.Moment([cirq.H(q1), cirq.H(q2)]),
cirq.Moment([cirq.measure(q1), cirq.Z(q2)]),
cirq.Moment([cirq.H(q1)]),
cirq.Moment([cirq.measure(q2).with_tags(NO_COMPILE_TAG)]),
]
)
assert_optimizes(before=before, after=after)
assert_optimizes(before=before, after=before, with_context=True)


def test_complex_move():
q1 = cirq.NamedQubit('q1')
q2 = cirq.NamedQubit('q2')
q3 = cirq.NamedQubit('q3')
before = cirq.Circuit(
[
cirq.Moment([cirq.H(q1), cirq.H(q2)]),
cirq.Moment([cirq.measure(q1), cirq.Z(q2)]),
cirq.Moment([cirq.H(q1), cirq.measure(q2).with_tags(NO_COMPILE_TAG)]),
cirq.Moment([cirq.H(q3)]),
cirq.Moment([cirq.X(q1), cirq.measure(q3).with_tags(NO_COMPILE_TAG)]),
]
)
after = cirq.Circuit(
[
cirq.Moment([cirq.H(q1), cirq.H(q2)]),
cirq.Moment([cirq.measure(q1), cirq.Z(q2)]),
cirq.Moment([cirq.H(q1)]),
cirq.Moment([cirq.H(q3)]),
cirq.Moment([cirq.X(q1)]),
cirq.Moment(
[
cirq.measure(q2).with_tags(NO_COMPILE_TAG),
cirq.measure(q3).with_tags(NO_COMPILE_TAG),
]
),
]
)
assert_optimizes(before=before, after=after)
assert_optimizes(before=before, after=before, with_context=True)


def test_complex_move_no_slide():
q1 = cirq.NamedQubit('q1')
q2 = cirq.NamedQubit('q2')
q3 = cirq.NamedQubit('q3')
before = cirq.Circuit(
[
cirq.Moment([cirq.H(q1), cirq.H(q2)]),
cirq.Moment([cirq.measure(q1), cirq.Z(q2)]),
cirq.Moment([cirq.H(q1), cirq.measure(q2).with_tags(NO_COMPILE_TAG)]),
cirq.Moment([cirq.H(q3)]),
cirq.Moment([cirq.X(q1), cirq.measure(q3)]),
]
)
after = cirq.Circuit(
[
cirq.Moment(cirq.H(q1), cirq.H(q2)),
cirq.Moment(cirq.measure(q1), cirq.Z(q2)),
cirq.Moment(cirq.H(q1)),
cirq.Moment(cirq.H(q3)),
cirq.Moment(cirq.X(q1), cirq.measure(q2).with_tags(NO_COMPILE_TAG), cirq.measure(q3)),
]
)
assert_optimizes(before=before, after=after, measure_only_moment=False)
assert_optimizes(before=before, after=before, measure_only_moment=False, with_context=True)


def test_multi_qubit():
q0, q1 = cirq.LineQubit.range(2)
circuit = cirq.Circuit(cirq.measure(q0, q1, key='m'), cirq.H(q1))
assert_optimizes(before=circuit, after=circuit)
8 changes: 4 additions & 4 deletions docs/tutorials/google/spin_echoes.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@
" # Alignment.\n",
" if with_alignment:\n",
" circuit = cirq.align_right(circuit)\n",
" cirq.SynchronizeTerminalMeasurements().optimize_circuit(circuit)\n",
" circuit = synchronize_terminal_measurements(circuit)\n",
"\n",
" return circuit\n",
"\n",
Expand Down Expand Up @@ -819,7 +819,7 @@
"id": "J537zBXPHny4"
},
"source": [
"Note: Optimizers can cause terminal measurements to become misaligned, but this can be fixed with `cirq.SynchronizeTerminalMeasurements` as discussed below."
"Note: Optimizers can cause terminal measurements to become misaligned, but this can be fixed with `cirq.synchronize_terminal_measurements` as discussed below."
]
},
{
Expand Down Expand Up @@ -1065,7 +1065,7 @@
"id": "q0J0zRqjKKhP"
},
"source": [
"You can use the `cirq.SynchronizeTerminalMeasurements` to move all measurements to the final moment if it can accommodate them (without overlapping with other operations)."
"You can use the `cirq.synchronize_terminal_measurements` to move all measurements to the final moment if it can accommodate them (without overlapping with other operations)."
]
},
{
Expand Down Expand Up @@ -1114,7 +1114,7 @@
}
],
"source": [
"cirq.SynchronizeTerminalMeasurements().optimize_circuit(circuit)\n",
"circuit = cirq.synchronize_terminal_measurements(circuit)\n",
"circuit"
]
},
Expand Down