Skip to content

Commit

Permalink
Make serialization explicit
Browse files Browse the repository at this point in the history
This commit makes the pickling of cache entries explicit. Previously it
was relying on conversion traits which hid some of the complexity but
this uses a pair of conversion functions instead.
  • Loading branch information
mtreinish committed Aug 30, 2024
1 parent a05df8c commit f231c83
Showing 1 changed file with 72 additions and 76 deletions.
148 changes: 72 additions & 76 deletions crates/accelerate/src/commutation_checker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
// that they have been altered from the originals.

use approx::abs_diff_eq;
use hashbrown::hash_map::Iter;
use hashbrown::{HashMap, HashSet};
use ndarray::linalg::kron;
use ndarray::Array2;
Expand Down Expand Up @@ -45,12 +44,36 @@ static SUPPORTED_OP: Lazy<HashSet<&str>> = Lazy::new(|| {
])
});

fn get_bits<T>(
py: Python,
bits1: &Bound<PyTuple>,
bits2: &Bound<PyTuple>,
) -> PyResult<(Vec<T>, Vec<T>)>
where
T: From<BitType> + Copy,
BitType: From<T>,
{
let mut bitdata: BitData<T> = BitData::new(py, "bits".to_string());

bits1.iter().chain(bits2.iter()).for_each(|bit| {
bitdata.add(py, &bit, false).unwrap();
});

Ok((
bitdata.map_bits(bits1)?.collect(),
bitdata.map_bits(bits2)?.collect(),
))
}

/// This is the internal structure for the Python CommutationChecker class
/// It handles the actual commutation checking, cache management, and library
/// lookups. It's not meant to be a public facing Python object though and only used
/// internally by the Python class.
#[pyclass(module = "qiskit._accelerate.commutation_checker")]
struct CommutationChecker {
library: CommutationLibrary,
cache_max_entries: usize,
cache: HashMap<(String, String), CommutationCacheEntry>,
#[pyo3(get)]
current_cache_entries: usize,
#[pyo3(get)]
gates: Option<HashSet<String>>,
Expand All @@ -61,7 +84,7 @@ impl CommutationChecker {
#[pyo3(signature = (standard_gate_commutations=None, cache_max_entries=1_000_000, gates=None))]
#[new]
fn py_new(
standard_gate_commutations: Option<Bound<PyAny>>, // Send a bound here
standard_gate_commutations: Option<Bound<PyAny>>,
cache_max_entries: usize,
gates: Option<HashSet<String>>,
) -> Self {
Expand Down Expand Up @@ -152,11 +175,12 @@ impl CommutationChecker {
)
}

#[pyo3(signature=())]
/// Return the current number of cache entries
fn num_cached_entries(&self) -> usize {
self.current_cache_entries
}
#[pyo3(signature=())]

/// Clear the cache
fn clear_cached_commutations(&mut self) {
self.clear_cache()
}
Expand All @@ -165,7 +189,11 @@ impl CommutationChecker {
let out_dict = PyDict::new_bound(py);
out_dict.set_item("cache_max_entries", self.cache_max_entries)?;
out_dict.set_item("current_cache_entries", self.current_cache_entries)?;
out_dict.set_item("cache", self.cache.clone())?;
let cache_dict = PyDict::new_bound(py);
for (key, value) in &self.cache {
cache_dict.set_item(key, commutation_entry_to_pydict(py, value))?;
}
out_dict.set_item("cache", cache_dict)?;
out_dict.set_item("library", self.library.library.to_object(py))?;
out_dict.set_item("gates", self.gates.clone())?;
Ok(out_dict.unbind())
Expand All @@ -184,7 +212,15 @@ impl CommutationChecker {
self.library = CommutationLibrary {
library: dict_state.get_item("library")?.unwrap().extract()?,
};
self.cache = dict_state.get_item("cache")?.unwrap().extract()?;
let raw_cache: Bound<PyDict> = dict_state.get_item("cache")?.unwrap().extract()?;
self.cache = HashMap::with_capacity(raw_cache.len());
for (key, value) in raw_cache.iter() {
let value_dict: &Bound<PyDict> = value.downcast()?;
self.cache.insert(
key.extract()?,
commutation_cache_entry_from_pydict(value_dict)?,
);
}
self.gates = dict_state.get_item("gates")?.unwrap().extract()?;
Ok(())
}
Expand Down Expand Up @@ -323,7 +359,7 @@ impl CommutationChecker {
);
entries.insert(key, is_commuting);
self.current_cache_entries += 1;
CommutationCacheEntry { mapping: entries }
entries
});
Ok(is_commuting)
}
Expand Down Expand Up @@ -495,6 +531,7 @@ fn get_matrix(py: Python, operation: &OperationRef, params: &[Param]) -> Option<
},
}
}

