Skip to content

Commit

Permalink
rust: implement ZstdCompressor::copy_stream()
Browse files Browse the repository at this point in the history
test_compressor.py::TestCompressor_copy_stream now passes!
  • Loading branch information
indygreg committed Jun 21, 2020
1 parent 03be2c4 commit 8286b51
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 4 deletions.
10 changes: 8 additions & 2 deletions rust-ext/src/compressionobj.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,17 @@ impl ZstdCompressionObj {
// TODO consider collecting chunks and joining
// TODO try to use zero copy into return value.
let mut compressed = Vec::new();
let write_size = zstd_safe::cstream_out_size();

let cctx = &state.cctx;
while !source.is_empty() {
let result = py
.allow_threads(|| {
cctx.compress_chunk(source, zstd_sys::ZSTD_EndDirective::ZSTD_e_continue)
cctx.compress_chunk(
source,
zstd_sys::ZSTD_EndDirective::ZSTD_e_continue,
write_size,
)
})
.or_else(|msg| {
Err(ZstdError::from_message(
Expand Down Expand Up @@ -114,14 +119,15 @@ impl ZstdCompressionObj {
state.finished = true;
}

let write_size = zstd_safe::cstream_out_size();
let cctx = &state.cctx;

// TODO avoid extra buffer copy.
let mut result = Vec::new();

loop {
let (chunk, _, call_again) = py
.allow_threads(|| cctx.compress_chunk(&[], flush_mode))
.allow_threads(|| cctx.compress_chunk(&[], flush_mode, write_size))
.or_else(|msg| {
Err(ZstdError::from_message(
py,
Expand Down
145 changes: 143 additions & 2 deletions rust-ext/src/compressor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ use crate::compressionobj::ZstdCompressionObj;
use crate::ZstdError;
use cpython::buffer::PyBuffer;
use cpython::exc::ValueError;
use cpython::{py_class, PyBytes, PyErr, PyModule, PyObject, PyResult, Python, PythonObject};
use cpython::{
py_class, ObjectProtocol, PyBytes, PyErr, PyModule, PyObject, PyResult, Python, PythonObject,
};
use std::cell::RefCell;
use std::marker::PhantomData;
use std::sync::Arc;
Expand Down Expand Up @@ -126,14 +128,15 @@ impl<'a> CCtx<'a> {
&self,
source: &'a [u8],
end_mode: zstd_sys::ZSTD_EndDirective,
output_size: usize,
) -> Result<(Vec<u8>, &'a [u8], bool), &'static str> {
let mut in_buffer = zstd_sys::ZSTD_inBuffer {
src: source.as_ptr() as *const _,
size: source.len() as _,
pos: 0,
};

let mut dest: Vec<u8> = Vec::with_capacity(zstd_safe::cstream_out_size());
let mut dest: Vec<u8> = Vec::with_capacity(output_size);

let mut out_buffer = zstd_sys::ZSTD_outBuffer {
dst: dest.as_mut_ptr() as *mut _,
Expand Down Expand Up @@ -225,6 +228,17 @@ py_class!(class ZstdCompressor |py| {
def compressobj(&self, size: Option<u64> = None) -> PyResult<ZstdCompressionObj> {
self.compressobj_impl(py, size)
}

def copy_stream(
&self,
ifh: PyObject,
ofh: PyObject,
size: Option<u64> = None,
read_size: Option<usize> = None,
write_size: Option<usize> = None
) -> PyResult<(usize, usize)> {
self.copy_stream_impl(py, ifh, ofh, size, read_size, write_size)
}
});

impl ZstdCompressor {
Expand Down Expand Up @@ -392,6 +406,133 @@ impl ZstdCompressor {

ZstdCompressionObj::new(py, state.cctx.clone())
}

fn copy_stream_impl(
&self,
py: Python,
source: PyObject,
dest: PyObject,
source_size: Option<u64>,
read_size: Option<usize>,
write_size: Option<usize>,
) -> PyResult<(usize, usize)> {
let state: std::cell::Ref<CompressorState> = self.state(py).borrow();

let source_size = if let Some(source_size) = source_size {
source_size
} else {
zstd_safe::CONTENTSIZE_UNKNOWN
};

let read_size = read_size.unwrap_or_else(|| zstd_safe::cstream_in_size());
let write_size = write_size.unwrap_or_else(|| zstd_safe::cstream_out_size());

if !source.hasattr(py, "read")? {
return Err(PyErr::new::<ValueError, _>(
py,
"first argument must have a read() method",
));
}

if !dest.hasattr(py, "write")? {
return Err(PyErr::new::<ValueError, _>(
py,
"second argument must have a write() method",
));
}

state.cctx.reset();
state
.cctx
.set_pledged_source_size(source_size)
.or_else(|msg| {
Err(ZstdError::from_message(
py,
format!("error setting source size: {}", msg).as_ref(),
))
})?;

let mut total_read = 0;
let mut total_write = 0;

loop {
// Try to read from source stream.
let read_object = source
.call_method(py, "read", (read_size,), None)
.or_else(|_| Err(ZstdError::from_message(py, "could not read() from source")))?;

let read_bytes = read_object.cast_into::<PyBytes>(py)?;
let read_data = read_bytes.data(py);

// If no data was read we are at EOF.
if read_data.len() == 0 {
break;
}

total_read += read_data.len();

// Send data to compressor.

let mut source = read_data;
let cctx = &state.cctx;

while !source.is_empty() {
let result = py
.allow_threads(|| {
cctx.compress_chunk(
source,
zstd_sys::ZSTD_EndDirective::ZSTD_e_continue,
write_size,
)
})
.or_else(|msg| {
Err(ZstdError::from_message(
py,
format!("zstd compress error: {}", msg).as_ref(),
))
})?;

source = result.1;

let chunk = &result.0;

if !chunk.is_empty() {
// TODO avoid buffer copy.
let data = PyBytes::new(py, chunk);
dest.call_method(py, "write", (data,), None)?;
total_write += chunk.len();
}
}
}

// We've finished reading. Now flush the compressor stream.
loop {
let result = state
.cctx
.compress_chunk(&[], zstd_sys::ZSTD_EndDirective::ZSTD_e_end, write_size)
.or_else(|msg| {
Err(ZstdError::from_message(
py,
format!("error ending compression stream: {}", msg).as_ref(),
))
})?;

let chunk = &result.0;

if !chunk.is_empty() {
// TODO avoid buffer copy.
let data = PyBytes::new(py, &chunk);
dest.call_method(py, "write", (&data,), None)?;
total_write += chunk.len();
}

if !result.2 {
break;
}
}

Ok((total_read, total_write))
}
}

pub(crate) fn init_module(py: Python, module: &PyModule) -> PyResult<()> {
Expand Down

0 comments on commit 8286b51

Please sign in to comment.