Skip to content

Commit

Permalink
rust: implement ZstdCompressionObj
Browse files Browse the repository at this point in the history
test_compressor.py::TestCompressor_compressobj now passes!

As part of this, we changed how the CCtx is stored in the compressor
so that we can share references to instances without having to share
Python object references, which necessitates implementing gc functions
on the Python types.
  • Loading branch information
indygreg committed Jun 21, 2020
1 parent 7f0a35f commit 03be2c4
Show file tree
Hide file tree
Showing 4 changed files with 240 additions and 5 deletions.
139 changes: 139 additions & 0 deletions rust-ext/src/compressionobj.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
// Copyright (c) 2020-present, Gregory Szorc
// All rights reserved.
//
// This software may be modified and distributed under the terms
// of the BSD license. See the LICENSE file for details.

use crate::compressor::CCtx;
use crate::constants::{COMPRESSOBJ_FLUSH_BLOCK, COMPRESSOBJ_FLUSH_FINISH};
use crate::ZstdError;
use cpython::buffer::PyBuffer;
use cpython::exc::ValueError;
use cpython::{py_class, PyBytes, PyErr, PyObject, PyResult, Python};
use std::cell::RefCell;
use std::sync::Arc;

pub struct CompressionObjState<'cctx> {
cctx: Arc<CCtx<'cctx>>,
finished: bool,
}

py_class!(pub class ZstdCompressionObj |py| {
data state: RefCell<CompressionObjState<'static>>;

def compress(&self, data: PyObject) -> PyResult<PyBytes> {
self.compress_impl(py, data)
}

def flush(&self, flush_mode: Option<i32> = None) -> PyResult<PyBytes> {
self.flush_impl(py, flush_mode)
}
});

impl ZstdCompressionObj {
pub fn new(py: Python, cctx: Arc<CCtx<'static>>) -> PyResult<ZstdCompressionObj> {
let state = CompressionObjState {
cctx,
finished: false,
};

Ok(ZstdCompressionObj::create_instance(
py,
RefCell::new(state),
)?)
}

fn compress_impl(&self, py: Python, data: PyObject) -> PyResult<PyBytes> {
let state: std::cell::Ref<CompressionObjState> = self.state(py).borrow();

if state.finished {
return Err(ZstdError::from_message(
py,
"cannot call compress() after compressor finished",
));
}

let buffer = PyBuffer::get(py, &data)?;

if !buffer.is_c_contiguous() || buffer.dimensions() > 1 {
return Err(PyErr::new::<ValueError, _>(
py,
"data buffer should be contiguous and have at most one dimension",
));
}

let mut source = unsafe {
std::slice::from_raw_parts::<u8>(buffer.buf_ptr() as *const _, buffer.len_bytes())
};

// TODO consider collecting chunks and joining
// TODO try to use zero copy into return value.
let mut compressed = Vec::new();

let cctx = &state.cctx;
while !source.is_empty() {
let result = py
.allow_threads(|| {
cctx.compress_chunk(source, zstd_sys::ZSTD_EndDirective::ZSTD_e_continue)
})
.or_else(|msg| {
Err(ZstdError::from_message(
py,
format!("zstd compress error: {}", msg).as_ref(),
))
})?;

compressed.extend(result.0);
source = result.1;
}

Ok(PyBytes::new(py, &compressed))
}

fn flush_impl(&self, py: Python, flush_mode: Option<i32>) -> PyResult<PyBytes> {
let mut state: std::cell::RefMut<CompressionObjState> = self.state(py).borrow_mut();

let flush_mode = if let Some(flush_mode) = flush_mode {
match flush_mode {
COMPRESSOBJ_FLUSH_FINISH => Ok(zstd_sys::ZSTD_EndDirective::ZSTD_e_end),
COMPRESSOBJ_FLUSH_BLOCK => Ok(zstd_sys::ZSTD_EndDirective::ZSTD_e_flush),
_ => Err(PyErr::new::<ValueError, _>(py, "flush mode not recognized")),
}
} else {
Ok(zstd_sys::ZSTD_EndDirective::ZSTD_e_end)
}?;

if state.finished {
return Err(ZstdError::from_message(
py,
"compressor object already finished",
));
}

if flush_mode == zstd_sys::ZSTD_EndDirective::ZSTD_e_end {
state.finished = true;
}

let cctx = &state.cctx;

// TODO avoid extra buffer copy.
let mut result = Vec::new();

loop {
let (chunk, _, call_again) = py
.allow_threads(|| cctx.compress_chunk(&[], flush_mode))
.or_else(|msg| {
Err(ZstdError::from_message(
py,
format!("error ending compression stream: {}", msg).as_ref(),
))
})?;

result.extend(&chunk);

if !call_again {
return Ok(PyBytes::new(py, &result));
}
}
}
}
98 changes: 95 additions & 3 deletions rust-ext/src/compressor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@

