From 1f67ba61a985c621cb10ae091aaa5c2b009aa721 Mon Sep 17 00:00:00 2001 From: Ben Kimock Date: Thu, 20 Apr 2023 23:11:47 -0400 Subject: [PATCH] Rewrite MemDecoder around pointers not a slice --- compiler/rustc_metadata/src/rmeta/decoder.rs | 25 +-- .../src/rmeta/def_path_hash_map.rs | 3 - compiler/rustc_middle/src/ty/codec.rs | 10 ++ .../rustc_query_impl/src/on_disk_cache.rs | 56 ++----- .../src/dep_graph/serialized.rs | 16 +- compiler/rustc_serialize/src/leb128.rs | 16 +- compiler/rustc_serialize/src/lib.rs | 1 + compiler/rustc_serialize/src/opaque.rs | 155 ++++++++++++++---- compiler/rustc_serialize/src/serialize.rs | 2 + compiler/rustc_serialize/tests/leb128.rs | 13 +- compiler/rustc_type_ir/src/codec.rs | 7 +- 11 files changed, 174 insertions(+), 130 deletions(-) diff --git a/compiler/rustc_metadata/src/rmeta/decoder.rs b/compiler/rustc_metadata/src/rmeta/decoder.rs index 951fb303e3cf9..f24ff41d63b5f 100644 --- a/compiler/rustc_metadata/src/rmeta/decoder.rs +++ b/compiler/rustc_metadata/src/rmeta/decoder.rs @@ -373,16 +373,6 @@ impl<'a, 'tcx> TyDecoder for DecodeContext<'a, 'tcx> { self.tcx() } - #[inline] - fn peek_byte(&self) -> u8 { - self.opaque.data[self.opaque.position()] - } - - #[inline] - fn position(&self) -> usize { - self.opaque.position() - } - fn cached_ty_for_shorthand(&mut self, shorthand: usize, or_insert_with: F) -> Ty<'tcx> where F: FnOnce(&mut Self) -> Ty<'tcx>, @@ -404,7 +394,7 @@ impl<'a, 'tcx> TyDecoder for DecodeContext<'a, 'tcx> { where F: FnOnce(&mut Self) -> R, { - let new_opaque = MemDecoder::new(self.opaque.data, pos); + let new_opaque = MemDecoder::new(self.opaque.data(), pos); let old_opaque = mem::replace(&mut self.opaque, new_opaque); let old_state = mem::replace(&mut self.lazy_state, LazyState::NoNode); let r = f(self); @@ -625,17 +615,12 @@ impl<'a, 'tcx> Decodable> for Symbol { SYMBOL_OFFSET => { // read str offset let pos = d.read_usize(); - let old_pos = d.opaque.position(); // move to str offset and read - d.opaque.set_position(pos); - let s = d.read_str(); - let sym = Symbol::intern(s); - - // restore position - d.opaque.set_position(old_pos); - - sym + d.opaque.with_position(pos, |d| { + let s = d.read_str(); + Symbol::intern(s) + }) } SYMBOL_PREINTERNED => { let symbol_index = d.read_u32(); diff --git a/compiler/rustc_metadata/src/rmeta/def_path_hash_map.rs b/compiler/rustc_metadata/src/rmeta/def_path_hash_map.rs index 02cab561b8f61..05402a58701f1 100644 --- a/compiler/rustc_metadata/src/rmeta/def_path_hash_map.rs +++ b/compiler/rustc_metadata/src/rmeta/def_path_hash_map.rs @@ -45,9 +45,6 @@ impl<'a, 'tcx> Encodable> for DefPathHashMapRef<'tcx> { impl<'a, 'tcx> Decodable> for DefPathHashMapRef<'static> { fn decode(d: &mut DecodeContext<'a, 'tcx>) -> DefPathHashMapRef<'static> { - // Import TyDecoder so we can access the DecodeContext::position() method - use crate::rustc_middle::ty::codec::TyDecoder; - let len = d.read_usize(); let pos = d.position(); let o = slice_owned(d.blob().clone(), |blob| &blob[pos..pos + len]); diff --git a/compiler/rustc_middle/src/ty/codec.rs b/compiler/rustc_middle/src/ty/codec.rs index 8ef4a46a733aa..26dd0fca349fd 100644 --- a/compiler/rustc_middle/src/ty/codec.rs +++ b/compiler/rustc_middle/src/ty/codec.rs @@ -519,6 +519,16 @@ macro_rules! implement_ty_decoder { fn read_raw_bytes(&mut self, len: usize) -> &[u8] { self.opaque.read_raw_bytes(len) } + + #[inline] + fn position(&self) -> usize { + self.opaque.position() + } + + #[inline] + fn peek_byte(&self) -> u8 { + self.opaque.peek_byte() + } } } } diff --git a/compiler/rustc_query_impl/src/on_disk_cache.rs b/compiler/rustc_query_impl/src/on_disk_cache.rs index 40869fdc467dd..d7ac8a0ad058a 100644 --- a/compiler/rustc_query_impl/src/on_disk_cache.rs +++ b/compiler/rustc_query_impl/src/on_disk_cache.rs @@ -169,13 +169,12 @@ impl<'sess> rustc_middle::ty::OnDiskCache<'sess> for OnDiskCache<'sess> { // Decode the *position* of the footer, which can be found in the // last 8 bytes of the file. - decoder.set_position(data.len() - IntEncodedWithFixedSize::ENCODED_SIZE); - let footer_pos = IntEncodedWithFixedSize::decode(&mut decoder).0 as usize; - + let footer_pos = decoder + .with_position(decoder.len() - IntEncodedWithFixedSize::ENCODED_SIZE, |decoder| { + IntEncodedWithFixedSize::decode(decoder).0 as usize + }); // Decode the file footer, which contains all the lookup tables, etc. - decoder.set_position(footer_pos); - - decode_tagged(&mut decoder, TAG_FILE_FOOTER) + decoder.with_position(footer_pos, |decoder| decode_tagged(decoder, TAG_FILE_FOOTER)) }; Self { @@ -522,29 +521,13 @@ impl<'a, 'tcx> CacheDecoder<'a, 'tcx> { } } -trait DecoderWithPosition: Decoder { - fn position(&self) -> usize; -} - -impl<'a> DecoderWithPosition for MemDecoder<'a> { - fn position(&self) -> usize { - self.position() - } -} - -impl<'a, 'tcx> DecoderWithPosition for CacheDecoder<'a, 'tcx> { - fn position(&self) -> usize { - self.opaque.position() - } -} - // Decodes something that was encoded with `encode_tagged()` and verify that the // tag matches and the correct amount of bytes was read. fn decode_tagged(decoder: &mut D, expected_tag: T) -> V where T: Decodable + Eq + std::fmt::Debug, V: Decodable, - D: DecoderWithPosition, + D: Decoder, { let start_pos = decoder.position(); @@ -568,16 +551,6 @@ impl<'a, 'tcx> TyDecoder for CacheDecoder<'a, 'tcx> { self.tcx } - #[inline] - fn position(&self) -> usize { - self.opaque.position() - } - - #[inline] - fn peek_byte(&self) -> u8 { - self.opaque.data[self.opaque.position()] - } - fn cached_ty_for_shorthand(&mut self, shorthand: usize, or_insert_with: F) -> Ty<'tcx> where F: FnOnce(&mut Self) -> Ty<'tcx>, @@ -600,9 +573,9 @@ impl<'a, 'tcx> TyDecoder for CacheDecoder<'a, 'tcx> { where F: FnOnce(&mut Self) -> R, { - debug_assert!(pos < self.opaque.data.len()); + debug_assert!(pos < self.opaque.len()); - let new_opaque = MemDecoder::new(self.opaque.data, pos); + let new_opaque = MemDecoder::new(self.opaque.data(), pos); let old_opaque = mem::replace(&mut self.opaque, new_opaque); let r = f(self); self.opaque = old_opaque; @@ -743,17 +716,12 @@ impl<'a, 'tcx> Decodable> for Symbol { SYMBOL_OFFSET => { // read str offset let pos = d.read_usize(); - let old_pos = d.opaque.position(); // move to str offset and read - d.opaque.set_position(pos); - let s = d.read_str(); - let sym = Symbol::intern(s); - - // restore position - d.opaque.set_position(old_pos); - - sym + d.opaque.with_position(pos, |d| { + let s = d.read_str(); + Symbol::intern(s) + }) } SYMBOL_PREINTERNED => { let symbol_index = d.read_u32(); diff --git a/compiler/rustc_query_system/src/dep_graph/serialized.rs b/compiler/rustc_query_system/src/dep_graph/serialized.rs index 3d19a84915aec..cbf89044915f8 100644 --- a/compiler/rustc_query_system/src/dep_graph/serialized.rs +++ b/compiler/rustc_query_system/src/dep_graph/serialized.rs @@ -94,21 +94,19 @@ impl<'a, K: DepKind + Decodable>> Decodable> { #[instrument(level = "debug", skip(d))] fn decode(d: &mut MemDecoder<'a>) -> SerializedDepGraph { - let start_position = d.position(); - // The last 16 bytes are the node count and edge count. debug!("position: {:?}", d.position()); - d.set_position(d.data.len() - 2 * IntEncodedWithFixedSize::ENCODED_SIZE); + let (node_count, edge_count) = + d.with_position(d.len() - 2 * IntEncodedWithFixedSize::ENCODED_SIZE, |d| { + debug!("position: {:?}", d.position()); + let node_count = IntEncodedWithFixedSize::decode(d).0 as usize; + let edge_count = IntEncodedWithFixedSize::decode(d).0 as usize; + (node_count, edge_count) + }); debug!("position: {:?}", d.position()); - let node_count = IntEncodedWithFixedSize::decode(d).0 as usize; - let edge_count = IntEncodedWithFixedSize::decode(d).0 as usize; debug!(?node_count, ?edge_count); - debug!("position: {:?}", d.position()); - d.set_position(start_position); - debug!("position: {:?}", d.position()); - let mut nodes = IndexVec::with_capacity(node_count); let mut fingerprints = IndexVec::with_capacity(node_count); let mut edge_list_indices = IndexVec::with_capacity(node_count); diff --git a/compiler/rustc_serialize/src/leb128.rs b/compiler/rustc_serialize/src/leb128.rs index 7dad9aa01fafd..e568b9e6786f9 100644 --- a/compiler/rustc_serialize/src/leb128.rs +++ b/compiler/rustc_serialize/src/leb128.rs @@ -1,3 +1,6 @@ +use crate::opaque::MemDecoder; +use crate::serialize::Decoder; + /// Returns the length of the longest LEB128 encoding for `T`, assuming `T` is an integer type pub const fn max_leb128_len() -> usize { // The longest LEB128 encoding for an integer uses 7 bits per byte. @@ -50,21 +53,19 @@ impl_write_unsigned_leb128!(write_usize_leb128, usize); macro_rules! impl_read_unsigned_leb128 { ($fn_name:ident, $int_ty:ty) => { #[inline] - pub fn $fn_name(slice: &[u8], position: &mut usize) -> $int_ty { + pub fn $fn_name(decoder: &mut MemDecoder<'_>) -> $int_ty { // The first iteration of this loop is unpeeled. This is a // performance win because this code is hot and integer values less // than 128 are very common, typically occurring 50-80% or more of // the time, even for u64 and u128. - let byte = slice[*position]; - *position += 1; + let byte = decoder.read_u8(); if (byte & 0x80) == 0 { return byte as $int_ty; } let mut result = (byte & 0x7F) as $int_ty; let mut shift = 7; loop { - let byte = slice[*position]; - *position += 1; + let byte = decoder.read_u8(); if (byte & 0x80) == 0 { result |= (byte as $int_ty) << shift; return result; @@ -127,14 +128,13 @@ impl_write_signed_leb128!(write_isize_leb128, isize); macro_rules! impl_read_signed_leb128 { ($fn_name:ident, $int_ty:ty) => { #[inline] - pub fn $fn_name(slice: &[u8], position: &mut usize) -> $int_ty { + pub fn $fn_name(decoder: &mut MemDecoder<'_>) -> $int_ty { let mut result = 0; let mut shift = 0; let mut byte; loop { - byte = slice[*position]; - *position += 1; + byte = decoder.read_u8(); result |= <$int_ty>::from(byte & 0x7F) << shift; shift += 7; diff --git a/compiler/rustc_serialize/src/lib.rs b/compiler/rustc_serialize/src/lib.rs index 1f8d2336c4e58..ce8503918b4f2 100644 --- a/compiler/rustc_serialize/src/lib.rs +++ b/compiler/rustc_serialize/src/lib.rs @@ -16,6 +16,7 @@ Core encoding and decoding interfaces. #![feature(maybe_uninit_slice)] #![feature(new_uninit)] #![feature(allocator_api)] +#![feature(ptr_sub_ptr)] #![cfg_attr(test, feature(test))] #![allow(rustc::internal)] #![deny(rustc::untranslatable_diagnostic)] diff --git a/compiler/rustc_serialize/src/opaque.rs b/compiler/rustc_serialize/src/opaque.rs index 53e5c89673652..b7976ea3b1c63 100644 --- a/compiler/rustc_serialize/src/opaque.rs +++ b/compiler/rustc_serialize/src/opaque.rs @@ -2,7 +2,9 @@ use crate::leb128::{self, largest_max_leb128_len}; use crate::serialize::{Decodable, Decoder, Encodable, Encoder}; use std::fs::File; use std::io::{self, Write}; +use std::marker::PhantomData; use std::mem::MaybeUninit; +use std::ops::Range; use std::path::Path; use std::ptr; @@ -510,38 +512,125 @@ impl Encoder for FileEncoder { // Decoder // ----------------------------------------------------------------------------- +// Conceptually, `MemDecoder` wraps a `&[u8]` with a cursor into it that is always valid. +// This is implemented with three pointers, two which represent the original slice and a +// third that is our cursor. +// It is an invariant of this type that start <= current <= end. +// Additionally, the implementation of this type never modifies start and end. pub struct MemDecoder<'a> { - pub data: &'a [u8], - position: usize, + start: *const u8, + current: *const u8, + end: *const u8, + _marker: PhantomData<&'a u8>, } impl<'a> MemDecoder<'a> { #[inline] pub fn new(data: &'a [u8], position: usize) -> MemDecoder<'a> { - MemDecoder { data, position } + let Range { start, end } = data.as_ptr_range(); + MemDecoder { start, current: data[position..].as_ptr(), end, _marker: PhantomData } } #[inline] - pub fn position(&self) -> usize { - self.position + pub fn data(&self) -> &'a [u8] { + // SAFETY: This recovers the original slice, only using members we never modify. + unsafe { std::slice::from_raw_parts(self.start, self.len()) } } #[inline] - pub fn set_position(&mut self, pos: usize) { - self.position = pos + pub fn len(&self) -> usize { + // SAFETY: This recovers the length of the original slice, only using members we never modify. + unsafe { self.end.sub_ptr(self.start) } + } + + #[inline] + pub fn remaining(&self) -> usize { + // SAFETY: This type guarantees current <= end. + unsafe { self.end.sub_ptr(self.current) } + } + + #[cold] + #[inline(never)] + fn decoder_exhausted() -> ! { + panic!("MemDecoder exhausted") } #[inline] - pub fn advance(&mut self, bytes: usize) { - self.position += bytes; + fn read_byte(&mut self) -> u8 { + if self.current == self.end { + Self::decoder_exhausted(); + } + // SAFETY: This type guarantees current <= end, and we just checked current == end. + unsafe { + let byte = *self.current; + self.current = self.current.add(1); + byte + } + } + + #[inline] + fn read_array(&mut self) -> [u8; N] { + self.read_raw_bytes(N).try_into().unwrap() + } + + // The trait method doesn't have a lifetime parameter, and we need a version of this + // that definitely returns a slice based on the underlying storage as opposed to + // the Decoder itself in order to implement read_str efficiently. + #[inline] + fn read_raw_bytes_inherent(&mut self, bytes: usize) -> &'a [u8] { + if bytes > self.remaining() { + Self::decoder_exhausted(); + } + // SAFETY: We just checked if this range is in-bounds above. + unsafe { + let slice = std::slice::from_raw_parts(self.current, bytes); + self.current = self.current.add(bytes); + slice + } + } + + /// While we could manually expose manipulation of the decoder position, + /// all current users of that method would need to reset the position later, + /// incurring the bounds check of set_position twice. + #[inline] + pub fn with_position(&mut self, pos: usize, func: F) -> T + where + F: Fn(&mut MemDecoder<'a>) -> T, + { + struct SetOnDrop<'a, 'guarded> { + decoder: &'guarded mut MemDecoder<'a>, + current: *const u8, + } + impl Drop for SetOnDrop<'_, '_> { + fn drop(&mut self) { + self.decoder.current = self.current; + } + } + + if pos >= self.len() { + Self::decoder_exhausted(); + } + let previous = self.current; + // SAFETY: We just checked if this add is in-bounds above. + unsafe { + self.current = self.start.add(pos); + } + let guard = SetOnDrop { current: previous, decoder: self }; + func(guard.decoder) } } macro_rules! read_leb128 { - ($dec:expr, $fun:ident) => {{ leb128::$fun($dec.data, &mut $dec.position) }}; + ($dec:expr, $fun:ident) => {{ leb128::$fun($dec) }}; } impl<'a> Decoder for MemDecoder<'a> { + #[inline] + fn position(&self) -> usize { + // SAFETY: This type guarantees start <= current + unsafe { self.current.sub_ptr(self.start) } + } + #[inline] fn read_u128(&mut self) -> u128 { read_leb128!(self, read_u128_leb128) @@ -559,17 +648,12 @@ impl<'a> Decoder for MemDecoder<'a> { #[inline] fn read_u16(&mut self) -> u16 { - let bytes = [self.data[self.position], self.data[self.position + 1]]; - let value = u16::from_le_bytes(bytes); - self.position += 2; - value + u16::from_le_bytes(self.read_array()) } #[inline] fn read_u8(&mut self) -> u8 { - let value = self.data[self.position]; - self.position += 1; - value + self.read_byte() } #[inline] @@ -594,17 +678,12 @@ impl<'a> Decoder for MemDecoder<'a> { #[inline] fn read_i16(&mut self) -> i16 { - let bytes = [self.data[self.position], self.data[self.position + 1]]; - let value = i16::from_le_bytes(bytes); - self.position += 2; - value + i16::from_le_bytes(self.read_array()) } #[inline] fn read_i8(&mut self) -> i8 { - let value = self.data[self.position]; - self.position += 1; - value as i8 + self.read_byte() as i8 } #[inline] @@ -625,22 +704,26 @@ impl<'a> Decoder for MemDecoder<'a> { } #[inline] - fn read_str(&mut self) -> &'a str { + fn read_str(&mut self) -> &str { let len = self.read_usize(); - let sentinel = self.data[self.position + len]; - assert!(sentinel == STR_SENTINEL); - let s = unsafe { - std::str::from_utf8_unchecked(&self.data[self.position..self.position + len]) - }; - self.position += len + 1; - s + let bytes = self.read_raw_bytes_inherent(len + 1); + assert!(bytes[len] == STR_SENTINEL); + unsafe { std::str::from_utf8_unchecked(&bytes[..len]) } } #[inline] - fn read_raw_bytes(&mut self, bytes: usize) -> &'a [u8] { - let start = self.position; - self.position += bytes; - &self.data[start..self.position] + fn read_raw_bytes(&mut self, bytes: usize) -> &[u8] { + self.read_raw_bytes_inherent(bytes) + } + + #[inline] + fn peek_byte(&self) -> u8 { + if self.current == self.end { + Self::decoder_exhausted(); + } + // SAFETY: This type guarantees current is inbounds or one-past-the-end, which is end. + // Since we just checked current == end, the current pointer must be inbounds. + unsafe { *self.current } } } diff --git a/compiler/rustc_serialize/src/serialize.rs b/compiler/rustc_serialize/src/serialize.rs index 527abc2372715..a6d9c7b7d4210 100644 --- a/compiler/rustc_serialize/src/serialize.rs +++ b/compiler/rustc_serialize/src/serialize.rs @@ -84,6 +84,8 @@ pub trait Decoder { fn read_char(&mut self) -> char; fn read_str(&mut self) -> &str; fn read_raw_bytes(&mut self, len: usize) -> &[u8]; + fn peek_byte(&self) -> u8; + fn position(&self) -> usize; } /// Trait for types that can be serialized diff --git a/compiler/rustc_serialize/tests/leb128.rs b/compiler/rustc_serialize/tests/leb128.rs index 314c07db981da..7872e7784311a 100644 --- a/compiler/rustc_serialize/tests/leb128.rs +++ b/compiler/rustc_serialize/tests/leb128.rs @@ -3,6 +3,7 @@ use rustc_serialize::leb128::*; use std::mem::MaybeUninit; +use rustc_serialize::Decoder; macro_rules! impl_test_unsigned_leb128 { ($test_name:ident, $write_fn_name:ident, $read_fn_name:ident, $int_ty:ident) => { @@ -28,12 +29,12 @@ macro_rules! impl_test_unsigned_leb128 { stream.extend($write_fn_name(&mut buf, x)); } - let mut position = 0; + let mut decoder = rustc_serialize::opaque::MemDecoder::new(&stream, 0); for &expected in &values { - let actual = $read_fn_name(&stream, &mut position); + let actual = $read_fn_name(&mut decoder); assert_eq!(expected, actual); } - assert_eq!(stream.len(), position); + assert_eq!(stream.len(), decoder.position()); } }; } @@ -74,12 +75,12 @@ macro_rules! impl_test_signed_leb128 { stream.extend($write_fn_name(&mut buf, x)); } - let mut position = 0; + let mut decoder = rustc_serialize::opaque::MemDecoder::new(&stream, 0); for &expected in &values { - let actual = $read_fn_name(&stream, &mut position); + let actual = $read_fn_name(&mut decoder); assert_eq!(expected, actual); } - assert_eq!(stream.len(), position); + assert_eq!(stream.len(), decoder.position()); } }; } diff --git a/compiler/rustc_type_ir/src/codec.rs b/compiler/rustc_type_ir/src/codec.rs index ee249050cc64e..3b638934629b5 100644 --- a/compiler/rustc_type_ir/src/codec.rs +++ b/compiler/rustc_type_ir/src/codec.rs @@ -27,10 +27,13 @@ pub trait TyEncoder: Encoder { const CLEAR_CROSS_CRATE: bool; fn position(&self) -> usize; + fn type_shorthands(&mut self) -> &mut FxHashMap<::Ty, usize>; + fn predicate_shorthands( &mut self, ) -> &mut FxHashMap<::PredicateKind, usize>; + fn encode_alloc_id(&mut self, alloc_id: &::AllocId); } @@ -40,10 +43,6 @@ pub trait TyDecoder: Decoder { fn interner(&self) -> Self::I; - fn peek_byte(&self) -> u8; - - fn position(&self) -> usize; - fn cached_ty_for_shorthand( &mut self, shorthand: usize,