diff --git a/perf/src/sigverify.rs b/perf/src/sigverify.rs index 9b797d3bde5829..44f0cf226b73df 100644 --- a/perf/src/sigverify.rs +++ b/perf/src/sigverify.rs @@ -131,7 +131,8 @@ fn do_get_packet_offsets( } // read the length of Transaction.signatures (serialized with short_vec) - let (sig_len_untrusted, sig_size) = decode_len(&packet.data)?; + let (sig_len_untrusted, sig_size) = + decode_len(&packet.data).map_err(|_| PacketError::InvalidShortVec)?; // Using msg_start_offset which is based on sig_len_untrusted introduces uncertainty. // Ultimately, the actual sigverify will determine the uncertainty. @@ -156,8 +157,8 @@ fn do_get_packet_offsets( } // read the length of Message.account_keys (serialized with short_vec) - let (pubkey_len, pubkey_len_size) = - decode_len(&packet.data[message_account_keys_len_offset..])?; + let (pubkey_len, pubkey_len_size) = decode_len(&packet.data[message_account_keys_len_offset..]) + .map_err(|_| PacketError::InvalidShortVec)?; if (message_account_keys_len_offset + pubkey_len * size_of::() + pubkey_len_size) > packet.meta.size diff --git a/sdk/src/short_vec.rs b/sdk/src/short_vec.rs index 3f9e94e3dc769c..4e0c59dee2e6d7 100644 --- a/sdk/src/short_vec.rs +++ b/sdk/src/short_vec.rs @@ -38,6 +38,26 @@ impl Serialize for ShortU16 { } } +enum VisitResult { + Done(usize, usize), + More(usize, usize), + Err, +} + +fn visit_byte(elem: u8, len: usize, size: usize) -> VisitResult { + let len = len | (elem as usize & 0x7f) << (size * 7); + let size = size + 1; + let more = elem as usize & 0x80 == 0x80; + + if size > size_of::() + 1 { + VisitResult::Err + } else if more { + VisitResult::More(len, size) + } else { + VisitResult::Done(len, size) + } +} + struct ShortLenVisitor; impl<'de> Visitor<'de> for ShortLenVisitor { @@ -58,15 +78,16 @@ impl<'de> Visitor<'de> for ShortLenVisitor { .next_element()? .ok_or_else(|| de::Error::invalid_length(size, &self))?; - len |= (elem as usize & 0x7f) << (size * 7); - size += 1; - - if elem as usize & 0x80 == 0 { - break; - } - - if size > size_of::() + 1 { - return Err(de::Error::invalid_length(size, &self)); + match visit_byte(elem, len, size) { + VisitResult::Done(l, _) => { + len = l; + break; + } + VisitResult::More(l, s) => { + len = l; + size = s; + } + VisitResult::Err => return Err(de::Error::invalid_length(size + 1, &self)), } } @@ -178,10 +199,20 @@ impl<'de, T: Deserialize<'de>> Deserialize<'de> for ShortVec { } /// Return the decoded value and how many bytes it consumed. -pub fn decode_len(bytes: &[u8]) -> Result<(usize, usize), Box> { - let short_len: ShortU16 = bincode::deserialize(bytes)?; - let num_bytes = bincode::serialized_size(&short_len)?; - Ok((short_len.0 as usize, num_bytes as usize)) +pub fn decode_len(bytes: &[u8]) -> Result<(usize, usize), ()> { + let mut len = 0; + let mut size = 0; + for byte in bytes.iter() { + match visit_byte(*byte, len, size) { + VisitResult::More(l, s) => { + len = l; + size = s; + } + VisitResult::Done(len, size) => return Ok((len, size)), + VisitResult::Err => return Err(()), + } + } + Err(()) } #[cfg(test)] @@ -246,4 +277,17 @@ mod tests { let s = serde_json::to_string(&vec).unwrap(); assert_eq!(s, "[[3],0,1,2]"); } + + #[test] + fn test_decode_len_aliased_values() { + let one1 = [0x01]; + let one2 = [0x81, 0x00]; + let one3 = [0x81, 0x80, 0x00]; + let one4 = [0x81, 0x80, 0x80, 0x00]; + + assert_eq!(decode_len(&one1).unwrap(), (1, 1)); + assert_eq!(decode_len(&one2).unwrap(), (1, 2)); + assert_eq!(decode_len(&one3).unwrap(), (1, 3)); + assert!(decode_len(&one4).is_err()); + } }