Skip to content

Commit

Permalink
Merge pull request #141 from Alexhuszagh/optunsafe
Browse files Browse the repository at this point in the history
Remove more local safety invariants with dedicated functions.

These changes primarily change the code from have local safety invariants with dedicated functions. This was limited in scope but still reduced the total number of unsafe statements to under 250 and almost the entirety of what's left is in dedicated structs and traits that are fully encapsulated for safety. There are 3 total unsafe functions now.
  • Loading branch information
Alexhuszagh authored Sep 15, 2024
2 parents b40a7fe + 222009e commit 4313694
Show file tree
Hide file tree
Showing 18 changed files with 682 additions and 872 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Changed

- Fixes a correctness regression

## [1.0.0] 2024-09-14

### Added
Expand Down
4 changes: 2 additions & 2 deletions lexical-parse-float/src/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
use lexical_parse_integer::algorithm;
use lexical_util::digit::char_to_valid_digit_const;
use lexical_util::format::NumberFormat;
use lexical_util::iterator::{AsBytes, BytesIter};
use lexical_util::iterator::{AsBytes, DigitsIter};
use lexical_util::step::u64_step;

use crate::float::{ExtendedFloat80, RawFloat};
Expand Down Expand Up @@ -104,7 +104,7 @@ pub fn parse_u64_digits<'a, Iter, const FORMAT: u128>(
overflowed: &mut bool,
zero: &mut bool,
) where
Iter: BytesIter<'a>,
Iter: DigitsIter<'a>,
{
let format = NumberFormat::<{ FORMAT }> {};
let radix = format.radix() as u64;
Expand Down
3 changes: 3 additions & 0 deletions lexical-parse-float/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ mod table_lemire;
mod table_radix;
mod table_small;

#[macro_use(parse_sign)]
extern crate lexical_parse_integer;

// Re-exports
#[cfg(feature = "f16")]
pub use lexical_util::bf16::bf16;
Expand Down
160 changes: 43 additions & 117 deletions lexical-parse-float/src/parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,12 @@
use lexical_parse_integer::algorithm;
#[cfg(feature = "f16")]
use lexical_util::bf16::bf16;
use lexical_util::buffer::Buffer;
use lexical_util::digit::{char_to_digit_const, char_to_valid_digit_const};
use lexical_util::error::Error;
#[cfg(feature = "f16")]
use lexical_util::f16::f16;
use lexical_util::format::NumberFormat;
use lexical_util::iterator::{AsBytes, Bytes, BytesIter};
use lexical_util::iterator::{AsBytes, Bytes, DigitsIter, Iter};
use lexical_util::result::Result;
use lexical_util::step::u64_step;

Expand Down Expand Up @@ -166,91 +165,37 @@ parse_float_as_f32! { bf16 f16 }
// code is only like 30 lines.

/// Parse the sign from the leading digits.
///
/// This routine does the following:
///
/// 1. Parses the sign digit.
/// 2. Handles if positive signs before integers are not allowed.
/// 3. Handles negative signs if the type is unsigned.
/// 4. Handles if the sign is required, but missing.
/// 5. Handles if the iterator is empty, before or after parsing the sign.
/// 6. Handles if the iterator has invalid, leading zeros.
///
/// Returns if the value is negative, or any values detected when
/// validating the input.
#[cfg_attr(not(feature = "compact"), inline(always))]
pub fn parse_mantissa_sign<const FORMAT: u128>(byte: &mut Bytes<'_, FORMAT>) -> Result<bool> {
let format = NumberFormat::<{ FORMAT }> {};

// NOTE: Using `read_if` with a predicate compiles badly and is very slow.
// Also, it's better to do the step_unchecked inside rather than get the step
// count and do `step_by_unchecked`. The compiler knows what to do better.
match byte.integer_iter().peek() {
Some(&b'+') if !format.no_positive_mantissa_sign() => {
// SAFETY: Safe, we peeked 1 byte.
unsafe { byte.step_unchecked() };
Ok(false)
},
Some(&b'+') if format.no_positive_mantissa_sign() => {
Err(Error::InvalidPositiveSign(byte.cursor()))
},
Some(&b'-') => {
// SAFETY: Safe, we peeked 1 byte.
unsafe { byte.step_unchecked() };
Ok(true)
},
Some(_) if format.required_mantissa_sign() => Err(Error::MissingSign(byte.cursor())),
_ if format.required_mantissa_sign() => Err(Error::MissingSign(byte.cursor())),
_ => Ok(false),
}
parse_sign!(
byte,
true,
format.no_positive_mantissa_sign(),
format.required_mantissa_sign(),
InvalidPositiveSign,
MissingSign
)
}

/// Parse the sign from the leading digits.
///
/// This routine does the following:
///
/// 1. Parses the sign digit.
/// 2. Handles if positive signs before integers are not allowed.
/// 3. Handles negative signs if the type is unsigned.
/// 4. Handles if the sign is required, but missing.
/// 5. Handles if the iterator is empty, before or after parsing the sign.
/// 6. Handles if the iterator has invalid, leading zeros.
///
/// Returns if the value is negative, or any values detected when
/// validating the input.
#[cfg_attr(not(feature = "compact"), inline(always))]
pub fn parse_exponent_sign<const FORMAT: u128>(byte: &mut Bytes<'_, FORMAT>) -> Result<bool> {
let format = NumberFormat::<{ FORMAT }> {};

// NOTE: Using `read_if` with a predicate compiles badly and is very slow.
// Also, it's better to do the step_unchecked inside rather than get the step
// count and do `step_by_unchecked`. The compiler knows what to do better.
match byte.integer_iter().peek() {
Some(&b'+') if !format.no_positive_exponent_sign() => {
// SAFETY: Safe, we peeked 1 byte.
unsafe { byte.step_unchecked() };
Ok(false)
},
Some(&b'+') if format.no_positive_exponent_sign() => {
Err(Error::InvalidPositiveExponentSign(byte.cursor()))
},
Some(&b'-') => {
// SAFETY: Safe, we peeked 1 byte.
unsafe { byte.step_unchecked() };
Ok(true)
},
Some(_) if format.required_exponent_sign() => {
Err(Error::MissingExponentSign(byte.cursor()))
},
_ if format.required_exponent_sign() => Err(Error::MissingExponentSign(byte.cursor())),
_ => Ok(false),
}
parse_sign!(
byte,
true,
format.no_positive_exponent_sign(),
format.required_exponent_sign(),
InvalidPositiveExponentSign,
MissingExponentSign
)
}

/// Utility to extract the result and handle any errors from parsing a `Number`.
///
/// - `format` - The numberical format as a packed integer
/// - `byte` - The BytesIter iterator
/// - `byte` - The DigitsIter iterator
/// - `is_negative` - If the final value is negative
/// - `parse_normal` - The function to parse non-special numbers with
/// - `parse_special` - The function to parse special numbers with
Expand Down Expand Up @@ -547,31 +492,20 @@ pub fn parse_partial_number<'a, const FORMAT: u128>(
// INTEGER

// Check to see if we have a valid base prefix.
// NOTE: `read_if` compiles poorly so use `peek` and then `step_unchecked`.
#[allow(unused_variables)]
let mut is_prefix = false;
#[cfg(feature = "format")]
{
let base_prefix = format.base_prefix();
let mut iter = byte.integer_iter();
if base_prefix != 0 && iter.peek() == Some(&b'0') {
// SAFETY: safe since `byte.len() >= 1`.
unsafe { iter.step_unchecked() };
if base_prefix != 0 && iter.read_if_value_cased(b'0').is_some() {
// Check to see if the next character is the base prefix.
// We must have a format like `0x`, `0d`, `0o`. Note:
if let Some(&c) = iter.peek() {
is_prefix = if format.case_sensitive_base_prefix() {
c == base_prefix
} else {
c.to_ascii_lowercase() == base_prefix.to_ascii_lowercase()
};
if is_prefix {
// SAFETY: safe since `byte.len() >= 1`.
unsafe { iter.step_unchecked() };
if iter.is_done() {
return Err(Error::Empty(iter.cursor()));
}
}
is_prefix = true;
if iter.read_if_value(base_prefix, format.case_sensitive_base_prefix()).is_some()
&& iter.is_done()
{
return Err(Error::Empty(iter.cursor()));
}
}
}
Expand Down Expand Up @@ -619,7 +553,8 @@ pub fn parse_partial_number<'a, const FORMAT: u128>(
let mut implicit_exponent: i64;
let int_end = n_digits as i64;
let mut fraction_digits = None;
if byte.first_is(decimal_point) {
// TODO: Change this to something different from read_if_value but same idea
if byte.first_is_cased(decimal_point) {
// SAFETY: byte cannot be empty due to first_is
unsafe { byte.step_unchecked() };
let before = byte.clone();
Expand Down Expand Up @@ -658,11 +593,8 @@ pub fn parse_partial_number<'a, const FORMAT: u128>(

// Handle scientific notation.
let mut explicit_exponent = 0_i64;
let is_exponent = if cfg!(feature = "format") && format.case_sensitive_exponent() {
byte.first_is(exponent_character)
} else {
byte.case_insensitive_first_is(exponent_character)
};
let is_exponent = byte
.first_is(exponent_character, format.case_sensitive_exponent() && cfg!(feature = "format"));
if is_exponent {
// Check float format syntax checks.
#[cfg(feature = "format")]
Expand All @@ -677,6 +609,7 @@ pub fn parse_partial_number<'a, const FORMAT: u128>(
}

// SAFETY: byte cannot be empty due to `first_is` from `is_exponent`.`
// TODO: Fix: we need a read_if for bytes themselves?
unsafe { byte.step_unchecked() };
let is_negative = parse_exponent_sign(&mut byte)?;

Expand Down Expand Up @@ -708,12 +641,7 @@ pub fn parse_partial_number<'a, const FORMAT: u128>(
let base_suffix = format.base_suffix();
#[cfg(feature = "format")]
if base_suffix != 0 {
let is_suffix: bool = if format.case_sensitive_base_suffix() {
byte.first_is(base_suffix)
} else {
byte.case_insensitive_first_is(base_suffix)
};
if is_suffix {
if byte.first_is(base_suffix, format.case_sensitive_base_suffix()) {
// SAFETY: safe since `byte.len() >= 1`.
unsafe { byte.step_unchecked() };
}
Expand Down Expand Up @@ -744,24 +672,20 @@ pub fn parse_partial_number<'a, const FORMAT: u128>(
}

// Check for leading zeros, and to see if we had a false overflow.
// NOTE: Once again, `read_if` is slow: do peek and step
n_digits -= step;
let mut zeros = start.clone();
let mut zeros_integer = zeros.integer_iter();
while zeros_integer.peek_is(b'0') {
while zeros_integer.read_if_value_cased(b'0').is_some() {
n_digits = n_digits.saturating_sub(1);
// SAFETY: safe since zeros cannot be empty due to peek_is
unsafe { zeros_integer.step_unchecked() };
}
if zeros.first_is(decimal_point) {
if zeros.first_is_cased(decimal_point) {
// TODO: Fix with some read_if like logic
// SAFETY: safe since zeros cannot be empty due to first_is
unsafe { zeros.step_unchecked() };
}
let mut zeros_fraction = zeros.fraction_iter();
while zeros_fraction.peek_is(b'0') {
while zeros_fraction.read_if_value_cased(b'0').is_some() {
n_digits = n_digits.saturating_sub(1);
// SAFETY: safe since zeros cannot be empty due to peek_is
unsafe { zeros_fraction.step_unchecked() };
}

// OVERFLOW
Expand Down Expand Up @@ -824,7 +748,7 @@ pub fn parse_number<'a, const FORMAT: u128>(
is_negative: bool,
options: &Options,
) -> Result<Number<'a>> {
let length = byte.length();
let length = byte.buffer_length();
let (float, count) = parse_partial_number::<FORMAT>(byte, is_negative, options)?;
if count == length {
Ok(float)
Expand All @@ -840,7 +764,7 @@ pub fn parse_number<'a, const FORMAT: u128>(
#[inline(always)]
pub fn parse_digits<'a, Iter, Cb, const FORMAT: u128>(mut iter: Iter, mut cb: Cb)
where
Iter: BytesIter<'a>,
Iter: DigitsIter<'a>,
Cb: FnMut(u32),
{
let format = NumberFormat::<{ FORMAT }> {};
Expand All @@ -851,6 +775,8 @@ where
None => break,
}
// SAFETY: iter cannot be empty due to `iter.peek()`.
// NOTE: Because of the match statement, this would optimize poorly with
// read_if.
unsafe { iter.step_unchecked() };
}
}
Expand All @@ -860,7 +786,7 @@ where
#[cfg(not(feature = "compact"))]
pub fn parse_8digits<'a, Iter, const FORMAT: u128>(mut iter: Iter, mantissa: &mut u64)
where
Iter: BytesIter<'a>,
Iter: DigitsIter<'a>,
{
let format = NumberFormat::<{ FORMAT }> {};
let radix: u64 = format.radix() as u64;
Expand All @@ -886,7 +812,7 @@ pub fn parse_u64_digits<'a, Iter, const FORMAT: u128>(
mantissa: &mut u64,
step: &mut usize,
) where
Iter: BytesIter<'a>,
Iter: DigitsIter<'a>,
{
let format = NumberFormat::<{ FORMAT }> {};
let radix = format.radix() as u64;
Expand Down Expand Up @@ -934,7 +860,7 @@ pub fn is_special_eq<const FORMAT: u128>(mut byte: Bytes<FORMAT>, string: &'stat
byte.special_iter().peek();
return byte.cursor();
}
} else if shared::case_insensitive_starts_with(byte.special_iter(), string.iter()) {
} else if shared::starts_with_uncased(byte.special_iter(), string.iter()) {
// Trim the iterator afterwards.
byte.special_iter().peek();
return byte.cursor();
Expand All @@ -957,7 +883,7 @@ where
}

let cursor = byte.cursor();
let length = byte.length() - cursor;
let length = byte.buffer_length() - cursor;
if let Some(nan_string) = options.nan_string() {
if length >= nan_string.len() {
let count = is_special_eq::<FORMAT>(byte.clone(), nan_string);
Expand Down Expand Up @@ -1013,7 +939,7 @@ pub fn parse_special<F, const FORMAT: u128>(
where
F: LemireFloat,
{
let length = byte.length();
let length = byte.buffer_length();
if let Some((float, count)) = parse_partial_special::<F, FORMAT>(byte, is_negative, options) {
if count == length {
return Some(float);
Expand Down
4 changes: 2 additions & 2 deletions lexical-parse-float/src/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ where
/// This optimizes decently well, to the following ASM for pure slices:
///
/// ```text
/// case_insensitive_starts_with_slc:
/// starts_with_uncased:
/// xor eax, eax
/// .LBB1_1:
/// cmp rcx, rax
Expand All @@ -134,7 +134,7 @@ where
/// ret
/// ```
#[cfg_attr(not(feature = "compact"), inline(always))]
pub fn case_insensitive_starts_with<'a, 'b, Iter1, Iter2>(mut x: Iter1, mut y: Iter2) -> bool
pub fn starts_with_uncased<'a, 'b, Iter1, Iter2>(mut x: Iter1, mut y: Iter2) -> bool
where
Iter1: Iterator<Item = &'a u8>,
Iter2: Iterator<Item = &'b u8>,
Expand Down
7 changes: 3 additions & 4 deletions lexical-parse-float/src/slow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,11 @@ use core::cmp;

#[cfg(not(feature = "compact"))]
use lexical_parse_integer::algorithm;
use lexical_util::buffer::Buffer;
use lexical_util::digit::char_to_valid_digit_const;
#[cfg(feature = "radix")]
use lexical_util::digit::digit_to_char_const;
use lexical_util::format::NumberFormat;
use lexical_util::iterator::{AsBytes, BytesIter};
use lexical_util::iterator::{AsBytes, DigitsIter, Iter};
use lexical_util::num::{AsPrimitive, Integer};

#[cfg(feature = "radix")]
Expand Down Expand Up @@ -358,12 +357,12 @@ macro_rules! round_up_truncated {
/// - `count` - The total number of parsed digits
macro_rules! round_up_nonzero {
($format:ident, $iter:expr, $result:ident, $count:ident) => {{
// NOTE: All digits must be valid.
// NOTE: All digits must already be valid.
let mut iter = $iter;

// First try reading 8-digits at a time.
if iter.is_contiguous() {
while let Some(value) = iter.read_u64() {
while let Some(value) = iter.peek_u64() {
// SAFETY: safe since we have at least 8 bytes in the buffer.
unsafe { iter.step_by_unchecked(8) };
if value != 0x3030_3030_3030_3030 {
Expand Down
Loading

0 comments on commit 4313694

Please sign in to comment.