diff --git a/newsfragments/4020.added.md b/newsfragments/4020.added.md new file mode 100644 index 00000000000..53d0d6b8fd3 --- /dev/null +++ b/newsfragments/4020.added.md @@ -0,0 +1 @@ +Adds `Clone`, `Debug`, `PartialEq`, `Eq`, `PartialOrd`, `Ord` and `Hash` implementation for `PyBackedBytes` and `PyBackedStr`, and `Display` for `PyBackedStr`. diff --git a/src/pybacked.rs b/src/pybacked.rs index 01b9c2a6f00..e0bacb86144 100644 --- a/src/pybacked.rs +++ b/src/pybacked.rs @@ -1,6 +1,6 @@ //! Contains types for working with Python objects that own the underlying data. -use std::{ops::Deref, ptr::NonNull}; +use std::{ops::Deref, ptr::NonNull, sync::Arc}; use crate::{ types::{ @@ -13,6 +13,7 @@ use crate::{ /// A wrapper around `str` where the storage is owned by a Python `bytes` or `str` object. /// /// This type gives access to the underlying data via a `Deref` implementation. +#[derive(Clone)] pub struct PyBackedStr { #[allow(dead_code)] // only held so that the storage is not dropped storage: Py, @@ -44,6 +45,14 @@ impl AsRef<[u8]> for PyBackedStr { unsafe impl Send for PyBackedStr {} unsafe impl Sync for PyBackedStr {} +impl std::fmt::Display for PyBackedStr { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.deref().fmt(f) + } +} + +impl_traits!(PyBackedStr, str); + impl TryFrom> for PyBackedStr { type Error = PyErr; fn try_from(py_string: Bound<'_, PyString>) -> Result { @@ -79,6 +88,7 @@ impl FromPyObject<'_> for PyBackedStr { /// A wrapper around `[u8]` where the storage is either owned by a Python `bytes` object, or a Rust `Box<[u8]>`. /// /// This type gives access to the underlying data via a `Deref` implementation. +#[derive(Clone)] pub struct PyBackedBytes { #[allow(dead_code)] // only held so that the storage is not dropped storage: PyBackedBytesStorage, @@ -86,9 +96,10 @@ pub struct PyBackedBytes { } #[allow(dead_code)] +#[derive(Clone)] enum PyBackedBytesStorage { Python(Py), - Rust(Box<[u8]>), + Rust(Arc<[u8]>), } impl Deref for PyBackedBytes { @@ -110,6 +121,32 @@ impl AsRef<[u8]> for PyBackedBytes { unsafe impl Send for PyBackedBytes {} unsafe impl Sync for PyBackedBytes {} +impl PartialEq<[u8; N]> for PyBackedBytes { + fn eq(&self, other: &[u8; N]) -> bool { + self.deref() == other + } +} + +impl PartialEq for [u8; N] { + fn eq(&self, other: &PyBackedBytes) -> bool { + self == other.deref() + } +} + +impl PartialEq<&[u8; N]> for PyBackedBytes { + fn eq(&self, other: &&[u8; N]) -> bool { + self.deref() == *other + } +} + +impl PartialEq for &[u8; N] { + fn eq(&self, other: &PyBackedBytes) -> bool { + self == &other.deref() + } +} + +impl_traits!(PyBackedBytes, [u8]); + impl From> for PyBackedBytes { fn from(py_bytes: Bound<'_, PyBytes>) -> Self { let b = py_bytes.as_bytes(); @@ -123,7 +160,7 @@ impl From> for PyBackedBytes { impl From> for PyBackedBytes { fn from(py_bytearray: Bound<'_, PyByteArray>) -> Self { - let s = py_bytearray.to_vec().into_boxed_slice(); + let s = Arc::<[u8]>::from(py_bytearray.to_vec()); let data = NonNull::from(s.as_ref()); Self { storage: PyBackedBytesStorage::Rust(s), @@ -144,10 +181,85 @@ impl FromPyObject<'_> for PyBackedBytes { } } +macro_rules! impl_traits { + ($slf:ty, $equiv:ty) => { + impl std::fmt::Debug for $slf { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.deref().fmt(f) + } + } + + impl PartialEq for $slf { + fn eq(&self, other: &Self) -> bool { + self.deref() == other.deref() + } + } + + impl PartialEq<$equiv> for $slf { + fn eq(&self, other: &$equiv) -> bool { + self.deref() == other + } + } + + impl PartialEq<&$equiv> for $slf { + fn eq(&self, other: &&$equiv) -> bool { + self.deref() == *other + } + } + + impl PartialEq<$slf> for $equiv { + fn eq(&self, other: &$slf) -> bool { + self == other.deref() + } + } + + impl PartialEq<$slf> for &$equiv { + fn eq(&self, other: &$slf) -> bool { + self == &other.deref() + } + } + + impl Eq for $slf {} + + impl PartialOrd for $slf { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } + } + + impl PartialOrd<$equiv> for $slf { + fn partial_cmp(&self, other: &$equiv) -> Option { + self.deref().partial_cmp(other) + } + } + + impl PartialOrd<$slf> for $equiv { + fn partial_cmp(&self, other: &$slf) -> Option { + self.partial_cmp(other.deref()) + } + } + + impl Ord for $slf { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.deref().cmp(other.deref()) + } + } + + impl std::hash::Hash for $slf { + fn hash(&self, state: &mut H) { + self.deref().hash(state) + } + } + }; +} +use impl_traits; + #[cfg(test)] mod test { use super::*; use crate::Python; + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; #[test] fn py_backed_str_empty() { @@ -223,4 +335,149 @@ mod test { is_send::(); is_sync::(); } + + #[test] + fn test_backed_str_clone() { + Python::with_gil(|py| { + let s1: PyBackedStr = PyString::new_bound(py, "hello").try_into().unwrap(); + let s2 = s1.clone(); + assert_eq!(s1, s2); + + drop(s1); + assert_eq!(s2, "hello"); + }); + } + + #[test] + fn test_backed_str_eq() { + Python::with_gil(|py| { + let s1: PyBackedStr = PyString::new_bound(py, "hello").try_into().unwrap(); + let s2: PyBackedStr = PyString::new_bound(py, "hello").try_into().unwrap(); + assert_eq!(s1, "hello"); + assert_eq!(s1, s2); + + let s3: PyBackedStr = PyString::new_bound(py, "abcde").try_into().unwrap(); + assert_eq!("abcde", s3); + assert_ne!(s1, s3); + }); + } + + #[test] + fn test_backed_str_hash() { + Python::with_gil(|py| { + let h = { + let mut hasher = DefaultHasher::new(); + "abcde".hash(&mut hasher); + hasher.finish() + }; + + let s1: PyBackedStr = PyString::new_bound(py, "abcde").try_into().unwrap(); + let h1 = { + let mut hasher = DefaultHasher::new(); + s1.hash(&mut hasher); + hasher.finish() + }; + + assert_eq!(h, h1); + }); + } + + #[test] + fn test_backed_str_ord() { + Python::with_gil(|py| { + let mut a = vec!["a", "c", "d", "b", "f", "g", "e"]; + let mut b = a + .iter() + .map(|s| PyString::new_bound(py, s).try_into().unwrap()) + .collect::>(); + + a.sort(); + b.sort(); + + assert_eq!(a, b); + }) + } + + #[test] + fn test_backed_bytes_from_bytes_clone() { + Python::with_gil(|py| { + let b1: PyBackedBytes = PyBytes::new_bound(py, b"abcde").into(); + let b2 = b1.clone(); + assert_eq!(b1, b2); + + drop(b1); + assert_eq!(b2, b"abcde"); + }); + } + + #[test] + fn test_backed_bytes_from_bytearray_clone() { + Python::with_gil(|py| { + let b1: PyBackedBytes = PyByteArray::new_bound(py, b"abcde").into(); + let b2 = b1.clone(); + assert_eq!(b1, b2); + + drop(b1); + assert_eq!(b2, b"abcde"); + }); + } + + #[test] + fn test_backed_bytes_eq() { + Python::with_gil(|py| { + let b1: PyBackedBytes = PyBytes::new_bound(py, b"abcde").into(); + let b2: PyBackedBytes = PyByteArray::new_bound(py, b"abcde").into(); + + assert_eq!(b1, b"abcde"); + assert_eq!(b1, b2); + + let b3: PyBackedBytes = PyBytes::new_bound(py, b"hello").into(); + assert_eq!(b"hello", b3); + assert_ne!(b1, b3); + }); + } + + #[test] + fn test_backed_bytes_hash() { + Python::with_gil(|py| { + let h = { + let mut hasher = DefaultHasher::new(); + b"abcde".hash(&mut hasher); + hasher.finish() + }; + + let b1: PyBackedBytes = PyBytes::new_bound(py, b"abcde").into(); + let h1 = { + let mut hasher = DefaultHasher::new(); + b1.hash(&mut hasher); + hasher.finish() + }; + + let b2: PyBackedBytes = PyByteArray::new_bound(py, b"abcde").into(); + let h2 = { + let mut hasher = DefaultHasher::new(); + b2.hash(&mut hasher); + hasher.finish() + }; + + assert_eq!(h, h1); + assert_eq!(h, h2); + }); + } + + #[test] + fn test_backed_bytes_ord() { + Python::with_gil(|py| { + let mut a = vec![b"a", b"c", b"d", b"b", b"f", b"g", b"e"]; + let mut b = a + .iter() + .map(|&b| PyBytes::new_bound(py, b).into()) + .collect::>(); + + a.sort(); + b.sort(); + + assert_eq!(a, b); + }) + } }