diff --git a/crates/accelerate/src/commutation_cancellation.rs b/crates/accelerate/src/commutation_cancellation.rs index 9387ee2cc2ca..74351c70b471 100644 --- a/crates/accelerate/src/commutation_cancellation.rs +++ b/crates/accelerate/src/commutation_cancellation.rs @@ -13,6 +13,7 @@ use std::f64::consts::PI; use hashbrown::{HashMap, HashSet}; +use pyo3::exceptions::PyRuntimeError; use pyo3::prelude::*; use pyo3::{pyfunction, pymodule, wrap_pyfunction, Bound, PyResult, Python}; use rustworkx_core::petgraph::stable_graph::NodeIndex; @@ -110,70 +111,69 @@ pub(crate) fn cancel_commutations( } else { continue; } - com_set.iter().for_each(|node| { + for node in com_set.iter() { let instr = match &dag.dag[*node] { NodeType::Operation(instr) => instr, _ => panic!("Unexpected type in commutation set."), }; let num_qargs = dag.get_qargs(instr.qubits).len(); // no support for cancellation of parameterized gates - if !instr.is_parameterized() { - if let Some(op_gate) = instr.op.try_standard_gate() { - if num_qargs == 1 && SUPPORTED_GATES.contains(&op_gate) { - cancellation_sets - .entry(CancellationSetKey { - gate: GateOrRotation::Gate(op_gate), - qubits: smallvec![wire], - com_set_index: com_set_idx, - second_index: None, - }) - .or_insert_with(Vec::new) - .push(*node); - } + if instr.is_parameterized() { + continue; + } + if let Some(op_gate) = instr.op.try_standard_gate() { + if num_qargs == 1 && SUPPORTED_GATES.contains(&op_gate) { + cancellation_sets + .entry(CancellationSetKey { + gate: GateOrRotation::Gate(op_gate), + qubits: smallvec![wire], + com_set_index: com_set_idx, + second_index: None, + }) + .or_insert_with(Vec::new) + .push(*node); + } - if num_qargs == 1 && Z_ROTATIONS.contains(&op_gate) { - cancellation_sets - .entry(CancellationSetKey { - gate: GateOrRotation::ZRotation, - qubits: smallvec![wire], - com_set_index: com_set_idx, - second_index: None, - }) - .or_insert_with(Vec::new) - .push(*node); - } - if num_qargs == 1 && X_ROTATIONS.contains(&op_gate) { - cancellation_sets - .entry(CancellationSetKey { - gate: GateOrRotation::XRotation, - qubits: smallvec![wire], - com_set_index: com_set_idx, - second_index: None, - }) - .or_insert_with(Vec::new) - .push(*node); - } - // Don't deal with Y rotation, because Y rotation doesn't commute with - // CNOT, so it should be dealt with by optimized1qgate pass - if num_qargs == 2 - && dag.get_qargs(instr.qubits)[0] == &wire - { - let second_qarg = dag.get_qargs(instr.qubits)[1]; - cancellation_sets - .entry(CancellationSetKey { - gate: GateOrRotation::Gate(op_gate), - qubits: smallvec![wire, second_qarg], - com_set_index: com_set_idx, - second_index: node_indices - .get(&(*node, Wire::Qubit(second_qarg))) - .copied(), - }) - .or_insert_with(Vec::new) - .push(*node); - } + if num_qargs == 1 && Z_ROTATIONS.contains(&op_gate) { + cancellation_sets + .entry(CancellationSetKey { + gate: GateOrRotation::ZRotation, + qubits: smallvec![wire], + com_set_index: com_set_idx, + second_index: None, + }) + .or_insert_with(Vec::new) + .push(*node); + } + if num_qargs == 1 && X_ROTATIONS.contains(&op_gate) { + cancellation_sets + .entry(CancellationSetKey { + gate: GateOrRotation::XRotation, + qubits: smallvec![wire], + com_set_index: com_set_idx, + second_index: None, + }) + .or_insert_with(Vec::new) + .push(*node); + } + // Don't deal with Y rotation, because Y rotation doesn't commute with + // CNOT, so it should be dealt with by optimized1qgate pass + if num_qargs == 2 && dag.get_qargs(instr.qubits)[0] == wire { + let second_qarg = dag.get_qargs(instr.qubits)[1]; + cancellation_sets + .entry(CancellationSetKey { + gate: GateOrRotation::Gate(op_gate), + qubits: smallvec![wire, second_qarg], + com_set_index: com_set_idx, + second_index: node_indices + .get(&(*node, Wire::Qubit(second_qarg))) + .copied(), + }) + .or_insert_with(Vec::new) + .push(*node); } } - }) + } } } }); @@ -191,12 +191,10 @@ pub(crate) fn cancel_commutations( if matches!(cancel_key.gate, GateOrRotation::ZRotation) && z_var_gate.is_none() { continue; } - if matches!(cancel_key.gate, GateOrRotation::ZRotation | GateOrRotation::XRotation) - { - let run_qarg = match &dag.dag[*cancel_set[0]] { - NodeType::Operation(instr) => dag.get_qargs(instr.qubits)[0], - _ => panic!("Unexpected type in commutation set run."), - }; + if matches!( + cancel_key.gate, + GateOrRotation::ZRotation | GateOrRotation::XRotation + ) { let mut total_angle: f64 = 0.0; let mut total_phase: f64 = 0.0; for current_node in cancel_set { @@ -206,14 +204,6 @@ pub(crate) fn cancel_commutations( }; let node_op_name = node_op.op.name(); - let node_qargs = dag.get_qargs(node_op.qubits); - if node_op.condition().is_some() - || node_qargs.len() > 1 - || node_qargs[0] != run_qarg - { - return Err(QiskitError::new_err("internal error")); - } - let node_angle = if ROTATION_GATES.contains(&node_op_name) { match node_op.params_view().first() { Some(Param::Float(f)) => Ok(*f), @@ -229,7 +219,7 @@ pub(crate) fn cancel_commutations( } else if EIGHTH_TURNS.contains(&node_op_name) { Ok(PI / 4.0) } else { - Err(QiskitError::new_err(format!( + Err(PyRuntimeError::new_err(format!( "Angle for operation {} is not defined", node_op_name ))) @@ -240,10 +230,10 @@ pub(crate) fn cancel_commutations( total_phase += new_phase } - let new_op = match cancel_key.gate + let new_op = match cancel_key.gate { GateOrRotation::ZRotation => z_var_gate.unwrap(), - GateOrRotation::XRotation) => &RXGate, - _ => unreachable!() + GateOrRotation::XRotation => &RXGate, + _ => unreachable!(), }; let gate_angle = euler_one_qubit_decomposer::mod_2pi(total_angle, 0.);