Skip to content

Commit

Permalink
Use vec over IndexSet + clippy
Browse files Browse the repository at this point in the history
- vec<vec> is slightly faster than vec<indexset>
- add custom types to satisfies clippy's complex type complaint
- don't handle Clbit/Var
  • Loading branch information
Cryoris committed Sep 4, 2024
1 parent c08cfae commit 7073570
Showing 1 changed file with 20 additions and 25 deletions.
45 changes: 20 additions & 25 deletions crates/accelerate/src/commutation_analysis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,26 @@
// copyright notice, and modified files need to carry a notice indicating
// that they have been altered from the originals.

use pyo3::exceptions::PyValueError;
use pyo3::prelude::PyModule;
use pyo3::{pyfunction, pymodule, wrap_pyfunction, Bound, PyResult, Python};
use qiskit_circuit::Qubit;

use crate::commutation_checker::CommutationChecker;
use hashbrown::HashMap;
use indexmap::IndexSet;
use pyo3::prelude::*;

use pyo3::types::{PyDict, PyList};
use qiskit_circuit::dag_circuit::{DAGCircuit, NodeType, Wire};
use rustworkx_core::petgraph::stable_graph::NodeIndex;

// custom types: IndexSet allows to iterate over the elements in insertion
// order, but uses the std hasher, which is slower than ahash
type AIndexSet<T> = IndexSet<T, ::ahash::RandomState>;
type CommutingNodes = Vec<AIndexSet<NodeIndex>>;
// Custom types to store the commutation sets and node indices,
// see the docstring below for more information.
type CommutationSet = HashMap<Wire, Vec<Vec<NodeIndex>>>;
type NodeIndices = HashMap<(NodeIndex, Wire), usize>;

// the maximum number of qubits we check commutativity for
const MAX_NUM_QUBITS: u32 = 3;

/// Compute the commutation sets for a given DAG.
///
Expand All @@ -47,19 +50,13 @@ type CommutingNodes = Vec<AIndexSet<NodeIndex>>;
/// commutation_set = {0: [[0], [2, 3], [4], [1]]}
/// node_indices = {(0, 0): 0, (1, 0): 3, (2, 0): 1, (3, 0): 1, (4, 0): 2}
///
#[allow(clippy::type_complexity)]
fn analyze_commutations_inner(
py: Python,
dag: &mut DAGCircuit,
commutation_checker: &mut CommutationChecker,
) -> PyResult<(
HashMap<Wire, CommutingNodes>,
HashMap<(NodeIndex, Wire), usize>,
)> {
let mut commutation_set: HashMap<Wire, CommutingNodes> = HashMap::new();
let mut node_indices: HashMap<(NodeIndex, Wire), usize> = HashMap::new();

let max_num_qubits = 3;
) -> PyResult<(CommutationSet, NodeIndices)> {
let mut commutation_set: CommutationSet = HashMap::new();
let mut node_indices: NodeIndices = HashMap::new();

for qubit in 0..dag.num_qubits() {
let wire = Wire::Qubit(Qubit(qubit as u32));
Expand All @@ -69,7 +66,7 @@ fn analyze_commutations_inner(
// index set containing the current gate
let commutation_entry = commutation_set
.entry(wire.clone())
.or_insert_with(|| vec![AIndexSet::from_iter([current_gate_idx])]);
.or_insert_with(|| vec![vec![current_gate_idx]]);

// we can unwrap as we know the commutation entry has at least one element
let last = commutation_entry.last_mut().unwrap();
Expand Down Expand Up @@ -106,7 +103,7 @@ fn analyze_commutations_inner(
packed_inst1.extra_attrs.as_deref(),
qargs2,
cargs2,
max_num_qubits,
MAX_NUM_QUBITS,
)?;
if !all_commute {
break;
Expand All @@ -119,10 +116,10 @@ fn analyze_commutations_inner(

if all_commute {
// all commute, add to current list
last.insert(current_gate_idx);
last.push(current_gate_idx);
} else {
// does not commute, create new list
commutation_entry.push(AIndexSet::from_iter([current_gate_idx]))
commutation_entry.push(vec![current_gate_idx]);
}
}

Expand Down Expand Up @@ -153,10 +150,11 @@ pub(crate) fn analyze_commutations(

// First set the {wire: [commuting_nodes_1, ...]} bit
for (wire, commutations) in commutation_set {
// we know all wires are of type Wire::Qubit, since in analyze_commutations_inner
// we only iterater over the qubits
let py_wire = match wire {
Wire::Qubit(q) => dag.qubits.get(q).unwrap().to_object(py),
Wire::Clbit(c) => dag.clbits.get(c).unwrap().to_object(py),
Wire::Var(v) => v,
_ => return Err(PyValueError::new_err("Unexpected wire type.")),
};

out_dict.set_item(
Expand All @@ -167,7 +165,7 @@ pub(crate) fn analyze_commutations(
PyList::new_bound(
py,
inner
.into_iter()
.iter()
.map(|node_index| dag.get_node(py, *node_index).unwrap()),
)
}),
Expand All @@ -177,12 +175,9 @@ pub(crate) fn analyze_commutations(

// Then we add the {(node, wire): index} dictionary
for ((node_index, wire), index) in node_indices {
// we could cache the py_wires to avoid this match and the python object creation,
// but this didn't make a noticable difference in runtime
let py_wire = match wire {
Wire::Qubit(q) => dag.qubits.get(q).unwrap().to_object(py),
Wire::Clbit(c) => dag.clbits.get(c).unwrap().to_object(py),
Wire::Var(v) => v,
_ => return Err(PyValueError::new_err("Unexpected wire type.")),
};
out_dict.set_item((dag.get_node(py, node_index)?, py_wire), index)?;
}
Expand Down

0 comments on commit 7073570

Please sign in to comment.