diff --git a/cirq-core/cirq/transformers/routing/route_circuit_cqc.py b/cirq-core/cirq/transformers/routing/route_circuit_cqc.py index fdf6eda1fea..3ce32eb98ee 100644 --- a/cirq-core/cirq/transformers/routing/route_circuit_cqc.py +++ b/cirq-core/cirq/transformers/routing/route_circuit_cqc.py @@ -249,26 +249,32 @@ def _get_one_and_two_qubit_ops_as_timesteps( output routed circuit, single-qubit operations are inserted before two-qubit operations. Raises: - ValueError: if circuit has intermediate measurement op's that act on 3 or more qubits. + ValueError: if circuit has intermediate measurements that act on three or more + qubits with a custom key. """ two_qubit_circuit = circuits.Circuit() single_qubit_ops: List[List[cirq.Operation]] = [] - if any( - protocols.num_qubits(op) > 2 and protocols.is_measurement(op) - for op in itertools.chain(*circuit.moments[:-1]) - ): - # There is at least one non-terminal measurement on 3+ qubits - raise ValueError('Non-terminal measurements on three or more qubits are not supported') - - for moment in circuit: + for i, moment in enumerate(circuit): for op in moment: timestep = two_qubit_circuit.earliest_available_moment(op) single_qubit_ops.extend([] for _ in range(timestep + 1 - len(single_qubit_ops))) two_qubit_circuit.append( circuits.Moment() for _ in range(timestep + 1 - len(two_qubit_circuit)) ) - if protocols.num_qubits(op) == 2: + if protocols.num_qubits(op) > 2 and protocols.is_measurement(op): + key = op.gate.key # type: ignore + default_key = ops.measure(op.qubits).gate.key # type: ignore + if len(circuit.moments) == i + 1: + single_qubit_ops[timestep].append(op) + elif key in ('', default_key): + single_qubit_ops[timestep].extend(ops.measure(qubit) for qubit in op.qubits) + else: + raise ValueError( + 'Intermediate measurements on three or more qubits ' + 'with a custom key are not supported' + ) + elif protocols.num_qubits(op) == 2: two_qubit_circuit[timestep] = two_qubit_circuit[timestep].with_operation(op) else: single_qubit_ops[timestep].append(op) diff --git a/cirq-core/cirq/transformers/routing/route_circuit_cqc_test.py b/cirq-core/cirq/transformers/routing/route_circuit_cqc_test.py index a07161da936..1fffb273fdc 100644 --- a/cirq-core/cirq/transformers/routing/route_circuit_cqc_test.py +++ b/cirq-core/cirq/transformers/routing/route_circuit_cqc_test.py @@ -107,35 +107,45 @@ def test_circuit_with_measurement_gates(): cirq.testing.assert_same_circuits(routed_circuit, circuit) -def test_circuit_with_valid_intermediate_multi_qubit_measurement_gates(): - device = cirq.testing.construct_ring_device(3) +def test_circuit_with_two_qubit_intermediate_measurement_gate(): + device = cirq.testing.construct_ring_device(2) device_graph = device.metadata.nx_graph router = cirq.RouteCQC(device_graph) - q = cirq.LineQubit.range(2) - hard_coded_mapper = cirq.HardCodedInitialMapper({q[i]: q[i] for i in range(2)}) - - valid_circuit = cirq.Circuit(cirq.measure_each(*q), cirq.H.on_each(q)) - - c_routed = router( - valid_circuit, initial_mapper=hard_coded_mapper, context=cirq.TransformerContext(deep=True) + qs = cirq.LineQubit.range(2) + hard_coded_mapper = cirq.HardCodedInitialMapper({qs[i]: qs[i] for i in range(2)}) + circuit = cirq.Circuit([cirq.Moment(cirq.measure(qs)), cirq.Moment(cirq.H.on_each(qs))]) + routed_circuit = router( + circuit, initial_mapper=hard_coded_mapper, context=cirq.TransformerContext(deep=True) ) - device.validate_circuit(c_routed) + device.validate_circuit(routed_circuit) -def test_circuit_with_invalid_intermediate_multi_qubit_measurement_gates(): +def test_circuit_with_multi_qubit_intermediate_measurement_gate_and_with_default_key(): device = cirq.testing.construct_ring_device(3) device_graph = device.metadata.nx_graph router = cirq.RouteCQC(device_graph) - q = cirq.LineQubit.range(3) - hard_coded_mapper = cirq.HardCodedInitialMapper({q[i]: q[i] for i in range(3)}) + qs = cirq.LineQubit.range(3) + hard_coded_mapper = cirq.HardCodedInitialMapper({qs[i]: qs[i] for i in range(3)}) + circuit = cirq.Circuit([cirq.Moment(cirq.measure(qs)), cirq.Moment(cirq.H.on_each(qs))]) + routed_circuit = router( + circuit, initial_mapper=hard_coded_mapper, context=cirq.TransformerContext(deep=True) + ) + expected = cirq.Circuit([cirq.Moment(cirq.measure_each(qs)), cirq.Moment(cirq.H.on_each(qs))]) + cirq.testing.assert_same_circuits(routed_circuit, expected) - invalid_circuit = cirq.Circuit(cirq.MeasurementGate(3).on(*q), cirq.H.on_each(*q)) +def test_circuit_with_multi_qubit_intermediate_measurement_gate_with_custom_key(): + device = cirq.testing.construct_ring_device(3) + device_graph = device.metadata.nx_graph + router = cirq.RouteCQC(device_graph) + qs = cirq.LineQubit.range(3) + hard_coded_mapper = cirq.HardCodedInitialMapper({qs[i]: qs[i] for i in range(3)}) + circuit = cirq.Circuit( + [cirq.Moment(cirq.measure(qs, key="test")), cirq.Moment(cirq.H.on_each(qs))] + ) with pytest.raises(ValueError): _ = router( - invalid_circuit, - initial_mapper=hard_coded_mapper, - context=cirq.TransformerContext(deep=True), + circuit, initial_mapper=hard_coded_mapper, context=cirq.TransformerContext(deep=True) )