Skip to content

Commit

Permalink
review comments & more docs
Browse files Browse the repository at this point in the history
  • Loading branch information
Cryoris committed Sep 4, 2024
1 parent 7207998 commit c08cfae
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 133 deletions.
257 changes: 126 additions & 131 deletions crates/accelerate/src/commutation_analysis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,11 @@
// copyright notice, and modified files need to carry a notice indicating
// that they have been altered from the originals.

use pyo3::exceptions::PyValueError;
use pyo3::prelude::PyModule;
use pyo3::{pyfunction, pymodule, wrap_pyfunction, Bound, PyResult, Python};
use qiskit_circuit::operations::Param;
use qiskit_circuit::Qubit;
use smallvec::{smallvec, SmallVec};
use std::hash::BuildHasherDefault;

use crate::commutation_checker::CommutationChecker;
use ahash::AHasher;
use hashbrown::HashMap;
use indexmap::IndexSet;
use pyo3::prelude::*;
Expand All @@ -28,116 +23,117 @@ use pyo3::types::{PyDict, PyList};
use qiskit_circuit::dag_circuit::{DAGCircuit, NodeType, Wire};
use rustworkx_core::petgraph::stable_graph::NodeIndex;

type AIndexSet<T> = IndexSet<T, BuildHasherDefault<AHasher>>;
#[derive(Clone, Debug)]
pub enum CommutationSetEntry {
Index(usize),
SetExists(Vec<AIndexSet<NodeIndex>>),
}

// custom types: IndexSet allows to iterate over the elements in insertion
// order, but uses the std hasher, which is slower than ahash
type AIndexSet<T> = IndexSet<T, ::ahash::RandomState>;
type CommutingNodes = Vec<AIndexSet<NodeIndex>>;

