Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Miri stacked borrows improvements for inflate #237

Merged
merged 5 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions zlib-rs/src/allocate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,16 @@ impl<'a> Allocator<'a> {
ptr
}

pub fn allocate_raw<T>(&self) -> Option<*mut T> {
let ptr = self.allocate_layout(Layout::new::<T>());

if ptr.is_null() {
None
} else {
Some(ptr as *mut T)
}
}

pub fn allocate<T>(&self) -> Option<&'a mut MaybeUninit<T>> {
let ptr = self.allocate_layout(Layout::new::<T>());

Expand Down
37 changes: 19 additions & 18 deletions zlib-rs/src/inflate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1836,11 +1836,12 @@ pub fn init(stream: &mut z_stream, config: InflateConfig) -> ReturnCode {
};

// allocated here to have the same order as zlib
let Some(state_allocation) = alloc.allocate::<State>() else {
let Some(state_allocation) = alloc.allocate_raw::<State>() else {
return ReturnCode::MemError;
};

stream.state = state_allocation.write(state) as *mut _ as *mut internal_state;
unsafe { state_allocation.write(state) };
stream.state = state_allocation as *mut internal_state;

// SAFETY: we've correctly initialized the stream to be an InflateStream
let ret = if let Some(stream) = unsafe { InflateStream::from_stream_mut(stream) } {
Expand Down Expand Up @@ -1889,9 +1890,9 @@ pub fn reset_with_config(stream: &mut InflateStream, config: InflateConfig) -> R
let mut window = Window::empty();
core::mem::swap(&mut window, &mut stream.state.window);

let window = window.into_inner();
assert!(!window.is_empty());
unsafe { stream.alloc.deallocate(window.as_mut_ptr(), window.len()) };
let (ptr, len) = window.into_raw_parts();
assert_ne!(len, 0);
unsafe { stream.alloc.deallocate(ptr, len) };
}

stream.state.wrap = wrap as u8;
Expand Down Expand Up @@ -1948,10 +1949,6 @@ pub unsafe fn inflate(stream: &mut InflateStream, flush: InflateFlush) -> Return
return ReturnCode::StreamError as _;
}

let source_slice = core::slice::from_raw_parts(stream.next_in, stream.avail_in as usize);
let dest_slice =
core::slice::from_raw_parts_mut(stream.next_out.cast(), stream.avail_out as usize);

let state = &mut stream.state;

// skip check
Expand All @@ -1961,8 +1958,12 @@ pub unsafe fn inflate(stream: &mut InflateStream, flush: InflateFlush) -> Return

state.flush = flush;

state.bit_reader.update_slice(source_slice);
state.writer = Writer::new_uninit(dest_slice);
unsafe {
state
.bit_reader
.update_slice(stream.next_in, stream.avail_in as usize)
};
state.writer = Writer::new_uninit(stream.next_out.cast(), stream.avail_out as usize);

state.in_available = stream.avail_in as _;
state.out_available = stream.avail_out as _;
Expand Down Expand Up @@ -2134,7 +2135,7 @@ pub unsafe fn copy<'a>(
}

