From f73e6a4600fbfa795d500d45caef4d48f8c85eff Mon Sep 17 00:00:00 2001 From: oyvindln Date: Sun, 15 Dec 2024 21:13:54 +0100 Subject: [PATCH] fix(inflate): fill fast lookup table with invalid code value instead of zero so we can avoid check in hot code path givin a small performance boost --- miniz_oxide/src/inflate/core.rs | 178 +++++++++++++++++++------------- miniz_oxide/tests/test.rs | 14 +++ 2 files changed, 118 insertions(+), 74 deletions(-) diff --git a/miniz_oxide/src/inflate/core.rs b/miniz_oxide/src/inflate/core.rs index ac91201..146c4fd 100644 --- a/miniz_oxide/src/inflate/core.rs +++ b/miniz_oxide/src/inflate/core.rs @@ -77,17 +77,14 @@ impl HuffmanTable { /// /// It's possible we could avoid checking for 0 if we can guarantee a sane table. /// TODO: Check if a smaller type for code_len helps performance. - fn lookup(&self, bit_buf: BitBuffer) -> Option<(i32, u32)> { + fn lookup(&self, bit_buf: BitBuffer) -> (i32, u32) { let symbol = self.fast_lookup(bit_buf).into(); if symbol >= 0 { let length = (symbol >> 9) as u32; - match length { - 0 => None, - _ => Some((symbol, length)), - } + (symbol, length) } else { // We didn't get a symbol from the fast lookup table, so check the tree instead. - Some(self.tree_lookup(symbol, bit_buf, FAST_LOOKUP_BITS)) + self.tree_lookup(symbol, bit_buf, FAST_LOOKUP_BITS) } } } @@ -189,6 +186,8 @@ pub struct DecompressorOxide { /// 1 if the current block is the last block, 0 otherwise. finish: u8, /// The type of the current block. + /// or if in a dynamic block, which huffman table we are currently + // initializing. block_type: u8, /// 1 if the adler32 value should be checked. check_adler32: u32, @@ -337,7 +336,6 @@ enum State { BadCodeSizeDistPrevLookup, InvalidLitlen, InvalidDist, - InvalidCodeLen, } impl State { @@ -531,9 +529,9 @@ where // /* bit buffer contains >=15 bits (deflate's max. Huffman code size). */ loop { let mut temp = i32::from(r.tables[table].fast_lookup(l.bit_buf)); - if temp >= 0 { let code_len = (temp >> 9) as u32; + // TODO: Is there any point to check for code_len != 0 here still? if (code_len != 0) && (l.num_bits >= code_len) { break; } @@ -601,10 +599,6 @@ where code_len = res.1; }; - if code_len == 0 { - return Action::Jump(InvalidCodeLen); - } - l.bit_buf >>= code_len; l.num_bits -= code_len; f(r, l, symbol) @@ -739,16 +733,23 @@ fn init_tree(r: &mut DecompressorOxide, l: &mut LocalVars) -> Option { let bt = r.block_type as usize; let code_sizes = match bt { - 0 => &mut r.code_size_literal[..], - 1 => &mut r.code_size_dist, - 2 => &mut r.code_size_huffman, + LITLEN_TABLE => &mut r.code_size_literal[..], + DIST_TABLE => &mut r.code_size_dist, + HUFFLEN_TABLE => &mut r.code_size_huffman, _ => return None, }; let table = &mut r.tables[bt]; let mut total_symbols = [0u16; 16]; let mut next_code = [0u32; 17]; - memset(&mut table.look_up[..], 0); + const INVALID_CODE: i16 = 1 << 9 | 286; + // Set the values in the fast table to return a + // non-zero length and an invalid symbol instead of zero + // so that we do not have to have a check for a zero + // code length in the hot code path later + // and can instead error out on the invalid symbol check + // on bogus input. + memset(&mut table.look_up[..], INVALID_CODE); memset(&mut table.tree[..], 0); let table_size = r.table_sizes[bt] as usize; @@ -765,6 +766,7 @@ fn init_tree(r: &mut DecompressorOxide, l: &mut LocalVars) -> Option { let mut used_symbols = 0; let mut total = 0u32; + // Count up the total number of used lengths and check that the table is not under or over-subscribed. for (&ts, next) in total_symbols.iter().zip(next_code[1..].iter_mut()).skip(1) { used_symbols += ts; total += u32::from(ts); @@ -772,7 +774,20 @@ fn init_tree(r: &mut DecompressorOxide, l: &mut LocalVars) -> Option { *next = total; } - if total != 65_536 && used_symbols > 1 { + // + // While it's not explicitly stated in the spec, a hufflen table + // with a single length (or none) would be invalid as there needs to be + // at minimum a length for both a non-zero length huffman code for the end of block symbol + // and one of the codes to represent 0 to make sense - so just reject that here as well. + // + // The distance table is allowed to have a single distance code though according to the spect it is + // supposed to be accompanied by a second dummy code. It can also be empty indicating no used codes. + // + // The literal/length table can not be empty as there has to be an end of block symbol, + // The standard doesn't specify that there should be a dummy code in case of a single + // symbol (i.e an empty block). Normally that's not an issue though the code will have + // to take that into account later on in case of malformed input. + if total != 65_536 && (used_symbols > 1 || bt == HUFFLEN_TABLE) { return Some(Action::Jump(BadTotalSymbols)); } @@ -809,7 +824,7 @@ fn init_tree(r: &mut DecompressorOxide, l: &mut LocalVars) -> Option { } let mut tree_cur = table.look_up[(rev_code & (FAST_LOOKUP_SIZE - 1)) as usize]; - if tree_cur == 0 { + if tree_cur == INVALID_CODE { table.look_up[(rev_code & (FAST_LOOKUP_SIZE - 1)) as usize] = tree_next; tree_cur = tree_next; tree_next -= 2; @@ -841,12 +856,12 @@ fn init_tree(r: &mut DecompressorOxide, l: &mut LocalVars) -> Option { table.tree[tree_index] = symbol_index as i16; } - if r.block_type == 2 { + if r.block_type == HUFFLEN_TABLE as u8 { l.counter = 0; return Some(Action::Jump(ReadLitlenDistTablesCodeSize)); } - if r.block_type == 0 { + if r.block_type == LITLEN_TABLE as u8 { break; } r.block_type -= 1; @@ -1036,43 +1051,35 @@ fn decompress_fast( fill_bit_buffer(&mut l, in_iter); - if let Some((symbol, code_len)) = r.tables[LITLEN_TABLE].lookup(l.bit_buf) { - l.counter = symbol as u32; + let (symbol, code_len) = r.tables[LITLEN_TABLE].lookup(l.bit_buf); + l.counter = symbol as u32; + l.bit_buf >>= code_len; + l.num_bits -= code_len; + + if (l.counter & 256) != 0 { + // The symbol is not a literal. + break; + } else { + // If we have a 32-bit buffer we need to read another two bytes now + // to have enough bits to keep going. + if cfg!(not(target_pointer_width = "64")) { + fill_bit_buffer(&mut l, in_iter); + } + + let (symbol, code_len) = r.tables[LITLEN_TABLE].lookup(l.bit_buf); l.bit_buf >>= code_len; l.num_bits -= code_len; - - if (l.counter & 256) != 0 { - // The symbol is not a literal. + // The previous symbol was a literal, so write it directly and check + // the next one. + out_buf.write_byte(l.counter as u8); + if (symbol & 256) != 0 { + l.counter = symbol as u32; + // The symbol is a length value. break; } else { - // If we have a 32-bit buffer we need to read another two bytes now - // to have enough bits to keep going. - if cfg!(not(target_pointer_width = "64")) { - fill_bit_buffer(&mut l, in_iter); - } - - if let Some((symbol, code_len)) = r.tables[LITLEN_TABLE].lookup(l.bit_buf) { - l.bit_buf >>= code_len; - l.num_bits -= code_len; - // The previous symbol was a literal, so write it directly and check - // the next one. - out_buf.write_byte(l.counter as u8); - if (symbol & 256) != 0 { - l.counter = symbol as u32; - // The symbol is a length value. - break; - } else { - // The symbol is a literal, so write it directly and continue. - out_buf.write_byte(symbol as u8); - } - } else { - state.begin(InvalidCodeLen); - break 'o TINFLStatus::Failed; - } + // The symbol is a literal, so write it directly and continue. + out_buf.write_byte(symbol as u8); } - } else { - state.begin(InvalidCodeLen); - break 'o TINFLStatus::Failed; } } @@ -1113,22 +1120,18 @@ fn decompress_fast( fill_bit_buffer(&mut l, in_iter); } - if let Some((mut symbol, code_len)) = r.tables[DIST_TABLE].lookup(l.bit_buf) { - symbol &= 511; - l.bit_buf >>= code_len; - l.num_bits -= code_len; - if symbol > 29 { - state.begin(InvalidDist); - break 'o TINFLStatus::Failed; - } - - l.num_extra = num_extra_bits_for_distance_code(symbol as u8); - l.dist = u32::from(DIST_BASE[symbol as usize]); - } else { - state.begin(InvalidCodeLen); + let (mut symbol, code_len) = r.tables[DIST_TABLE].lookup(l.bit_buf); + symbol &= 511; + l.bit_buf >>= code_len; + l.num_bits -= code_len; + if symbol > 29 { + state.begin(InvalidDist); break 'o TINFLStatus::Failed; } + l.num_extra = num_extra_bits_for_distance_code(symbol as u8); + l.dist = u32::from(DIST_BASE[symbol as usize]); + if l.num_extra != 0 { fill_bit_buffer(&mut l, in_iter); let extra_bits = l.bit_buf & ((1 << l.num_extra) - 1); @@ -1544,7 +1547,7 @@ pub fn decompress( } else { fill_bit_buffer(&mut l, &mut in_iter); - if let Some((symbol, code_len)) = r.tables[LITLEN_TABLE].lookup(l.bit_buf) { + let (symbol, code_len) = r.tables[LITLEN_TABLE].lookup(l.bit_buf); l.counter = symbol as u32; l.bit_buf >>= code_len; @@ -1560,7 +1563,7 @@ pub fn decompress( fill_bit_buffer(&mut l, &mut in_iter); } - if let Some((symbol, code_len)) = r.tables[LITLEN_TABLE].lookup(l.bit_buf) { + let (symbol, code_len) = r.tables[LITLEN_TABLE].lookup(l.bit_buf); l.bit_buf >>= code_len; l.num_bits -= code_len; @@ -1576,13 +1579,9 @@ pub fn decompress( out_buf.write_byte(symbol as u8); Action::None } - } else { - Action::Jump(InvalidCodeLen) - } - } - } else { - Action::Jump(InvalidCodeLen) + } + } }), @@ -1925,7 +1924,7 @@ mod test { } fn masked_lookup(table: &HuffmanTable, bit_buf: BitBuffer) -> (i32, u32) { - let ret = table.lookup(bit_buf).unwrap(); + let ret = table.lookup(bit_buf); (ret.0 & 511, ret.1) } @@ -2097,4 +2096,35 @@ mod test { assert_eq!(dist, num_extra_bits_for_distance_code(i as u8)); } } + + #[test] + fn check_tree() { + let mut r = DecompressorOxide::new(); + let mut l = LocalVars { + bit_buf: 0, + num_bits: 0, + dist: 0, + counter: 0, + num_extra: 0, + }; + + r.code_size_huffman[0] = 1; + r.code_size_huffman[1] = 1; + //r.code_size_huffman[2] = 3; + //r.code_size_huffman[3] = 3; + //r.code_size_huffman[1] = 4; + r.block_type = HUFFLEN_TABLE as u8; + r.table_sizes[HUFFLEN_TABLE] = 4; + let res = init_tree(&mut r, &mut l).unwrap(); + + let status = match res { + Action::Jump(s) => s, + _ => { + //println!("issue"); + return; + } + }; + //println!("status {:?}", status); + assert!(status != BadTotalSymbols); + } } diff --git a/miniz_oxide/tests/test.rs b/miniz_oxide/tests/test.rs index a5117d9..d8ac979 100644 --- a/miniz_oxide/tests/test.rs +++ b/miniz_oxide/tests/test.rs @@ -250,6 +250,20 @@ fn issue_143_return_buf_error_on_finish_without_end_header() { assert_eq!(inflate_result.status.unwrap_err(), MZError::Buf) } +#[test] +fn decompress_empty_dynamic() { + // Empty block with dynamic huffman codes. + let enc = vec![5, 192, 129, 8, 0, 0, 0, 0, 32, 127, 235, 0b011, 0, 0, 0]; + + let res = decompress_to_vec(enc.as_slice()).unwrap(); + assert!(res.is_empty()); + + let enc = vec![5, 192, 129, 8, 0, 0, 0, 0, 32, 127, 235, 0b1111011, 0, 0, 0]; + + let res = decompress_to_vec(enc.as_slice()); + assert!(res.is_err()); +} + /* #[test] fn partial_decompression_imap_issue_158() {