/// Compute the commutation sets for a given DAG.
///
/// We return two HashMaps:
/// * {wire: commutation_sets}: For each wire, we keep a vector of index sets, where each index
/// set contains mutually commuting nodes. Note that these include the input and output nodes
/// which do not commute with anything.
/// * {(node, wire): index}: For each (node, wire) pair we store the index indicating in which
/// commutation set the node appears on a given wire.
///
/// For example, if we have a circuit
///
/// |0> -- X -- SX -- Z (out)
/// 0 2 3 4 1 <-- node indices including input (0) and output (1) nodes
///
/// Then we would have
///
/// commutation_set = {0: [[0], [2, 3], [4], [1]]}
/// node_indices = {(0, 0): 0, (1, 0): 3, (2, 0): 1, (3, 0): 1, (4, 0): 2}
///
#[allow(clippy::type_complexity)]
fn analyze_commutations_inner(
py: Python,
dag: &mut DAGCircuit,
commutation_checker: &mut CommutationChecker,
) -> PyResult<HashMap<(Option<NodeIndex>, Wire), CommutationSetEntry>> {
// The commutation set stores two types of keys:
// * (None, wire): The indices of the commuting nodes on the wire
// * (node, wire): The index containing the node on a given wire in above vector
// The Option<NodeIndex> thus captures None/node and the CommutationSetEntry enum captures
// the fact that the value could be an index or a set of nodes.
let mut commutation_set: HashMap<(Option<NodeIndex>, Wire), CommutationSetEntry> =
HashMap::new();
let max_num_qubits = 3;
) -> PyResult<(
HashMap<Wire, CommutingNodes>,
HashMap<(NodeIndex, Wire), usize>,
)> {
let mut commutation_set: HashMap<Wire, CommutingNodes> = HashMap::new();
let mut node_indices: HashMap<(NodeIndex, Wire), usize> = HashMap::new();

// placeholder parameters we can pass to the commutation checker, in case there are no
// parameters in our instructions
let empty_params: Box<SmallVec<[Param; 3]>> = Box::new(smallvec![]);
let max_num_qubits = 3;

for qubit in 0..dag.num_qubits() {
let wire = Wire::Qubit(Qubit(qubit as u32));

for current_gate_idx in dag.nodes_on_wire(py, &wire, false) {
// get the commutation set associated with the current wire
if let CommutationSetEntry::SetExists(ref mut commutation_entry) = commutation_set
.entry((None, wire.clone()))
.or_insert_with(|| {
CommutationSetEntry::SetExists(vec![AIndexSet::from_iter([current_gate_idx])])
})
{
let last = commutation_entry.last_mut().unwrap();

if !last.contains(&current_gate_idx) {
let mut all_commute = true;
for prev_gate_idx in last.iter() {
if let (
NodeType::Operation(packed_inst0),
NodeType::Operation(packed_inst1),
) = (&dag.dag[current_gate_idx], &dag.dag[*prev_gate_idx])
{
let op1 = packed_inst0.op.view();
let op2 = packed_inst1.op.view();
let params1 = match packed_inst0.params.as_ref() {
Some(params) => params,
None => &empty_params,
};
let params2 = match packed_inst1.params.as_ref() {
Some(params) => params,
None => &empty_params,
};
let qargs1 = dag.qargs_interner.get(packed_inst0.qubits);
let qargs2 = dag.qargs_interner.get(packed_inst1.qubits);
let cargs1 = dag.cargs_interner.get(packed_inst0.clbits);
let cargs2 = dag.cargs_interner.get(packed_inst1.clbits);

all_commute = commutation_checker.commute_inner(
py,
&op1,
params1,
packed_inst0.extra_attrs.as_deref(),
qargs1,
cargs1,
&op2,
params2,
packed_inst1.extra_attrs.as_deref(),
qargs2,
cargs2,
max_num_qubits,
)?;
if !all_commute {
break;
}
} else {
all_commute = false;
// get the commutation set associated with the current wire, or create a new
// index set containing the current gate
let commutation_entry = commutation_set
.entry(wire.clone())
.or_insert_with(|| vec![AIndexSet::from_iter([current_gate_idx])]);

// we can unwrap as we know the commutation entry has at least one element
let last = commutation_entry.last_mut().unwrap();

// if the current gate index is not in the set, check whether it commutes with
// the previous nodes -- if yes, add it to the commutation set
if !last.contains(&current_gate_idx) {
let mut all_commute = true;

for prev_gate_idx in last.iter() {
// if the node is an input/output node, they do not commute, so we only
// continue if the nodes are operation nodes
if let (NodeType::Operation(packed_inst0), NodeType::Operation(packed_inst1)) =
(&dag.dag[current_gate_idx], &dag.dag[*prev_gate_idx])
{
let op1 = packed_inst0.op.view();
let op2 = packed_inst1.op.view();
let params1 = packed_inst0.params_view();
let params2 = packed_inst1.params_view();
let qargs1 = dag.get_qargs(packed_inst0.qubits);
let qargs2 = dag.get_qargs(packed_inst1.qubits);
let cargs1 = dag.get_cargs(packed_inst0.clbits);
let cargs2 = dag.get_cargs(packed_inst1.clbits);

all_commute = commutation_checker.commute_inner(
py,
&op1,
params1,
packed_inst0.extra_attrs.as_deref(),
qargs1,
cargs1,
&op2,
params2,
packed_inst1.extra_attrs.as_deref(),
qargs2,
cargs2,
max_num_qubits,
)?;
if !all_commute {
break;
}
}

if all_commute {
// all commute, add to current list
last.insert(current_gate_idx);
} else {
// does not commute, create new list
commutation_entry.push(AIndexSet::from_iter([current_gate_idx]))
all_commute = false;
break;
}
}
} else {
return Err(PyValueError::new_err(
"Wrong format in commutation analysis, expected SetExists but got Index",
));
}

if let CommutationSetEntry::SetExists(last_entry) =
commutation_set.get(&(None, wire.clone())).unwrap()
{
commutation_set.insert(
(Some(current_gate_idx), wire.clone()),
CommutationSetEntry::Index(last_entry.len() - 1),
);
if all_commute {
// all commute, add to current list
last.insert(current_gate_idx);
} else {
// does not commute, create new list
commutation_entry.push(AIndexSet::from_iter([current_gate_idx]))
}
}

node_indices.insert(
(current_gate_idx, wire.clone()),
commutation_entry.len() - 1,
);
}
}

Ok(commutation_set)
Ok((commutation_set, node_indices))
}

#[pyfunction]
Expand All @@ -147,51 +143,50 @@ pub(crate) fn analyze_commutations(
dag: &mut DAGCircuit,
commutation_checker: &mut CommutationChecker,
) -> PyResult<Py<PyDict>> {
let commutations = analyze_commutations_inner(py, dag, commutation_checker)?;
// This returns two HashMaps:
// * The commuting nodes per wire: {wire: [commuting_nodes_1, commuting_nodes_2, ...]}
// * The index in which commutation set a given node is located on a wire: {(node, wire): index}
// The Python dict will store both of these dictionaries in one.
let (commutation_set, node_indices) = analyze_commutations_inner(py, dag, commutation_checker)?;

let out_dict = PyDict::new_bound(py);
for (k, comms) in commutations {
let nidx = k.0;
let wire = match k.1 {

// First set the {wire: [commuting_nodes_1, ...]} bit
for (wire, commutations) in commutation_set {
let py_wire = match wire {
Wire::Qubit(q) => dag.qubits.get(q).unwrap().to_object(py),
Wire::Clbit(c) => dag.clbits.get(c).unwrap().to_object(py),
Wire::Var(v) => v,
};

if nidx.is_some() {
match comms {
CommutationSetEntry::Index(idx) => {
out_dict.set_item((dag.get_node(py, nidx.unwrap())?, wire), idx)?
}
_ => {
return Err(PyValueError::new_err(
"Wrong format in commutation analysis, expected Index but found SetExists",
));
}
};
} else {
match comms {
CommutationSetEntry::SetExists(comm_set) => out_dict.set_item(
wire,
out_dict.set_item(
py_wire,
PyList::new_bound(
py,
commutations.iter().map(|inner| {
PyList::new_bound(
py,
comm_set.iter().map(|inner| {
PyList::new_bound(
py,
inner
.into_iter()
.map(|ndidx| dag.get_node(py, *ndidx).unwrap()),
)
}),
),
)?,
_ => {
return Err(PyValueError::new_err(
"Wrong format in commutation analysis, expected SetExists but found Index",
));
}
}
}
inner
.into_iter()
.map(|node_index| dag.get_node(py, *node_index).unwrap()),
)
}),
),
)?;
}

// Then we add the {(node, wire): index} dictionary
for ((node_index, wire), index) in node_indices {
// we could cache the py_wires to avoid this match and the python object creation,
// but this didn't make a noticable difference in runtime
let py_wire = match wire {
Wire::Qubit(q) => dag.qubits.get(q).unwrap().to_object(py),
Wire::Clbit(c) => dag.clbits.get(c).unwrap().to_object(py),
Wire::Var(v) => v,
};
out_dict.set_item((dag.get_node(py, node_index)?, py_wire), index)?;
}

Ok(out_dict.unbind())
}

Expand Down
4 changes: 2 additions & 2 deletions crates/circuit/src/dag_circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,9 +231,9 @@ pub struct DAGCircuit {
cregs: Py<PyDict>,

/// The cache used to intern instruction qargs.
pub qargs_interner: Interner<[Qubit]>,
qargs_interner: Interner<[Qubit]>,
/// The cache used to intern instruction cargs.
pub cargs_interner: Interner<[Clbit]>,
cargs_interner: Interner<[Clbit]>,
/// Qubits registered in the circuit.
pub qubits: BitData<Qubit>,
/// Clbits registered in the circuit.
Expand Down

0 comments on commit c08cfae

Please sign in to comment.