// allocated here to have the same order as zlib
let Some(state_allocation) = source.alloc.allocate::<State>() else {
let Some(state_allocation) = source.alloc.allocate_raw::<State>() else {
return ReturnCode::MemError;
};

Expand Down Expand Up @@ -2183,19 +2184,19 @@ pub unsafe fn copy<'a>(

if !state.window.is_empty() {
let Some(window) = state.window.clone_in(&source.alloc) else {
source.alloc.deallocate(state_allocation.as_mut_ptr(), 1);
source.alloc.deallocate(state_allocation, 1);
return ReturnCode::MemError;
};

copy.window = window;
}

// write the cloned state into state_ptr
let state_ptr = state_allocation.write(copy);
unsafe { state_allocation.write(copy) };

// insert the state_ptr into `dest`
let field_ptr = unsafe { core::ptr::addr_of_mut!((*dest.as_mut_ptr()).state) };
unsafe { core::ptr::write(field_ptr as *mut *mut State, state_ptr) };
unsafe { core::ptr::write(field_ptr as *mut *mut State, state_allocation) };

// update the writer; it cannot be cloned so we need to use some shennanigans
let field_ptr = unsafe { core::ptr::addr_of_mut!((*dest.as_mut_ptr()).state.writer) };
Expand Down Expand Up @@ -2282,8 +2283,8 @@ pub fn end<'a>(stream: &'a mut InflateStream<'a>) -> &'a mut z_stream {

// safety: window is not used again
if !window.is_empty() {
let window = window.into_inner();
unsafe { alloc.deallocate(window.as_mut_ptr(), window.len()) };
let (ptr, len) = window.into_raw_parts();
unsafe { alloc.deallocate(ptr, len) };
}

let stream = stream.as_z_stream_mut();
Expand All @@ -2301,7 +2302,7 @@ pub fn end<'a>(stream: &'a mut InflateStream<'a>) -> &'a mut z_stream {
/// The caller must guarantee:
///
/// * If `head` is `Some`:
// - If `head.extra` is not NULL, it must be writable for at least `head.extra_max` bytes
/// - If `head.extra` is not NULL, it must be writable for at least `head.extra_max` bytes
/// - if `head.name` is not NULL, it must be writable for at least `head.name_max` bytes
/// - if `head.comment` is not NULL, it must be writable for at least `head.comm_max` bytes
pub unsafe fn get_header<'a>(
Expand Down
8 changes: 4 additions & 4 deletions zlib-rs/src/inflate/bitreader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@ impl<'a> BitReader<'a> {
}

#[inline(always)]
pub fn update_slice(&mut self, slice: &[u8]) {
let range = slice.as_ptr_range();
pub unsafe fn update_slice(&mut self, ptr: *const u8, len: usize) {
let end = ptr.wrapping_add(len);

*self = Self {
ptr: range.start,
end: range.end,
ptr,
end,
bit_buffer: self.bit_buffer,
bits_used: self.bits_used,
_marker: PhantomData,
Expand Down
72 changes: 40 additions & 32 deletions zlib-rs/src/inflate/window.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::{
adler32::{adler32, adler32_fold_copy},
allocate::Allocator,
crc32::Crc32Fold,
weak_slice::WeakSliceMut,
};

// translation guide:
Expand All @@ -11,30 +12,26 @@ use crate::{
// whave -> buf.filled.len()
#[derive(Debug)]
pub struct Window<'a> {
buf: &'a mut [u8],
buf: WeakSliceMut<'a, u8>,

have: usize, // number of bytes logically written to the window. this can be higher than
// buf.len() if we run out of space in the window
next: usize, // write head
}

impl<'a> Window<'a> {
pub fn into_inner(self) -> &'a mut [u8] {
self.buf
pub fn into_raw_parts(self) -> (*mut u8, usize) {
self.buf.into_raw_parts()
}

pub fn is_empty(&self) -> bool {
self.size() == 0
}

pub fn size(&self) -> usize {
if self.buf.is_empty() {
// an empty `buf` is used when the window has not yet been allocated,
// or when it has been deallocated.
0
} else {
self.buf.len() - Self::padding()
}
// `self.len == 0` is used for uninitialized buffers
assert!(self.buf.len() == 0 || self.buf.len() >= Self::padding());
self.buf.len().saturating_sub(Self::padding())
}

/// number of bytes in the window. Saturates at `Self::capacity`.
Expand All @@ -49,7 +46,7 @@ impl<'a> Window<'a> {

pub fn empty() -> Self {
Self {
buf: &mut [],
buf: WeakSliceMut::empty(),
have: 0,
next: 0,
}
Expand All @@ -61,7 +58,7 @@ impl<'a> Window<'a> {
}

pub fn as_slice(&self) -> &[u8] {
&self.buf[..self.have]
&self.buf.as_slice()[..self.have]
}

pub fn as_ptr(&self) -> *const u8 {
Expand Down Expand Up @@ -92,13 +89,13 @@ impl<'a> Window<'a> {
if update_checksum {
if flags != 0 {
crc_fold.fold(non_window_slice, 0);
crc_fold.fold_copy(&mut self.buf[..wsize], window_slice);
crc_fold.fold_copy(&mut self.buf.as_mut_slice()[..wsize], window_slice);
} else {
*checksum = adler32(*checksum, non_window_slice);
*checksum = adler32_fold_copy(*checksum, self.buf, window_slice);
*checksum = adler32_fold_copy(*checksum, self.buf.as_mut_slice(), window_slice);
}
} else {
self.buf[..wsize].copy_from_slice(window_slice);
self.buf.as_mut_slice()[..wsize].copy_from_slice(window_slice);
}

self.next = 0;
Expand All @@ -111,18 +108,18 @@ impl<'a> Window<'a> {
let (end_part, start_part) = slice.split_at(dist);

if update_checksum {
let dst = &mut self.buf[self.next..][..end_part.len()];
let dst = &mut self.buf.as_mut_slice()[self.next..][..end_part.len()];
if flags != 0 {
crc_fold.fold_copy(dst, end_part);
} else {
*checksum = adler32_fold_copy(*checksum, dst, end_part);
}
} else {
self.buf[self.next..][..end_part.len()].copy_from_slice(end_part);
self.buf.as_mut_slice()[self.next..][..end_part.len()].copy_from_slice(end_part);
}

if !start_part.is_empty() {
let dst = &mut self.buf[..start_part.len()];
let dst = &mut self.buf.as_mut_slice()[..start_part.len()];

if update_checksum {
if flags != 0 {
Expand Down Expand Up @@ -156,10 +153,8 @@ impl<'a> Window<'a> {
return None;
}

let buf = unsafe { core::slice::from_raw_parts_mut(ptr, len) };

Some(Self {
buf,
buf: unsafe { WeakSliceMut::from_raw_parts_mut(ptr, len) },
have: 0,
next: 0,
})
Expand All @@ -173,10 +168,8 @@ impl<'a> Window<'a> {
return None;
}

let buf = unsafe { core::slice::from_raw_parts_mut(ptr, len) };

Some(Self {
buf,
buf: unsafe { WeakSliceMut::from_raw_parts_mut(ptr, len) },
have: self.have,
next: self.next,
})
Expand Down Expand Up @@ -211,19 +204,24 @@ mod test {
assert_eq!(window.have, 5);
assert_eq!(window.next, 5);

let slice = &window.buf[..window.size()];
let slice = &window.buf.as_slice()[..window.size()];
assert_eq!(&[1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], slice);

window.extend_adler32(&[2; 7], &mut checksum);
assert_eq!(window.have, 12);
assert_eq!(window.next, 12);

let slice = &window.buf[..window.size()];
let slice = &window.buf.as_slice()[..window.size()];
assert_eq!(&[1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0], slice);

assert_eq!(checksum, 6946835);

unsafe { Allocator::RUST.deallocate(window.buf.as_mut_ptr(), window.buf.len()) }
unsafe {
Allocator::RUST.deallocate(
window.buf.as_mut_slice().as_mut_ptr(),
window.buf.as_slice().len(),
)
}
}

#[test]
Expand All @@ -236,19 +234,24 @@ mod test {
assert_eq!(window.have, 3);
assert_eq!(window.next, 3);

let slice = &window.buf[..window.size()];
let slice = &window.buf.as_slice()[..window.size()];
assert_eq!(&[1, 1, 1, 0], slice);

window.extend_adler32(&[2; 3], &mut checksum);
assert_eq!(window.have, 4);
assert_eq!(window.next, 2);

let slice = &window.buf[..window.size()];
let slice = &window.buf.as_slice()[..window.size()];
assert_eq!(&[2, 2, 1, 2], slice);

assert_eq!(checksum, 1769481);

unsafe { Allocator::RUST.deallocate(window.buf.as_mut_ptr(), window.buf.len()) }
unsafe {
Allocator::RUST.deallocate(
window.buf.as_mut_slice().as_mut_ptr(),
window.buf.as_slice().len(),
)
}
}

#[test]
Expand All @@ -262,11 +265,16 @@ mod test {
assert_eq!(window.have, 8);
assert_eq!(window.next, 0);

let slice = &window.buf[..window.size()];
let slice = &window.as_slice()[..window.size()];
assert_eq!(&[2, 3, 4, 5, 6, 7, 8, 9], slice);

assert_eq!(checksum, 10813485);

unsafe { Allocator::RUST.deallocate(window.buf.as_mut_ptr(), window.buf.len()) }
unsafe {
Allocator::RUST.deallocate(
window.buf.as_mut_slice().as_mut_ptr(),
window.as_slice().len(),
)
}
}
}
Loading
Loading