diff --git a/crates/accelerate/src/commutation_checker.rs b/crates/accelerate/src/commutation_checker.rs index 2eb30fe58bab..0b2104cdf1d8 100644 --- a/crates/accelerate/src/commutation_checker.rs +++ b/crates/accelerate/src/commutation_checker.rs @@ -11,7 +11,6 @@ // that they have been altered from the originals. use approx::abs_diff_eq; -use hashbrown::hash_map::Iter; use hashbrown::{HashMap, HashSet}; use ndarray::linalg::kron; use ndarray::Array2; @@ -45,12 +44,36 @@ static SUPPORTED_OP: Lazy> = Lazy::new(|| { ]) }); +fn get_bits( + py: Python, + bits1: &Bound, + bits2: &Bound, +) -> PyResult<(Vec, Vec)> +where + T: From + Copy, + BitType: From, +{ + let mut bitdata: BitData = BitData::new(py, "bits".to_string()); + + bits1.iter().chain(bits2.iter()).for_each(|bit| { + bitdata.add(py, &bit, false).unwrap(); + }); + + Ok(( + bitdata.map_bits(bits1)?.collect(), + bitdata.map_bits(bits2)?.collect(), + )) +} + +/// This is the internal structure for the Python CommutationChecker class +/// It handles the actual commutation checking, cache management, and library +/// lookups. It's not meant to be a public facing Python object though and only used +/// internally by the Python class. #[pyclass(module = "qiskit._accelerate.commutation_checker")] struct CommutationChecker { library: CommutationLibrary, cache_max_entries: usize, cache: HashMap<(String, String), CommutationCacheEntry>, - #[pyo3(get)] current_cache_entries: usize, #[pyo3(get)] gates: Option>, @@ -61,7 +84,7 @@ impl CommutationChecker { #[pyo3(signature = (standard_gate_commutations=None, cache_max_entries=1_000_000, gates=None))] #[new] fn py_new( - standard_gate_commutations: Option>, // Send a bound here + standard_gate_commutations: Option>, cache_max_entries: usize, gates: Option>, ) -> Self { @@ -152,11 +175,12 @@ impl CommutationChecker { ) } - #[pyo3(signature=())] + /// Return the current number of cache entries fn num_cached_entries(&self) -> usize { self.current_cache_entries } - #[pyo3(signature=())] + + /// Clear the cache fn clear_cached_commutations(&mut self) { self.clear_cache() } @@ -165,7 +189,11 @@ impl CommutationChecker { let out_dict = PyDict::new_bound(py); out_dict.set_item("cache_max_entries", self.cache_max_entries)?; out_dict.set_item("current_cache_entries", self.current_cache_entries)?; - out_dict.set_item("cache", self.cache.clone())?; + let cache_dict = PyDict::new_bound(py); + for (key, value) in &self.cache { + cache_dict.set_item(key, commutation_entry_to_pydict(py, value))?; + } + out_dict.set_item("cache", cache_dict)?; out_dict.set_item("library", self.library.library.to_object(py))?; out_dict.set_item("gates", self.gates.clone())?; Ok(out_dict.unbind()) @@ -184,7 +212,15 @@ impl CommutationChecker { self.library = CommutationLibrary { library: dict_state.get_item("library")?.unwrap().extract()?, }; - self.cache = dict_state.get_item("cache")?.unwrap().extract()?; + let raw_cache: Bound = dict_state.get_item("cache")?.unwrap().extract()?; + self.cache = HashMap::with_capacity(raw_cache.len()); + for (key, value) in raw_cache.iter() { + let value_dict: &Bound = value.downcast()?; + self.cache.insert( + key.extract()?, + commutation_cache_entry_from_pydict(value_dict)?, + ); + } self.gates = dict_state.get_item("gates")?.unwrap().extract()?; Ok(()) } @@ -323,7 +359,7 @@ impl CommutationChecker { ); entries.insert(key, is_commuting); self.current_cache_entries += 1; - CommutationCacheEntry { mapping: entries } + entries }); Ok(is_commuting) } @@ -495,6 +531,7 @@ fn get_matrix(py: Python, operation: &OperationRef, params: &[Param]) -> Option< }, } } + fn matrix_via_operator(py: Python, py_obj: &PyObject) -> Option> { Some( QI_OPERATOR @@ -632,61 +669,41 @@ type CacheKey = ( (SmallVec<[ParameterKey; 3]>, SmallVec<[ParameterKey; 3]>), ); -// Need a struct instead of a type definition because we cannot implement serialization traits otherwise -#[derive(Clone, Debug)] -struct CommutationCacheEntry { - mapping: HashMap, -} -impl CommutationCacheEntry { - fn get(&self, key: &CacheKey) -> Option<&bool> { - self.mapping.get(key) - } - fn iter(&self) -> Iter<'_, CacheKey, bool> { - self.mapping.iter() - } - - fn insert(&mut self, k: CacheKey, v: bool) -> Option { - self.mapping.insert(k, v) +type CommutationCacheEntry = HashMap; + +fn commutation_entry_to_pydict(py: Python, entry: &CommutationCacheEntry) -> Py { + let out_dict = PyDict::new_bound(py); + for (k, v) in entry.iter() { + let qubits = PyTuple::new_bound(py, k.0.iter().map(|q| q.map(|t| t.0))); + let params0 = PyTuple::new_bound(py, k.1 .0.iter().map(|pk| pk.0)); + let params1 = PyTuple::new_bound(py, k.1 .1.iter().map(|pk| pk.0)); + out_dict + .set_item( + PyTuple::new_bound(py, [qubits, PyTuple::new_bound(py, [params0, params1])]), + PyBool::new_bound(py, *v), + ) + .expect("Failed to construct commutation cache for serialization"); } + out_dict.unbind() } -impl ToPyObject for CommutationCacheEntry { - fn to_object(&self, py: Python) -> PyObject { - let out_dict = PyDict::new_bound(py); - for (k, v) in self.iter() { - let qubits = PyTuple::new_bound(py, k.0.iter().map(|q| q.map(|t| t.0))); - let params0 = PyTuple::new_bound(py, k.1 .0.iter().map(|pk| pk.0)); - let params1 = PyTuple::new_bound(py, k.1 .1.iter().map(|pk| pk.0)); - out_dict - .set_item( - PyTuple::new_bound(py, [qubits, PyTuple::new_bound(py, [params0, params1])]), - PyBool::new_bound(py, *v), - ) - .expect("Failed to construct commutation cache for serialization"); - } - out_dict.into_any().unbind() +fn commutation_cache_entry_from_pydict(dict: &Bound) -> PyResult { + let mut ret = hashbrown::HashMap::with_capacity(dict.len()); + for (k, v) in dict { + let raw_key: CacheKeyRaw = k.extract()?; + let qubits = raw_key.0.iter().map(|q| q.map(Qubit)).collect(); + let params0: SmallVec<_> = raw_key.1 .0; + let params1: SmallVec<_> = raw_key.1 .1; + let v: bool = v.extract()?; + ret.insert((qubits, (params0, params1)), v); } + Ok(ret) } type CacheKeyRaw = ( SmallVec<[Option; 2]>, - (SmallVec<[f64; 3]>, SmallVec<[f64; 3]>), + (SmallVec<[ParameterKey; 3]>, SmallVec<[ParameterKey; 3]>), ); -impl<'py> FromPyObject<'py> for CommutationCacheEntry { - fn extract_bound(b: &Bound<'py, PyAny>) -> Result { - let dict = b.downcast::()?; - let mut ret = hashbrown::HashMap::with_capacity(dict.len()); - for (k, v) in dict { - let raw_key: CacheKeyRaw = k.extract()?; - let qubits = raw_key.0.iter().map(|q| q.map(Qubit)).collect(); - let params0: SmallVec<_> = raw_key.1 .0.iter().map(|p| ParameterKey(*p)).collect(); - let params1: SmallVec<_> = raw_key.1 .1.iter().map(|p| ParameterKey(*p)).collect(); - let v: bool = v.extract()?; - ret.insert((qubits, (params0, params1)), v); - } - Ok(CommutationCacheEntry { mapping: ret }) - } -} /// This newtype wraps a f64 to make it hashable so we can cache parameterized gates /// based on the parameter value (assuming it's a float angle). However, Rust doesn't do @@ -694,7 +711,7 @@ impl<'py> FromPyObject<'py> for CommutationCacheEntry { /// is this does not work with f64::NAN, f64::INFINITY, or f64::NEG_INFINITY /// If you try to use these values with this type they will not work as expected. /// This should only be used with the cache hashmap's keys and not used beyond that. -#[derive(Debug, Copy, Clone, PartialEq)] +#[derive(Debug, Copy, Clone, PartialEq, FromPyObject)] struct ParameterKey(f64); impl ParameterKey { @@ -751,24 +768,3 @@ pub fn commutation_checker(m: &Bound) -> PyResult<()> { m.add_class::()?; Ok(()) } - -fn get_bits( - py: Python, - bits1: &Bound, - bits2: &Bound, -) -> PyResult<(Vec, Vec)> -where - T: From + Copy, - BitType: From, -{ - let mut bitdata: BitData = BitData::new(py, "bits".to_string()); - - bits1.iter().chain(bits2.iter()).for_each(|bit| { - bitdata.add(py, &bit, false).unwrap(); - }); - - Ok(( - bitdata.map_bits(bits1)?.collect(), - bitdata.map_bits(bits2)?.collect(), - )) -}