Skip to content

Commit

Permalink
Auto merge of #3429 - eduardosm:shift, r=RalfJung
Browse files Browse the repository at this point in the history
De-duplicate SSE2 sll/srl/sra code
  • Loading branch information
bors committed Apr 13, 2024
2 parents 788a1db + 474a047 commit c3136b2
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 159 deletions.
80 changes: 80 additions & 0 deletions src/tools/miri/src/shims/x86/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,86 @@ fn unary_op_ps<'tcx>(
Ok(())
}

enum ShiftOp {
/// Shift left, logically (shift in zeros) -- same as shift left, arithmetically
Left,
/// Shift right, logically (shift in zeros)
RightLogic,
/// Shift right, arithmetically (shift in sign)
RightArith,
}

/// Shifts each element of `left` by a scalar amount. The shift amount
/// is determined by the lowest 64 bits of `right` (which is a 128-bit vector).
///
/// For logic shifts, when right is larger than BITS - 1, zero is produced.
/// For arithmetic right-shifts, when right is larger than BITS - 1, the sign
/// bit is copied to remaining bits.
fn shift_simd_by_scalar<'tcx>(
this: &mut crate::MiriInterpCx<'_, 'tcx>,
left: &OpTy<'tcx, Provenance>,
right: &OpTy<'tcx, Provenance>,
which: ShiftOp,
dest: &MPlaceTy<'tcx, Provenance>,
) -> InterpResult<'tcx, ()> {
let (left, left_len) = this.operand_to_simd(left)?;
let (dest, dest_len) = this.mplace_to_simd(dest)?;

assert_eq!(dest_len, left_len);
// `right` may have a different length, and we only care about its
// lowest 64bit anyway.

// Get the 64-bit shift operand and convert it to the type expected
// by checked_{shl,shr} (u32).
// It is ok to saturate the value to u32::MAX because any value
// above BITS - 1 will produce the same result.
let shift = u32::try_from(extract_first_u64(this, right)?).unwrap_or(u32::MAX);

for i in 0..dest_len {
let left = this.read_scalar(&this.project_index(&left, i)?)?;
let dest = this.project_index(&dest, i)?;

let res = match which {
ShiftOp::Left => {
let left = left.to_uint(dest.layout.size)?;
let res = left.checked_shl(shift).unwrap_or(0);
// `truncate` is needed as left-shift can make the absolute value larger.
Scalar::from_uint(dest.layout.size.truncate(res), dest.layout.size)
}
ShiftOp::RightLogic => {
let left = left.to_uint(dest.layout.size)?;
let res = left.checked_shr(shift).unwrap_or(0);
// No `truncate` needed as right-shift can only make the absolute value smaller.
Scalar::from_uint(res, dest.layout.size)
}
ShiftOp::RightArith => {
let left = left.to_int(dest.layout.size)?;
// On overflow, copy the sign bit to the remaining bits
let res = left.checked_shr(shift).unwrap_or(left >> 127);
// No `truncate` needed as right-shift can only make the absolute value smaller.
Scalar::from_int(res, dest.layout.size)
}
};
this.write_scalar(res, &dest)?;
}

Ok(())
}

/// Takes a 128-bit vector, transmutes it to `[u64; 2]` and extracts
/// the first value.
fn extract_first_u64<'tcx>(
this: &crate::MiriInterpCx<'_, 'tcx>,
op: &OpTy<'tcx, Provenance>,
) -> InterpResult<'tcx, u64> {
// Transmute vector to `[u64; 2]`
let array_layout = this.layout_of(Ty::new_array(this.tcx.tcx, this.tcx.types.u64, 2))?;
let op = op.transmute(array_layout, this)?;

// Get the first u64 from the array
this.read_scalar(&this.project_index(&op, 0)?)?.to_u64()
}

// Rounds the first element of `right` according to `rounding`
// and copies the remaining elements from `left`.
fn round_first<'tcx, F: rustc_apfloat::Float>(
Expand Down
176 changes: 17 additions & 159 deletions src/tools/miri/src/shims/x86/sse2.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use rustc_apfloat::ieee::Double;
use rustc_middle::ty::layout::LayoutOf as _;
use rustc_middle::ty::Ty;
use rustc_span::Symbol;
use rustc_target::spec::abi::Abi;

