Skip to content

Commit

Permalink
Auto merge of #130215 - RalfJung:interpret-simd, r=compiler-errors
Browse files Browse the repository at this point in the history
interpret: simplify SIMD type handling

This is possible as a follow-up to rust-lang/rust#129403
  • Loading branch information
bors committed Sep 13, 2024
2 parents 30e8618 + d79ea9e commit f4d49d6
Show file tree
Hide file tree
Showing 10 changed files with 150 additions and 147 deletions.
82 changes: 41 additions & 41 deletions src/intrinsics/simd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
| "bitreverse"
=> {
let [op] = check_arg_count(args)?;
let (op, op_len) = this.operand_to_simd(op)?;
let (dest, dest_len) = this.mplace_to_simd(dest)?;
let (op, op_len) = this.project_to_simd(op)?;
let (dest, dest_len) = this.project_to_simd(dest)?;

assert_eq!(dest_len, op_len);

Expand Down Expand Up @@ -200,9 +200,9 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
use mir::BinOp;

let [left, right] = check_arg_count(args)?;
let (left, left_len) = this.operand_to_simd(left)?;
let (right, right_len) = this.operand_to_simd(right)?;
let (dest, dest_len) = this.mplace_to_simd(dest)?;
let (left, left_len) = this.project_to_simd(left)?;
let (right, right_len) = this.project_to_simd(right)?;
let (dest, dest_len) = this.project_to_simd(dest)?;

assert_eq!(dest_len, left_len);
assert_eq!(dest_len, right_len);
Expand Down Expand Up @@ -291,10 +291,10 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
}
"fma" => {
let [a, b, c] = check_arg_count(args)?;
let (a, a_len) = this.operand_to_simd(a)?;
let (b, b_len) = this.operand_to_simd(b)?;
let (c, c_len) = this.operand_to_simd(c)?;
let (dest, dest_len) = this.mplace_to_simd(dest)?;
let (a, a_len) = this.project_to_simd(a)?;
let (b, b_len) = this.project_to_simd(b)?;
let (c, c_len) = this.project_to_simd(c)?;
let (dest, dest_len) = this.project_to_simd(dest)?;

assert_eq!(dest_len, a_len);
assert_eq!(dest_len, b_len);
Expand Down Expand Up @@ -345,7 +345,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
use mir::BinOp;

let [op] = check_arg_count(args)?;
let (op, op_len) = this.operand_to_simd(op)?;
let (op, op_len) = this.project_to_simd(op)?;

let imm_from_bool =
|b| ImmTy::from_scalar(Scalar::from_bool(b), this.machine.layouts.bool);
Expand Down Expand Up @@ -408,7 +408,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
use mir::BinOp;

let [op, init] = check_arg_count(args)?;
let (op, op_len) = this.operand_to_simd(op)?;
let (op, op_len) = this.project_to_simd(op)?;
let init = this.read_immediate(init)?;

let mir_op = match intrinsic_name {
Expand All @@ -426,10 +426,10 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
}
"select" => {
let [mask, yes, no] = check_arg_count(args)?;
let (mask, mask_len) = this.operand_to_simd(mask)?;
let (yes, yes_len) = this.operand_to_simd(yes)?;
let (no, no_len) = this.operand_to_simd(no)?;
let (dest, dest_len) = this.mplace_to_simd(dest)?;
let (mask, mask_len) = this.project_to_simd(mask)?;
let (yes, yes_len) = this.project_to_simd(yes)?;
let (no, no_len) = this.project_to_simd(no)?;
let (dest, dest_len) = this.project_to_simd(dest)?;

assert_eq!(dest_len, mask_len);
assert_eq!(dest_len, yes_len);
Expand All @@ -448,9 +448,9 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
// Variant of `select` that takes a bitmask rather than a "vector of bool".
"select_bitmask" => {
let [mask, yes, no] = check_arg_count(args)?;
let (yes, yes_len) = this.operand_to_simd(yes)?;
let (no, no_len) = this.operand_to_simd(no)?;
let (dest, dest_len) = this.mplace_to_simd(dest)?;
let (yes, yes_len) = this.project_to_simd(yes)?;
let (no, no_len) = this.project_to_simd(no)?;
let (dest, dest_len) = this.project_to_simd(dest)?;
let bitmask_len = dest_len.next_multiple_of(8);
if bitmask_len > 64 {
throw_unsup_format!(
Expand Down Expand Up @@ -522,7 +522,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
// Converts a "vector of bool" into a bitmask.
"bitmask" => {
let [op] = check_arg_count(args)?;
let (op, op_len) = this.operand_to_simd(op)?;
let (op, op_len) = this.project_to_simd(op)?;
let bitmask_len = op_len.next_multiple_of(8);
if bitmask_len > 64 {
throw_unsup_format!(
Expand Down Expand Up @@ -570,8 +570,8 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
}
"cast" | "as" | "cast_ptr" | "expose_provenance" | "with_exposed_provenance" => {
let [op] = check_arg_count(args)?;
let (op, op_len) = this.operand_to_simd(op)?;
let (dest, dest_len) = this.mplace_to_simd(dest)?;
let (op, op_len) = this.project_to_simd(op)?;
let (dest, dest_len) = this.project_to_simd(dest)?;

assert_eq!(dest_len, op_len);

Expand Down Expand Up @@ -627,9 +627,9 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
}
"shuffle_generic" => {
let [left, right] = check_arg_count(args)?;
let (left, left_len) = this.operand_to_simd(left)?;
let (right, right_len) = this.operand_to_simd(right)?;
let (dest, dest_len) = this.mplace_to_simd(dest)?;
let (left, left_len) = this.project_to_simd(left)?;
let (right, right_len) = this.project_to_simd(right)?;
let (dest, dest_len) = this.project_to_simd(dest)?;

let index = generic_args[2]
.expect_const()
Expand Down Expand Up @@ -662,15 +662,15 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
}
"shuffle" => {
let [left, right, index] = check_arg_count(args)?;
let (left, left_len) = this.operand_to_simd(left)?;
let (right, right_len) = this.operand_to_simd(right)?;
let (dest, dest_len) = this.mplace_to_simd(dest)?;
let (left, left_len) = this.project_to_simd(left)?;
let (right, right_len) = this.project_to_simd(right)?;
let (dest, dest_len) = this.project_to_simd(dest)?;

// `index` is an array or a SIMD type
let (index, index_len) = match index.layout.ty.kind() {
// FIXME: remove this once `index` must always be a SIMD vector.
ty::Array(..) => (index.assert_mem_place(), index.len(this)?),
_ => this.operand_to_simd(index)?,
ty::Array(..) => (index.clone(), index.len(this)?),
_ => this.project_to_simd(index)?,
};

assert_eq!(left_len, right_len);
Expand Down Expand Up @@ -699,10 +699,10 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
}
"gather" => {
let [passthru, ptrs, mask] = check_arg_count(args)?;
let (passthru, passthru_len) = this.operand_to_simd(passthru)?;
let (ptrs, ptrs_len) = this.operand_to_simd(ptrs)?;
let (mask, mask_len) = this.operand_to_simd(mask)?;
let (dest, dest_len) = this.mplace_to_simd(dest)?;
let (passthru, passthru_len) = this.project_to_simd(passthru)?;
let (ptrs, ptrs_len) = this.project_to_simd(ptrs)?;
let (mask, mask_len) = this.project_to_simd(mask)?;
let (dest, dest_len) = this.project_to_simd(dest)?;

assert_eq!(dest_len, passthru_len);
assert_eq!(dest_len, ptrs_len);
Expand All @@ -725,9 +725,9 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
}
"scatter" => {
let [value, ptrs, mask] = check_arg_count(args)?;
let (value, value_len) = this.operand_to_simd(value)?;
let (ptrs, ptrs_len) = this.operand_to_simd(ptrs)?;
let (mask, mask_len) = this.operand_to_simd(mask)?;
let (value, value_len) = this.project_to_simd(value)?;
let (ptrs, ptrs_len) = this.project_to_simd(ptrs)?;
let (mask, mask_len) = this.project_to_simd(mask)?;

assert_eq!(ptrs_len, value_len);
assert_eq!(ptrs_len, mask_len);
Expand All @@ -745,10 +745,10 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
}
"masked_load" => {
let [mask, ptr, default] = check_arg_count(args)?;
let (mask, mask_len) = this.operand_to_simd(mask)?;
let (mask, mask_len) = this.project_to_simd(mask)?;
let ptr = this.read_pointer(ptr)?;
let (default, default_len) = this.operand_to_simd(default)?;
let (dest, dest_len) = this.mplace_to_simd(dest)?;
let (default, default_len) = this.project_to_simd(default)?;
let (dest, dest_len) = this.project_to_simd(dest)?;

assert_eq!(dest_len, mask_len);
assert_eq!(dest_len, default_len);
Expand All @@ -772,9 +772,9 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
}
"masked_store" => {
let [mask, ptr, vals] = check_arg_count(args)?;
let (mask, mask_len) = this.operand_to_simd(mask)?;
let (mask, mask_len) = this.project_to_simd(mask)?;
let ptr = this.read_pointer(ptr)?;
let (vals, vals_len) = this.operand_to_simd(vals)?;
let (vals, vals_len) = this.project_to_simd(vals)?;

assert_eq!(mask_len, vals_len);

Expand Down
4 changes: 2 additions & 2 deletions src/shims/foreign_items.rs
Original file line number Diff line number Diff line change
Expand Up @@ -903,8 +903,8 @@ trait EvalContextExtPriv<'tcx>: crate::MiriInterpCxExt<'tcx> {
name if name.starts_with("llvm.ctpop.v") => {
let [op] = this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;

let (op, op_len) = this.operand_to_simd(op)?;
let (dest, dest_len) = this.mplace_to_simd(dest)?;
let (op, op_len) = this.project_to_simd(op)?;
let (dest, dest_len) = this.project_to_simd(dest)?;

assert_eq!(dest_len, op_len);

Expand Down
12 changes: 6 additions & 6 deletions src/shims/x86/avx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,9 @@ pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
let [data, control] =
this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;

let (data, data_len) = this.operand_to_simd(data)?;
let (control, control_len) = this.operand_to_simd(control)?;
let (dest, dest_len) = this.mplace_to_simd(dest)?;
let (data, data_len) = this.project_to_simd(data)?;
let (control, control_len) = this.project_to_simd(control)?;
let (dest, dest_len) = this.project_to_simd(dest)?;

assert_eq!(dest_len, data_len);
assert_eq!(dest_len, control_len);
Expand Down Expand Up @@ -193,9 +193,9 @@ pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
let [data, control] =
this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;

let (data, data_len) = this.operand_to_simd(data)?;
let (control, control_len) = this.operand_to_simd(control)?;
let (dest, dest_len) = this.mplace_to_simd(dest)?;
let (data, data_len) = this.project_to_simd(data)?;
let (control, control_len) = this.project_to_simd(control)?;
let (dest, dest_len) = this.project_to_simd(dest)?;

assert_eq!(dest_len, data_len);
assert_eq!(dest_len, control_len);
Expand Down
38 changes: 19 additions & 19 deletions src/shims/x86/avx2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@ pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {

assert_eq!(dest.layout, src.layout);

let (src, _) = this.operand_to_simd(src)?;
let (offsets, offsets_len) = this.operand_to_simd(offsets)?;
let (mask, mask_len) = this.operand_to_simd(mask)?;
let (dest, dest_len) = this.mplace_to_simd(dest)?;
let (src, _) = this.project_to_simd(src)?;
let (offsets, offsets_len) = this.project_to_simd(offsets)?;
let (mask, mask_len) = this.project_to_simd(mask)?;
let (dest, dest_len) = this.project_to_simd(dest)?;

// There are cases like dest: i32x4, offsets: i64x2
// If dest has more elements than offset, extra dest elements are filled with zero.
Expand Down Expand Up @@ -118,9 +118,9 @@ pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
let [left, right] =
this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;

let (left, left_len) = this.operand_to_simd(left)?;
let (right, right_len) = this.operand_to_simd(right)?;
let (dest, dest_len) = this.mplace_to_simd(dest)?;
let (left, left_len) = this.project_to_simd(left)?;
let (right, right_len) = this.project_to_simd(right)?;
let (dest, dest_len) = this.project_to_simd(dest)?;

assert_eq!(left_len, right_len);
assert_eq!(dest_len.strict_mul(2), left_len);
Expand Down Expand Up @@ -155,9 +155,9 @@ pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
let [left, right] =
this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;

let (left, left_len) = this.operand_to_simd(left)?;
let (right, right_len) = this.operand_to_simd(right)?;
let (dest, dest_len) = this.mplace_to_simd(dest)?;
let (left, left_len) = this.project_to_simd(left)?;
let (right, right_len) = this.project_to_simd(right)?;
let (dest, dest_len) = this.project_to_simd(dest)?;

assert_eq!(left_len, right_len);
assert_eq!(dest_len.strict_mul(2), left_len);
Expand Down Expand Up @@ -271,9 +271,9 @@ pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
let [left, right] =
this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;

let (left, left_len) = this.operand_to_simd(left)?;
let (right, right_len) = this.operand_to_simd(right)?;
let (dest, dest_len) = this.mplace_to_simd(dest)?;
let (left, left_len) = this.project_to_simd(left)?;
let (right, right_len) = this.project_to_simd(right)?;
let (dest, dest_len) = this.project_to_simd(dest)?;

assert_eq!(dest_len, left_len);
assert_eq!(dest_len, right_len);
Expand Down Expand Up @@ -330,9 +330,9 @@ pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
let [left, right] =
this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;

let (left, left_len) = this.operand_to_simd(left)?;
let (right, right_len) = this.operand_to_simd(right)?;
let (dest, dest_len) = this.mplace_to_simd(dest)?;
let (left, left_len) = this.project_to_simd(left)?;
let (right, right_len) = this.project_to_simd(right)?;
let (dest, dest_len) = this.project_to_simd(dest)?;

assert_eq!(left_len, right_len);
assert_eq!(left_len, dest_len.strict_mul(8));
Expand Down Expand Up @@ -363,9 +363,9 @@ pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
let [left, right] =
this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;

let (left, left_len) = this.operand_to_simd(left)?;
let (right, right_len) = this.operand_to_simd(right)?;
let (dest, dest_len) = this.mplace_to_simd(dest)?;
let (left, left_len) = this.project_to_simd(left)?;
let (right, right_len) = this.project_to_simd(right)?;
let (dest, dest_len) = this.project_to_simd(dest)?;

assert_eq!(dest_len, left_len);
assert_eq!(dest_len, right_len);
Expand Down
Loading

0 comments on commit f4d49d6

Please sign in to comment.