diff --git a/src/pybacked.rs b/src/pybacked.rs index 8e105ee3b1a..d090b261a4c 100644 --- a/src/pybacked.rs +++ b/src/pybacked.rs @@ -1,6 +1,6 @@ //! Contains types for working with Python objects that own the underlying data. -use std::{ops::Deref, ptr::NonNull}; +use std::{ops::Deref, ptr::NonNull, sync::Arc}; use crate::{ types::{ @@ -99,7 +99,7 @@ pub struct PyBackedBytes { #[derive(Clone)] enum PyBackedBytesStorage { Python(Py), - Rust(Box<[u8]>), + Rust(Arc<[u8]>), } impl Deref for PyBackedBytes { @@ -121,6 +121,30 @@ impl AsRef<[u8]> for PyBackedBytes { unsafe impl Send for PyBackedBytes {} unsafe impl Sync for PyBackedBytes {} +impl PartialEq<[u8; N]> for PyBackedBytes { + fn eq(&self, other: &[u8; N]) -> bool { + self.deref() == other + } +} + +impl PartialEq for [u8; N] { + fn eq(&self, other: &PyBackedBytes) -> bool { + self == other.deref() + } +} + +impl PartialEq<&[u8; N]> for PyBackedBytes { + fn eq(&self, other: &&[u8; N]) -> bool { + self.deref() == *other + } +} + +impl PartialEq for &[u8; N] { + fn eq(&self, other: &PyBackedBytes) -> bool { + self == &other.deref() + } +} + impl_traits!(PyBackedBytes, [u8]); impl From> for PyBackedBytes { @@ -136,7 +160,7 @@ impl From> for PyBackedBytes { impl From> for PyBackedBytes { fn from(py_bytearray: Bound<'_, PyByteArray>) -> Self { - let s = py_bytearray.to_vec().into_boxed_slice(); + let s = Arc::<[u8]>::from(py_bytearray.to_vec()); let data = NonNull::from(s.as_ref()); Self { storage: PyBackedBytesStorage::Rust(s), @@ -177,12 +201,24 @@ macro_rules! impl_traits { } } + impl PartialEq<&$equiv> for $slf { + fn eq(&self, other: &&$equiv) -> bool { + self.deref() == *other + } + } + impl PartialEq<$slf> for $equiv { fn eq(&self, other: &$slf) -> bool { self == other.deref() } } + impl PartialEq<$slf> for &$equiv { + fn eq(&self, other: &$slf) -> bool { + self == &other.deref() + } + } + impl Eq for $slf {} impl PartialOrd for $slf { @@ -222,6 +258,7 @@ use impl_traits; mod test { use super::*; use crate::Python; + use std::hash::{DefaultHasher, Hash, Hasher}; #[test] fn py_backed_str_empty() { @@ -297,4 +334,149 @@ mod test { is_send::(); is_sync::(); } + + #[test] + fn test_backed_str_clone() { + Python::with_gil(|py| { + let s1: PyBackedStr = PyString::new_bound(py, "hello").try_into().unwrap(); + let s2 = s1.clone(); + assert_eq!(s1, s2); + + drop(s1); + assert_eq!(s2, "hello"); + }); + } + + #[test] + fn test_backed_str_eq() { + Python::with_gil(|py| { + let s1: PyBackedStr = PyString::new_bound(py, "hello").try_into().unwrap(); + let s2: PyBackedStr = PyString::new_bound(py, "hello").try_into().unwrap(); + assert_eq!(s1, "hello"); + assert_eq!(s1, s2); + + let s3: PyBackedStr = PyString::new_bound(py, "abcde").try_into().unwrap(); + assert_eq!("abcde", s3); + assert_ne!(s1, s3); + }); + } + + #[test] + fn test_backed_str_hash() { + Python::with_gil(|py| { + let h = { + let mut hasher = DefaultHasher::new(); + "abcde".hash(&mut hasher); + hasher.finish() + }; + + let s1: PyBackedStr = PyString::new_bound(py, "abcde").try_into().unwrap(); + let h1 = { + let mut hasher = DefaultHasher::new(); + s1.hash(&mut hasher); + hasher.finish() + }; + + assert_eq!(h, h1); + }); + } + + #[test] + fn test_backed_str_ord() { + Python::with_gil(|py| { + let mut a = vec!["a", "c", "d", "b", "f", "g", "e"]; + let mut b = a + .iter() + .map(|s| PyString::new_bound(py, s).try_into().unwrap()) + .collect::>(); + + a.sort(); + b.sort(); + + assert_eq!(a, b); + }) + } + + #[test] + fn test_backed_bytes_from_bytes_clone() { + Python::with_gil(|py| { + let b1: PyBackedBytes = PyBytes::new_bound(py, b"abcde").into(); + let b2 = b1.clone(); + assert_eq!(b1, b2); + + drop(b1); + assert_eq!(b2, b"abcde"); + }); + } + + #[test] + fn test_backed_bytes_from_bytearray_clone() { + Python::with_gil(|py| { + let b1: PyBackedBytes = PyByteArray::new_bound(py, b"abcde").into(); + let b2 = b1.clone(); + assert_eq!(b1, b2); + + drop(b1); + assert_eq!(b2, b"abcde"); + }); + } + + #[test] + fn test_backed_bytes_eq() { + Python::with_gil(|py| { + let b1: PyBackedBytes = PyBytes::new_bound(py, b"abcde").into(); + let b2: PyBackedBytes = PyByteArray::new_bound(py, b"abcde").into(); + + assert_eq!(b1, b"abcde"); + assert_eq!(b1, b2); + + let b3: PyBackedBytes = PyBytes::new_bound(py, b"hello").into(); + assert_eq!(b"hello", b3); + assert_ne!(b1, b3); + }); + } + + #[test] + fn test_backed_bytes_hash() { + Python::with_gil(|py| { + let h = { + let mut hasher = DefaultHasher::new(); + b"abcde".hash(&mut hasher); + hasher.finish() + }; + + let b1: PyBackedBytes = PyBytes::new_bound(py, b"abcde").into(); + let h1 = { + let mut hasher = DefaultHasher::new(); + b1.hash(&mut hasher); + hasher.finish() + }; + + let b2: PyBackedBytes = PyByteArray::new_bound(py, b"abcde").into(); + let h2 = { + let mut hasher = DefaultHasher::new(); + b2.hash(&mut hasher); + hasher.finish() + }; + + assert_eq!(h, h1); + assert_eq!(h, h2); + }); + } + + #[test] + fn test_backed_bytes_ord() { + Python::with_gil(|py| { + let mut a = vec![b"a", b"c", b"d", b"b", b"f", b"g", b"e"]; + let mut b = a + .iter() + .map(|&b| PyBytes::new_bound(py, b).into()) + .collect::>(); + + a.sort(); + b.sort(); + + assert_eq!(a, b); + }) + } }