Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: Raynel Sanchez <[email protected]>
  • Loading branch information
sbrandhsn and raynelfss authored Aug 12, 2024
1 parent b59499a commit f5bb696
Showing 1 changed file with 26 additions and 60 deletions.
86 changes: 26 additions & 60 deletions crates/accelerate/src/commutation_checker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@ impl CommutationChecker {
#[pyo3(signature = (standard_gate_commutations=None, cache_max_entries=1_000_000, gates=None))]
#[new]
fn py_new(
py: Python,
standard_gate_commutations: Option<Py<PyAny>>,
standard_gate_commutations: Option<Bound<PyAny>>, // Send a bound here
cache_max_entries: usize,
gates: Option<HashSet<String>>,
) -> Self {
Expand All @@ -86,27 +85,19 @@ impl CommutationChecker {
max_num_qubits: u32,
) -> PyResult<bool> {
let mut bq: BitData<Qubit> = BitData::new(py, "qubits".to_string());
op1.instruction
.qubits
.bind(py)
let op1_bound_qubits = op1.instruction.qubits.bind(py);
let op2_bound_qubits = op2.instruction.qubits.bind(py);
op1_bound_qubits
.iter()
.for_each(|q| bq.add(py, &q, false).unwrap());
op2.instruction
.qubits
.bind(py)
op2_bound_qubits
.iter()
.for_each(|q| bq.add(py, &q, false).unwrap());
let qargs1 = op1
.instruction
.qubits
.bind(py)
let qargs1 = op1_bound_qubits
.iter()
.map(|q| bq.find(&q).unwrap().0 as usize)
.collect::<Vec<_>>();
let qargs2 = op2
.instruction
.qubits
.bind(py)
let qargs2 = op2_bound_qubits
.iter()
.map(|q| bq.find(&q).unwrap().0 as usize)
.collect::<Vec<_>>();
Expand Down Expand Up @@ -292,22 +283,22 @@ impl CommutationChecker {

let skip_cache: bool = NO_CACHE_NAMES.contains(&first_op.name()) ||
NO_CACHE_NAMES.contains(&second_op.name()) ||
//skip params that do not evaluate to floats for caching and commutation library
// Skip params that do not evaluate to floats for caching and commutation library
first_instr.params.iter().any(|p| !matches!(p, Param::Float(_))) ||
second_instr.params.iter().any(|p| !matches!(p, Param::Float(_)));

if skip_cache {
return self.commute_matmul(first_instr, first_qargs, second_instr, second_qargs);
}

//query commutation library
// Query commutation library
if let Some(is_commuting) =
self.library
.check_commutation_entries(&first_op, first_qargs, &second_op, second_qargs)
{
return is_commuting;
}
//query cache
// Query cache
if let Some(commutation_dict) = self
.cache
.get(&(first_op.name().to_string(), second_op.name().to_string()))
Expand All @@ -328,15 +319,15 @@ impl CommutationChecker {
self._cache_miss += 1;
}

// perform matrix multiplication to determine commutation
// Perform matrix multiplication to determine commutation
let is_commuting =
self.commute_matmul(first_instr, first_qargs, second_instr, second_qargs);

// TODO: implement a LRU cache for this
if self.current_cache_entries >= self.cache_max_entries {
self.clear_cache();
}
// cache results from is_commuting
// Cache results from is_commuting
self.cache
.entry((first_op.name().to_string(), second_op.name().to_string()))
.and_modify(|entries| {
Expand Down Expand Up @@ -387,14 +378,8 @@ impl CommutationChecker {
}
}

let first_qarg: Vec<_> = first_qargs
.iter()
.map(|q| *qarg.get_item(q).unwrap())
.collect();
let second_qarg: Vec<_> = second_qargs
.iter()
.map(|q| *qarg.get_item(q).unwrap())
.collect();
let first_qarg: Vec<_> = first_qargs.iter().map(|q| *qarg.get(q).unwrap()).collect();
let second_qarg: Vec<_> = second_qargs.iter().map(|q| *qarg.get(q).unwrap()).collect();

assert!(
first_qarg.len() <= second_qarg.len(),
Expand Down Expand Up @@ -547,7 +532,6 @@ impl CommutationChecker {
self.current_cache_entries = 0;
self._cache_miss = 0;
self._cache_hit = 0;
self._cache_hit = 0;
}
}

Expand Down Expand Up @@ -627,7 +611,7 @@ impl<'py> FromPyObject<'py> for CommutationLibraryEntry {
impl ToPyObject for CommutationLibraryEntry {
fn to_object(&self, py: Python) -> PyObject {
match self {
CommutationLibraryEntry::Commutes(b) => PyBool::new_bound(py, *b).to_object(py),
CommutationLibraryEntry::Commutes(b) => b.into_py(py),
CommutationLibraryEntry::QubitMapping(qm) => {
let out_dict = PyDict::new_bound(py);

Expand All @@ -639,24 +623,22 @@ impl ToPyObject for CommutationLibraryEntry {
k.iter()
.map(|q| q.map(|t| t.0))
.collect::<Vec<Option<u32>>>(),
)
.to_object(py),
PyBool::new_bound(py, *v).to_object(py),
),
PyBool::new_bound(py, *v),
)
.ok()
.unwrap()
});
out_dict.to_object(py)
out_dict.unbind().into_any()
}
}
}
}

type CacheKey = (
SmallVec<[Option<Qubit>; 2]>,
(SmallVec<[ParameterKey; 3]>, SmallVec<[ParameterKey; 3]>),
);
//need a struct instead of a type definition because we cannot implement serialization traits otherwise
// Need a struct instead of a type definition because we cannot implement serialization traits otherwise
#[derive(Clone)]
struct CommutationCacheEntry {
mapping: HashMap<CacheKey, bool>,
Expand All @@ -678,40 +660,24 @@ impl ToPyObject for CommutationCacheEntry {
fn to_object(&self, py: Python) -> PyObject {
let out_dict = PyDict::new_bound(py);
for (k, v) in self.iter() {
let qubits = PyTuple::new_bound(
let qubits = PyTuple::new_bound(
py,
k.0.iter()
.map(|q| q.map(|t| t.0))
.collect::<Vec<Option<u32>>>(),
)
.to_object(py);
);
let params0 =
PyTuple::new_bound(py, k.1 .0.iter().map(|pk| pk.0).collect::<Vec<f64>>())
.to_object(py);
PyTuple::new_bound(py, k.1 .0.iter().map(|pk| pk.0).collect::<Vec<f64>>());
let params1 =
PyTuple::new_bound(py, k.1 .1.iter().map(|pk| pk.0).collect::<Vec<f64>>())
.to_object(py);
PyTuple::new_bound(py, k.1 .1.iter().map(|pk| pk.0).collect::<Vec<f64>>());
out_dict
.set_item(
PyTuple::new_bound(
py,
[
qubits,
PyTuple::new_bound(
py,
[params0, params1].iter().collect::<Vec<&PyObject>>(),
)
.to_object(py),
]
.iter()
.collect::<Vec<&PyObject>>(),
),
PyBool::new_bound(py, *v).to_object(py),
PyTuple::new_bound(py, [qubits, PyTuple::new_bound(py, [params0, params1])]),
PyBool::new_bound(py, *v),
)
.expect("Failed to construct commutation cache for serialization");
}
out_dict.to_object(py)
}
out_dict.into_any().unbind()
}

type CacheKeyRaw = (
Expand Down

0 comments on commit f5bb696

Please sign in to comment.