use crate::compression_dict::ZstdCompressionDict;
use crate::compression_parameters::{CCtxParams, ZstdCompressionParameters};
use crate::compressionobj::ZstdCompressionObj;
use crate::ZstdError;
use cpython::buffer::PyBuffer;
use cpython::exc::ValueError;
use cpython::{py_class, PyBytes, PyErr, PyModule, PyObject, PyResult, Python, PythonObject};
use std::cell::RefCell;
use std::marker::PhantomData;
use std::sync::Arc;

pub(crate) struct CCtx<'a>(*mut zstd_sys::ZSTD_CCtx, PhantomData<&'a ()>);
pub struct CCtx<'a>(*mut zstd_sys::ZSTD_CCtx, PhantomData<&'a ()>);

impl<'a> Drop for CCtx<'a> {
fn drop(&mut self) {
Expand Down Expand Up @@ -69,6 +71,10 @@ impl<'a> CCtx<'a> {
}
}

pub fn get_frame_progression(&self) -> zstd_sys::ZSTD_frameProgression {
unsafe { zstd_sys::ZSTD_getFrameProgression(self.0) }
}

pub fn compress(&self, source: &[u8]) -> Result<Vec<u8>, &'static str> {
self.reset();

Expand Down Expand Up @@ -111,13 +117,58 @@ impl<'a> CCtx<'a> {
Ok(dest)
}
}

/// Compress input data as part of a stream.
///
/// Returns a tuple of the emitted compressed data, a slice of unconsumed input,
/// and whether there is more work to be done.
pub fn compress_chunk(
&self,
source: &'a [u8],
end_mode: zstd_sys::ZSTD_EndDirective,
) -> Result<(Vec<u8>, &'a [u8], bool), &'static str> {
let mut in_buffer = zstd_sys::ZSTD_inBuffer {
src: source.as_ptr() as *const _,
size: source.len() as _,
pos: 0,
};

let mut dest: Vec<u8> = Vec::with_capacity(zstd_safe::cstream_out_size());

let mut out_buffer = zstd_sys::ZSTD_outBuffer {
dst: dest.as_mut_ptr() as *mut _,
size: dest.capacity(),
pos: 0,
};

let zresult = unsafe {
zstd_sys::ZSTD_compressStream2(
self.0,
&mut out_buffer as *mut _,
&mut in_buffer as *mut _,
end_mode,
)
};

if unsafe { zstd_sys::ZSTD_isError(zresult) } != 0 {
return Err(zstd_safe::get_error_name(zresult));
}

unsafe {
dest.set_len(out_buffer.pos);
}

let remaining = &source[in_buffer.pos..source.len()];

Ok((dest, remaining, zresult != 0))
}
}

struct CompressorState<'params, 'cctx> {
threads: i32,
dict: Option<ZstdCompressionDict>,
params: CCtxParams<'params>,
cctx: CCtx<'cctx>,
cctx: Arc<CCtx<'cctx>>,
}

