Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added various std traits for PyBackedStr and PyBackedBytes #4020

Merged
merged 3 commits into from
Apr 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions newsfragments/4020.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Adds `Clone`, `Debug`, `PartialEq`, `Eq`, `PartialOrd`, `Ord` and `Hash` implementation for `PyBackedBytes` and `PyBackedStr`, and `Display` for `PyBackedStr`.
263 changes: 260 additions & 3 deletions src/pybacked.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Contains types for working with Python objects that own the underlying data.

use std::{ops::Deref, ptr::NonNull};
use std::{ops::Deref, ptr::NonNull, sync::Arc};

use crate::{
types::{
Expand All @@ -13,6 +13,7 @@ use crate::{
/// A wrapper around `str` where the storage is owned by a Python `bytes` or `str` object.
///
/// This type gives access to the underlying data via a `Deref` implementation.
#[derive(Clone)]
pub struct PyBackedStr {
#[allow(dead_code)] // only held so that the storage is not dropped
storage: Py<PyAny>,
Expand Down Expand Up @@ -44,6 +45,14 @@ impl AsRef<[u8]> for PyBackedStr {
unsafe impl Send for PyBackedStr {}
unsafe impl Sync for PyBackedStr {}

impl std::fmt::Display for PyBackedStr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.deref().fmt(f)
}
}

impl_traits!(PyBackedStr, str);

impl TryFrom<Bound<'_, PyString>> for PyBackedStr {
type Error = PyErr;
fn try_from(py_string: Bound<'_, PyString>) -> Result<Self, Self::Error> {
Expand Down Expand Up @@ -79,16 +88,18 @@ impl FromPyObject<'_> for PyBackedStr {
/// A wrapper around `[u8]` where the storage is either owned by a Python `bytes` object, or a Rust `Box<[u8]>`.
///
/// This type gives access to the underlying data via a `Deref` implementation.
#[derive(Clone)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't the derived Clone impl incorrect for the Rust variant? Cloning Box<[u8]> will yield a new allocation for data will still point to the old one meaning this is a use-after-free in waiting?

So we either need to manually fix up the data pointer or use a shared ownership pointer for the Rust side as well, e.g. Arc<[u8]>.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(And so yes, these do need tests. ;-))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I was thinking similarly about the Box, an Arc seems reasonable to me.

I also half wonder if it was a mistake to have the Rust variant here given the name PyBackedBytes. I added it to support bytearrays in the same way we already do for Cow<[u8]> but on further reflection it does seem a bit add odds with the name that we do the copy.

I think probably it's ok as I would expect most uses to be with bytes, but we might want to document this copy-from-bytearray more clearly and justify for sake of safety.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't the derived Clone impl incorrect for the Rust variant? Cloning Box<[u8]> will yield a new allocation for data will still point to the old one meaning this is a use-after-free in waiting?

Big oops, thanks for catching.

Using Arc<[u8]> sounds good to me aswell, that way this stays a cheap clone either way.

pub struct PyBackedBytes {
#[allow(dead_code)] // only held so that the storage is not dropped
storage: PyBackedBytesStorage,
data: NonNull<[u8]>,
}

#[allow(dead_code)]
#[derive(Clone)]
enum PyBackedBytesStorage {
Python(Py<PyBytes>),
Rust(Box<[u8]>),
Rust(Arc<[u8]>),
}

impl Deref for PyBackedBytes {
Expand All @@ -110,6 +121,32 @@ impl AsRef<[u8]> for PyBackedBytes {
unsafe impl Send for PyBackedBytes {}
unsafe impl Sync for PyBackedBytes {}

impl<const N: usize> PartialEq<[u8; N]> for PyBackedBytes {
fn eq(&self, other: &[u8; N]) -> bool {
self.deref() == other
}
}

impl<const N: usize> PartialEq<PyBackedBytes> for [u8; N] {
fn eq(&self, other: &PyBackedBytes) -> bool {
self == other.deref()
}
}

impl<const N: usize> PartialEq<&[u8; N]> for PyBackedBytes {
fn eq(&self, other: &&[u8; N]) -> bool {
self.deref() == *other
}
}

impl<const N: usize> PartialEq<PyBackedBytes> for &[u8; N] {
fn eq(&self, other: &PyBackedBytes) -> bool {
self == &other.deref()
}
}

impl_traits!(PyBackedBytes, [u8]);

impl From<Bound<'_, PyBytes>> for PyBackedBytes {
fn from(py_bytes: Bound<'_, PyBytes>) -> Self {
let b = py_bytes.as_bytes();
Expand All @@ -123,7 +160,7 @@ impl From<Bound<'_, PyBytes>> for PyBackedBytes {

impl From<Bound<'_, PyByteArray>> for PyBackedBytes {
fn from(py_bytearray: Bound<'_, PyByteArray>) -> Self {
let s = py_bytearray.to_vec().into_boxed_slice();
let s = Arc::<[u8]>::from(py_bytearray.to_vec());
let data = NonNull::from(s.as_ref());
Self {
storage: PyBackedBytesStorage::Rust(s),
Expand All @@ -144,10 +181,85 @@ impl FromPyObject<'_> for PyBackedBytes {
}
}

macro_rules! impl_traits {
($slf:ty, $equiv:ty) => {
impl std::fmt::Debug for $slf {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.deref().fmt(f)
}
}

impl PartialEq for $slf {
fn eq(&self, other: &Self) -> bool {
self.deref() == other.deref()
}
}

impl PartialEq<$equiv> for $slf {
fn eq(&self, other: &$equiv) -> bool {
self.deref() == other
}
}

impl PartialEq<&$equiv> for $slf {
fn eq(&self, other: &&$equiv) -> bool {
self.deref() == *other
}
}

impl PartialEq<$slf> for $equiv {
fn eq(&self, other: &$slf) -> bool {
self == other.deref()
}
}

impl PartialEq<$slf> for &$equiv {
fn eq(&self, other: &$slf) -> bool {
self == &other.deref()
}
}

impl Eq for $slf {}

impl PartialOrd for $slf {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}

impl PartialOrd<$equiv> for $slf {
fn partial_cmp(&self, other: &$equiv) -> Option<std::cmp::Ordering> {
self.deref().partial_cmp(other)
}
}

impl PartialOrd<$slf> for $equiv {
fn partial_cmp(&self, other: &$slf) -> Option<std::cmp::Ordering> {
self.partial_cmp(other.deref())
}
}

impl Ord for $slf {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.deref().cmp(other.deref())
}
}

impl std::hash::Hash for $slf {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.deref().hash(state)
}
}
};
}
use impl_traits;

#[cfg(test)]
mod test {
use super::*;
use crate::Python;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};

#[test]
fn py_backed_str_empty() {
Expand Down Expand Up @@ -223,4 +335,149 @@ mod test {
is_send::<PyBackedBytes>();
is_sync::<PyBackedBytes>();
}

#[test]
fn test_backed_str_clone() {
Python::with_gil(|py| {
let s1: PyBackedStr = PyString::new_bound(py, "hello").try_into().unwrap();
let s2 = s1.clone();
assert_eq!(s1, s2);

drop(s1);
assert_eq!(s2, "hello");
});
}

#[test]
fn test_backed_str_eq() {
Python::with_gil(|py| {
let s1: PyBackedStr = PyString::new_bound(py, "hello").try_into().unwrap();
let s2: PyBackedStr = PyString::new_bound(py, "hello").try_into().unwrap();
assert_eq!(s1, "hello");
assert_eq!(s1, s2);

let s3: PyBackedStr = PyString::new_bound(py, "abcde").try_into().unwrap();
assert_eq!("abcde", s3);
assert_ne!(s1, s3);
});
}

#[test]
fn test_backed_str_hash() {
Python::with_gil(|py| {
let h = {
let mut hasher = DefaultHasher::new();
"abcde".hash(&mut hasher);
hasher.finish()
};

let s1: PyBackedStr = PyString::new_bound(py, "abcde").try_into().unwrap();
let h1 = {
let mut hasher = DefaultHasher::new();
s1.hash(&mut hasher);
hasher.finish()
};

assert_eq!(h, h1);
});
}

#[test]
fn test_backed_str_ord() {
Python::with_gil(|py| {
let mut a = vec!["a", "c", "d", "b", "f", "g", "e"];
let mut b = a
.iter()
.map(|s| PyString::new_bound(py, s).try_into().unwrap())
.collect::<Vec<PyBackedStr>>();

a.sort();
b.sort();

assert_eq!(a, b);
})
}

#[test]
fn test_backed_bytes_from_bytes_clone() {
Python::with_gil(|py| {
let b1: PyBackedBytes = PyBytes::new_bound(py, b"abcde").into();
let b2 = b1.clone();
assert_eq!(b1, b2);

drop(b1);
assert_eq!(b2, b"abcde");
});
}

#[test]
fn test_backed_bytes_from_bytearray_clone() {
Python::with_gil(|py| {
let b1: PyBackedBytes = PyByteArray::new_bound(py, b"abcde").into();
let b2 = b1.clone();
assert_eq!(b1, b2);

drop(b1);
assert_eq!(b2, b"abcde");
});
}

#[test]
fn test_backed_bytes_eq() {
Python::with_gil(|py| {
let b1: PyBackedBytes = PyBytes::new_bound(py, b"abcde").into();
let b2: PyBackedBytes = PyByteArray::new_bound(py, b"abcde").into();

assert_eq!(b1, b"abcde");
assert_eq!(b1, b2);

let b3: PyBackedBytes = PyBytes::new_bound(py, b"hello").into();
assert_eq!(b"hello", b3);
assert_ne!(b1, b3);
});
}

#[test]
fn test_backed_bytes_hash() {
Python::with_gil(|py| {
let h = {
let mut hasher = DefaultHasher::new();
b"abcde".hash(&mut hasher);
hasher.finish()
};

let b1: PyBackedBytes = PyBytes::new_bound(py, b"abcde").into();
let h1 = {
let mut hasher = DefaultHasher::new();
b1.hash(&mut hasher);
hasher.finish()
};

let b2: PyBackedBytes = PyByteArray::new_bound(py, b"abcde").into();
let h2 = {
let mut hasher = DefaultHasher::new();
b2.hash(&mut hasher);
hasher.finish()
};

assert_eq!(h, h1);
assert_eq!(h, h2);
});
}

#[test]
fn test_backed_bytes_ord() {
Python::with_gil(|py| {
let mut a = vec![b"a", b"c", b"d", b"b", b"f", b"g", b"e"];
let mut b = a
.iter()
.map(|&b| PyBytes::new_bound(py, b).into())
.collect::<Vec<PyBackedBytes>>();

a.sort();
b.sort();

assert_eq!(a, b);
})
}
}
Loading