diff --git a/cirq-core/cirq/__init__.py b/cirq-core/cirq/__init__.py index 1cbf03d05e8..657eba57b60 100644 --- a/cirq-core/cirq/__init__.py +++ b/cirq-core/cirq/__init__.py @@ -371,6 +371,7 @@ single_qubit_matrix_to_phased_x_z, single_qubit_matrix_to_phxz, single_qubit_op_to_framed_phase_form, + synchronize_terminal_measurements, TRANSFORMER, TransformerContext, TransformerLogger, diff --git a/cirq-core/cirq/optimizers/synchronize_terminal_measurements.py b/cirq-core/cirq/optimizers/synchronize_terminal_measurements.py index 6181c35090f..5f8c091d3d9 100644 --- a/cirq-core/cirq/optimizers/synchronize_terminal_measurements.py +++ b/cirq-core/cirq/optimizers/synchronize_terminal_measurements.py @@ -15,8 +15,10 @@ from typing import List, Set, Tuple, cast from cirq import circuits, ops, protocols +from cirq._compat import deprecated_class +@deprecated_class(deadline='v1.0', fix='Use cirq.synchronize_terminal_measurements instead.') class SynchronizeTerminalMeasurements: """Move measurements to the end of the circuit. diff --git a/cirq-core/cirq/optimizers/synchronize_terminal_measurements_test.py b/cirq-core/cirq/optimizers/synchronize_terminal_measurements_test.py index 04b1fc35851..1e2331eaf03 100644 --- a/cirq-core/cirq/optimizers/synchronize_terminal_measurements_test.py +++ b/cirq-core/cirq/optimizers/synchronize_terminal_measurements_test.py @@ -16,9 +16,12 @@ def assert_optimizes(before, after, measure_only_moment=True): - opt = cirq.SynchronizeTerminalMeasurements(measure_only_moment) - opt(before) - assert before == after + with cirq.testing.assert_deprecated( + "Use cirq.synchronize_terminal_measurements", deadline='v1.0' + ): + opt = cirq.SynchronizeTerminalMeasurements(measure_only_moment) + opt(before) + assert before == after def test_no_move(): diff --git a/cirq-core/cirq/transformers/__init__.py b/cirq-core/cirq/transformers/__init__.py index 7b1dac91f5b..5911472a57a 100644 --- a/cirq-core/cirq/transformers/__init__.py +++ b/cirq-core/cirq/transformers/__init__.py @@ -43,6 +43,8 @@ from cirq.transformers.align import align_left, align_right +from cirq.transformers.synchronize_terminal_measurements import synchronize_terminal_measurements + from cirq.transformers.transformer_api import ( LogLevel, TRANSFORMER, diff --git a/cirq-core/cirq/transformers/synchronize_terminal_measurements.py b/cirq-core/cirq/transformers/synchronize_terminal_measurements.py new file mode 100644 index 00000000000..d182abab107 --- /dev/null +++ b/cirq-core/cirq/transformers/synchronize_terminal_measurements.py @@ -0,0 +1,99 @@ +# Copyright 2022 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Transformer pass to move terminal measurements to the end of circuit.""" + +from typing import List, Optional, Set, Tuple, TYPE_CHECKING +from cirq import protocols, ops +from cirq.transformers import transformer_api + +if TYPE_CHECKING: + import cirq + + +def find_terminal_measurements( + circuit: 'cirq.AbstractCircuit', +) -> List[Tuple[int, 'cirq.Operation']]: + """Finds all terminal measurements in the given circuit. + + A measurement is terminal if there are no other operations acting on the measured qubits + after the measurement operation occurs in the circuit. + + Args: + circuit: The circuit to find terminal measurements in. + + Returns: + List of terminal measurements, each specified as (moment_index, measurement_operation). + """ + + open_qubits: Set['cirq.Qid'] = set(circuit.all_qubits()) + seen_control_keys: Set['cirq.MeasurementKey'] = set() + terminal_measurements: List[Tuple[int, 'cirq.Operation']] = [] + for i in range(len(circuit) - 1, -1, -1): + moment = circuit[i] + for q in open_qubits: + op = moment.operation_at(q) + seen_control_keys |= protocols.control_keys(op) + if ( + op is not None + and open_qubits.issuperset(op.qubits) + and protocols.is_measurement(op) + and not (seen_control_keys & protocols.measurement_key_objs(op)) + ): + terminal_measurements.append((i, op)) + open_qubits -= moment.qubits + if not open_qubits: + break + return terminal_measurements + + +@transformer_api.transformer +def synchronize_terminal_measurements( + circuit: 'cirq.AbstractCircuit', + *, + context: Optional['cirq.TransformerContext'] = None, + after_other_operations: bool = True, +) -> 'cirq.Circuit': + """Move measurements to the end of the circuit. + + Move all measurements in a circuit to the final moment, if it can accommodate them (without + overlapping with other operations). If `after_other_operations` is true, then a new moment will + be added to the end of the circuit containing all the measurements that should be brought + forward. + + Args: + circuit: Input circuit to transform. + context: `cirq.TransformerContext` storing common configurable options for transformers. + after_other_operations: Set by default. If the circuit's final moment contains + non-measurement operations and this is set then a new empty moment is appended to + the circuit before pushing measurements to the end. + Returns: + Copy of the transformed input circuit. + """ + if context is None: + context = transformer_api.TransformerContext() + terminal_measurements = [ + (i, op) + for i, op in find_terminal_measurements(circuit) + if set(op.tags).isdisjoint(context.ignore_tags) + ] + ret = circuit.unfreeze(copy=True) + if not terminal_measurements: + return ret + + ret.batch_remove(terminal_measurements) + if ret[-1] and after_other_operations: + ret.append(ops.Moment()) + ret[-1] = ret[-1].with_operations(op for _, op in terminal_measurements) + return ret diff --git a/cirq-core/cirq/transformers/synchronize_terminal_measurements_test.py b/cirq-core/cirq/transformers/synchronize_terminal_measurements_test.py new file mode 100644 index 00000000000..ad69aa11050 --- /dev/null +++ b/cirq-core/cirq/transformers/synchronize_terminal_measurements_test.py @@ -0,0 +1,214 @@ +# Copyright 2022 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cirq + +NO_COMPILE_TAG = "no_compile_tag" + + +def assert_optimizes(before, after, measure_only_moment=True, with_context=False): + transformed_circuit = ( + cirq.synchronize_terminal_measurements(before, after_other_operations=measure_only_moment) + if not with_context + else cirq.synchronize_terminal_measurements( + before, + context=cirq.TransformerContext(ignore_tags=(NO_COMPILE_TAG,)), + after_other_operations=measure_only_moment, + ) + ) + cirq.testing.assert_same_circuits(transformed_circuit, after) + + +def test_no_move(): + q1 = cirq.NamedQubit('q1') + before = cirq.Circuit([cirq.Moment([cirq.H(q1)])]) + after = before + assert_optimizes(before=before, after=after) + + +def test_simple_align(): + q1 = cirq.NamedQubit('q1') + q2 = cirq.NamedQubit('q2') + before = cirq.Circuit( + [ + cirq.Moment([cirq.H(q1), cirq.H(q2)]), + cirq.Moment([cirq.measure(q1).with_tags(NO_COMPILE_TAG), cirq.Z(q2)]), + cirq.Moment([cirq.measure(q2)]), + ] + ) + after = cirq.Circuit( + [ + cirq.Moment([cirq.H(q1), cirq.H(q2)]), + cirq.Moment([cirq.Z(q2)]), + cirq.Moment([cirq.measure(q1).with_tags(NO_COMPILE_TAG), cirq.measure(q2)]), + ] + ) + assert_optimizes(before=before, after=after) + assert_optimizes(before=before, after=before, with_context=True) + + +def test_simple_partial_align(): + q1 = cirq.NamedQubit('q1') + q2 = cirq.NamedQubit('q2') + before = cirq.Circuit( + [ + cirq.Moment([cirq.measure(q1), cirq.Z(q2)]), + cirq.Moment([cirq.Z(q1), cirq.measure(q2).with_tags(NO_COMPILE_TAG)]), + ] + ) + after = cirq.Circuit( + [ + cirq.Moment([cirq.measure(q1), cirq.Z(q2)]), + cirq.Moment([cirq.Z(q1)]), + cirq.Moment([cirq.measure(q2).with_tags(NO_COMPILE_TAG)]), + ] + ) + assert_optimizes(before=before, after=after) + assert_optimizes(before=before, after=before, with_context=True) + + +def test_slide_forward_one(): + q1 = cirq.NamedQubit('q1') + q2 = cirq.NamedQubit('q2') + q3 = cirq.NamedQubit('q3') + before = cirq.Circuit( + [ + cirq.Moment([cirq.H(q1), cirq.measure(q2).with_tags(NO_COMPILE_TAG), cirq.measure(q3)]), + ] + ) + after = cirq.Circuit( + [ + cirq.Moment([cirq.H(q1)]), + cirq.Moment([cirq.measure(q2).with_tags(NO_COMPILE_TAG), cirq.measure(q3)]), + ] + ) + after_no_compile = cirq.Circuit( + [ + cirq.Moment([cirq.H(q1), cirq.measure(q2).with_tags(NO_COMPILE_TAG)]), + cirq.Moment([cirq.measure(q3)]), + ] + ) + assert_optimizes(before=before, after=after) + assert_optimizes(before=before, after=after_no_compile, with_context=True) + + +def test_no_slide_forward_one(): + q1 = cirq.NamedQubit('q1') + q2 = cirq.NamedQubit('q2') + q3 = cirq.NamedQubit('q3') + before = cirq.Circuit( + [ + cirq.Moment([cirq.H(q1), cirq.measure(q2), cirq.measure(q3)]), + ] + ) + after = cirq.Circuit( + [ + cirq.Moment([cirq.H(q1), cirq.measure(q2), cirq.measure(q3)]), + ] + ) + assert_optimizes(before=before, after=after, measure_only_moment=False) + + +def test_blocked_shift_one(): + q1 = cirq.NamedQubit('q1') + q2 = cirq.NamedQubit('q2') + before = cirq.Circuit( + [ + cirq.Moment([cirq.H(q1), cirq.H(q2)]), + cirq.Moment([cirq.measure(q1), cirq.Z(q2)]), + cirq.Moment([cirq.H(q1), cirq.measure(q2).with_tags(NO_COMPILE_TAG)]), + ] + ) + after = cirq.Circuit( + [ + cirq.Moment([cirq.H(q1), cirq.H(q2)]), + cirq.Moment([cirq.measure(q1), cirq.Z(q2)]), + cirq.Moment([cirq.H(q1)]), + cirq.Moment([cirq.measure(q2).with_tags(NO_COMPILE_TAG)]), + ] + ) + assert_optimizes(before=before, after=after) + assert_optimizes(before=before, after=before, with_context=True) + + +def test_complex_move(): + q1 = cirq.NamedQubit('q1') + q2 = cirq.NamedQubit('q2') + q3 = cirq.NamedQubit('q3') + before = cirq.Circuit( + [ + cirq.Moment([cirq.H(q1), cirq.H(q2)]), + cirq.Moment([cirq.measure(q1), cirq.Z(q2)]), + cirq.Moment([cirq.H(q1), cirq.measure(q2).with_tags(NO_COMPILE_TAG)]), + cirq.Moment([cirq.H(q3)]), + cirq.Moment([cirq.X(q1), cirq.measure(q3).with_tags(NO_COMPILE_TAG)]), + ] + ) + after = cirq.Circuit( + [ + cirq.Moment([cirq.H(q1), cirq.H(q2)]), + cirq.Moment([cirq.measure(q1), cirq.Z(q2)]), + cirq.Moment([cirq.H(q1)]), + cirq.Moment([cirq.H(q3)]), + cirq.Moment([cirq.X(q1)]), + cirq.Moment( + [ + cirq.measure(q2).with_tags(NO_COMPILE_TAG), + cirq.measure(q3).with_tags(NO_COMPILE_TAG), + ] + ), + ] + ) + assert_optimizes(before=before, after=after) + assert_optimizes(before=before, after=before, with_context=True) + + +def test_complex_move_no_slide(): + q1 = cirq.NamedQubit('q1') + q2 = cirq.NamedQubit('q2') + q3 = cirq.NamedQubit('q3') + before = cirq.Circuit( + [ + cirq.Moment([cirq.H(q1), cirq.H(q2)]), + cirq.Moment([cirq.measure(q1), cirq.Z(q2)]), + cirq.Moment([cirq.H(q1), cirq.measure(q2).with_tags(NO_COMPILE_TAG)]), + cirq.Moment([cirq.H(q3)]), + cirq.Moment([cirq.X(q1), cirq.measure(q3)]), + ] + ) + after = cirq.Circuit( + [ + cirq.Moment(cirq.H(q1), cirq.H(q2)), + cirq.Moment(cirq.measure(q1), cirq.Z(q2)), + cirq.Moment(cirq.H(q1)), + cirq.Moment(cirq.H(q3)), + cirq.Moment(cirq.X(q1), cirq.measure(q2).with_tags(NO_COMPILE_TAG), cirq.measure(q3)), + ] + ) + assert_optimizes(before=before, after=after, measure_only_moment=False) + assert_optimizes(before=before, after=before, measure_only_moment=False, with_context=True) + + +def test_multi_qubit(): + q0, q1 = cirq.LineQubit.range(2) + circuit = cirq.Circuit(cirq.measure(q0, q1, key='m'), cirq.H(q1)) + assert_optimizes(before=circuit, after=circuit) + + +def test_classically_controlled_op(): + q0, q1 = cirq.LineQubit.range(2) + circuit = cirq.Circuit( + cirq.H(q0), cirq.measure(q0, key='m'), cirq.X(q1).with_classical_controls('m') + ) + assert_optimizes(before=circuit, after=circuit) diff --git a/docs/tutorials/google/spin_echoes.ipynb b/docs/tutorials/google/spin_echoes.ipynb index c881bf5c2e4..990504eb551 100644 --- a/docs/tutorials/google/spin_echoes.ipynb +++ b/docs/tutorials/google/spin_echoes.ipynb @@ -250,7 +250,7 @@ " # Alignment.\n", " if with_alignment:\n", " circuit = cirq.align_right(circuit)\n", - " cirq.SynchronizeTerminalMeasurements().optimize_circuit(circuit)\n", + " circuit = synchronize_terminal_measurements(circuit)\n", "\n", " return circuit\n", "\n", @@ -819,7 +819,7 @@ "id": "J537zBXPHny4" }, "source": [ - "Note: Optimizers can cause terminal measurements to become misaligned, but this can be fixed with `cirq.SynchronizeTerminalMeasurements` as discussed below." + "Note: Optimizers can cause terminal measurements to become misaligned, but this can be fixed with `cirq.synchronize_terminal_measurements` as discussed below." ] }, { @@ -1065,7 +1065,7 @@ "id": "q0J0zRqjKKhP" }, "source": [ - "You can use the `cirq.SynchronizeTerminalMeasurements` to move all measurements to the final moment if it can accommodate them (without overlapping with other operations)." + "You can use the `cirq.synchronize_terminal_measurements` to move all measurements to the final moment if it can accommodate them (without overlapping with other operations)." ] }, { @@ -1114,7 +1114,7 @@ } ], "source": [ - "cirq.SynchronizeTerminalMeasurements().optimize_circuit(circuit)\n", + "circuit = cirq.synchronize_terminal_measurements(circuit)\n", "circuit" ] },