Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prototype using const generic for simd_shuffle IDX array #115933

Merged
merged 2 commits into from
Sep 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 49 additions & 1 deletion compiler/rustc_codegen_cranelift/src/intrinsics/simd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ fn report_simd_type_validation_error(
pub(super) fn codegen_simd_intrinsic_call<'tcx>(
fx: &mut FunctionCx<'_, '_, 'tcx>,
intrinsic: Symbol,
_args: GenericArgsRef<'tcx>,
generic_args: GenericArgsRef<'tcx>,
args: &[mir::Operand<'tcx>],
ret: CPlace<'tcx>,
target: BasicBlock,
Expand Down Expand Up @@ -117,6 +117,54 @@ pub(super) fn codegen_simd_intrinsic_call<'tcx>(
});
}

// simd_shuffle_generic<T, U, const I: &[u32]>(x: T, y: T) -> U
sym::simd_shuffle_generic => {
let [x, y] = args else {
bug!("wrong number of args for intrinsic {intrinsic}");
};
let x = codegen_operand(fx, x);
let y = codegen_operand(fx, y);

if !x.layout().ty.is_simd() {
report_simd_type_validation_error(fx, intrinsic, span, x.layout().ty);
return;
}

let idx = generic_args[2]
.expect_const()
.eval(fx.tcx, ty::ParamEnv::reveal_all(), Some(span))
.unwrap()
.unwrap_branch();

assert_eq!(x.layout(), y.layout());
let layout = x.layout();

let (lane_count, lane_ty) = layout.ty.simd_size_and_type(fx.tcx);
let (ret_lane_count, ret_lane_ty) = ret.layout().ty.simd_size_and_type(fx.tcx);

assert_eq!(lane_ty, ret_lane_ty);
assert_eq!(idx.len() as u64, ret_lane_count);

let total_len = lane_count * 2;

let indexes =
idx.iter().map(|idx| idx.unwrap_leaf().try_to_u16().unwrap()).collect::<Vec<u16>>();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should have used u32. Will fix in https://github.com/bjorn3/rustc_codegen_cranelift


for &idx in &indexes {
assert!(u64::from(idx) < total_len, "idx {} out of range 0..{}", idx, total_len);
}

for (out_idx, in_idx) in indexes.into_iter().enumerate() {
let in_lane = if u64::from(in_idx) < lane_count {
x.value_lane(fx, in_idx.into())
} else {
y.value_lane(fx, u64::from(in_idx) - lane_count)
};
let out_lane = ret.place_lane(fx, u64::try_from(out_idx).unwrap());
out_lane.write_cvalue(fx, in_lane);
}
}

