From c7a3e0baa637eef2c4fd5a6d5621e201691eebca Mon Sep 17 00:00:00 2001 From: Miles Granger Date: Fri, 1 Mar 2024 10:20:31 +0100 Subject: [PATCH] Trying doing things the right way, CPython happy, PyPy sad [skip ci] --- cramjam-python/src/io.rs | 65 +++++++++------------------ cramjam-python/src/lib.rs | 16 +++---- cramjam-python/src/lz4.rs | 4 +- cramjam-python/src/snappy.rs | 4 +- cramjam-python/tests/test_variants.py | 2 +- 5 files changed, 33 insertions(+), 58 deletions(-) diff --git a/cramjam-python/src/io.rs b/cramjam-python/src/io.rs index 123aaa5e..992e66d0 100644 --- a/cramjam-python/src/io.rs +++ b/cramjam-python/src/io.rs @@ -18,7 +18,7 @@ use std::path::PathBuf; pub(crate) trait AsBytes { fn as_bytes(&self) -> &[u8]; - fn as_bytes_mut(&mut self) -> &mut [u8]; + fn as_bytes_mut(&mut self) -> PyResult<&mut [u8]>; } /// A native Rust file-like object. Reading and writing takes place @@ -49,7 +49,7 @@ impl AsBytes for RustyFile { entire file into memory; consider using cramjam.Buffer" ) } - fn as_bytes_mut(&mut self) -> &mut [u8] { + fn as_bytes_mut(&mut self) -> PyResult<&mut [u8]> { unimplemented!( "Converting a File to bytes is not supported, as it'd require reading the \ entire file into memory; consider using cramjam.Buffer" @@ -194,7 +194,7 @@ impl PythonBuffer { } /// Is the Python buffer readonly pub fn readonly(&self) -> bool { - self.inner.readonly != 0 + self.inner.readonly == 1 } /// Get the underlying buffer as a slice of bytes pub fn as_slice(&self) -> &[u8] { @@ -202,13 +202,16 @@ impl PythonBuffer { } /// Get the underlying buffer as a mutable slice of bytes pub fn as_slice_mut(&mut self) -> PyResult<&mut [u8]> { - // TODO: For v3 release, add self.readonly check; bytes is readonly but - // v1 and v2 releases have not treated it as such. + if self.readonly() { + let repr = Python::with_gil(|py| unsafe { PyObject::from_borrowed_ptr(py, self.inner.obj) }.to_string()); + let msg = format!("The output buffer '{}' is readonly, refusing to overwrite.", repr); + return Err(pyo3::exceptions::PyTypeError::new_err(msg)); + } Ok(unsafe { std::slice::from_raw_parts_mut(self.buf_ptr() as *mut u8, self.len_bytes()) }) } /// If underlying buffer is c_contiguous pub fn is_c_contiguous(&self) -> bool { - unsafe { ffi::PyBuffer_IsContiguous(&*self.inner as *const ffi::Py_buffer, b'C' as std::os::raw::c_char) != 0 } + unsafe { ffi::PyBuffer_IsContiguous(&*self.inner as *const ffi::Py_buffer, b'C' as std::os::raw::c_char) == 1 } } /// Dimensions for buffer pub fn dimensions(&self) -> usize { @@ -244,46 +247,17 @@ impl<'py> FromPyObject<'py> for PythonBuffer { } } -#[cfg(not(PyPy))] -fn make_py_buffer(obj: &PyAny) -> PyResult { - let mut buf = mem::MaybeUninit::uninit(); - let rc = unsafe { ffi::PyObject_GetBuffer(obj.as_ptr(), buf.as_mut_ptr(), ffi::PyBUF_CONTIG_RO) }; - if rc != 0 { - return Err(exceptions::PyBufferError::new_err( - "Failed to get buffer, is it C contiguous, and shape is not null?", - )); - } - let buf = unsafe { mem::MaybeUninit::::assume_init(buf) }; - Ok(buf) -} - -#[cfg(PyPy)] -fn make_py_buffer(obj: &PyAny) -> PyResult { - let is_memview = unsafe { ffi::PyMemoryView_Check(obj.as_ptr()) } == 1; - - let mut object = Python::with_gil(|py| obj.to_object(py)); - - if !is_memview { - let ptr = unsafe { ffi::PyMemoryView_FromObject(obj.as_ptr()) }; - Python::with_gil(|py| { - object = unsafe { PyObject::from_owned_ptr(py, ptr) }; - }) - } - let mut buf = mem::MaybeUninit::uninit(); - let rc = unsafe { ffi::PyObject_GetBuffer(object.as_ptr(), buf.as_mut_ptr(), ffi::PyBUF_CONTIG_RO) }; - if rc != 0 { - return Err(exceptions::PyBufferError::new_err( - "Failed to get buffer, is it C contiguous, and shape is not null?", - )); - } - let buf = unsafe { mem::MaybeUninit::::assume_init(buf) }; - Ok(buf) -} - impl<'py> TryFrom<&'py PyAny> for PythonBuffer { type Error = PyErr; fn try_from(obj: &'py PyAny) -> Result { - let py_buffer = make_py_buffer(obj)?; + let mut buf = mem::MaybeUninit::uninit(); + let rc = unsafe { ffi::PyObject_GetBuffer(obj.as_ptr(), buf.as_mut_ptr(), ffi::PyBUF_CONTIG_RO) }; + if rc != 0 { + return Err(exceptions::PyBufferError::new_err( + "Failed to get buffer, is it C contiguous, and shape is not null?", + )); + } + let py_buffer = unsafe { mem::MaybeUninit::::assume_init(buf) }; let buf = Self { inner: std::pin::Pin::from(Box::new(py_buffer)), pos: 0, @@ -357,8 +331,9 @@ impl AsBytes for RustyBuffer { fn as_bytes(&self) -> &[u8] { self.inner.get_ref().as_slice() } - fn as_bytes_mut(&mut self) -> &mut [u8] { - self.inner.get_mut().as_mut_slice() + fn as_bytes_mut(&mut self) -> PyResult<&mut [u8]> { + let slice = self.inner.get_mut().as_mut_slice(); + Ok(slice) } } diff --git a/cramjam-python/src/lib.rs b/cramjam-python/src/lib.rs index c494db86..f5fea3ba 100644 --- a/cramjam-python/src/lib.rs +++ b/cramjam-python/src/lib.rs @@ -101,18 +101,18 @@ impl<'a> AsBytes for BytesType<'a> { } } } - fn as_bytes_mut(&mut self) -> &mut [u8] { + fn as_bytes_mut(&mut self) -> PyResult<&mut [u8]> { match self { BytesType::RustyBuffer(b) => { let mut py_ref = b.borrow_mut(); - let bytes = py_ref.as_bytes_mut(); - unsafe { std::slice::from_raw_parts_mut(bytes.as_mut_ptr(), bytes.len()) } + let bytes = py_ref.as_bytes_mut()?; + Ok(unsafe { std::slice::from_raw_parts_mut(bytes.as_mut_ptr(), bytes.len()) }) } - BytesType::PyBuffer(b) => b.as_slice_mut().unwrap(), + BytesType::PyBuffer(b) => b.as_slice_mut(), BytesType::RustyFile(b) => { let mut py_ref = b.borrow_mut(); - let bytes = py_ref.as_bytes_mut(); - unsafe { std::slice::from_raw_parts_mut(bytes.as_mut_ptr(), bytes.len()) } + let bytes = py_ref.as_bytes_mut()?; + Ok(unsafe { std::slice::from_raw_parts_mut(bytes.as_mut_ptr(), bytes.len()) }) } } } @@ -212,7 +212,7 @@ macro_rules! generic { }) }, _ => { - let bytes_out = $output.as_bytes_mut(); + let bytes_out = $output.as_bytes_mut()?; $py.allow_threads(|| { $op(f_in, &mut Cursor::new(bytes_out) $(, $level)? ) }) @@ -237,7 +237,7 @@ macro_rules! generic { }) }, _ => { - let bytes_out = $output.as_bytes_mut(); + let bytes_out = $output.as_bytes_mut()?; $py.allow_threads(|| { $op(bytes_in, &mut Cursor::new(bytes_out) $(, $level)?) }) diff --git a/cramjam-python/src/lz4.rs b/cramjam-python/src/lz4.rs index 016c4470..c7894a00 100644 --- a/cramjam-python/src/lz4.rs +++ b/cramjam-python/src/lz4.rs @@ -146,7 +146,7 @@ pub fn compress_block( #[pyfunction] pub fn decompress_block_into(py: Python, input: BytesType, mut output: BytesType) -> PyResult { let bytes = input.as_bytes(); - let out_bytes = output.as_bytes_mut(); + let out_bytes = output.as_bytes_mut()?; py.allow_threads(|| libcramjam::lz4::block::decompress_into(bytes, out_bytes, Some(true))) .map_err(DecompressionError::from_err) .map(|v| v as _) @@ -180,7 +180,7 @@ pub fn compress_block_into( store_size: Option, ) -> PyResult { let bytes = data.as_bytes(); - let out_bytes = output.as_bytes_mut(); + let out_bytes = output.as_bytes_mut()?; py.allow_threads(|| { libcramjam::lz4::block::compress_into(bytes, out_bytes, compression.map(|v| v as _), acceleration, store_size) }) diff --git a/cramjam-python/src/snappy.rs b/cramjam-python/src/snappy.rs index 1cd9ed83..9133c7af 100644 --- a/cramjam-python/src/snappy.rs +++ b/cramjam-python/src/snappy.rs @@ -100,7 +100,7 @@ pub fn decompress_into(py: Python, input: BytesType, mut output: BytesType) -> P #[pyfunction] pub fn compress_raw_into(py: Python, input: BytesType, mut output: BytesType) -> PyResult { let bytes_in = input.as_bytes(); - let bytes_out = output.as_bytes_mut(); + let bytes_out = output.as_bytes_mut()?; py.allow_threads(|| libcramjam::snappy::raw::compress(bytes_in, bytes_out)) .map_err(CompressionError::from_err) } @@ -109,7 +109,7 @@ pub fn compress_raw_into(py: Python, input: BytesType, mut output: BytesType) -> #[pyfunction] pub fn decompress_raw_into(py: Python, input: BytesType, mut output: BytesType) -> PyResult { let bytes_in = input.as_bytes(); - let bytes_out = output.as_bytes_mut(); + let bytes_out = output.as_bytes_mut()?; py.allow_threads(|| libcramjam::snappy::raw::decompress(bytes_in, bytes_out)) .map_err(DecompressionError::from_err) } diff --git a/cramjam-python/tests/test_variants.py b/cramjam-python/tests/test_variants.py index ecc9181b..bc93740b 100644 --- a/cramjam-python/tests/test_variants.py +++ b/cramjam-python/tests/test_variants.py @@ -124,7 +124,7 @@ def test_variants_compress_into( else: output = output_type(b"0" * compressed_len) - n_bytes = variant.compress_into(input, output) + n_bytes = variant.compress_into(input, memoryview(output)) assert n_bytes == compressed_len if hasattr(output, "read"):