Skip to content

Commit

Permalink
code review
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrandhsn committed Sep 10, 2024
1 parent d0cfc8c commit 55c0730
Showing 1 changed file with 63 additions and 73 deletions.
136 changes: 63 additions & 73 deletions crates/accelerate/src/commutation_cancellation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
use std::f64::consts::PI;

use hashbrown::{HashMap, HashSet};
use pyo3::exceptions::PyRuntimeError;
use pyo3::prelude::*;
use pyo3::{pyfunction, pymodule, wrap_pyfunction, Bound, PyResult, Python};
use rustworkx_core::petgraph::stable_graph::NodeIndex;
Expand Down Expand Up @@ -110,70 +111,69 @@ pub(crate) fn cancel_commutations(
} else {
continue;
}
com_set.iter().for_each(|node| {
for node in com_set.iter() {
let instr = match &dag.dag[*node] {
NodeType::Operation(instr) => instr,
_ => panic!("Unexpected type in commutation set."),
};
let num_qargs = dag.get_qargs(instr.qubits).len();
// no support for cancellation of parameterized gates
if !instr.is_parameterized() {
if let Some(op_gate) = instr.op.try_standard_gate() {
if num_qargs == 1 && SUPPORTED_GATES.contains(&op_gate) {
cancellation_sets
.entry(CancellationSetKey {
gate: GateOrRotation::Gate(op_gate),
qubits: smallvec![wire],
com_set_index: com_set_idx,
second_index: None,
})
.or_insert_with(Vec::new)
.push(*node);
}
if instr.is_parameterized() {
continue;
}
if let Some(op_gate) = instr.op.try_standard_gate() {
if num_qargs == 1 && SUPPORTED_GATES.contains(&op_gate) {
cancellation_sets
.entry(CancellationSetKey {
gate: GateOrRotation::Gate(op_gate),
qubits: smallvec![wire],
com_set_index: com_set_idx,
second_index: None,
})
.or_insert_with(Vec::new)
.push(*node);
}

if num_qargs == 1 && Z_ROTATIONS.contains(&op_gate) {
cancellation_sets
.entry(CancellationSetKey {
gate: GateOrRotation::ZRotation,
qubits: smallvec![wire],
com_set_index: com_set_idx,
second_index: None,
})
.or_insert_with(Vec::new)
.push(*node);
}
if num_qargs == 1 && X_ROTATIONS.contains(&op_gate) {
cancellation_sets
.entry(CancellationSetKey {
gate: GateOrRotation::XRotation,
qubits: smallvec![wire],
com_set_index: com_set_idx,
second_index: None,
})
.or_insert_with(Vec::new)
.push(*node);
}
// Don't deal with Y rotation, because Y rotation doesn't commute with
// CNOT, so it should be dealt with by optimized1qgate pass
if num_qargs == 2
&& dag.get_qargs(instr.qubits)[0] == &wire
{
let second_qarg = dag.get_qargs(instr.qubits)[1];
cancellation_sets
.entry(CancellationSetKey {
gate: GateOrRotation::Gate(op_gate),
qubits: smallvec![wire, second_qarg],
com_set_index: com_set_idx,
second_index: node_indices
.get(&(*node, Wire::Qubit(second_qarg)))
.copied(),
})
.or_insert_with(Vec::new)
.push(*node);
}
if num_qargs == 1 && Z_ROTATIONS.contains(&op_gate) {
cancellation_sets
.entry(CancellationSetKey {
gate: GateOrRotation::ZRotation,
qubits: smallvec![wire],
com_set_index: com_set_idx,
second_index: None,
})
.or_insert_with(Vec::new)
.push(*node);
}
if num_qargs == 1 && X_ROTATIONS.contains(&op_gate) {
cancellation_sets
.entry(CancellationSetKey {
gate: GateOrRotation::XRotation,
qubits: smallvec![wire],
com_set_index: com_set_idx,
second_index: None,
})
.or_insert_with(Vec::new)
.push(*node);
}
// Don't deal with Y rotation, because Y rotation doesn't commute with
// CNOT, so it should be dealt with by optimized1qgate pass
if num_qargs == 2 && dag.get_qargs(instr.qubits)[0] == wire {
let second_qarg = dag.get_qargs(instr.qubits)[1];
cancellation_sets
.entry(CancellationSetKey {
gate: GateOrRotation::Gate(op_gate),
qubits: smallvec![wire, second_qarg],
com_set_index: com_set_idx,
second_index: node_indices
.get(&(*node, Wire::Qubit(second_qarg)))
.copied(),
})
.or_insert_with(Vec::new)
.push(*node);
}
}
})
}
}
}
});
Expand All @@ -191,12 +191,10 @@ pub(crate) fn cancel_commutations(
if matches!(cancel_key.gate, GateOrRotation::ZRotation) && z_var_gate.is_none() {
continue;
}
if matches!(cancel_key.gate, GateOrRotation::ZRotation | GateOrRotation::XRotation)
{
let run_qarg = match &dag.dag[*cancel_set[0]] {
NodeType::Operation(instr) => dag.get_qargs(instr.qubits)[0],
_ => panic!("Unexpected type in commutation set run."),
};
if matches!(
cancel_key.gate,
GateOrRotation::ZRotation | GateOrRotation::XRotation
) {
let mut total_angle: f64 = 0.0;
let mut total_phase: f64 = 0.0;
for current_node in cancel_set {
Expand All @@ -206,14 +204,6 @@ pub(crate) fn cancel_commutations(
};
let node_op_name = node_op.op.name();

let node_qargs = dag.get_qargs(node_op.qubits);
if node_op.condition().is_some()
|| node_qargs.len() > 1
|| node_qargs[0] != run_qarg
{
return Err(QiskitError::new_err("internal error"));
}

let node_angle = if ROTATION_GATES.contains(&node_op_name) {
match node_op.params_view().first() {
Some(Param::Float(f)) => Ok(*f),
Expand All @@ -229,7 +219,7 @@ pub(crate) fn cancel_commutations(
} else if EIGHTH_TURNS.contains(&node_op_name) {
Ok(PI / 4.0)
} else {
Err(QiskitError::new_err(format!(
Err(PyRuntimeError::new_err(format!(
"Angle for operation {} is not defined",
node_op_name
)))
Expand All @@ -240,10 +230,10 @@ pub(crate) fn cancel_commutations(
total_phase += new_phase
}

let new_op = match cancel_key.gate
let new_op = match cancel_key.gate {
GateOrRotation::ZRotation => z_var_gate.unwrap(),
GateOrRotation::XRotation) => &RXGate,
_ => unreachable!()
GateOrRotation::XRotation => &RXGate,
_ => unreachable!(),
};

let gate_angle = euler_one_qubit_decomposer::mod_2pi(total_angle, 0.);
Expand Down

0 comments on commit 55c0730

Please sign in to comment.