impl<'params, 'cctx> CompressorState<'params, 'cctx> {
Expand Down Expand Up @@ -163,9 +214,17 @@ py_class!(class ZstdCompressor |py| {
Ok(self.state(py).borrow().cctx.memory_size())
}

def frame_progression(&self) -> PyResult<(usize, usize, usize)> {
self.frame_progression_impl(py)
}

def compress(&self, data: PyObject) -> PyResult<PyBytes> {
self.compress_impl(py, data)
}

def compressobj(&self, size: Option<u64> = None) -> PyResult<ZstdCompressionObj> {
self.compressobj_impl(py, size)
}
});

impl ZstdCompressor {
Expand Down Expand Up @@ -195,7 +254,7 @@ impl ZstdCompressor {
threads
};

let cctx = CCtx::new().or_else(|msg| Err(PyErr::new::<ZstdError, _>(py, msg)))?;
let cctx = Arc::new(CCtx::new().or_else(|msg| Err(PyErr::new::<ZstdError, _>(py, msg)))?);
let params = CCtxParams::create(py)?;

if let Some(ref compression_params) = compression_params {
Expand Down Expand Up @@ -273,6 +332,18 @@ impl ZstdCompressor {
Ok(ZstdCompressor::create_instance(py, RefCell::new(state))?.into_object())
}

fn frame_progression_impl(&self, py: Python) -> PyResult<(usize, usize, usize)> {
let state: std::cell::Ref<CompressorState> = self.state(py).borrow();

let progression = state.cctx.get_frame_progression();

Ok((
progression.ingested as usize,
progression.consumed as usize,
progression.produced as usize,
))
}

fn compress_impl(&self, py: Python, data: PyObject) -> PyResult<PyBytes> {
let state: std::cell::Ref<CompressorState> = self.state(py).borrow();

Expand Down Expand Up @@ -300,6 +371,27 @@ impl ZstdCompressor {

Ok(PyBytes::new(py, &data))
}

fn compressobj_impl(&self, py: Python, size: Option<u64>) -> PyResult<ZstdCompressionObj> {
let state: std::cell::Ref<CompressorState> = self.state(py).borrow();

state.cctx.reset();

let size = if let Some(size) = size {
size
} else {
zstd_safe::CONTENTSIZE_UNKNOWN
};

state.cctx.set_pledged_source_size(size).or_else(|msg| {
Err(ZstdError::from_message(
py,
format!("error setting source size: {}", msg).as_ref(),
))
})?;

ZstdCompressionObj::new(py, state.cctx.clone())
}
}

pub(crate) fn init_module(py: Python, module: &PyModule) -> PyResult<()> {
Expand Down
7 changes: 5 additions & 2 deletions rust-ext/src/constants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,18 @@

use cpython::{PyBytes, PyModule, PyResult, Python};

pub(crate) const COMPRESSOBJ_FLUSH_FINISH: i32 = 0;
pub(crate) const COMPRESSOBJ_FLUSH_BLOCK: i32 = 1;

pub(crate) fn init_module(py: Python, module: &PyModule) -> PyResult<()> {
module.add(py, "__version", super::VERSION)?;
module.add(py, "__doc__", "Rust backend for zstandard bindings")?;

module.add(py, "FLUSH_BLOCK", 0)?;
module.add(py, "FLUSH_FRAME", 1)?;

module.add(py, "COMPRESSOBJ_FLUSH_FINISH", 0)?;
module.add(py, "COMPRESSOBJ_FLUSH_BLOCK", 1)?;
module.add(py, "COMPRESSOBJ_FLUSH_FINISH", COMPRESSOBJ_FLUSH_FINISH)?;
module.add(py, "COMPRESSOBJ_FLUSH_BLOCK", COMPRESSOBJ_FLUSH_BLOCK)?;

module.add(
py,
Expand Down
1 change: 1 addition & 0 deletions rust-ext/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use cpython::{py_module_initializer, PyModule, PyResult, Python};

mod compression_dict;
mod compression_parameters;
mod compressionobj;
mod compressor;
mod constants;
mod exceptions;
Expand Down

0 comments on commit 03be2c4

Please sign in to comment.