use super::{bin_op_simd_float_all, bin_op_simd_float_first, convert_float_to_int, FloatBinOp};
use super::{
bin_op_simd_float_all, bin_op_simd_float_first, convert_float_to_int, shift_simd_by_scalar,
FloatBinOp, ShiftOp,
};
use crate::*;
use shims::foreign_items::EmulateForeignItemResult;

Expand Down Expand Up @@ -109,156 +110,27 @@ pub(super) trait EvalContextExt<'mir, 'tcx: 'mir>:
this.write_scalar(Scalar::from_u64(res.into()), &dest)?;
}
}
// Used to implement the _mm_{sll,srl,sra}_epi16 functions.
// Shifts 16-bit packed integers in left by the amount in right.
// Both operands are vectors of 16-bit integers. However, right is
// interpreted as a single 64-bit integer (remaining bits are ignored).
// For logic shifts, when right is larger than 15, zero is produced.
// For arithmetic shifts, when right is larger than 15, the sign bit
// Used to implement the _mm_{sll,srl,sra}_epi{16,32,64} functions
// (except _mm_sra_epi64, which is not available in SSE2).
// Shifts N-bit packed integers in left by the amount in right.
// Both operands are 128-bit vectors. However, right is interpreted as
// a single 64-bit integer (remaining bits are ignored).
// For logic shifts, when right is larger than N - 1, zero is produced.
// For arithmetic shifts, when right is larger than N - 1, the sign bit
// is copied to remaining bits.
"psll.w" | "psrl.w" | "psra.w" => {
"psll.w" | "psrl.w" | "psra.w" | "psll.d" | "psrl.d" | "psra.d" | "psll.q"
| "psrl.q" => {
let [left, right] =
this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;

let (left, left_len) = this.operand_to_simd(left)?;
let (right, right_len) = this.operand_to_simd(right)?;
let (dest, dest_len) = this.mplace_to_simd(dest)?;

assert_eq!(dest_len, left_len);
assert_eq!(dest_len, right_len);

enum ShiftOp {
Sll,
Srl,
Sra,
}
let which = match unprefixed_name {
"psll.w" => ShiftOp::Sll,
"psrl.w" => ShiftOp::Srl,
"psra.w" => ShiftOp::Sra,
"psll.w" | "psll.d" | "psll.q" => ShiftOp::Left,
"psrl.w" | "psrl.d" | "psrl.q" => ShiftOp::RightLogic,
"psra.w" | "psra.d" => ShiftOp::RightArith,
_ => unreachable!(),
};

// Get the 64-bit shift operand and convert it to the type expected
// by checked_{shl,shr} (u32).
// It is ok to saturate the value to u32::MAX because any value
// above 15 will produce the same result.
let shift = extract_first_u64(this, &right)?.try_into().unwrap_or(u32::MAX);

for i in 0..dest_len {
let left = this.read_scalar(&this.project_index(&left, i)?)?.to_u16()?;
let dest = this.project_index(&dest, i)?;

let res = match which {
ShiftOp::Sll => left.checked_shl(shift).unwrap_or(0),
ShiftOp::Srl => left.checked_shr(shift).unwrap_or(0),
#[allow(clippy::cast_possible_wrap, clippy::cast_sign_loss)]
ShiftOp::Sra => {
// Convert u16 to i16 to use arithmetic shift
let left = left as i16;
// Copy the sign bit to the remaining bits
left.checked_shr(shift).unwrap_or(left >> 15) as u16
}
};

this.write_scalar(Scalar::from_u16(res), &dest)?;
}
}
// Used to implement the _mm_{sll,srl,sra}_epi32 functions.
// 32-bit equivalent to the shift functions above.
"psll.d" | "psrl.d" | "psra.d" => {
let [left, right] =
this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;

let (left, left_len) = this.operand_to_simd(left)?;
let (right, right_len) = this.operand_to_simd(right)?;
let (dest, dest_len) = this.mplace_to_simd(dest)?;

assert_eq!(dest_len, left_len);
assert_eq!(dest_len, right_len);

enum ShiftOp {
Sll,
Srl,
Sra,
}
let which = match unprefixed_name {
"psll.d" => ShiftOp::Sll,
"psrl.d" => ShiftOp::Srl,
"psra.d" => ShiftOp::Sra,
_ => unreachable!(),
};

// Get the 64-bit shift operand and convert it to the type expected
// by checked_{shl,shr} (u32).
// It is ok to saturate the value to u32::MAX because any value
// above 31 will produce the same result.
let shift = extract_first_u64(this, &right)?.try_into().unwrap_or(u32::MAX);

for i in 0..dest_len {
let left = this.read_scalar(&this.project_index(&left, i)?)?.to_u32()?;
let dest = this.project_index(&dest, i)?;

let res = match which {
ShiftOp::Sll => left.checked_shl(shift).unwrap_or(0),
ShiftOp::Srl => left.checked_shr(shift).unwrap_or(0),
#[allow(clippy::cast_possible_wrap, clippy::cast_sign_loss)]
ShiftOp::Sra => {
// Convert u32 to i32 to use arithmetic shift
let left = left as i32;
// Copy the sign bit to the remaining bits
left.checked_shr(shift).unwrap_or(left >> 31) as u32
}
};

this.write_scalar(Scalar::from_u32(res), &dest)?;
}
}
// Used to implement the _mm_{sll,srl}_epi64 functions.
// 64-bit equivalent to the shift functions above, except _mm_sra_epi64,
// which is not available in SSE2.
"psll.q" | "psrl.q" => {
let [left, right] =
this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;

let (left, left_len) = this.operand_to_simd(left)?;
let (right, right_len) = this.operand_to_simd(right)?;
let (dest, dest_len) = this.mplace_to_simd(dest)?;

assert_eq!(dest_len, left_len);
assert_eq!(dest_len, right_len);

enum ShiftOp {
Sll,
Srl,
}
let which = match unprefixed_name {
"psll.q" => ShiftOp::Sll,
"psrl.q" => ShiftOp::Srl,
_ => unreachable!(),
};

// Get the 64-bit shift operand and convert it to the type expected
// by checked_{shl,shr} (u32).
// It is ok to saturate the value to u32::MAX because any value
// above 63 will produce the same result.
let shift = this
.read_scalar(&this.project_index(&right, 0)?)?
.to_u64()?
.try_into()
.unwrap_or(u32::MAX);

for i in 0..dest_len {
let left = this.read_scalar(&this.project_index(&left, i)?)?.to_u64()?;
let dest = this.project_index(&dest, i)?;

let res = match which {
ShiftOp::Sll => left.checked_shl(shift).unwrap_or(0),
ShiftOp::Srl => left.checked_shr(shift).unwrap_or(0),
};

this.write_scalar(Scalar::from_u64(res), &dest)?;
}
shift_simd_by_scalar(this, left, right, which, dest)?;
}
// Used to implement the _mm_cvtps_epi32, _mm_cvttps_epi32, _mm_cvtpd_epi32
// and _mm_cvttpd_epi32 functions.
Expand Down Expand Up @@ -585,17 +457,3 @@ pub(super) trait EvalContextExt<'mir, 'tcx: 'mir>:
Ok(EmulateForeignItemResult::NeedsJumping)
}
}

/// Takes a 128-bit vector, transmutes it to `[u64; 2]` and extracts
/// the first value.
fn extract_first_u64<'tcx>(
this: &crate::MiriInterpCx<'_, 'tcx>,
op: &MPlaceTy<'tcx, Provenance>,
) -> InterpResult<'tcx, u64> {
// Transmute vector to `[u64; 2]`
let u64_array_layout = this.layout_of(Ty::new_array(this.tcx.tcx, this.tcx.types.u64, 2))?;
let op = op.transmute(u64_array_layout, this)?;

// Get the first u64 from the array
this.read_scalar(&this.project_index(&op, 0)?)?.to_u64()
}

0 comments on commit c3136b2

Please sign in to comment.