diff --git a/crates/accelerate/src/commutation_analysis.rs b/crates/accelerate/src/commutation_analysis.rs index cf33774037fd..2da8cfce931e 100644 --- a/crates/accelerate/src/commutation_analysis.rs +++ b/crates/accelerate/src/commutation_analysis.rs @@ -50,7 +50,7 @@ const MAX_NUM_QUBITS: u32 = 3; /// commutation_set = {0: [[0], [2, 3], [4], [1]]} /// node_indices = {(0, 0): 0, (1, 0): 3, (2, 0): 1, (3, 0): 1, (4, 0): 2} /// -pub fn analyze_commutations_inner( +pub(crate) fn analyze_commutations_inner( py: Python, dag: &mut DAGCircuit, commutation_checker: &mut CommutationChecker, diff --git a/crates/accelerate/src/commutation_cancellation.rs b/crates/accelerate/src/commutation_cancellation.rs index 89914bb4fb57..b6a459b1ab7e 100644 --- a/crates/accelerate/src/commutation_cancellation.rs +++ b/crates/accelerate/src/commutation_cancellation.rs @@ -12,15 +12,19 @@ use crate::commutation_analysis::analyze_commutations_inner; use crate::commutation_checker::CommutationChecker; -use crate::target_transpiler::Target; +use crate::{euler_one_qubit_decomposer, QiskitError}; use hashbrown::{HashMap, HashSet}; use pyo3::prelude::*; use pyo3::{pyfunction, pymodule, wrap_pyfunction, Bound, PyResult, Python}; use qiskit_circuit::dag_circuit::{DAGCircuit, NodeType, Wire}; -use qiskit_circuit::operations::StandardGate::{PhaseGate, RXGate, RZGate, U1Gate}; +use qiskit_circuit::operations::StandardGate::{ + CXGate, CYGate, CZGate, HGate, PhaseGate, RXGate, RZGate, SGate, TGate, U1Gate, XGate, YGate, + ZGate, +}; use qiskit_circuit::operations::{Operation, Param, StandardGate}; use qiskit_circuit::Qubit; use rustworkx_core::petgraph::stable_graph::NodeIndex; +use smallvec::{smallvec, SmallVec}; use std::f64::consts::PI; const _CUTOFF_PRECISION: f64 = 1e-5; @@ -29,41 +33,56 @@ static HALF_TURNS: [&str; 2] = ["z", "x"]; static QUARTER_TURNS: [&str; 1] = ["s"]; static EIGHTH_TURNS: [&str; 1] = ["t"]; -const Z_ROTATION: &str = "z_rotation"; -const X_ROTATION: &str = "x_rotation"; +static VAR_Z_MAP: [(&str, StandardGate); 3] = [("rz", RZGate), ("p", PhaseGate), ("u1", U1Gate)]; +static Z_ROTATIONS: [StandardGate; 6] = [PhaseGate, ZGate, U1Gate, RZGate, TGate, SGate]; +static X_ROTATIONS: [StandardGate; 2] = [XGate, RXGate]; +static SUPPORTED_GATES: [StandardGate; 5] = [CXGate, CYGate, CZGate, HGate, YGate]; + +#[derive(Hash, Eq, PartialEq, Debug)] +enum GateOrRotation { + Gate(StandardGate), + ZRotation, + XRotation, +} +#[derive(Hash, Eq, PartialEq, Debug)] +struct CancellationSetKey { + gate: GateOrRotation, + qubits: SmallVec<[Qubit; 2]>, + com_set_index: usize, + second_index: Option, +} #[pyfunction] -#[pyo3(signature = (dag, commutation_checker, basis_gates=None, target=None))] +#[pyo3(signature = (dag, commutation_checker, basis_gates=None))] pub(crate) fn cancel_commutations( py: Python, dag: &mut DAGCircuit, commutation_checker: &mut CommutationChecker, basis_gates: Option>, - target: Option<&Target>, ) -> PyResult<()> { - let basis: HashSet = if let Some(tar) = target { - HashSet::from_iter(tar.operation_names().map(String::from)) - } else if let Some(basis) = basis_gates { + let basis: HashSet = if let Some(basis) = basis_gates { basis } else { HashSet::new() }; - - let _var_z_map: HashMap<&str, StandardGate> = - HashMap::from([("rz", RZGate), ("p", PhaseGate), ("u1", U1Gate)]); - - let _z_rotations: HashSet<&str> = HashSet::from(["p", "z", "u1", "rz", "t", "s"]); - let _x_rotations: HashSet<&str> = HashSet::from(["x", "rx"]); - let _gates: HashSet<&str> = HashSet::from(["cx", "cy", "cz", "h", "y"]); - let z_var_gate = dag .op_names .keys() - .find(|g| _var_z_map.contains_key(g.as_str())) - // Fallback to the first matching key from basis if there is no match in dag.op_names - .or_else(|| basis.iter().find(|g| _var_z_map.contains_key(g.as_str()))) - // get the StandardGate associated with that string - .and_then(|key| _var_z_map.get(key.as_str())); + .find_map(|g| { + VAR_Z_MAP + .iter() + .find(|(key, _)| *key == g.as_str()) + .map(|(_, gate)| gate) + }) + .or_else(|| { + basis.iter().find_map(|g| { + VAR_Z_MAP + .iter() + .find(|(key, _)| *key == g.as_str()) + .map(|(_, gate)| gate) + }) + }); + // Fallback to the first matching key from basis if there is no match in dag.op_names // Gate sets to be cancelled /* Traverse each qubit to generate the cancel dictionaries @@ -75,10 +94,7 @@ pub(crate) fn cancel_commutations( qubits and commutation sets. */ let (commutation_set, node_indices) = analyze_commutations_inner(py, dag, commutation_checker)?; - let mut single_q_cancellation_sets: HashMap<(String, Qubit, usize), Vec> = - HashMap::new(); - let mut two_q_cancellation_sets: HashMap<(String, Qubit, Qubit, usize, usize), Vec> = - HashMap::new(); + let mut cancellation_sets: HashMap> = HashMap::new(); (0..dag.num_qubits() as u32).for_each(|qubit| { let wire = Qubit(qubit); @@ -101,45 +117,59 @@ pub(crate) fn cancel_commutations( .iter() .all(|p| !matches!(p, Param::ParameterExpression(_))) { - let op_name = op.op.name().to_string(); - if num_qargs == 1usize && _gates.contains(op_name.as_str()) { - single_q_cancellation_sets - .entry((op_name.clone(), wire, com_set_idx)) - .or_insert_with(Vec::new) - .push(*node); - } + if let Some(op_gate) = op.op.try_standard_gate() { + if num_qargs == 1usize && SUPPORTED_GATES.contains(&op_gate) { + cancellation_sets + .entry(CancellationSetKey { + gate: GateOrRotation::Gate(op_gate), + qubits: smallvec![wire], + com_set_index: com_set_idx, + second_index: None, + }) + .or_insert_with(Vec::new) + .push(*node); + } - if num_qargs == 1usize && _z_rotations.contains(op_name.as_str()) { - single_q_cancellation_sets - .entry((Z_ROTATION.to_string(), wire, com_set_idx)) - .or_insert_with(Vec::new) - .push(*node); - } - if num_qargs == 1usize && _x_rotations.contains(op_name.as_str()) { - single_q_cancellation_sets - .entry((X_ROTATION.to_string(), wire, com_set_idx)) - .or_insert_with(Vec::new) - .push(*node); - } - // Don't deal with Y rotation, because Y rotation doesn't commute with - // CNOT, so it should be dealt with by optimized1qgate pass - if num_qargs == 2usize - && dag.get_qargs(op.qubits).first().unwrap() == &wire - { - let second_qarg = dag.get_qargs(op.qubits)[1]; - let q2_key = ( - op_name, - wire, - second_qarg, - com_set_idx, - *node_indices - .get(&(*node, Wire::Qubit(second_qarg))) - .unwrap(), - ); - two_q_cancellation_sets - .entry(q2_key) - .or_insert_with(Vec::new) - .push(*node); + if num_qargs == 1usize && Z_ROTATIONS.contains(&op_gate) { + cancellation_sets + .entry(CancellationSetKey { + gate: GateOrRotation::ZRotation, + qubits: smallvec![wire], + com_set_index: com_set_idx, + second_index: None, + }) + .or_insert_with(Vec::new) + .push(*node); + } + if num_qargs == 1usize && X_ROTATIONS.contains(&op_gate) { + cancellation_sets + .entry(CancellationSetKey { + gate: GateOrRotation::XRotation, + qubits: smallvec![wire], + com_set_index: com_set_idx, + second_index: None, + }) + .or_insert_with(Vec::new) + .push(*node); + } + // Don't deal with Y rotation, because Y rotation doesn't commute with + // CNOT, so it should be dealt with by optimized1qgate pass + if num_qargs == 2usize + && dag.get_qargs(op.qubits).first().unwrap() == &wire + { + let second_qarg = dag.get_qargs(op.qubits)[1]; + cancellation_sets + .entry(CancellationSetKey { + gate: GateOrRotation::Gate(op_gate), + qubits: smallvec![wire, second_qarg], + com_set_index: com_set_idx, + second_index: node_indices + .get(&(*node, Wire::Qubit(second_qarg))) + .copied(), + }) + .or_insert_with(Vec::new) + .push(*node); + } } } }) @@ -148,136 +178,112 @@ pub(crate) fn cancel_commutations( } }); - for (cancel_key, cancel_set) in &two_q_cancellation_sets { - if cancel_set.len() > 1 && _gates.contains(cancel_key.0.as_str()) { - for &c_node in &cancel_set[0..(cancel_set.len() / 2) * 2] { - dag.remove_op_node(c_node); + for (cancel_key, cancel_set) in &cancellation_sets { + if cancel_set.len() > 1 { + if let GateOrRotation::Gate(g) = cancel_key.gate { + if SUPPORTED_GATES.contains(&g) { + for &c_node in &cancel_set[0..(cancel_set.len() / 2) * 2] { + dag.remove_op_node(c_node); + } + } } - } - } - - for (cancel_key, cancel_set) in &single_q_cancellation_sets { - if cancel_key.0 == Z_ROTATION && z_var_gate.is_none() { - continue; - } - if cancel_set.len() > 1 && _gates.contains(cancel_key.0.as_str()) { - for &c_node in &cancel_set[0..(cancel_set.len() / 2) * 2] { - dag.remove_op_node(c_node); + if matches!(cancel_key.gate, GateOrRotation::ZRotation) && z_var_gate.is_none() { + continue; } - } else if cancel_set.len() > 1 && (cancel_key.0 == Z_ROTATION || cancel_key.0 == X_ROTATION) - { - let run_op = match &dag.dag[*cancel_set.first().unwrap()] { - NodeType::Operation(instr) => instr, - _ => panic!("Unexpected type in commutation set run."), - }; - - let run_qarg = dag.get_qargs(run_op.qubits).first().unwrap(); - let mut total_angle: f64 = 0.0f64; - let mut total_phase: f64 = 0.0f64; - for current_node in cancel_set { - let node_op = match &dag.dag[*current_node] { + if matches!(cancel_key.gate, GateOrRotation::ZRotation) + || matches!(cancel_key.gate, GateOrRotation::XRotation) + { + let run_op = match &dag.dag[*cancel_set.first().unwrap()] { NodeType::Operation(instr) => instr, _ => panic!("Unexpected type in commutation set run."), }; - let node_op_name = node_op.op.name(); - let node_qargs = dag.get_qargs(node_op.qubits); - if node_op - .extra_attrs - .as_deref() - .is_some_and(|attr| attr.condition.is_some()) - || node_qargs.len() > 1 - || &node_qargs[0] != run_qarg - { - panic!("internal error"); - } + let run_qarg = dag.get_qargs(run_op.qubits).first().unwrap(); + let mut total_angle: f64 = 0.0; + let mut total_phase: f64 = 0.0; + for current_node in cancel_set { + let node_op = match &dag.dag[*current_node] { + NodeType::Operation(instr) => instr, + _ => panic!("Unexpected type in commutation set run."), + }; + let node_op_name = node_op.op.name(); - let node_angle = if ROTATION_GATES.contains(&node_op_name) { - match node_op.params_view().first() { - Some(Param::Float(f)) => *f, - _ => panic!( - "Rotational gate with parameter expression encoutned in cancellation" - ), + let node_qargs = dag.get_qargs(node_op.qubits); + if node_op + .extra_attrs + .as_deref() + .is_some_and(|attr| attr.condition.is_some()) + || node_qargs.len() > 1 + || &node_qargs[0] != run_qarg + { + panic!("internal error"); } - } else if HALF_TURNS.contains(&node_op_name) { - PI - } else if QUARTER_TURNS.contains(&node_op_name) { - PI / 2.0 - } else if EIGHTH_TURNS.contains(&node_op_name) { - PI / 4.0 - } else { - panic!("Angle for operation {node_op_name} is not defined") - }; - total_angle += node_angle; - if let Some(definition) = node_op.op.definition(node_op.params_view()) { - //TODO check for PyNone global phase? - //total_phase += match definition.global_phase() {Param::Float(f) => f, Param::Obj(pyop) => , Param::ParameterExpression(_) => panic!("PackedInstruction with definition has global phase set as parameter expression")}; - total_phase += match definition.global_phase() {Param::Float(f) => f, _ => panic!("PackedInstruction with definition has no global phase set as floating point number")}; + let node_angle = if ROTATION_GATES.contains(&node_op_name) { + match node_op.params_view().first() { + Some(Param::Float(f)) => *f, + _ => return Err(QiskitError::new_err(format!( + "Rotational gate with parameter expression encoutned in cancellation {:?}", + node_op.op + ))) + } + } else if HALF_TURNS.contains(&node_op_name) { + PI + } else if QUARTER_TURNS.contains(&node_op_name) { + PI / 2.0 + } else if EIGHTH_TURNS.contains(&node_op_name) { + PI / 4.0 + } else { + panic!("Angle for operation {node_op_name} is not defined") + }; + total_angle += node_angle; + + if let Some(definition) = node_op.op.definition(node_op.params_view()) { + total_phase += match definition.global_phase() {Param::Float(f) => f, _ => panic!("PackedInstruction with definition has no global phase set as floating point number")}; + } } - } - let new_op = if cancel_key.0 == Z_ROTATION { - z_var_gate.unwrap() - } else if cancel_key.0 == X_ROTATION { - &RXGate - } else { - panic!("impossible case!"); - }; + let new_op = if matches!(cancel_key.gate, GateOrRotation::ZRotation) { + z_var_gate.unwrap() + } else if matches!(cancel_key.gate, GateOrRotation::XRotation) { + &RXGate + } else { + return Err(QiskitError::new_err("impossible case!")); + }; - let gate_angle = mod_2pi(total_angle, 0.); + let gate_angle = euler_one_qubit_decomposer::mod_2pi(total_angle, 0.); - let new_op_phase: f64 = if gate_angle.abs() > _CUTOFF_PRECISION { - let new_index = dag.insert_1q_on_incoming_qubit( - (*new_op, &[total_angle]), - *cancel_set.first().unwrap(), - ); - let new_node = match &dag.dag[new_index] { - NodeType::Operation(instr) => instr, - _ => panic!("Unexpected type in commutation set run."), - }; + let new_op_phase: f64 = if gate_angle.abs() > _CUTOFF_PRECISION { + let new_index = dag.insert_1q_on_incoming_qubit( + (*new_op, &[total_angle]), + *cancel_set.first().unwrap(), + ); + let new_node = match &dag.dag[new_index] { + NodeType::Operation(instr) => instr, + _ => panic!("Unexpected type in commutation set run."), + }; - if let Some(definition) = new_node.op.definition(new_node.params_view()) { - //TODO check for PyNone global phase? - match definition.global_phase() {Param::Float(f) => *f, _ => panic!("PackedInstruction with definition has no global phase set as floating point number")} + if let Some(definition) = new_node.op.definition(new_node.params_view()) { + match definition.global_phase() {Param::Float(f) => *f, _ => panic!("PackedInstruction with definition has no global phase set as floating point number")} + } else { + 0.0 + } } else { 0.0 - } - } else { - 0.0 - }; + }; - dag.add_global_phase(py, &Param::Float(total_phase - new_op_phase))?; + dag.add_global_phase(py, &Param::Float(total_phase - new_op_phase))?; - for node in cancel_set { - dag.remove_op_node(*node); + for node in cancel_set { + dag.remove_op_node(*node); + } } - - //TODO do we need this due to numerical instability? - /* - if np.mod(total_angle, (2 * np.pi)) < _CUTOFF_PRECISION: - dag.remove_op_node(run[0]) - */ } } Ok(()) } -/// Wrap angle into interval [-π,π). If within atol of the endpoint, clamp to -π -#[inline] -fn mod_2pi(angle: f64, atol: f64) -> f64 { - // f64::rem_euclid() isn't exactly the same as Python's % operator, but because - // the RHS here is a constant and positive it is effectively equivalent for - // this case - let wrapped = (angle + PI).rem_euclid(2. * PI) - PI; - if (wrapped - PI).abs() < atol { - -PI - } else { - wrapped - } -} - #[pymodule] pub fn commutation_cancellation(m: &Bound) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(cancel_commutations))?; diff --git a/crates/accelerate/src/euler_one_qubit_decomposer.rs b/crates/accelerate/src/euler_one_qubit_decomposer.rs index a3cb11ea45a2..4c0a8539cf34 100644 --- a/crates/accelerate/src/euler_one_qubit_decomposer.rs +++ b/crates/accelerate/src/euler_one_qubit_decomposer.rs @@ -924,7 +924,7 @@ pub fn det_one_qubit(mat: ArrayView2) -> Complex64 { /// Wrap angle into interval [-π,π). If within atol of the endpoint, clamp to -π #[inline] -fn mod_2pi(angle: f64, atol: f64) -> f64 { +pub(crate) fn mod_2pi(angle: f64, atol: f64) -> f64 { // f64::rem_euclid() isn't exactly the same as Python's % operator, but because // the RHS here is a constant and positive it is effectively equivalent for // this case diff --git a/qiskit/transpiler/passes/optimization/commutative_cancellation.py b/qiskit/transpiler/passes/optimization/commutative_cancellation.py index 98a9aac2aa77..130ff0609354 100644 --- a/qiskit/transpiler/passes/optimization/commutative_cancellation.py +++ b/qiskit/transpiler/passes/optimization/commutative_cancellation.py @@ -64,7 +64,7 @@ def __init__(self, basis_gates=None, target=None): # build a commutation checker restricted to the gates we cancel -- the others we # do not have to investigate, which allows to save time - self.commutation_checker = CommutationChecker( + self._commutation_checker = CommutationChecker( StandardGateCommutations, gates=self._gates | self._z_rotations | self._x_rotations ) @@ -78,7 +78,5 @@ def run(self, dag): Returns: DAGCircuit: the optimized DAG. """ - commutation_cancellation.cancel_commutations( - dag, self.commutation_checker, self.basis, self.target - ) + commutation_cancellation.cancel_commutations(dag, self._commutation_checker, self.basis) return dag