Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support all buffer types where c-contiguous #105

Merged
merged 1 commit into from
May 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion cramjam-python/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "cramjam-python"
version = "2.7.0-rc2"
version = "2.7.0-rc3"
authors = ["Miles Granger <[email protected]>"]
edition = "2021"
license = "MIT"
Expand Down
90 changes: 69 additions & 21 deletions cramjam-python/src/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
use std::convert::TryFrom;
use std::fs::{File, OpenOptions};
use std::io::{copy, Cursor, Read, Seek, SeekFrom, Write};
use std::mem;
use std::os::raw::c_int;

use crate::exceptions::CompressionError;
use crate::BytesType;
use pyo3::buffer::{Element, PyBuffer};
use pyo3::exceptions::PyBufferError;
use pyo3::exceptions::{self, PyBufferError};
use pyo3::prelude::*;
use pyo3::types::PyBytes;
use pyo3::{ffi, AsPyPointer};
Expand Down Expand Up @@ -169,11 +169,17 @@ impl RustyFile {

/// Internal wrapper to PyBuffer, not exposed thru API
/// used only for impl of Read/Write
pub struct PythonBuffer<T: Element> {
pub(crate) inner: PyBuffer<T>,
// Inspired from pyo3 PyBuffer<T>, but here we don't want or care about T
pub struct PythonBuffer {
pub(crate) inner: std::pin::Pin<Box<ffi::Py_buffer>>,
pub(crate) pos: usize,
}
impl<T: Element> PythonBuffer<T> {
// PyBuffer is thread-safe: the shape of the buffer is immutable while a Py_buffer exists.
// Accessing the buffer contents is protected using the GIL.
unsafe impl Send for PythonBuffer {}
unsafe impl Sync for PythonBuffer {}

impl PythonBuffer {
/// Reset the read/write position of cursor
pub fn reset_position(&mut self) {
self.pos = 0;
Expand All @@ -188,41 +194,83 @@ impl<T: Element> PythonBuffer<T> {
}
/// Is the Python buffer readonly
pub fn readonly(&self) -> bool {
self.inner.readonly()
self.inner.readonly != 0
}
/// Get the underlying buffer as a slice of bytes
pub fn as_slice(&self) -> &[u8] {
unsafe { std::slice::from_raw_parts(self.inner.buf_ptr() as *const u8, self.inner.len_bytes()) }
unsafe { std::slice::from_raw_parts(self.buf_ptr() as *const u8, self.len_bytes()) }
}
/// Get the underlying buffer as a mutable slice of bytes
pub fn as_slice_mut(&mut self) -> PyResult<&mut [u8]> {
// TODO: For v3 release, add self.readonly check; bytes is readonly but
// v1 and v2 releases have not treated it as such.
Ok(unsafe { std::slice::from_raw_parts_mut(self.inner.buf_ptr() as *mut u8, self.inner.len_bytes()) })
Ok(unsafe { std::slice::from_raw_parts_mut(self.buf_ptr() as *mut u8, self.len_bytes()) })
}
/// If underlying buffer is c_contiguous
pub fn is_c_contiguous(&self) -> bool {
unsafe { ffi::PyBuffer_IsContiguous(&*self.inner as *const ffi::Py_buffer, b'C' as std::os::raw::c_char) != 0 }
}
/// Dimensions for buffer
pub fn dimensions(&self) -> usize {
self.inner.ndim as usize
}
/// raw pointer to buffer
pub fn buf_ptr(&self) -> *mut std::os::raw::c_void {
self.inner.buf
}
/// length of buffer in bytes
pub fn len_bytes(&self) -> usize {
self.inner.len as usize
}
/// the buffer item size
pub fn item_size(&self) -> usize {
self.inner.itemsize as usize
}
/// number of items in buffer
pub fn item_count(&self) -> usize {
(self.inner.len as usize) / (self.inner.itemsize as usize)
}
}

impl<'py, T: Element> FromPyObject<'py> for PythonBuffer<T> {
impl Drop for PythonBuffer {
fn drop(&mut self) {
Python::with_gil(|_| unsafe { ffi::PyBuffer_Release(&mut *self.inner) })
}
}

impl<'py> FromPyObject<'py> for PythonBuffer {
fn extract(obj: &'py PyAny) -> PyResult<Self> {
let buf = PyBuffer::get(obj)?;
PythonBuffer::try_from(buf)
Self::try_from(obj)
}
}

impl<T: Element> TryFrom<PyBuffer<T>> for PythonBuffer<T> {
impl<'py> TryFrom<&'py PyAny> for PythonBuffer {
type Error = PyErr;
fn try_from(buf: PyBuffer<T>) -> Result<Self, Self::Error> {
if !buf.is_c_contiguous() {
fn try_from(obj: &'py PyAny) -> Result<Self, Self::Error> {
let mut buf = Box::new(mem::MaybeUninit::uninit());
let rc = unsafe { ffi::PyObject_GetBuffer(obj.as_ptr(), buf.as_mut_ptr(), ffi::PyBUF_CONTIG_RO) };
if rc != 0 {
return Err(exceptions::PyBufferError::new_err(
"Failed to get buffer, is it C contiguous, and shape is not null?",
));
}
let buf = Box::new(unsafe { mem::MaybeUninit::<ffi::Py_buffer>::assume_init(*buf) });
let buf = Self {
inner: std::pin::Pin::from(buf),
pos: 0,
};
// sanity checks
if buf.inner.shape.is_null() {
Err(exceptions::PyBufferError::new_err("shape is null"))
} else if !buf.is_c_contiguous() {
Err(PyBufferError::new_err("Buffer is not C contiguous"))
} else if buf.dimensions() != 1 {
Err(PyBufferError::new_err("Buffer is not 1 dimensional"))
} else {
Ok(Self { inner: buf, pos: 0 })
Ok(buf)
}
}
}

impl<T: Element> Read for PythonBuffer<T> {
impl Read for PythonBuffer {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
let slice = self.as_slice();
if self.pos < slice.len() {
Expand All @@ -235,7 +283,7 @@ impl<T: Element> Read for PythonBuffer<T> {
}
}

impl<T: Element> Write for PythonBuffer<T> {
impl Write for PythonBuffer {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
let pos = self.position();
let slice = self
Expand Down Expand Up @@ -463,9 +511,9 @@ impl Seek for RustyFile {
self.inner.seek(pos)
}
}
impl<T: Element> Seek for PythonBuffer<T> {
impl Seek for PythonBuffer {
fn seek(&mut self, pos: SeekFrom) -> std::io::Result<u64> {
let len = self.inner.len_bytes();
let len = self.len_bytes();
let current = self.position();
match pos {
SeekFrom::Start(n) => self.set_position(n as usize),
Expand Down
2 changes: 1 addition & 1 deletion cramjam-python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ pub enum BytesType<'a> {
RustyFile(&'a PyCell<RustyFile>),
/// `object` implementing the Buffer Protocol
#[pyo3(transparent, annotation = "pybuffer")]
PyBuffer(PythonBuffer<u8>),
PyBuffer(PythonBuffer),
}

impl<'a> AsBytes for BytesType<'a> {
Expand Down
32 changes: 26 additions & 6 deletions cramjam-python/tests/test_variants.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@
import hashlib
from datetime import timedelta
from hypothesis import strategies as st, given, settings
from hypothesis.extra import numpy as st_np

VARIANTS = ("snappy", "brotli", "bzip2", "lz4", "gzip", "deflate", "zstd")

# Some OS can be slow or have higher variability in their runtimes on CI
settings.register_profile("local", deadline=timedelta(milliseconds=1000))
settings.register_profile("CI", deadline=None, max_examples=10)
settings.register_profile(
"local", deadline=timedelta(milliseconds=1000), max_examples=100
)
settings.register_profile("CI", deadline=None, max_examples=25)
if os.getenv("CI"):
settings.load_profile("CI")
else:
Expand All @@ -28,6 +31,23 @@ def test_has_version():
assert isinstance(__version__, str)


@pytest.mark.parametrize("variant_str", VARIANTS)
@given(arr=st_np.arrays(st_np.scalar_dtypes(), shape=st.integers(0, int(1e5))))
def test_variants_different_dtypes(variant_str, arr):
variant = getattr(cramjam, variant_str)
compressed = variant.compress(arr)
decompressed = variant.decompress(compressed)
assert same_same(bytes(decompressed), arr.tobytes())

# And compress n dims > 1
if arr.shape[0] % 2 == 0:
arr = arr.reshape((2, -1))
compressed = variant.compress(arr)
decompressed = variant.decompress(compressed)
assert same_same(bytes(decompressed), arr.tobytes())



@pytest.mark.parametrize("is_bytearray", (True, False))
@pytest.mark.parametrize("variant_str", VARIANTS)
@given(uncompressed=st.binary(min_size=1))
Expand All @@ -44,7 +64,7 @@ def test_variants_simple(variant_str, is_bytearray, uncompressed: bytes):
assert isinstance(compressed, cramjam.Buffer)

decompressed = variant.decompress(compressed, output_len=len(uncompressed))
assert decompressed.read() == uncompressed
assert same_same(decompressed.read(), uncompressed)
assert isinstance(decompressed, cramjam.Buffer)


Expand Down Expand Up @@ -262,7 +282,7 @@ def test_lz4_block(compress_kwargs):
lz4.compress_block(data, **compress_kwargs),
output_len=len(data) if not compress_kwargs["store_size"] else None,
)
assert bytes(out) == data
assert same_same(bytes(out), data)


@given(first=st.binary(), second=st.binary())
Expand All @@ -280,7 +300,7 @@ def test_gzip_multiple_streams(first: bytes, second: bytes):
o1 = bytes(cramjam.gzip.compress(first))
o2 = bytes(cramjam.gzip.compress(second))
out = bytes(cramjam.gzip.decompress(o1 + o2))
assert out == first + second
assert same_same(out, first + second)


@pytest.mark.parametrize(
Expand All @@ -307,7 +327,7 @@ def test_streams_compressor(mod, first: bytes, second: bytes):

out += bytes(compressor.finish())
decompressed = mod.decompress(out)
assert bytes(decompressed) == first + second
assert same_same(bytes(decompressed), first + second)

# just empty bytes after the first .finish()
# same behavior as brotli.Compressor()
Expand Down