Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: architecture refactoring for greater similarity #1

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 13 additions & 26 deletions src/avx2/deser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@ use std::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;

use std::mem;

pub use crate::error::{Error, ErrorType};
pub use crate::Deserializer;
pub use crate::Result;
pub use crate::avx2::stage1::*;
pub use crate::avx2::utf8check::*;
pub use crate::stringparse::*;

Expand All @@ -27,7 +26,7 @@ impl<'de> Deserializer<'de> {
let mut src_i: usize = 0;
let mut len = src_i;
loop {
let v: __m256i = if src.len() >= src_i + 32 {
let srcx: __m256i = if src.len() >= src_i + 32 {
// This is safe since we ensure src is at least 32 wide
#[allow(clippy::cast_ptr_alignment)]
unsafe {
Expand All @@ -44,16 +43,8 @@ impl<'de> Deserializer<'de> {
}
};

// store to dest unconditionally - we can overwrite the bits we don't like
// later
let bs_bits: u32 = unsafe {
static_cast_u32!(_mm256_movemask_epi8(_mm256_cmpeq_epi8(
v,
_mm256_set1_epi8(b'\\' as i8)
)))
};
let quote_mask = unsafe { _mm256_cmpeq_epi8(v, _mm256_set1_epi8(b'"' as i8)) };
let quote_bits = unsafe { static_cast_u32!(_mm256_movemask_epi8(quote_mask)) };
let ParseStringHelper { bs_bits, quote_bits } = find_bs_bits_and_quote_bits(srcx);

if (bs_bits.wrapping_sub(1) & quote_bits) != 0 {
// we encountered quotes first. Move dst to point to quotes and exit
// find out where the quote is...
Expand Down Expand Up @@ -94,7 +85,7 @@ impl<'de> Deserializer<'de> {
let dst: &mut [u8] = &mut self.strings;

loop {
let v: __m256i = if src.len() >= src_i + 32 {
let srcx: __m256i = if src.len() >= src_i + 32 {
// This is safe since we ensure src is at least 32 wide
#[allow(clippy::cast_ptr_alignment)]
unsafe {
Expand All @@ -113,19 +104,13 @@ impl<'de> Deserializer<'de> {

#[allow(clippy::cast_ptr_alignment)]
unsafe {
_mm256_storeu_si256(dst.as_mut_ptr().add(dst_i) as *mut __m256i, v)
_mm256_storeu_si256(dst.as_mut_ptr().add(dst_i) as *mut __m256i, srcx)
};

// store to dest unconditionally - we can overwrite the bits we don't like
// later
let bs_bits: u32 = unsafe {
static_cast_u32!(_mm256_movemask_epi8(_mm256_cmpeq_epi8(
v,
_mm256_set1_epi8(b'\\' as i8)
)))
};
let quote_mask = unsafe { _mm256_cmpeq_epi8(v, _mm256_set1_epi8(b'"' as i8)) };
let quote_bits = unsafe { static_cast_u32!(_mm256_movemask_epi8(quote_mask)) };
let ParseStringHelper { bs_bits, quote_bits } = find_bs_bits_and_quote_bits(srcx);

if (bs_bits.wrapping_sub(1) & quote_bits) != 0 {
// we encountered quotes first. Move dst to point to quotes and exit
// find out where the quote is...
Expand Down Expand Up @@ -164,9 +149,11 @@ impl<'de> Deserializer<'de> {
src_i += bs_dist as usize;
dst_i += bs_dist as usize;
let (o, s) = if let Ok(r) =
handle_unicode_codepoint(unsafe { src.get_unchecked(src_i..) }, unsafe {
dst.get_unchecked_mut(dst_i..)
}) {
handle_unicode_codepoint(
unsafe { src.get_unchecked(src_i..) },
unsafe { dst.get_unchecked_mut(dst_i..) },
)
{
r
} else {
return Err(self.error(ErrorType::InvlaidUnicodeCodepoint));
Expand Down
172 changes: 109 additions & 63 deletions src/avx2/stage1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,45 +10,86 @@ use std::mem;

pub const SIMDJSON_PADDING: usize = mem::size_of::<__m256i>();

unsafe fn compute_quote_mask(quote_bits: u64) -> u64 {
_mm_cvtsi128_si64(
_mm_clmulepi64_si128(
_mm_set_epi64x(0, static_cast_i64!(quote_bits)),
_mm_set1_epi8(-1 /* 0xFF */),
0,
)
) as u64
}

#[cfg_attr(not(feature = "no-inline"), inline(always))]
unsafe fn check_ascii(input: &SimdInput) -> bool {
let highbit: __m256i = _mm256_set1_epi8(static_cast_i8!(0x80u8));
let test_v0v1 = _mm256_testz_si256(_mm256_or_si256(input.v0, input.v1), highbit);

test_v0v1 == 1
}

#[derive(Debug)]
struct SimdInput {
lo: __m256i,
hi: __m256i,
v0: __m256i,
v1: __m256i,
}

fn fill_input(ptr: &[u8]) -> SimdInput {
unsafe {
#[allow(clippy::cast_ptr_alignment)]
SimdInput {
lo: _mm256_loadu_si256(ptr.as_ptr() as *const __m256i),
hi: _mm256_loadu_si256(ptr.as_ptr().add(32) as *const __m256i),
v0: _mm256_loadu_si256(ptr.as_ptr() as *const __m256i),
v1: _mm256_loadu_si256(ptr.as_ptr().add(32) as *const __m256i),
}
}
}

struct Utf8CheckingState {
has_error: __m256i,
previous: ProcessedUtfBytes,
}

impl Default for Utf8CheckingState {
#[cfg_attr(not(feature = "no-inline"), inline)]
fn default() -> Self {
Utf8CheckingState {
has_error: unsafe { _mm256_setzero_si256() },
previous: ProcessedUtfBytes::default(),
}
}
}

#[inline]
fn is_utf8_status_ok(has_error: __m256i) -> bool {
unsafe {
_mm256_testz_si256(has_error, has_error) != 0
}
}

#[cfg_attr(not(feature = "no-inline"), inline(always))]
unsafe fn check_utf8(
input: &SimdInput,
has_error: &mut __m256i,
previous: &mut AvxProcessedUtfBytes,
state: &mut Utf8CheckingState,
) {
let highbit: __m256i = _mm256_set1_epi8(static_cast_i8!(0x80u8));
if (_mm256_testz_si256(_mm256_or_si256(input.lo, input.hi), highbit)) == 1 {
// it is ascii, we just check continuation
*has_error = _mm256_or_si256(
if check_ascii(input) {
// All bytes are ascii. Therefore the byte that was just before must be
// ascii too. We only check the byte that was just before simd_input. Nines
// are arbitrary values.
let verror: __m256i = _mm256_setr_epi8(
9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9,
9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 1,
);
state.has_error = _mm256_or_si256(
_mm256_cmpgt_epi8(
previous.carried_continuations,
_mm256_setr_epi8(
9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9,
9, 9, 9, 9, 9, 1,
),
state.previous.carried_continuations,
verror,
),
*has_error,
state.has_error,
);
} else {
// it is not ascii so we have to do heavy work
*previous = avxcheck_utf8_bytes(input.lo, &previous, has_error);
*previous = avxcheck_utf8_bytes(input.hi, &previous, has_error);
state.previous = check_utf8_bytes(input.v0, &mut state.previous, &mut state.has_error);
state.previous = check_utf8_bytes(input.v1, &mut state.previous, &mut state.has_error);
}
}

Expand All @@ -58,9 +99,9 @@ unsafe fn check_utf8(
fn cmp_mask_against_input(input: &SimdInput, m: u8) -> u64 {
unsafe {
let mask: __m256i = _mm256_set1_epi8(m as i8);
let cmp_res_0: __m256i = _mm256_cmpeq_epi8(input.lo, mask);
let cmp_res_0: __m256i = _mm256_cmpeq_epi8(input.v0, mask);
let res_0: u64 = u64::from(static_cast_u32!(_mm256_movemask_epi8(cmp_res_0)));
let cmp_res_1: __m256i = _mm256_cmpeq_epi8(input.hi, mask);
let cmp_res_1: __m256i = _mm256_cmpeq_epi8(input.v1, mask);
let res_1: u64 = _mm256_movemask_epi8(cmp_res_1) as u64;
res_0 | (res_1 << 32)
}
Expand All @@ -70,10 +111,9 @@ fn cmp_mask_against_input(input: &SimdInput, m: u8) -> u64 {
#[cfg_attr(not(feature = "no-inline"), inline(always))]
fn unsigned_lteq_against_input(input: &SimdInput, maxval: __m256i) -> u64 {
unsafe {
let cmp_res_0: __m256i = _mm256_cmpeq_epi8(_mm256_max_epu8(maxval, input.lo), maxval);
// TODO: c++ uses static cast, here what are the implications?
let cmp_res_0: __m256i = _mm256_cmpeq_epi8(_mm256_max_epu8(maxval, input.v0), maxval);
let res_0: u64 = u64::from(static_cast_u32!(_mm256_movemask_epi8(cmp_res_0)));
let cmp_res_1: __m256i = _mm256_cmpeq_epi8(_mm256_max_epu8(maxval, input.hi), maxval);
let cmp_res_1: __m256i = _mm256_cmpeq_epi8(_mm256_max_epu8(maxval, input.v1), maxval);
let res_1: u64 = _mm256_movemask_epi8(cmp_res_1) as u64;
res_0 | (res_1 << 32)
}
Expand Down Expand Up @@ -107,9 +147,10 @@ fn find_odd_backslash_sequences(input: &SimdInput, prev_iter_ends_odd_backslash:
// should be flipped
let (mut odd_carries, iter_ends_odd_backslash) = bs_bits.overflowing_add(odd_starts);

odd_carries |= *prev_iter_ends_odd_backslash; // push in bit zero as a potential end
// if we had an odd-numbered run at the
// end of the previous iteration
odd_carries |= *prev_iter_ends_odd_backslash;
// push in bit zero as a potential end
// if we had an odd-numbered run at the
// end of the previous iteration
*prev_iter_ends_odd_backslash = if iter_ends_odd_backslash { 0x1 } else { 0x0 };
let even_carry_ends: u64 = even_carries & !bs_bits;
let odd_carry_ends: u64 = odd_carries & !bs_bits;
Expand Down Expand Up @@ -141,12 +182,8 @@ unsafe fn find_quote_mask_and_bits(
*quote_bits = cmp_mask_against_input(&input, b'"');
*quote_bits &= !odd_ends;
// remove from the valid quoted region the unescapted characters.
#[allow(overflowing_literals)]
let mut quote_mask: u64 = _mm_cvtsi128_si64(_mm_clmulepi64_si128(
_mm_set_epi64x(0, static_cast_i64!(*quote_bits)),
_mm_set1_epi8(0xFF),
0,
)) as u64;
let mut quote_mask: u64 = compute_quote_mask(*quote_bits);

quote_mask ^= *prev_iter_inside_quote;
// All Unicode characters may be placed within the
// quotation marks, except for the characters that MUST be escaped:
Expand Down Expand Up @@ -187,57 +224,57 @@ unsafe fn find_whitespace_and_structurals(

// TODO: const?
let low_nibble_mask: __m256i = _mm256_setr_epi8(
16, 0, 0, 0, 0, 0, 0, 0, 0, 8, 12, 1, 2, 9, 0, 0, 16, 0, 0, 0, 0, 0, 0, 0, 0, 8, 12, 1, 2,
9, 0, 0,
16, 0, 0, 0, 0, 0, 0, 0, 0, 8, 12, 1, 2, 9, 0, 0,
16, 0, 0, 0, 0, 0, 0, 0, 0, 8, 12, 1, 2, 9, 0, 0,
);
// TODO: const?
let high_nibble_mask: __m256i = _mm256_setr_epi8(
8, 0, 18, 4, 0, 1, 0, 1, 0, 0, 0, 3, 2, 1, 0, 0, 8, 0, 18, 4, 0, 1, 0, 1, 0, 0, 0, 3, 2, 1,
0, 0,
8, 0, 18, 4, 0, 1, 0, 1, 0, 0, 0, 3, 2, 1, 0, 0,
8, 0, 18, 4, 0, 1, 0, 1, 0, 0, 0, 3, 2, 1, 0, 0,
);

let structural_shufti_mask: __m256i = _mm256_set1_epi8(0x7);
let whitespace_shufti_mask: __m256i = _mm256_set1_epi8(0x18);

let v_lo: __m256i = _mm256_and_si256(
_mm256_shuffle_epi8(low_nibble_mask, input.lo),
let v_v0: __m256i = _mm256_and_si256(
_mm256_shuffle_epi8(low_nibble_mask, input.v0),
_mm256_shuffle_epi8(
high_nibble_mask,
_mm256_and_si256(_mm256_srli_epi32(input.lo, 4), _mm256_set1_epi8(0x7f)),
_mm256_and_si256(_mm256_srli_epi32(input.v0, 4), _mm256_set1_epi8(0x7f)),
),
);

let v_hi: __m256i = _mm256_and_si256(
_mm256_shuffle_epi8(low_nibble_mask, input.hi),
let v_v1: __m256i = _mm256_and_si256(
_mm256_shuffle_epi8(low_nibble_mask, input.v1),
_mm256_shuffle_epi8(
high_nibble_mask,
_mm256_and_si256(_mm256_srli_epi32(input.hi, 4), _mm256_set1_epi8(0x7f)),
_mm256_and_si256(_mm256_srli_epi32(input.v1, 4), _mm256_set1_epi8(0x7f)),
),
);
let tmp_lo: __m256i = _mm256_cmpeq_epi8(
_mm256_and_si256(v_lo, structural_shufti_mask),

let tmp_v0: __m256i = _mm256_cmpeq_epi8(
_mm256_and_si256(v_v0, structural_shufti_mask),
_mm256_set1_epi8(0),
);
let tmp_hi: __m256i = _mm256_cmpeq_epi8(
_mm256_and_si256(v_hi, structural_shufti_mask),
let tmp_v1: __m256i = _mm256_cmpeq_epi8(
_mm256_and_si256(v_v1, structural_shufti_mask),
_mm256_set1_epi8(0),
);

let structural_res_0: u64 = u64::from(static_cast_u32!(_mm256_movemask_epi8(tmp_lo)));
let structural_res_1: u64 = _mm256_movemask_epi8(tmp_hi) as u64;
let structural_res_0: u64 = u64::from(static_cast_u32!(_mm256_movemask_epi8(tmp_v0)));
let structural_res_1: u64 = _mm256_movemask_epi8(tmp_v1) as u64;
*structurals = !(structural_res_0 | (structural_res_1 << 32));

let tmp_ws_lo: __m256i = _mm256_cmpeq_epi8(
_mm256_and_si256(v_lo, whitespace_shufti_mask),
let tmp_ws_v0: __m256i = _mm256_cmpeq_epi8(
_mm256_and_si256(v_v0, whitespace_shufti_mask),
_mm256_set1_epi8(0),
);
let tmp_ws_hi: __m256i = _mm256_cmpeq_epi8(
_mm256_and_si256(v_hi, whitespace_shufti_mask),
let tmp_ws_v1: __m256i = _mm256_cmpeq_epi8(
_mm256_and_si256(v_v1, whitespace_shufti_mask),
_mm256_set1_epi8(0),
);

let ws_res_0: u64 = u64::from(static_cast_u32!(_mm256_movemask_epi8(tmp_ws_lo)));
let ws_res_1: u64 = _mm256_movemask_epi8(tmp_ws_hi) as u64;
let ws_res_0: u64 = u64::from(static_cast_u32!(_mm256_movemask_epi8(tmp_ws_v0)));
let ws_res_1: u64 = _mm256_movemask_epi8(tmp_ws_v1) as u64;
*whitespace = !(ws_res_0 | (ws_res_1 << 32));
}

Expand Down Expand Up @@ -345,9 +382,18 @@ fn finalize_structurals(
structurals
}

//WARN_UNUSED
/*never_inline*/
//#[inline(never)]
pub fn find_bs_bits_and_quote_bits(v: __m256i) -> ParseStringHelper {
let quote_mask = unsafe { _mm256_cmpeq_epi8(v, _mm256_set1_epi8(b'"' as i8)) };
let quote_bits = unsafe { static_cast_u32!(_mm256_movemask_epi8(quote_mask)) };
let bs_mask = unsafe { _mm256_cmpeq_epi8(v, _mm256_set1_epi8(b'\\' as i8)) };
let bs_bits = unsafe { static_cast_u32!(_mm256_movemask_epi8(bs_mask)) };

ParseStringHelper {
bs_bits,
quote_bits,
}
}

impl<'de> Deserializer<'de> {
//#[inline(never)]
pub unsafe fn find_structural_bits(input: &[u8]) -> std::result::Result<Vec<u32>, ErrorType> {
Expand All @@ -357,8 +403,8 @@ impl<'de> Deserializer<'de> {
let mut structural_indexes = Vec::with_capacity(len / 6);
structural_indexes.push(0); // push extra root element

let mut has_error: __m256i = _mm256_setzero_si256();
let mut previous = AvxProcessedUtfBytes::default();
let mut utf8_state: Utf8CheckingState = Utf8CheckingState::default();

// we have padded the input out to 64 byte multiple with the remainder being
// zeros

Expand Down Expand Up @@ -394,7 +440,7 @@ impl<'de> Deserializer<'de> {
#endif
*/
let input: SimdInput = fill_input(input.get_unchecked(idx as usize..));
check_utf8(&input, &mut has_error, &mut previous);
check_utf8(&input, &mut utf8_state);
// detect odd sequences of backslashes
let odd_ends: u64 =
find_odd_backslash_sequences(&input, &mut prev_iter_ends_odd_backslash);
Expand Down Expand Up @@ -438,7 +484,7 @@ impl<'de> Deserializer<'de> {
.copy_from(input.as_ptr().add(idx), len as usize - idx);
let input: SimdInput = fill_input(&tmpbuf);

check_utf8(&input, &mut has_error, &mut previous);
check_utf8(&input, &mut utf8_state);

// detect odd sequences of backslashes
let odd_ends: u64 =
Expand Down Expand Up @@ -493,7 +539,7 @@ impl<'de> Deserializer<'de> {
return Err(ErrorType::Syntax);
}

if _mm256_testz_si256(has_error, has_error) != 0 {
if is_utf8_status_ok(utf8_state.has_error) {
Ok(structural_indexes)
} else {
Err(ErrorType::InvalidUTF8)
Expand Down
Loading