From 5178e6aa718e39536316730a658ab58880e447bc Mon Sep 17 00:00:00 2001 From: oyvindln Date: Wed, 20 Nov 2024 00:39:42 +0100 Subject: [PATCH] fix(inflate): use inputwrapper struct instead of iter to simplify input reading and change some data types for performance --- miniz_oxide/src/inflate/core.rs | 112 +++++++++++------------ miniz_oxide/src/inflate/output_buffer.rs | 58 +++++++++++- 2 files changed, 105 insertions(+), 65 deletions(-) diff --git a/miniz_oxide/src/inflate/core.rs b/miniz_oxide/src/inflate/core.rs index 857e0d6e..595c81cd 100644 --- a/miniz_oxide/src/inflate/core.rs +++ b/miniz_oxide/src/inflate/core.rs @@ -4,10 +4,10 @@ use super::*; use crate::shared::{update_adler32, HUFFMAN_LENGTH_ORDER}; use ::core::cell::Cell; +use ::core::cmp; use ::core::convert::TryInto; -use ::core::{cmp, slice}; -use self::output_buffer::OutputBuffer; +use self::output_buffer::{InputWrapper, OutputBuffer}; pub const TINFL_LZ_DICT_SIZE: usize = 32_768; @@ -47,7 +47,7 @@ impl HuffmanTable { /// Get the symbol and the code length from the huffman tree. #[inline] - fn tree_lookup(&self, fast_symbol: i32, bit_buf: BitBuffer, mut code_len: u32) -> (i32, u32) { + fn tree_lookup(&self, fast_symbol: i32, bit_buf: BitBuffer, mut code_len: u8) -> (i32, u32) { let mut symbol = fast_symbol; // We step through the tree until we encounter a positive value, which indicates a // symbol. @@ -65,7 +65,9 @@ impl HuffmanTable { break; } } - (symbol, code_len) + // Note: Using a u8 for code_len inside this function seems to improve performance, but changing it + // in localvars seems to worsen things so we convert it to a u32 here. + (symbol, u32::from(code_len)) } #[inline] @@ -87,7 +89,7 @@ impl HuffmanTable { } } else { // We didn't get a symbol from the fast lookup table, so check the tree instead. - Some(self.tree_lookup(symbol, bit_buf, FAST_LOOKUP_BITS.into())) + Some(self.tree_lookup(symbol, bit_buf, FAST_LOOKUP_BITS)) } } } @@ -370,10 +372,12 @@ const DIST_BASE: [u16; 30] = [ /// Get the number of extra bits used for a distance code. /// (Code numbers above `NUM_DISTANCE_CODES` will give some garbage /// value.) +#[inline(always)] const fn num_extra_bits_for_distance_code(code: u8) -> u8 { + // TODO: Need to verify that this is faster on all platforms. // This can be easily calculated without a lookup. let c = code >> 1; - c - (c != 0) as u8 + c.saturating_sub(1) } /// The mask used when indexing the base/extra arrays. @@ -392,27 +396,12 @@ fn memset(slice: &mut [T], val: T) { /// # Panics /// Panics if there are less than two bytes left. #[inline] -fn read_u16_le(iter: &mut slice::Iter) -> u16 { +fn read_u16_le(iter: &mut InputWrapper) -> u16 { let ret = { - let two_bytes = iter.as_ref()[..2].try_into().unwrap(); + let two_bytes = iter.as_slice()[..2].try_into().unwrap_or_default(); u16::from_le_bytes(two_bytes) }; - iter.nth(1); - ret -} - -/// Read an le u32 value from the slice iterator. -/// -/// # Panics -/// Panics if there are less than four bytes left. -#[inline(always)] -#[cfg(target_pointer_width = "64")] -fn read_u32_le(iter: &mut slice::Iter) -> u32 { - let ret = { - let four_bytes: [u8; 4] = iter.as_ref()[..4].try_into().unwrap(); - u32::from_le_bytes(four_bytes) - }; - iter.nth(3); + iter.advance(2); ret } @@ -423,10 +412,10 @@ fn read_u32_le(iter: &mut slice::Iter) -> u32 { /// This function assumes that there is at least 4 bytes left in the input buffer. #[inline(always)] #[cfg(target_pointer_width = "64")] -fn fill_bit_buffer(l: &mut LocalVars, in_iter: &mut slice::Iter) { +fn fill_bit_buffer(l: &mut LocalVars, in_iter: &mut InputWrapper) { // Read four bytes into the buffer at once. if l.num_bits < 30 { - l.bit_buf |= BitBuffer::from(read_u32_le(in_iter)) << l.num_bits; + l.bit_buf |= BitBuffer::from(in_iter.read_u32_le()) << l.num_bits; l.num_bits += 32; } } @@ -491,7 +480,7 @@ fn decode_huffman_code( l: &mut LocalVars, table: usize, flags: u32, - in_iter: &mut slice::Iter, + in_iter: &mut InputWrapper, f: F, ) -> Action where @@ -501,7 +490,7 @@ where // ready in the bit buffer to start decoding the next huffman code. if l.num_bits < 15 { // First, make sure there is enough data in the bit buffer to decode a huffman code. - if in_iter.len() < 2 { + if in_iter.bytes_left() < 2 { // If there is less than 2 bytes left in the input buffer, we try to look up // the huffman code with what's available, and return if that doesn't succeed. // Original explanation in miniz: @@ -581,7 +570,7 @@ where // Mask out the length value. symbol &= 511; } else { - let res = r.tables[table].tree_lookup(symbol, l.bit_buf, u32::from(FAST_LOOKUP_BITS)); + let res = r.tables[table].tree_lookup(symbol, l.bit_buf, FAST_LOOKUP_BITS); symbol = res.0; code_len = res.1; }; @@ -599,13 +588,13 @@ where /// returning the result. /// If reading fails, `Action::End is returned` #[inline] -fn read_byte(in_iter: &mut slice::Iter, flags: u32, f: F) -> Action +fn read_byte(in_iter: &mut InputWrapper, flags: u32, f: F) -> Action where F: FnOnce(u8) -> Action, { - match in_iter.next() { + match in_iter.read_byte() { None => end_of_input(flags), - Some(&byte) => f(byte), + Some(byte) => f(byte), } } @@ -618,7 +607,7 @@ where fn read_bits( l: &mut LocalVars, amount: u32, - in_iter: &mut slice::Iter, + in_iter: &mut InputWrapper, flags: u32, f: F, ) -> Action @@ -647,7 +636,7 @@ where } #[inline] -fn pad_to_bytes(l: &mut LocalVars, in_iter: &mut slice::Iter, flags: u32, f: F) -> Action +fn pad_to_bytes(l: &mut LocalVars, in_iter: &mut InputWrapper, flags: u32, f: F) -> Action where F: FnOnce(&mut LocalVars) -> Action, { @@ -854,7 +843,7 @@ struct LocalVars { pub num_bits: u32, pub dist: u32, pub counter: u32, - pub num_extra: u32, + pub num_extra: u8, } #[inline] @@ -981,7 +970,7 @@ fn apply_match( /// and already improves decompression speed a fair bit. fn decompress_fast( r: &mut DecompressorOxide, - in_iter: &mut slice::Iter, + in_iter: &mut InputWrapper, out_buf: &mut OutputBuffer, flags: u32, local_vars: &mut LocalVars, @@ -1001,7 +990,7 @@ fn decompress_fast( // + 29 + 32 (left in bit buf, including last 13 dist extra) = 111 bits < 14 bytes // We need the one extra byte as we may write one length and one full match // before checking again. - if out_buf.bytes_left() < 259 || in_iter.len() < 14 { + if out_buf.bytes_left() < 259 || in_iter.bytes_left() < 14 { state = State::DecodeLitlen; break 'o TINFLStatus::Done; } @@ -1063,18 +1052,19 @@ fn decompress_fast( // The symbol was a length code. // # Optimization // Mask the value to avoid bounds checks - // We could use get_unchecked later if can statically verify that - // this will never go out of bounds. - l.num_extra = u32::from(LENGTH_EXTRA[(l.counter - 257) as usize & BASE_EXTRA_MASK]); + // While the maximum is checked, the compiler isn't able to know that the + // value won't wrap around here. + l.num_extra = LENGTH_EXTRA[(l.counter - 257) as usize & BASE_EXTRA_MASK]; l.counter = u32::from(LENGTH_BASE[(l.counter - 257) as usize & BASE_EXTRA_MASK]); // Length and distance codes have a number of extra bits depending on // the base, which together with the base gives us the exact value. + // We need to make sure we have at least 33 (so min 5 bytes) bits in the buffer at this spot. fill_bit_buffer(&mut l, in_iter); if l.num_extra != 0 { let extra_bits = l.bit_buf & ((1 << l.num_extra) - 1); l.bit_buf >>= l.num_extra; - l.num_bits -= l.num_extra; + l.num_bits -= u32::from(l.num_extra); l.counter += extra_bits as u32; } @@ -1093,7 +1083,7 @@ fn decompress_fast( break 'o TINFLStatus::Failed; } - l.num_extra = u32::from(num_extra_bits_for_distance_code(symbol as u8)); + l.num_extra = num_extra_bits_for_distance_code(symbol as u8); l.dist = u32::from(DIST_BASE[symbol as usize]); } else { state.begin(InvalidCodeLen); @@ -1104,7 +1094,7 @@ fn decompress_fast( fill_bit_buffer(&mut l, in_iter); let extra_bits = l.bit_buf & ((1 << l.num_extra) - 1); l.bit_buf >>= l.num_extra; - l.num_bits -= l.num_extra; + l.num_bits -= u32::from(l.num_extra); l.dist += extra_bits as u32; } @@ -1194,7 +1184,7 @@ pub fn decompress( return (TINFLStatus::BadParam, 0, 0); } - let mut in_iter = in_buf.iter(); + let mut in_iter = InputWrapper::from_slice(in_buf); let mut state = r.state; @@ -1206,7 +1196,7 @@ pub fn decompress( num_bits: r.num_bits, dist: r.dist, counter: r.counter, - num_extra: r.num_extra, + num_extra: r.num_extra as u8, }; let mut status = 'state_machine: loop { @@ -1351,20 +1341,20 @@ pub fn decompress( }), RawMemcpy2 => generate_state!(state, 'state_machine, { - if in_iter.len() > 0 { + if in_iter.bytes_left() > 0 { // Copy as many raw bytes as possible from the input to the output using memcpy. // Raw block lengths are limited to 64 * 1024, so casting through usize and u32 // is not an issue. let space_left = out_buf.bytes_left(); let bytes_to_copy = cmp::min(cmp::min( space_left, - in_iter.len()), + in_iter.bytes_left()), l.counter as usize ); out_buf.write_slice(&in_iter.as_slice()[..bytes_to_copy]); - in_iter.nth(bytes_to_copy - 1); + in_iter.advance(bytes_to_copy); l.counter -= bytes_to_copy as u32; Action::Jump(RawMemcpy1) } else { @@ -1456,7 +1446,7 @@ pub fn decompress( }), ReadExtraBitsCodeSize => generate_state!(state, 'state_machine, { - let num_extra = l.num_extra; + let num_extra = l.num_extra.into(); read_bits(&mut l, num_extra, &mut in_iter, flags, |l, mut extra_bits| { // Mask to avoid a bounds check. extra_bits += [3, 3, 11][(l.dist as usize - 16) & 3]; @@ -1478,7 +1468,7 @@ pub fn decompress( }), DecodeLitlen => generate_state!(state, 'state_machine, { - if in_iter.len() < 4 || out_buf.bytes_left() < 2 { + if in_iter.bytes_left() < 4 || out_buf.bytes_left() < 2 { // See if we can decode a literal with the data we have left. // Jumps to next state (WriteSymbol) if successful. decode_huffman_code( @@ -1496,7 +1486,7 @@ pub fn decompress( // If there is enough space, use the fast inner decompression // function. out_buf.bytes_left() >= 259 && - in_iter.len() >= 14 + in_iter.bytes_left() >= 14 { let (status, new_state) = decompress_fast( r, @@ -1587,7 +1577,7 @@ pub fn decompress( // We could use get_unchecked later if can statically verify that // this will never go out of bounds. l.num_extra = - u32::from(LENGTH_EXTRA[(l.counter - 257) as usize & BASE_EXTRA_MASK]); + LENGTH_EXTRA[(l.counter - 257) as usize & BASE_EXTRA_MASK]; l.counter = u32::from(LENGTH_BASE[(l.counter - 257) as usize & BASE_EXTRA_MASK]); // Length and distance codes have a number of extra bits depending on // the base, which together with the base gives us the exact value. @@ -1600,7 +1590,7 @@ pub fn decompress( }), ReadExtraBitsLitlen => generate_state!(state, 'state_machine, { - let num_extra = l.num_extra; + let num_extra = l.num_extra.into(); read_bits(&mut l, num_extra, &mut in_iter, flags, |l, extra_bits| { l.counter += extra_bits as u32; Action::Jump(DecodeDistance) @@ -1622,7 +1612,7 @@ pub fn decompress( // Invalid distance code. return Action::Jump(InvalidDist) } - l.num_extra = u32::from(num_extra_bits_for_distance_code(symbol as u8)); + l.num_extra = num_extra_bits_for_distance_code(symbol as u8); l.dist = u32::from(DIST_BASE[symbol]); if l.num_extra != 0 { // ReadEXTRA_BITS_DISTACNE @@ -1634,7 +1624,7 @@ pub fn decompress( }), ReadExtraBitsDistance => generate_state!(state, 'state_machine, { - let num_extra = l.num_extra; + let num_extra = l.num_extra.into(); read_bits(&mut l, num_extra, &mut in_iter, flags, |l, extra_bits| { l.dist += extra_bits as u32; Action::Jump(HuffDecodeOuterLoop2) @@ -1710,9 +1700,9 @@ pub fn decompress( if r.finish != 0 { pad_to_bytes(&mut l, &mut in_iter, flags, |_| Action::None); - let in_consumed = in_buf.len() - in_iter.len(); + let in_consumed = in_buf.len() - in_iter.bytes_left(); let undo = undo_bytes(&mut l, in_consumed as u32) as usize; - in_iter = in_buf[in_consumed - undo..].iter(); + in_iter = InputWrapper::from_slice(in_buf[in_consumed - undo..].iter().as_slice()); l.bit_buf &= ((1 as BitBuffer) << l.num_bits) - 1; debug_assert_eq!(l.num_bits, 0); @@ -1765,7 +1755,7 @@ pub fn decompress( let in_undo = if status != TINFLStatus::NeedsMoreInput && status != TINFLStatus::FailedCannotMakeProgress { - undo_bytes(&mut l, (in_buf.len() - in_iter.len()) as u32) as usize + undo_bytes(&mut l, (in_buf.len() - in_iter.bytes_left()) as u32) as usize } else { 0 }; @@ -1785,7 +1775,7 @@ pub fn decompress( r.num_bits = l.num_bits; r.dist = l.dist; r.counter = l.counter; - r.num_extra = l.num_extra; + r.num_extra = l.num_extra.into(); r.bit_buf &= ((1 as BitBuffer) << r.num_bits) - 1; @@ -1816,7 +1806,7 @@ pub fn decompress( ( status, - in_buf.len() - in_iter.len() - in_undo, + in_buf.len() - in_iter.bytes_left() - in_undo, out_buf.position() - out_pos, ) } @@ -1911,7 +1901,7 @@ mod test { num_bits: d.num_bits, dist: d.dist, counter: d.counter, - num_extra: d.num_extra, + num_extra: d.num_extra as u8, }; init_tree(&mut d, &mut l).unwrap(); let llt = &d.tables[LITLEN_TABLE]; diff --git a/miniz_oxide/src/inflate/output_buffer.rs b/miniz_oxide/src/inflate/output_buffer.rs index 5218a807..49e54a5a 100644 --- a/miniz_oxide/src/inflate/output_buffer.rs +++ b/miniz_oxide/src/inflate/output_buffer.rs @@ -14,12 +14,12 @@ impl<'a> OutputBuffer<'a> { OutputBuffer { slice, position } } - #[inline] + #[inline(always)] pub const fn position(&self) -> usize { self.position } - #[inline] + #[inline(always)] pub fn set_position(&mut self, position: usize) { self.position = position; } @@ -48,13 +48,63 @@ impl<'a> OutputBuffer<'a> { self.slice.len() - self.position } - #[inline] + #[inline(always)] pub const fn get_ref(&self) -> &[u8] { self.slice } - #[inline] + #[inline(always)] pub fn get_mut(&mut self) -> &mut [u8] { self.slice } } + +/// A wrapper for the output slice used when decompressing. +/// +/// Using this rather than `Cursor` lets us implement the writing methods directly on +/// the buffer and lets us use a usize rather than u64 for the position which helps with +/// performance on 32-bit systems. +#[derive(Copy, Clone)] +pub struct InputWrapper<'a> { + slice: &'a [u8], +} + +impl<'a> InputWrapper<'a> { + #[inline(always)] + pub const fn as_slice(&self) -> &[u8] { + self.slice + } + + #[inline(always)] + pub const fn from_slice(slice: &'a [u8]) -> InputWrapper<'a> { + InputWrapper { slice } + } + + #[inline(always)] + pub fn advance(&mut self, steps: usize) { + self.slice = &self.slice[steps..]; + } + + #[inline] + pub fn read_byte(&mut self) -> Option { + self.slice.first().map(|n| { + self.advance(1); + *n + }) + } + + #[inline] + pub fn read_u32_le(&mut self) -> u32 { + let ret = { + let four_bytes: [u8; 4] = self.slice[..4].try_into().unwrap_or_default(); + u32::from_le_bytes(four_bytes) + }; + self.advance(4); + ret + } + + #[inline(always)] + pub const fn bytes_left(&self) -> usize { + self.slice.len() + } +}