fn matrix_via_operator(py: Python, py_obj: &PyObject) -> Option<Array2<Complex64>> {
Some(
QI_OPERATOR
Expand Down Expand Up @@ -632,69 +669,49 @@ type CacheKey = (
(SmallVec<[ParameterKey; 3]>, SmallVec<[ParameterKey; 3]>),
);

// Need a struct instead of a type definition because we cannot implement serialization traits otherwise
#[derive(Clone, Debug)]
struct CommutationCacheEntry {
mapping: HashMap<CacheKey, bool>,
}
impl CommutationCacheEntry {
fn get(&self, key: &CacheKey) -> Option<&bool> {
self.mapping.get(key)
}
fn iter(&self) -> Iter<'_, CacheKey, bool> {
self.mapping.iter()
}

fn insert(&mut self, k: CacheKey, v: bool) -> Option<bool> {
self.mapping.insert(k, v)
type CommutationCacheEntry = HashMap<CacheKey, bool>;

fn commutation_entry_to_pydict(py: Python, entry: &CommutationCacheEntry) -> Py<PyDict> {
let out_dict = PyDict::new_bound(py);
for (k, v) in entry.iter() {
let qubits = PyTuple::new_bound(py, k.0.iter().map(|q| q.map(|t| t.0)));
let params0 = PyTuple::new_bound(py, k.1 .0.iter().map(|pk| pk.0));
let params1 = PyTuple::new_bound(py, k.1 .1.iter().map(|pk| pk.0));
out_dict
.set_item(
PyTuple::new_bound(py, [qubits, PyTuple::new_bound(py, [params0, params1])]),
PyBool::new_bound(py, *v),
)
.expect("Failed to construct commutation cache for serialization");
}
out_dict.unbind()
}

impl ToPyObject for CommutationCacheEntry {
fn to_object(&self, py: Python) -> PyObject {
let out_dict = PyDict::new_bound(py);
for (k, v) in self.iter() {
let qubits = PyTuple::new_bound(py, k.0.iter().map(|q| q.map(|t| t.0)));
let params0 = PyTuple::new_bound(py, k.1 .0.iter().map(|pk| pk.0));
let params1 = PyTuple::new_bound(py, k.1 .1.iter().map(|pk| pk.0));
out_dict
.set_item(
PyTuple::new_bound(py, [qubits, PyTuple::new_bound(py, [params0, params1])]),
PyBool::new_bound(py, *v),
)
.expect("Failed to construct commutation cache for serialization");
}
out_dict.into_any().unbind()
fn commutation_cache_entry_from_pydict(dict: &Bound<PyDict>) -> PyResult<CommutationCacheEntry> {
let mut ret = hashbrown::HashMap::with_capacity(dict.len());
for (k, v) in dict {
let raw_key: CacheKeyRaw = k.extract()?;
let qubits = raw_key.0.iter().map(|q| q.map(Qubit)).collect();
let params0: SmallVec<_> = raw_key.1 .0;
let params1: SmallVec<_> = raw_key.1 .1;
let v: bool = v.extract()?;
ret.insert((qubits, (params0, params1)), v);
}
Ok(ret)
}

type CacheKeyRaw = (
SmallVec<[Option<u32>; 2]>,
(SmallVec<[f64; 3]>, SmallVec<[f64; 3]>),
(SmallVec<[ParameterKey; 3]>, SmallVec<[ParameterKey; 3]>),
);
impl<'py> FromPyObject<'py> for CommutationCacheEntry {
fn extract_bound(b: &Bound<'py, PyAny>) -> Result<Self, PyErr> {
let dict = b.downcast::<PyDict>()?;
let mut ret = hashbrown::HashMap::with_capacity(dict.len());
for (k, v) in dict {
let raw_key: CacheKeyRaw = k.extract()?;
let qubits = raw_key.0.iter().map(|q| q.map(Qubit)).collect();
let params0: SmallVec<_> = raw_key.1 .0.iter().map(|p| ParameterKey(*p)).collect();
let params1: SmallVec<_> = raw_key.1 .1.iter().map(|p| ParameterKey(*p)).collect();
let v: bool = v.extract()?;
ret.insert((qubits, (params0, params1)), v);
}
Ok(CommutationCacheEntry { mapping: ret })
}
}

/// This newtype wraps a f64 to make it hashable so we can cache parameterized gates
/// based on the parameter value (assuming it's a float angle). However, Rust doesn't do
/// this by default and there are edge cases to track around it's usage. The biggest one
/// is this does not work with f64::NAN, f64::INFINITY, or f64::NEG_INFINITY
/// If you try to use these values with this type they will not work as expected.
/// This should only be used with the cache hashmap's keys and not used beyond that.
#[derive(Debug, Copy, Clone, PartialEq)]
#[derive(Debug, Copy, Clone, PartialEq, FromPyObject)]
struct ParameterKey(f64);

impl ParameterKey {
Expand Down Expand Up @@ -751,24 +768,3 @@ pub fn commutation_checker(m: &Bound<PyModule>) -> PyResult<()> {
m.add_class::<CommutationChecker>()?;
Ok(())
}

fn get_bits<T>(
py: Python,
bits1: &Bound<PyTuple>,
bits2: &Bound<PyTuple>,
) -> PyResult<(Vec<T>, Vec<T>)>
where
T: From<BitType> + Copy,
BitType: From<T>,
{
let mut bitdata: BitData<T> = BitData::new(py, "bits".to_string());

bits1.iter().chain(bits2.iter()).for_each(|bit| {
bitdata.add(py, &bit, false).unwrap();
});

Ok((
bitdata.map_bits(bits1)?.collect(),
bitdata.map_bits(bits2)?.collect(),
))
}

0 comments on commit f231c83

Please sign in to comment.