From 474a0473b4a09f936ef2adf0ac9f22fc6b1488bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eduardo=20S=C3=A1nchez=20Mu=C3=B1oz?= Date: Fri, 12 Apr 2024 14:28:33 +0200 Subject: [PATCH] De-duplicate SSE2 sll/srl/sra code --- src/tools/miri/src/shims/x86/mod.rs | 80 ++++++++++++ src/tools/miri/src/shims/x86/sse2.rs | 176 +++------------------------ 2 files changed, 97 insertions(+), 159 deletions(-) diff --git a/src/tools/miri/src/shims/x86/mod.rs b/src/tools/miri/src/shims/x86/mod.rs index 7b7921219e616..2a663d300a736 100644 --- a/src/tools/miri/src/shims/x86/mod.rs +++ b/src/tools/miri/src/shims/x86/mod.rs @@ -468,6 +468,86 @@ fn unary_op_ps<'tcx>( Ok(()) } +enum ShiftOp { + /// Shift left, logically (shift in zeros) -- same as shift left, arithmetically + Left, + /// Shift right, logically (shift in zeros) + RightLogic, + /// Shift right, arithmetically (shift in sign) + RightArith, +} + +/// Shifts each element of `left` by a scalar amount. The shift amount +/// is determined by the lowest 64 bits of `right` (which is a 128-bit vector). +/// +/// For logic shifts, when right is larger than BITS - 1, zero is produced. +/// For arithmetic right-shifts, when right is larger than BITS - 1, the sign +/// bit is copied to remaining bits. +fn shift_simd_by_scalar<'tcx>( + this: &mut crate::MiriInterpCx<'_, 'tcx>, + left: &OpTy<'tcx, Provenance>, + right: &OpTy<'tcx, Provenance>, + which: ShiftOp, + dest: &MPlaceTy<'tcx, Provenance>, +) -> InterpResult<'tcx, ()> { + let (left, left_len) = this.operand_to_simd(left)?; + let (dest, dest_len) = this.mplace_to_simd(dest)?; + + assert_eq!(dest_len, left_len); + // `right` may have a different length, and we only care about its + // lowest 64bit anyway. + + // Get the 64-bit shift operand and convert it to the type expected + // by checked_{shl,shr} (u32). + // It is ok to saturate the value to u32::MAX because any value + // above BITS - 1 will produce the same result. + let shift = u32::try_from(extract_first_u64(this, right)?).unwrap_or(u32::MAX); + + for i in 0..dest_len { + let left = this.read_scalar(&this.project_index(&left, i)?)?; + let dest = this.project_index(&dest, i)?; + + let res = match which { + ShiftOp::Left => { + let left = left.to_uint(dest.layout.size)?; + let res = left.checked_shl(shift).unwrap_or(0); + // `truncate` is needed as left-shift can make the absolute value larger. + Scalar::from_uint(dest.layout.size.truncate(res), dest.layout.size) + } + ShiftOp::RightLogic => { + let left = left.to_uint(dest.layout.size)?; + let res = left.checked_shr(shift).unwrap_or(0); + // No `truncate` needed as right-shift can only make the absolute value smaller. + Scalar::from_uint(res, dest.layout.size) + } + ShiftOp::RightArith => { + let left = left.to_int(dest.layout.size)?; + // On overflow, copy the sign bit to the remaining bits + let res = left.checked_shr(shift).unwrap_or(left >> 127); + // No `truncate` needed as right-shift can only make the absolute value smaller. + Scalar::from_int(res, dest.layout.size) + } + }; + this.write_scalar(res, &dest)?; + } + + Ok(()) +} + +/// Takes a 128-bit vector, transmutes it to `[u64; 2]` and extracts +/// the first value. +fn extract_first_u64<'tcx>( + this: &crate::MiriInterpCx<'_, 'tcx>, + op: &OpTy<'tcx, Provenance>, +) -> InterpResult<'tcx, u64> { + // Transmute vector to `[u64; 2]` + let array_layout = this.layout_of(Ty::new_array(this.tcx.tcx, this.tcx.types.u64, 2))?; + let op = op.transmute(array_layout, this)?; + + // Get the first u64 from the array + this.read_scalar(&this.project_index(&op, 0)?)?.to_u64() +} + // Rounds the first element of `right` according to `rounding` // and copies the remaining elements from `left`. fn round_first<'tcx, F: rustc_apfloat::Float>( diff --git a/src/tools/miri/src/shims/x86/sse2.rs b/src/tools/miri/src/shims/x86/sse2.rs index eb2cc9d37c826..9db30d7ddca21 100644 --- a/src/tools/miri/src/shims/x86/sse2.rs +++ b/src/tools/miri/src/shims/x86/sse2.rs @@ -1,10 +1,11 @@ use rustc_apfloat::ieee::Double; -use rustc_middle::ty::layout::LayoutOf as _; -use rustc_middle::ty::Ty; use rustc_span::Symbol; use rustc_target::spec::abi::Abi; -use super::{bin_op_simd_float_all, bin_op_simd_float_first, convert_float_to_int, FloatBinOp}; +use super::{ + bin_op_simd_float_all, bin_op_simd_float_first, convert_float_to_int, shift_simd_by_scalar, + FloatBinOp, ShiftOp, +}; use crate::*; use shims::foreign_items::EmulateForeignItemResult; @@ -109,156 +110,27 @@ pub(super) trait EvalContextExt<'mir, 'tcx: 'mir>: this.write_scalar(Scalar::from_u64(res.into()), &dest)?; } } - // Used to implement the _mm_{sll,srl,sra}_epi16 functions. - // Shifts 16-bit packed integers in left by the amount in right. - // Both operands are vectors of 16-bit integers. However, right is - // interpreted as a single 64-bit integer (remaining bits are ignored). - // For logic shifts, when right is larger than 15, zero is produced. - // For arithmetic shifts, when right is larger than 15, the sign bit + // Used to implement the _mm_{sll,srl,sra}_epi{16,32,64} functions + // (except _mm_sra_epi64, which is not available in SSE2). + // Shifts N-bit packed integers in left by the amount in right. + // Both operands are 128-bit vectors. However, right is interpreted as + // a single 64-bit integer (remaining bits are ignored). + // For logic shifts, when right is larger than N - 1, zero is produced. + // For arithmetic shifts, when right is larger than N - 1, the sign bit // is copied to remaining bits. - "psll.w" | "psrl.w" | "psra.w" => { + "psll.w" | "psrl.w" | "psra.w" | "psll.d" | "psrl.d" | "psra.d" | "psll.q" + | "psrl.q" => { let [left, right] = this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?; - let (left, left_len) = this.operand_to_simd(left)?; - let (right, right_len) = this.operand_to_simd(right)?; - let (dest, dest_len) = this.mplace_to_simd(dest)?; - - assert_eq!(dest_len, left_len); - assert_eq!(dest_len, right_len); - - enum ShiftOp { - Sll, - Srl, - Sra, - } let which = match unprefixed_name { - "psll.w" => ShiftOp::Sll, - "psrl.w" => ShiftOp::Srl, - "psra.w" => ShiftOp::Sra, + "psll.w" | "psll.d" | "psll.q" => ShiftOp::Left, + "psrl.w" | "psrl.d" | "psrl.q" => ShiftOp::RightLogic, + "psra.w" | "psra.d" => ShiftOp::RightArith, _ => unreachable!(), }; - // Get the 64-bit shift operand and convert it to the type expected - // by checked_{shl,shr} (u32). - // It is ok to saturate the value to u32::MAX because any value - // above 15 will produce the same result. - let shift = extract_first_u64(this, &right)?.try_into().unwrap_or(u32::MAX); - - for i in 0..dest_len { - let left = this.read_scalar(&this.project_index(&left, i)?)?.to_u16()?; - let dest = this.project_index(&dest, i)?; - - let res = match which { - ShiftOp::Sll => left.checked_shl(shift).unwrap_or(0), - ShiftOp::Srl => left.checked_shr(shift).unwrap_or(0), - #[allow(clippy::cast_possible_wrap, clippy::cast_sign_loss)] - ShiftOp::Sra => { - // Convert u16 to i16 to use arithmetic shift - let left = left as i16; - // Copy the sign bit to the remaining bits - left.checked_shr(shift).unwrap_or(left >> 15) as u16 - } - }; - - this.write_scalar(Scalar::from_u16(res), &dest)?; - } - } - // Used to implement the _mm_{sll,srl,sra}_epi32 functions. - // 32-bit equivalent to the shift functions above. - "psll.d" | "psrl.d" | "psra.d" => { - let [left, right] = - this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?; - - let (left, left_len) = this.operand_to_simd(left)?; - let (right, right_len) = this.operand_to_simd(right)?; - let (dest, dest_len) = this.mplace_to_simd(dest)?; - - assert_eq!(dest_len, left_len); - assert_eq!(dest_len, right_len); - - enum ShiftOp { - Sll, - Srl, - Sra, - } - let which = match unprefixed_name { - "psll.d" => ShiftOp::Sll, - "psrl.d" => ShiftOp::Srl, - "psra.d" => ShiftOp::Sra, - _ => unreachable!(), - }; - - // Get the 64-bit shift operand and convert it to the type expected - // by checked_{shl,shr} (u32). - // It is ok to saturate the value to u32::MAX because any value - // above 31 will produce the same result. - let shift = extract_first_u64(this, &right)?.try_into().unwrap_or(u32::MAX); - - for i in 0..dest_len { - let left = this.read_scalar(&this.project_index(&left, i)?)?.to_u32()?; - let dest = this.project_index(&dest, i)?; - - let res = match which { - ShiftOp::Sll => left.checked_shl(shift).unwrap_or(0), - ShiftOp::Srl => left.checked_shr(shift).unwrap_or(0), - #[allow(clippy::cast_possible_wrap, clippy::cast_sign_loss)] - ShiftOp::Sra => { - // Convert u32 to i32 to use arithmetic shift - let left = left as i32; - // Copy the sign bit to the remaining bits - left.checked_shr(shift).unwrap_or(left >> 31) as u32 - } - }; - - this.write_scalar(Scalar::from_u32(res), &dest)?; - } - } - // Used to implement the _mm_{sll,srl}_epi64 functions. - // 64-bit equivalent to the shift functions above, except _mm_sra_epi64, - // which is not available in SSE2. - "psll.q" | "psrl.q" => { - let [left, right] = - this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?; - - let (left, left_len) = this.operand_to_simd(left)?; - let (right, right_len) = this.operand_to_simd(right)?; - let (dest, dest_len) = this.mplace_to_simd(dest)?; - - assert_eq!(dest_len, left_len); - assert_eq!(dest_len, right_len); - - enum ShiftOp { - Sll, - Srl, - } - let which = match unprefixed_name { - "psll.q" => ShiftOp::Sll, - "psrl.q" => ShiftOp::Srl, - _ => unreachable!(), - }; - - // Get the 64-bit shift operand and convert it to the type expected - // by checked_{shl,shr} (u32). - // It is ok to saturate the value to u32::MAX because any value - // above 63 will produce the same result. - let shift = this - .read_scalar(&this.project_index(&right, 0)?)? - .to_u64()? - .try_into() - .unwrap_or(u32::MAX); - - for i in 0..dest_len { - let left = this.read_scalar(&this.project_index(&left, i)?)?.to_u64()?; - let dest = this.project_index(&dest, i)?; - - let res = match which { - ShiftOp::Sll => left.checked_shl(shift).unwrap_or(0), - ShiftOp::Srl => left.checked_shr(shift).unwrap_or(0), - }; - - this.write_scalar(Scalar::from_u64(res), &dest)?; - } + shift_simd_by_scalar(this, left, right, which, dest)?; } // Used to implement the _mm_cvtps_epi32, _mm_cvttps_epi32, _mm_cvtpd_epi32 // and _mm_cvttpd_epi32 functions. @@ -585,17 +457,3 @@ pub(super) trait EvalContextExt<'mir, 'tcx: 'mir>: Ok(EmulateForeignItemResult::NeedsJumping) } } - -/// Takes a 128-bit vector, transmutes it to `[u64; 2]` and extracts -/// the first value. -fn extract_first_u64<'tcx>( - this: &crate::MiriInterpCx<'_, 'tcx>, - op: &MPlaceTy<'tcx, Provenance>, -) -> InterpResult<'tcx, u64> { - // Transmute vector to `[u64; 2]` - let u64_array_layout = this.layout_of(Ty::new_array(this.tcx.tcx, this.tcx.types.u64, 2))?; - let op = op.transmute(u64_array_layout, this)?; - - // Get the first u64 from the array - this.read_scalar(&this.project_index(&op, 0)?)?.to_u64() -}