diff --git a/src/ffi_ptr_ext.rs b/src/ffi_ptr_ext.rs index 0034ea853f1..5e48a3913d6 100644 --- a/src/ffi_ptr_ext.rs +++ b/src/ffi_ptr_ext.rs @@ -35,6 +35,9 @@ pub(crate) trait FfiPtrExt: Sealed { /// Same as `assume_borrowed_or_err`, but panics on NULL. unsafe fn assume_borrowed<'a>(self, py: Python<'_>) -> Py2Borrowed<'a, '_, PyAny>; + + /// Same as `assume_borrowed_or_err`, but does not check for NULL. + unsafe fn assume_borrowed_unchecked<'a>(self, py: Python<'_>) -> Py2Borrowed<'a, '_, PyAny>; } impl FfiPtrExt for *mut ffi::PyObject { @@ -68,4 +71,9 @@ impl FfiPtrExt for *mut ffi::PyObject { unsafe fn assume_borrowed<'a>(self, py: Python<'_>) -> Py2Borrowed<'a, '_, PyAny> { Py2Borrowed::from_ptr(py, self) } + + #[inline] + unsafe fn assume_borrowed_unchecked<'a>(self, py: Python<'_>) -> Py2Borrowed<'a, '_, PyAny> { + Py2Borrowed::from_ptr_unchecked(py, self) + } } diff --git a/src/instance.rs b/src/instance.rs index 7568b28f108..b3dc8609503 100644 --- a/src/instance.rs +++ b/src/instance.rs @@ -262,6 +262,14 @@ impl<'a, 'py> Py2Borrowed<'a, 'py, PyAny> { py, ) } + + /// # Safety + /// This is similar to `std::slice::from_raw_parts`, the lifetime `'a` is completely defined by + /// the caller and it's the caller's responsibility to ensure that the reference this is + /// derived from is valid for the lifetime `'a`. + pub(crate) unsafe fn from_ptr_unchecked(py: Python<'py>, ptr: *mut ffi::PyObject) -> Self { + Self(NonNull::new_unchecked(ptr), PhantomData, py) + } } impl<'a, 'py, T> From<&'a Py2<'py, T>> for Py2Borrowed<'a, 'py, T> { diff --git a/src/prelude.rs b/src/prelude.rs index 85409ffebbb..772a661420a 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -29,5 +29,6 @@ pub use crate::wrap_pyfunction; // pub(crate) use crate::types::boolobject::PyBoolMethods; // pub(crate) use crate::types::bytearray::PyByteArrayMethods; // pub(crate) use crate::types::bytes::PyBytesMethods; +// pub(crate) use crate::types::dict::PyDictMethods; // pub(crate) use crate::types::float::PyFloatMethods; // pub(crate) use crate::types::sequence::PySequenceMethods; diff --git a/src/types/dict.rs b/src/types/dict.rs index fd32074b167..c03a03ec534 100644 --- a/src/types/dict.rs +++ b/src/types/dict.rs @@ -1,8 +1,12 @@ use super::PyMapping; use crate::err::{self, PyErr, PyResult}; use crate::ffi::Py_ssize_t; +use crate::ffi_ptr_ext::FfiPtrExt; +use crate::instance::Py2; +use crate::py_result_ext::PyResultExt; +use crate::types::any::PyAnyMethods; use crate::types::{PyAny, PyList}; -use crate::{ffi, PyObject, Python, ToPyObject}; +use crate::{ffi, Python, ToPyObject}; /// Represents a Python `dict`. #[repr(transparent)] @@ -54,7 +58,7 @@ pyobject_native_type_core!( impl PyDict { /// Creates a new empty dictionary. pub fn new(py: Python<'_>) -> &PyDict { - unsafe { py.from_owned_ptr::(ffi::PyDict_New()) } + unsafe { py.from_owned_ptr(ffi::PyDict_New()) } } /// Creates a new dictionary from the sequence given. @@ -78,39 +82,26 @@ impl PyDict { /// /// This is equivalent to the Python expression `self.copy()`. pub fn copy(&self) -> PyResult<&PyDict> { - unsafe { - self.py() - .from_owned_ptr_or_err::(ffi::PyDict_Copy(self.as_ptr())) - } + Py2::borrowed_from_gil_ref(&self) + .copy() + .map(Py2::into_gil_ref) } /// Empties an existing dictionary of all key-value pairs. pub fn clear(&self) { - unsafe { ffi::PyDict_Clear(self.as_ptr()) } + Py2::borrowed_from_gil_ref(&self).clear() } /// Return the number of items in the dictionary. /// /// This is equivalent to the Python expression `len(self)`. pub fn len(&self) -> usize { - self._len() as usize - } - - fn _len(&self) -> Py_ssize_t { - #[cfg(any(not(Py_3_8), PyPy, Py_LIMITED_API))] - unsafe { - ffi::PyDict_Size(self.as_ptr()) - } - - #[cfg(all(Py_3_8, not(PyPy), not(Py_LIMITED_API)))] - unsafe { - (*self.as_ptr().cast::()).ma_used - } + Py2::borrowed_from_gil_ref(&self).len() } /// Checks if the dict is empty, i.e. `len(self) == 0`. pub fn is_empty(&self) -> bool { - self.len() == 0 + Py2::borrowed_from_gil_ref(&self).is_empty() } /// Determines if the dictionary contains the specified key. @@ -120,15 +111,7 @@ impl PyDict { where K: ToPyObject, { - fn inner(dict: &PyDict, key: PyObject) -> PyResult { - match unsafe { ffi::PyDict_Contains(dict.as_ptr(), key.as_ptr()) } { - 1 => Ok(true), - 0 => Ok(false), - _ => Err(PyErr::fetch(dict.py())), - } - } - - inner(self, key.to_object(self.py())) + Py2::borrowed_from_gil_ref(&self).contains(key) } /// Gets an item from the dictionary. @@ -177,22 +160,11 @@ impl PyDict { where K: ToPyObject, { - fn inner(dict: &PyDict, key: PyObject) -> PyResult> { - let py = dict.py(); - // PyDict_GetItemWithError returns a borrowed ptr, must make it owned for safety (see #890). - // PyObject::from_borrowed_ptr_or_opt will take ownership in this way. - unsafe { - PyObject::from_borrowed_ptr_or_opt( - py, - ffi::PyDict_GetItemWithError(dict.as_ptr(), key.as_ptr()), - ) - } - .map(|pyobject| Ok(pyobject.into_ref(py))) - .or_else(|| PyErr::take(py).map(Err)) - .transpose() + match Py2::borrowed_from_gil_ref(&self).get_item(key) { + Ok(Some(item)) => Ok(Some(item.into_gil_ref())), + Ok(None) => Ok(None), + Err(e) => Err(e), } - - inner(self, key.to_object(self.py())) } /// Deprecated version of `get_item`. @@ -216,14 +188,7 @@ impl PyDict { K: ToPyObject, V: ToPyObject, { - fn inner(dict: &PyDict, key: PyObject, value: PyObject) -> PyResult<()> { - err::error_on_minusone(dict.py(), unsafe { - ffi::PyDict_SetItem(dict.as_ptr(), key.as_ptr(), value.as_ptr()) - }) - } - - let py = self.py(); - inner(self, key.to_object(py), value.to_object(py)) + Py2::borrowed_from_gil_ref(&self).set_item(key, value) } /// Deletes an item. @@ -233,43 +198,28 @@ impl PyDict { where K: ToPyObject, { - fn inner(dict: &PyDict, key: PyObject) -> PyResult<()> { - err::error_on_minusone(dict.py(), unsafe { - ffi::PyDict_DelItem(dict.as_ptr(), key.as_ptr()) - }) - } - - inner(self, key.to_object(self.py())) + Py2::borrowed_from_gil_ref(&self).del_item(key) } /// Returns a list of dict keys. /// /// This is equivalent to the Python expression `list(dict.keys())`. pub fn keys(&self) -> &PyList { - unsafe { - self.py() - .from_owned_ptr::(ffi::PyDict_Keys(self.as_ptr())) - } + Py2::borrowed_from_gil_ref(&self).keys().into_gil_ref() } /// Returns a list of dict values. /// /// This is equivalent to the Python expression `list(dict.values())`. pub fn values(&self) -> &PyList { - unsafe { - self.py() - .from_owned_ptr::(ffi::PyDict_Values(self.as_ptr())) - } + Py2::borrowed_from_gil_ref(&self).values().into_gil_ref() } /// Returns a list of dict items. /// /// This is equivalent to the Python expression `list(dict.items())`. pub fn items(&self) -> &PyList { - unsafe { - self.py() - .from_owned_ptr::(ffi::PyDict_Items(self.as_ptr())) - } + Py2::borrowed_from_gil_ref(&self).items().into_gil_ref() } /// Returns an iterator of `(key, value)` pairs in this dictionary. @@ -280,7 +230,7 @@ impl PyDict { /// It is allowed to modify values as you iterate over the dictionary, but only /// so long as the set of keys does not change. pub fn iter(&self) -> PyDictIterator<'_> { - IntoIterator::into_iter(self) + PyDictIterator(Py2::borrowed_from_gil_ref(&self).iter()) } /// Returns `self` cast as a `PyMapping`. @@ -293,10 +243,7 @@ impl PyDict { /// This is equivalent to the Python expression `self.update(other)`. If `other` is a `PyDict`, you may want /// to use `self.update(other.as_mapping())`, note: `PyDict::as_mapping` is a zero-cost conversion. pub fn update(&self, other: &PyMapping) -> PyResult<()> { - let py = self.py(); - err::error_on_minusone(py, unsafe { - ffi::PyDict_Update(self.as_ptr(), other.as_ptr()) - }) + Py2::borrowed_from_gil_ref(&self).update(Py2::borrowed_from_gil_ref(&other)) } /// Add key/value pairs from another dictionary to this one only when they do not exist in this. @@ -308,27 +255,309 @@ impl PyDict { /// This method uses [`PyDict_Merge`](https://docs.python.org/3/c-api/dict.html#c.PyDict_Merge) internally, /// so should have the same performance as `update`. pub fn update_if_missing(&self, other: &PyMapping) -> PyResult<()> { + Py2::borrowed_from_gil_ref(&self).update_if_missing(Py2::borrowed_from_gil_ref(&other)) + } +} + +/// Implementation of functionality for [`PyDict`]. +/// +/// These methods are defined for the `Py2<'py, PyDict>` smart pointer, so to use method call +/// syntax these methods are separated into a trait, because stable Rust does not yet support +/// `arbitrary_self_types`. +#[doc(alias = "PyDict")] +pub(crate) trait PyDictMethods<'py> { + /// Returns a new dictionary that contains the same key-value pairs as self. + /// + /// This is equivalent to the Python expression `self.copy()`. + fn copy(&self) -> PyResult>; + + /// Empties an existing dictionary of all key-value pairs. + fn clear(&self); + + /// Return the number of items in the dictionary. + /// + /// This is equivalent to the Python expression `len(self)`. + fn len(&self) -> usize; + + /// Checks if the dict is empty, i.e. `len(self) == 0`. + fn is_empty(&self) -> bool; + + /// Determines if the dictionary contains the specified key. + /// + /// This is equivalent to the Python expression `key in self`. + fn contains(&self, key: K) -> PyResult + where + K: ToPyObject; + + /// Gets an item from the dictionary. + /// + /// Returns `None` if the item is not present, or if an error occurs. + /// + /// To get a `KeyError` for non-existing keys, use `PyAny::get_item`. + fn get_item(&self, key: K) -> PyResult>> + where + K: ToPyObject; + + /// Sets an item value. + /// + /// This is equivalent to the Python statement `self[key] = value`. + fn set_item(&self, key: K, value: V) -> PyResult<()> + where + K: ToPyObject, + V: ToPyObject; + + /// Deletes an item. + /// + /// This is equivalent to the Python statement `del self[key]`. + fn del_item(&self, key: K) -> PyResult<()> + where + K: ToPyObject; + + /// Returns a list of dict keys. + /// + /// This is equivalent to the Python expression `list(dict.keys())`. + fn keys(&self) -> Py2<'py, PyList>; + + /// Returns a list of dict values. + /// + /// This is equivalent to the Python expression `list(dict.values())`. + fn values(&self) -> Py2<'py, PyList>; + + /// Returns a list of dict items. + /// + /// This is equivalent to the Python expression `list(dict.items())`. + fn items(&self) -> Py2<'py, PyList>; + + /// Returns an iterator of `(key, value)` pairs in this dictionary. + /// + /// # Panics + /// + /// If PyO3 detects that the dictionary is mutated during iteration, it will panic. + /// It is allowed to modify values as you iterate over the dictionary, but only + /// so long as the set of keys does not change. + fn iter(&self) -> PyDictIterator2<'py>; + + /// Returns `self` cast as a `PyMapping`. + fn as_mapping(&self) -> &Py2<'py, PyMapping>; + + /// Update this dictionary with the key/value pairs from another. + /// + /// This is equivalent to the Python expression `self.update(other)`. If `other` is a `PyDict`, you may want + /// to use `self.update(other.as_mapping())`, note: `PyDict::as_mapping` is a zero-cost conversion. + fn update(&self, other: &Py2<'_, PyMapping>) -> PyResult<()>; + + /// Add key/value pairs from another dictionary to this one only when they do not exist in this. + /// + /// This is equivalent to the Python expression `self.update({k: v for k, v in other.items() if k not in self})`. + /// If `other` is a `PyDict`, you may want to use `self.update_if_missing(other.as_mapping())`, + /// note: `PyDict::as_mapping` is a zero-cost conversion. + /// + /// This method uses [`PyDict_Merge`](https://docs.python.org/3/c-api/dict.html#c.PyDict_Merge) internally, + /// so should have the same performance as `update`. + fn update_if_missing(&self, other: &Py2<'_, PyMapping>) -> PyResult<()>; +} + +impl<'py> PyDictMethods<'py> for Py2<'py, PyDict> { + fn copy(&self) -> PyResult> { + unsafe { + ffi::PyDict_Copy(self.as_ptr()) + .assume_owned_or_err(self.py()) + .downcast_into_unchecked() + } + } + + fn clear(&self) { + unsafe { ffi::PyDict_Clear(self.as_ptr()) } + } + + fn len(&self) -> usize { + dict_len(self) as usize + } + + fn is_empty(&self) -> bool { + self.len() == 0 + } + + fn contains(&self, key: K) -> PyResult + where + K: ToPyObject, + { + fn inner(dict: &Py2<'_, PyDict>, key: Py2<'_, PyAny>) -> PyResult { + match unsafe { ffi::PyDict_Contains(dict.as_ptr(), key.as_ptr()) } { + 1 => Ok(true), + 0 => Ok(false), + _ => Err(PyErr::fetch(dict.py())), + } + } + let py = self.py(); - err::error_on_minusone(py, unsafe { + inner(self, key.to_object(py).attach_into(py)) + } + + fn get_item(&self, key: K) -> PyResult>> + where + K: ToPyObject, + { + fn inner<'py>( + dict: &Py2<'py, PyDict>, + key: Py2<'_, PyAny>, + ) -> PyResult>> { + let py = dict.py(); + match unsafe { + ffi::PyDict_GetItemWithError(dict.as_ptr(), key.as_ptr()) + .assume_borrowed_or_opt(py) + .map(|borrowed_any| borrowed_any.clone()) + } { + some @ Some(_) => Ok(some), + None => PyErr::take(py).map(Err).transpose(), + } + } + + let py = self.py(); + inner(self, key.to_object(py).attach_into(py)) + } + + fn set_item(&self, key: K, value: V) -> PyResult<()> + where + K: ToPyObject, + V: ToPyObject, + { + fn inner( + dict: &Py2<'_, PyDict>, + key: Py2<'_, PyAny>, + value: Py2<'_, PyAny>, + ) -> PyResult<()> { + err::error_on_minusone(dict.py(), unsafe { + ffi::PyDict_SetItem(dict.as_ptr(), key.as_ptr(), value.as_ptr()) + }) + } + + let py = self.py(); + inner( + self, + key.to_object(py).attach_into(py), + value.to_object(py).attach_into(py), + ) + } + + fn del_item(&self, key: K) -> PyResult<()> + where + K: ToPyObject, + { + fn inner(dict: &Py2<'_, PyDict>, key: Py2<'_, PyAny>) -> PyResult<()> { + err::error_on_minusone(dict.py(), unsafe { + ffi::PyDict_DelItem(dict.as_ptr(), key.as_ptr()) + }) + } + + let py = self.py(); + inner(self, key.to_object(py).attach_into(py)) + } + + fn keys(&self) -> Py2<'py, PyList> { + unsafe { + ffi::PyDict_Keys(self.as_ptr()) + .assume_owned(self.py()) + .downcast_into_unchecked() + } + } + + fn values(&self) -> Py2<'py, PyList> { + unsafe { + ffi::PyDict_Values(self.as_ptr()) + .assume_owned(self.py()) + .downcast_into_unchecked() + } + } + + fn items(&self) -> Py2<'py, PyList> { + unsafe { + ffi::PyDict_Items(self.as_ptr()) + .assume_owned(self.py()) + .downcast_into_unchecked() + } + } + + fn iter(&self) -> PyDictIterator2<'py> { + PyDictIterator2::new(self.clone()) + } + + fn as_mapping(&self) -> &Py2<'py, PyMapping> { + unsafe { self.downcast_unchecked() } + } + + fn update(&self, other: &Py2<'_, PyMapping>) -> PyResult<()> { + err::error_on_minusone(self.py(), unsafe { + ffi::PyDict_Update(self.as_ptr(), other.as_ptr()) + }) + } + + fn update_if_missing(&self, other: &Py2<'_, PyMapping>) -> PyResult<()> { + err::error_on_minusone(self.py(), unsafe { ffi::PyDict_Merge(self.as_ptr(), other.as_ptr(), 0) }) } } +fn dict_len(dict: &Py2<'_, PyDict>) -> Py_ssize_t { + #[cfg(any(not(Py_3_8), PyPy, Py_LIMITED_API))] + unsafe { + ffi::PyDict_Size(dict.as_ptr()) + } + + #[cfg(all(Py_3_8, not(PyPy), not(Py_LIMITED_API)))] + unsafe { + (*dict.as_ptr().cast::()).ma_used + } +} + +/// PyO3 implementation of an iterator for a Python `dict` object. +pub struct PyDictIterator<'py>(PyDictIterator2<'py>); + +impl<'py> Iterator for PyDictIterator<'py> { + type Item = (&'py PyAny, &'py PyAny); + + #[inline] + fn next(&mut self) -> Option { + let (key, value) = self.0.next()?; + Some((key.into_gil_ref(), value.into_gil_ref())) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + self.0.size_hint() + } +} + +impl<'py> ExactSizeIterator for PyDictIterator<'py> { + fn len(&self) -> usize { + self.0.len() + } +} + +impl<'a> IntoIterator for &'a PyDict { + type Item = (&'a PyAny, &'a PyAny); + type IntoIter = PyDictIterator<'a>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + /// PyO3 implementation of an iterator for a Python `dict` object. -pub struct PyDictIterator<'py> { - dict: &'py PyDict, +pub(crate) struct PyDictIterator2<'py> { + dict: Py2<'py, PyDict>, ppos: ffi::Py_ssize_t, di_used: ffi::Py_ssize_t, len: ffi::Py_ssize_t, } -impl<'py> Iterator for PyDictIterator<'py> { - type Item = (&'py PyAny, &'py PyAny); +impl<'py> Iterator for PyDictIterator2<'py> { + type Item = (Py2<'py, PyAny>, Py2<'py, PyAny>); #[inline] fn next(&mut self) -> Option { - let ma_used = self.dict._len(); + let ma_used = dict_len(&self.dict); // These checks are similar to what CPython does. // @@ -354,11 +583,24 @@ impl<'py> Iterator for PyDictIterator<'py> { panic!("dictionary keys changed during iteration"); }; - let ret = unsafe { self.next_unchecked() }; - if ret.is_some() { - self.len -= 1 + let mut key: *mut ffi::PyObject = std::ptr::null_mut(); + let mut value: *mut ffi::PyObject = std::ptr::null_mut(); + + if unsafe { ffi::PyDict_Next(self.dict.as_ptr(), &mut self.ppos, &mut key, &mut value) } + != 0 + { + self.len -= 1; + let py = self.dict.py(); + // Safety: + // - PyDict_Next returns borrowed values + // - we have already checked that `PyDict_Next` succeeded, so we can assume these to be non-null + Some(( + unsafe { key.assume_borrowed_unchecked(py) }.clone(), + unsafe { value.assume_borrowed_unchecked(py) }.clone(), + )) + } else { + None } - ret } #[inline] @@ -368,45 +610,39 @@ impl<'py> Iterator for PyDictIterator<'py> { } } -impl<'py> ExactSizeIterator for PyDictIterator<'py> { +impl<'py> ExactSizeIterator for PyDictIterator2<'py> { fn len(&self) -> usize { self.len as usize } } -impl<'a> std::iter::IntoIterator for &'a PyDict { - type Item = (&'a PyAny, &'a PyAny); - type IntoIter = PyDictIterator<'a>; - - fn into_iter(self) -> Self::IntoIter { - PyDictIterator { - dict: self, +impl<'py> PyDictIterator2<'py> { + fn new(dict: Py2<'py, PyDict>) -> Self { + let len = dict_len(&dict); + PyDictIterator2 { + dict, ppos: 0, - di_used: self._len(), - len: self._len(), + di_used: len, + len, } } } -impl<'py> PyDictIterator<'py> { - /// Advances the iterator without checking for concurrent modification. - /// - /// See [`PyDict_Next`](https://docs.python.org/3/c-api/dict.html#c.PyDict_Next) - /// for more information. - unsafe fn next_unchecked(&mut self) -> Option<(&'py PyAny, &'py PyAny)> { - let mut key: *mut ffi::PyObject = std::ptr::null_mut(); - let mut value: *mut ffi::PyObject = std::ptr::null_mut(); +impl<'py> IntoIterator for &'_ Py2<'py, PyDict> { + type Item = (Py2<'py, PyAny>, Py2<'py, PyAny>); + type IntoIter = PyDictIterator2<'py>; - if ffi::PyDict_Next(self.dict.as_ptr(), &mut self.ppos, &mut key, &mut value) != 0 { - let py = self.dict.py(); - // PyDict_Next returns borrowed values; for safety must make them owned (see #890) - Some(( - py.from_owned_ptr(ffi::_Py_NewRef(key)), - py.from_owned_ptr(ffi::_Py_NewRef(value)), - )) - } else { - None - } + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +impl<'py> IntoIterator for Py2<'py, PyDict> { + type Item = (Py2<'py, PyAny>, Py2<'py, PyAny>); + type IntoIter = PyDictIterator2<'py>; + + fn into_iter(self) -> Self::IntoIter { + PyDictIterator2::new(self) } } @@ -523,9 +759,11 @@ mod tests { .extract::() .unwrap() ); - let map: HashMap<&str, i32> = [("a", 1), ("b", 2)].iter().cloned().collect(); + let map: HashMap = + [("a".into(), 1), ("b".into(), 2)].into_iter().collect(); assert_eq!(map, dict.extract().unwrap()); - let map: BTreeMap<&str, i32> = [("a", 1), ("b", 2)].iter().cloned().collect(); + let map: BTreeMap = + [("a".into(), 1), ("b".into(), 2)].into_iter().collect(); assert_eq!(map, dict.extract().unwrap()); }); } diff --git a/src/types/mod.rs b/src/types/mod.rs index a7ea5b631f8..97ecc82e8f8 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -281,7 +281,7 @@ mod code; mod complex; #[cfg(not(Py_LIMITED_API))] pub(crate) mod datetime; -mod dict; +pub(crate) mod dict; mod ellipsis; pub(crate) mod float; #[cfg(all(not(Py_LIMITED_API), not(PyPy)))]