// simd_shuffle<T, I, U>(x: T, y: T, idx: I) -> U
sym::simd_shuffle => {
let (x, y, idx) = match args {
Expand Down
57 changes: 55 additions & 2 deletions compiler/rustc_codegen_llvm/src/intrinsic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use rustc_codegen_ssa::mir::place::PlaceRef;
use rustc_codegen_ssa::traits::*;
use rustc_hir as hir;
use rustc_middle::ty::layout::{FnAbiOf, HasTyCtxt, LayoutOf};
use rustc_middle::ty::{self, Ty};
use rustc_middle::ty::{self, GenericArgsRef, Ty};
use rustc_middle::{bug, span_bug};
use rustc_span::{sym, symbol::kw, Span, Symbol};
use rustc_target::abi::{self, Align, HasDataLayout, Primitive};
Expand Down Expand Up @@ -376,7 +376,9 @@ impl<'ll, 'tcx> IntrinsicCallMethods<'tcx> for Builder<'_, 'll, 'tcx> {
}

_ if name.as_str().starts_with("simd_") => {
match generic_simd_intrinsic(self, name, callee_ty, args, ret_ty, llret_ty, span) {
match generic_simd_intrinsic(
self, name, callee_ty, fn_args, args, ret_ty, llret_ty, span,
) {
Ok(llval) => llval,
Err(()) => return,
}
Expand Down Expand Up @@ -911,6 +913,7 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
bx: &mut Builder<'_, 'll, 'tcx>,
name: Symbol,
callee_ty: Ty<'tcx>,
fn_args: GenericArgsRef<'tcx>,
args: &[OperandRef<'tcx, &'ll Value>],
ret_ty: Ty<'tcx>,
llret_ty: &'ll Type,
Expand Down Expand Up @@ -1030,6 +1033,56 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
));
}

if name == sym::simd_shuffle_generic {
let idx = fn_args[2]
.expect_const()
.eval(tcx, ty::ParamEnv::reveal_all(), Some(span))
.unwrap()
.unwrap_branch();
let n = idx.len() as u64;

require_simd!(ret_ty, InvalidMonomorphization::SimdReturn { span, name, ty: ret_ty });
let (out_len, out_ty) = ret_ty.simd_size_and_type(bx.tcx());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(ret_len, ret_elem) would be more consistent with the in/out variables, I think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's copy paste from the simd_shuffle impl, but yes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are 8 occurences now (there were 7) of this same code. We should probably deduplicate something here and reconsider the names, but not in this PR ^^

require!(
out_len == n,
InvalidMonomorphization::ReturnLength { span, name, in_len: n, ret_ty, out_len }
);
require!(
in_elem == out_ty,
InvalidMonomorphization::ReturnElement { span, name, in_elem, in_ty, ret_ty, out_ty }
);

let total_len = in_len * 2;

let indices: Option<Vec<_>> = idx
.iter()
.enumerate()
.map(|(arg_idx, val)| {
let idx = val.unwrap_leaf().try_to_i32().unwrap();
if idx >= i32::try_from(total_len).unwrap() {
bx.sess().emit_err(InvalidMonomorphization::ShuffleIndexOutOfBounds {
span,
name,
arg_idx: arg_idx as u64,
total_len: total_len.into(),
});
None
} else {
Some(bx.const_i32(idx))
}
})
.collect();
let Some(indices) = indices else {
return Ok(bx.const_null(llret_ty));
};

return Ok(bx.shuffle_vector(
args[0].immediate(),
args[1].immediate(),
bx.const_vector(&indices),
));
}

if name == sym::simd_shuffle {
// Make sure this is actually an array, since typeck only checks the length-suffixed
// version of this intrinsic.
Expand Down
44 changes: 23 additions & 21 deletions compiler/rustc_hir_analysis/src/check/intrinsic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ fn equate_intrinsic_type<'tcx>(
it: &hir::ForeignItem<'_>,
n_tps: usize,
n_lts: usize,
n_cts: usize,
sig: ty::PolyFnSig<'tcx>,
) {
let (own_counts, span) = match &it.kind {
Expand Down Expand Up @@ -51,7 +52,7 @@ fn equate_intrinsic_type<'tcx>(

if gen_count_ok(own_counts.lifetimes, n_lts, "lifetime")
&& gen_count_ok(own_counts.types, n_tps, "type")
&& gen_count_ok(own_counts.consts, 0, "const")
&& gen_count_ok(own_counts.consts, n_cts, "const")
{
let fty = Ty::new_fn_ptr(tcx, sig);
let it_def_id = it.owner_id.def_id;
Expand Down Expand Up @@ -492,7 +493,7 @@ pub fn check_intrinsic_type(tcx: TyCtxt<'_>, it: &hir::ForeignItem<'_>) {
};
let sig = tcx.mk_fn_sig(inputs, output, false, unsafety, Abi::RustIntrinsic);
let sig = ty::Binder::bind_with_vars(sig, bound_vars);
equate_intrinsic_type(tcx, it, n_tps, n_lts, sig)
equate_intrinsic_type(tcx, it, n_tps, n_lts, 0, sig)
}

/// Type-check `extern "platform-intrinsic" { ... }` functions.
Expand All @@ -504,9 +505,9 @@ pub fn check_platform_intrinsic_type(tcx: TyCtxt<'_>, it: &hir::ForeignItem<'_>)

let name = it.ident.name;

let (n_tps, inputs, output) = match name {
let (n_tps, n_cts, inputs, output) = match name {
sym::simd_eq | sym::simd_ne | sym::simd_lt | sym::simd_le | sym::simd_gt | sym::simd_ge => {
(2, vec![param(0), param(0)], param(1))
(2, 0, vec![param(0), param(0)], param(1))
}
sym::simd_add
| sym::simd_sub
Expand All @@ -522,8 +523,8 @@ pub fn check_platform_intrinsic_type(tcx: TyCtxt<'_>, it: &hir::ForeignItem<'_>)
| sym::simd_fmax
| sym::simd_fpow
| sym::simd_saturating_add
| sym::simd_saturating_sub => (1, vec![param(0), param(0)], param(0)),
sym::simd_arith_offset => (2, vec![param(0), param(1)], param(0)),
| sym::simd_saturating_sub => (1, 0, vec![param(0), param(0)], param(0)),
sym::simd_arith_offset => (2, 0, vec![param(0), param(1)], param(0)),
sym::simd_neg
| sym::simd_bswap
| sym::simd_bitreverse
Expand All @@ -541,25 +542,25 @@ pub fn check_platform_intrinsic_type(tcx: TyCtxt<'_>, it: &hir::ForeignItem<'_>)
| sym::simd_ceil
| sym::simd_floor
| sym::simd_round
| sym::simd_trunc => (1, vec![param(0)], param(0)),
sym::simd_fpowi => (1, vec![param(0), tcx.types.i32], param(0)),
sym::simd_fma => (1, vec![param(0), param(0), param(0)], param(0)),
sym::simd_gather => (3, vec![param(0), param(1), param(2)], param(0)),
sym::simd_scatter => (3, vec![param(0), param(1), param(2)], Ty::new_unit(tcx)),
sym::simd_insert => (2, vec![param(0), tcx.types.u32, param(1)], param(0)),
sym::simd_extract => (2, vec![param(0), tcx.types.u32], param(1)),
| sym::simd_trunc => (1, 0, vec![param(0)], param(0)),
sym::simd_fpowi => (1, 0, vec![param(0), tcx.types.i32], param(0)),
sym::simd_fma => (1, 0, vec![param(0), param(0), param(0)], param(0)),
sym::simd_gather => (3, 0, vec![param(0), param(1), param(2)], param(0)),
sym::simd_scatter => (3, 0, vec![param(0), param(1), param(2)], Ty::new_unit(tcx)),
sym::simd_insert => (2, 0, vec![param(0), tcx.types.u32, param(1)], param(0)),
sym::simd_extract => (2, 0, vec![param(0), tcx.types.u32], param(1)),
sym::simd_cast
| sym::simd_as
| sym::simd_cast_ptr
| sym::simd_expose_addr
| sym::simd_from_exposed_addr => (2, vec![param(0)], param(1)),
sym::simd_bitmask => (2, vec![param(0)], param(1)),
| sym::simd_from_exposed_addr => (2, 0, vec![param(0)], param(1)),
sym::simd_bitmask => (2, 0, vec![param(0)], param(1)),
sym::simd_select | sym::simd_select_bitmask => {
(2, vec![param(0), param(1), param(1)], param(1))
(2, 0, vec![param(0), param(1), param(1)], param(1))
}
sym::simd_reduce_all | sym::simd_reduce_any => (1, vec![param(0)], tcx.types.bool),
sym::simd_reduce_all | sym::simd_reduce_any => (1, 0, vec![param(0)], tcx.types.bool),
sym::simd_reduce_add_ordered | sym::simd_reduce_mul_ordered => {
(2, vec![param(0), param(1)], param(1))
(2, 0, vec![param(0), param(1)], param(1))
}
sym::simd_reduce_add_unordered
| sym::simd_reduce_mul_unordered
Expand All @@ -569,8 +570,9 @@ pub fn check_platform_intrinsic_type(tcx: TyCtxt<'_>, it: &hir::ForeignItem<'_>)
| sym::simd_reduce_min
| sym::simd_reduce_max
| sym::simd_reduce_min_nanless
| sym::simd_reduce_max_nanless => (2, vec![param(0)], param(1)),
sym::simd_shuffle => (3, vec![param(0), param(0), param(1)], param(2)),
| sym::simd_reduce_max_nanless => (2, 0, vec![param(0)], param(1)),
sym::simd_shuffle => (3, 0, vec![param(0), param(0), param(1)], param(2)),
sym::simd_shuffle_generic => (2, 1, vec![param(0), param(0)], param(1)),
_ => {
let msg = format!("unrecognized platform-specific intrinsic function: `{name}`");
tcx.sess.struct_span_err(it.span, msg).emit();
Expand All @@ -580,5 +582,5 @@ pub fn check_platform_intrinsic_type(tcx: TyCtxt<'_>, it: &hir::ForeignItem<'_>)

let sig = tcx.mk_fn_sig(inputs, output, false, hir::Unsafety::Unsafe, Abi::PlatformIntrinsic);
let sig = ty::Binder::dummy(sig);
equate_intrinsic_type(tcx, it, n_tps, 0, sig)
equate_intrinsic_type(tcx, it, n_tps, 0, n_cts, sig)
}
1 change: 1 addition & 0 deletions compiler/rustc_span/src/symbol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1465,6 +1465,7 @@ symbols! {
simd_shl,
simd_shr,
simd_shuffle,
simd_shuffle_generic,
simd_sub,
simd_trunc,
simd_xor,
Expand Down
5 changes: 3 additions & 2 deletions src/tools/miri/src/shims/intrinsics/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
}

// The rest jumps to `ret` immediately.
this.emulate_intrinsic_by_name(intrinsic_name, args, dest)?;
this.emulate_intrinsic_by_name(intrinsic_name, instance.args, args, dest)?;

trace!("{:?}", this.dump_place(dest));
this.go_to_block(ret);
Expand All @@ -71,6 +71,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
fn emulate_intrinsic_by_name(
&mut self,
intrinsic_name: &str,
generic_args: ty::GenericArgsRef<'tcx>,
args: &[OpTy<'tcx, Provenance>],
dest: &PlaceTy<'tcx, Provenance>,
) -> InterpResult<'tcx> {
Expand All @@ -80,7 +81,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
return this.emulate_atomic_intrinsic(name, args, dest);
}
if let Some(name) = intrinsic_name.strip_prefix("simd_") {
return this.emulate_simd_intrinsic(name, args, dest);
return this.emulate_simd_intrinsic(name, generic_args, args, dest);
}

match intrinsic_name {
Expand Down
33 changes: 33 additions & 0 deletions src/tools/miri/src/shims/intrinsics/simd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
fn emulate_simd_intrinsic(
&mut self,
intrinsic_name: &str,
generic_args: ty::GenericArgsRef<'tcx>,
args: &[OpTy<'tcx, Provenance>],
dest: &PlaceTy<'tcx, Provenance>,
) -> InterpResult<'tcx> {
Expand Down Expand Up @@ -490,6 +491,38 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
this.write_immediate(val, &dest)?;
}
}
"shuffle_generic" => {
let [left, right] = check_arg_count(args)?;
let (left, left_len) = this.operand_to_simd(left)?;
let (right, right_len) = this.operand_to_simd(right)?;
let (dest, dest_len) = this.place_to_simd(dest)?;

let index = generic_args[2].expect_const().eval(*this.tcx, this.param_env(), Some(this.tcx.span)).unwrap().unwrap_branch();
let index_len = index.len();

assert_eq!(left_len, right_len);
assert_eq!(index_len as u64, dest_len);

for i in 0..dest_len {
let src_index: u64 = index[i as usize].unwrap_leaf()
.try_to_u32().unwrap()
.into();
let dest = this.project_index(&dest, i)?;

let val = if src_index < left_len {
this.read_immediate(&this.project_index(&left, src_index)?)?
} else if src_index < left_len.checked_add(right_len).unwrap() {
let right_idx = src_index.checked_sub(left_len).unwrap();
this.read_immediate(&this.project_index(&right, right_idx)?)?
} else {
span_bug!(
this.cur_span(),
"simd_shuffle index {src_index} is out of bounds for 2 vectors of size {left_len}",
);
};
this.write_immediate(*val, &dest)?;
}
}
"shuffle" => {
let [left, right, index] = check_arg_count(args)?;
let (left, left_len) = this.operand_to_simd(left)?;
Expand Down
21 changes: 20 additions & 1 deletion src/tools/miri/tests/pass/portable-simd.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
//@compile-flags: -Zmiri-strict-provenance
#![feature(portable_simd, platform_intrinsics)]
#![feature(portable_simd, platform_intrinsics, adt_const_params, inline_const)]
#![allow(incomplete_features)]
use std::simd::*;

extern "platform-intrinsic" {
Expand Down Expand Up @@ -390,6 +391,8 @@ fn simd_intrinsics() {
fn simd_reduce_any<T>(x: T) -> bool;
fn simd_reduce_all<T>(x: T) -> bool;
fn simd_select<M, T>(m: M, yes: T, no: T) -> T;
fn simd_shuffle_generic<T, U, const IDX: &'static [u32]>(x: T, y: T) -> U;
fn simd_shuffle<T, IDX, U>(x: T, y: T, idx: IDX) -> U;
}
unsafe {
// Make sure simd_eq returns all-1 for `true`
Expand All @@ -413,6 +416,22 @@ fn simd_intrinsics() {
simd_select(i8x4::from_array([0, -1, -1, 0]), b, a),
i32x4::from_array([10, 2, 10, 10])
);
assert_eq!(
simd_shuffle_generic::<_, i32x4, {&[3, 1, 0, 2]}>(a, b),
a,
);
assert_eq!(
simd_shuffle::<_, _, i32x4>(a, b, const {[3, 1, 0, 2]}),
a,
);
assert_eq!(
simd_shuffle_generic::<_, i32x4, {&[7, 5, 4, 6]}>(a, b),
i32x4::from_array([4, 2, 1, 10]),
);
assert_eq!(
simd_shuffle::<_, _, i32x4>(a, b, const {[7, 5, 4, 6]}),
i32x4::from_array([4, 2, 1, 10]),
);
}
}

Expand Down
Loading
Loading