Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed worst-case miri performance with lossy string decoding #98592

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 88 additions & 45 deletions library/core/src/str/lossy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,19 +51,43 @@ impl<'a> Iterator for Utf8LossyChunksIter<'a> {
}

const TAG_CONT_U8: u8 = 128;
fn safe_get(xs: &[u8], i: usize) -> u8 {
*xs.get(i).unwrap_or(&0)

/// Gets the byte at `current`, returning zero if `current` is past `end`
///
/// # Safety
///
/// `current` must be a valid pointer to an initialized byte
///
unsafe fn try_get(current: *const u8, end: *const u8) -> u8 {
// SAFETY: current is a valid pointer
if current < end { unsafe { *current } } else { 0 }
}

/// Checks if the byte at `current` bitand 192 is `TAG_CONT_U8`, returning
/// false if not or if `current` is past `end`
///
/// # Safety
///
/// `current` must be a valid pointer to an initalized byte
///
unsafe fn shouldnt_continue(current: *const u8, end: *const u8) -> bool {
// SAFETY: current is a valid pointer
unsafe { current < end && *current & 192 != TAG_CONT_U8 }
}

let mut i = 0;
let mut valid_up_to = 0;
while i < self.source.len() {
// SAFETY: `i < self.source.len()` per previous line.
let length = self.source.len();
let mut current = self.source.as_ptr();
// SAFETY: current + length is one past the end of the allocation
let (start, end, mut valid_up_to) = unsafe { (current, current.add(length), current) };

while current < end {
// SAFETY: `current < end` per previous line.
// For some reason the following are both significantly slower:
// while let Some(&byte) = self.source.get(i) {
// while let Some(byte) = self.source.get(i).copied() {
let byte = unsafe { *self.source.get_unchecked(i) };
Comment on lines -60 to -65
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR as-written adds a a lot of unsafe. I do not think that we should do that without sufficient justification.

How much of a perf impact do we see on Miri backtrace printing and the standard library benchmarks from only adjusting these lines? The specific pattern we want to avoid is this: https://github.com/rust-lang/miri/blob/8fdb720329d7674a878a8252fe4b79ef93d6ffec/bench-cargo-miri/slice-get-unchecked/src/main.rs#L8-L9

    while i < x.len() {
        let _element = unsafe { *x.get_unchecked(i) };

I'm not completely opposed the pervasive sort of changes that you've implemented in the rest of this function, but we need benchmarking that makes the case for those changes, as opposed to just tweaking the ASCII fast path.

i += 1;
let byte = unsafe { *current };
// SAFETY: This will be at most one past the end of the slice (and then equal to end)
current = unsafe { current.add(1) };

if byte < 128 {
// This could be a `1 => ...` case in the match below, but for
Expand All @@ -72,51 +96,70 @@ impl<'a> Iterator for Utf8LossyChunksIter<'a> {
} else {
let w = utf8_char_width(byte);

match w {
2 => {
if safe_get(self.source, i) & 192 != TAG_CONT_U8 {
break;
}
i += 1;
}
3 => {
match (byte, safe_get(self.source, i)) {
(0xE0, 0xA0..=0xBF) => (),
(0xE1..=0xEC, 0x80..=0xBF) => (),
(0xED, 0x80..=0x9F) => (),
(0xEE..=0xEF, 0x80..=0xBF) => (),
_ => break,
}
i += 1;
if safe_get(self.source, i) & 192 != TAG_CONT_U8 {
break;
}
i += 1;
}
4 => {
match (byte, safe_get(self.source, i)) {
(0xF0, 0x90..=0xBF) => (),
(0xF1..=0xF3, 0x80..=0xBF) => (),
(0xF4, 0x80..=0x8F) => (),
_ => break,
// SAFETY: All pointers will be at most one past the end of
// the slice (and then equal to end)
unsafe {
match w {
2 => {
if shouldnt_continue(current, end) {
break;
}

current = current.add(1);
}
i += 1;
if safe_get(self.source, i) & 192 != TAG_CONT_U8 {
break;

3 => {
match (byte, try_get(current, end)) {
(0xE0, 0xA0..=0xBF)
| (0xE1..=0xEC, 0x80..=0xBF)
| (0xED, 0x80..=0x9F)
| (0xEE..=0xEF, 0x80..=0xBF) => {}
_ => break,
}
current = current.add(1);

if shouldnt_continue(current, end) {
break;
}
current = current.add(1);
}
i += 1;
if safe_get(self.source, i) & 192 != TAG_CONT_U8 {
break;

4 => {
match (byte, try_get(current, end)) {
(0xF0, 0x90..=0xBF)
| (0xF1..=0xF3, 0x80..=0xBF)
| (0xF4, 0x80..=0x8F) => {}
_ => break,
}
current = current.add(1);

if shouldnt_continue(current, end) {
break;
}
current = current.add(1);

if shouldnt_continue(current, end) {
break;
}
current = current.add(1);
}
i += 1;

_ => break,
}
_ => break,
}
}

valid_up_to = i;
valid_up_to = current;
}

// SAFETY: Both pointers come from the same allocation
let idx = unsafe { current.offset_from(start) as usize };
debug_assert!(idx <= length);

// SAFETY: Both pointers come from the same allocation
let valid_up_to = unsafe { valid_up_to.offset_from(start) as usize };
debug_assert!(valid_up_to <= length);

// SAFETY: `i <= self.source.len()` because it is only ever incremented
// via `i += 1` and in between every single one of those increments, `i`
// is compared against `self.source.len()`. That happens either
Expand All @@ -125,7 +168,7 @@ impl<'a> Iterator for Utf8LossyChunksIter<'a> {
// loop is terminated as soon as the latest `i += 1` has made `i` no
// longer less than `self.source.len()`, which means it'll be at most
// equal to `self.source.len()`.
let (inspected, remaining) = unsafe { self.source.split_at_unchecked(i) };
let (inspected, remaining) = unsafe { self.source.split_at_unchecked(idx) };
self.source = remaining;

// SAFETY: `valid_up_to <= i` because it is only ever assigned via
Expand Down