Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add f16 inline ASM support for RISC-V #126530

Merged
merged 1 commit into from
Jun 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 49 additions & 6 deletions compiler/rustc_codegen_llvm/src/asm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use rustc_codegen_ssa::traits::*;
use rustc_data_structures::fx::FxHashMap;
use rustc_middle::ty::layout::TyAndLayout;
use rustc_middle::{bug, span_bug, ty::Instance};
use rustc_span::{Pos, Span};
use rustc_span::{sym, Pos, Span, Symbol};
use rustc_target::abi::*;
use rustc_target::asm::*;
use tracing::debug;
Expand Down Expand Up @@ -64,7 +64,7 @@ impl<'ll, 'tcx> AsmBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
let mut layout = None;
let ty = if let Some(ref place) = place {
layout = Some(&place.layout);
llvm_fixup_output_type(self.cx, reg.reg_class(), &place.layout)
llvm_fixup_output_type(self.cx, reg.reg_class(), &place.layout, instance)
} else if matches!(
reg.reg_class(),
InlineAsmRegClass::X86(
Expand Down Expand Up @@ -112,7 +112,7 @@ impl<'ll, 'tcx> AsmBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
// so we just use the type of the input.
&in_value.layout
};
let ty = llvm_fixup_output_type(self.cx, reg.reg_class(), layout);
let ty = llvm_fixup_output_type(self.cx, reg.reg_class(), layout, instance);
output_types.push(ty);
op_idx.insert(idx, constraints.len());
let prefix = if late { "=" } else { "=&" };
Expand All @@ -127,8 +127,13 @@ impl<'ll, 'tcx> AsmBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
for (idx, op) in operands.iter().enumerate() {
match *op {
InlineAsmOperandRef::In { reg, value } => {
let llval =
llvm_fixup_input(self, value.immediate(), reg.reg_class(), &value.layout);
let llval = llvm_fixup_input(
self,
value.immediate(),
reg.reg_class(),
&value.layout,
instance,
);
inputs.push(llval);
op_idx.insert(idx, constraints.len());
constraints.push(reg_to_llvm(reg, Some(&value.layout)));
Expand All @@ -139,6 +144,7 @@ impl<'ll, 'tcx> AsmBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
in_value.immediate(),
reg.reg_class(),
&in_value.layout,
instance,
);
inputs.push(value);

Expand Down Expand Up @@ -341,7 +347,8 @@ impl<'ll, 'tcx> AsmBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
} else {
self.extract_value(result, op_idx[&idx] as u64)
};
let value = llvm_fixup_output(self, value, reg.reg_class(), &place.layout);
let value =
llvm_fixup_output(self, value, reg.reg_class(), &place.layout, instance);
OperandValue::Immediate(value).store(self, place);
}
}
Expand Down Expand Up @@ -913,12 +920,22 @@ fn llvm_asm_scalar_type<'ll>(cx: &CodegenCx<'ll, '_>, scalar: Scalar) -> &'ll Ty
}
}

fn any_target_feature_enabled(
cx: &CodegenCx<'_, '_>,
instance: Instance<'_>,
features: &[Symbol],
) -> bool {
let enabled = cx.tcx.asm_target_features(instance.def_id());
features.iter().any(|feat| enabled.contains(feat))
}

/// Fix up an input value to work around LLVM bugs.
fn llvm_fixup_input<'ll, 'tcx>(
bx: &mut Builder<'_, 'll, 'tcx>,
mut value: &'ll Value,
reg: InlineAsmRegClass,
layout: &TyAndLayout<'tcx>,
instance: Instance<'_>,
) -> &'ll Value {
let dl = &bx.tcx.data_layout;
match (reg, layout.abi) {
Expand Down Expand Up @@ -1029,6 +1046,16 @@ fn llvm_fixup_input<'ll, 'tcx>(
_ => value,
}
}
(InlineAsmRegClass::RiscV(RiscVInlineAsmRegClass::freg), Abi::Scalar(s))
if s.primitive() == Primitive::Float(Float::F16)
&& !any_target_feature_enabled(bx, instance, &[sym::zfhmin, sym::zfh]) =>
{
// Smaller floats are always "NaN-boxed" inside larger floats on RISC-V.
let value = bx.bitcast(value, bx.type_i16());
let value = bx.zext(value, bx.type_i32());
let value = bx.or(value, bx.const_u32(0xFFFF_0000));
bx.bitcast(value, bx.type_f32())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does LLVM not support using f16 directly here? I expect this would result in much better codegen.

Copy link
Contributor Author

@beetrees beetrees Jun 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately not. Regardless, LLVM's codegen isn't very good around when NaN-boxing needs to occur at the moment anyway; for example, an identity function will needlessly re-NaN-box an argument that is already guaranteed to be NaN-boxed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My reading of the LLVM code is that it will accept f16 directly if the zfhmin feature is enabled for the current function. Can you use f16 directly in this case and only do the manual conversion to/from f32 if zfhmin is not available?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

}
_ => value,
}
}
Expand All @@ -1039,6 +1066,7 @@ fn llvm_fixup_output<'ll, 'tcx>(
mut value: &'ll Value,
reg: InlineAsmRegClass,
layout: &TyAndLayout<'tcx>,
instance: Instance<'_>,
) -> &'ll Value {
match (reg, layout.abi) {
(InlineAsmRegClass::AArch64(AArch64InlineAsmRegClass::vreg), Abi::Scalar(s)) => {
Expand Down Expand Up @@ -1140,6 +1168,14 @@ fn llvm_fixup_output<'ll, 'tcx>(
_ => value,
}
}
(InlineAsmRegClass::RiscV(RiscVInlineAsmRegClass::freg), Abi::Scalar(s))
if s.primitive() == Primitive::Float(Float::F16)
&& !any_target_feature_enabled(bx, instance, &[sym::zfhmin, sym::zfh]) =>
{
let value = bx.bitcast(value, bx.type_i32());
let value = bx.trunc(value, bx.type_i16());
bx.bitcast(value, bx.type_f16())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

}
_ => value,
}
}
Expand All @@ -1149,6 +1185,7 @@ fn llvm_fixup_output_type<'ll, 'tcx>(
cx: &CodegenCx<'ll, 'tcx>,
reg: InlineAsmRegClass,
layout: &TyAndLayout<'tcx>,
instance: Instance<'_>,
) -> &'ll Type {
match (reg, layout.abi) {
(InlineAsmRegClass::AArch64(AArch64InlineAsmRegClass::vreg), Abi::Scalar(s)) => {
Expand Down Expand Up @@ -1242,6 +1279,12 @@ fn llvm_fixup_output_type<'ll, 'tcx>(
_ => layout.llvm_type(cx),
}
}
(InlineAsmRegClass::RiscV(RiscVInlineAsmRegClass::freg), Abi::Scalar(s))
if s.primitive() == Primitive::Float(Float::F16)
&& !any_target_feature_enabled(cx, instance, &[sym::zfhmin, sym::zfh]) =>
{
cx.type_f32()
}
_ => layout.llvm_type(cx),
}
}
2 changes: 2 additions & 0 deletions compiler/rustc_span/src/symbol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2054,6 +2054,8 @@ symbols! {
yes,
yield_expr,
ymm_reg,
zfh,
zfhmin,
zmm_reg,
}
}
Expand Down
7 changes: 4 additions & 3 deletions compiler/rustc_target/src/asm/riscv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,13 @@ impl RiscVInlineAsmRegClass {
match self {
Self::reg => {
if arch == InlineAsmArch::RiscV64 {
types! { _: I8, I16, I32, I64, F32, F64; }
types! { _: I8, I16, I32, I64, F16, F32, F64; }
} else {
types! { _: I8, I16, I32, F32; }
types! { _: I8, I16, I32, F16, F32; }
}
}
Self::freg => types! { f: F32; d: F64; },
// FIXME(f16_f128): Add `q: F128;` once LLVM support the `Q` extension.
Self::freg => types! { f: F16, F32; d: F64; },
Amanieu marked this conversation as resolved.
Show resolved Hide resolved
Self::vreg => &[],
Amanieu marked this conversation as resolved.
Show resolved Hide resolved
}
}
Expand Down
55 changes: 53 additions & 2 deletions tests/assembly/asm/riscv-types.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,34 @@
//@ revisions: riscv64 riscv32
//@ revisions: riscv64 riscv32 riscv64-zfhmin riscv32-zfhmin riscv64-zfh riscv32-zfh
//@ assembly-output: emit-asm

//@[riscv64] compile-flags: --target riscv64imac-unknown-none-elf
//@[riscv64] needs-llvm-components: riscv

//@[riscv32] compile-flags: --target riscv32imac-unknown-none-elf
//@[riscv32] needs-llvm-components: riscv

//@[riscv64-zfhmin] compile-flags: --target riscv64imac-unknown-none-elf --cfg riscv64
//@[riscv64-zfhmin] needs-llvm-components: riscv
//@[riscv64-zfhmin] compile-flags: -C target-feature=+zfhmin
//@[riscv64-zfhmin] filecheck-flags: --check-prefix riscv64

//@[riscv32-zfhmin] compile-flags: --target riscv32imac-unknown-none-elf
//@[riscv32-zfhmin] needs-llvm-components: riscv
//@[riscv32-zfhmin] compile-flags: -C target-feature=+zfhmin

//@[riscv64-zfh] compile-flags: --target riscv64imac-unknown-none-elf --cfg riscv64
//@[riscv64-zfh] needs-llvm-components: riscv
//@[riscv64-zfh] compile-flags: -C target-feature=+zfh
//@[riscv64-zfh] filecheck-flags: --check-prefix riscv64 --check-prefix zfhmin

//@[riscv32-zfh] compile-flags: --target riscv32imac-unknown-none-elf
//@[riscv32-zfh] needs-llvm-components: riscv
//@[riscv32-zfh] compile-flags: -C target-feature=+zfh
//@[riscv32-zfh] filecheck-flags: --check-prefix zfhmin

//@ compile-flags: -C target-feature=+d

#![feature(no_core, lang_items, rustc_attrs)]
#![feature(no_core, lang_items, rustc_attrs, f16)]
#![crate_type = "rlib"]
#![no_core]
#![allow(asm_sub_register)]
Expand All @@ -33,6 +55,7 @@ type ptr = *mut u8;

impl Copy for i8 {}
impl Copy for i16 {}
impl Copy for f16 {}
impl Copy for i32 {}
impl Copy for f32 {}
impl Copy for i64 {}
Expand Down Expand Up @@ -103,6 +126,12 @@ macro_rules! check_reg {
// CHECK: #NO_APP
check!(reg_i8 i8 reg "mv");

// CHECK-LABEL: reg_f16:
// CHECK: #APP
// CHECK: mv {{[a-z0-9]+}}, {{[a-z0-9]+}}
// CHECK: #NO_APP
check!(reg_f16 f16 reg "mv");

// CHECK-LABEL: reg_i16:
// CHECK: #APP
// CHECK: mv {{[a-z0-9]+}}, {{[a-z0-9]+}}
Expand Down Expand Up @@ -141,6 +170,14 @@ check!(reg_f64 f64 reg "mv");
// CHECK: #NO_APP
check!(reg_ptr ptr reg "mv");

// CHECK-LABEL: freg_f16:
// zfhmin-NOT: or
// CHECK: #APP
// CHECK: fmv.s f{{[a-z0-9]+}}, f{{[a-z0-9]+}}
// CHECK: #NO_APP
// zfhmin-NOT: or
check!(freg_f16 f16 freg "fmv.s");

// CHECK-LABEL: freg_f32:
// CHECK: #APP
// CHECK: fmv.s f{{[a-z0-9]+}}, f{{[a-z0-9]+}}
Expand All @@ -165,6 +202,12 @@ check_reg!(a0_i8 i8 "a0" "mv");
// CHECK: #NO_APP
check_reg!(a0_i16 i16 "a0" "mv");

// CHECK-LABEL: a0_f16:
// CHECK: #APP
// CHECK: mv a0, a0
// CHECK: #NO_APP
check_reg!(a0_f16 f16 "a0" "mv");

// CHECK-LABEL: a0_i32:
// CHECK: #APP
// CHECK: mv a0, a0
Expand Down Expand Up @@ -197,6 +240,14 @@ check_reg!(a0_f64 f64 "a0" "mv");
// CHECK: #NO_APP
check_reg!(a0_ptr ptr "a0" "mv");

// CHECK-LABEL: fa0_f16:
// zfhmin-NOT: or
// CHECK: #APP
// CHECK: fmv.s fa0, fa0
// CHECK: #NO_APP
// zfhmin-NOT: or
check_reg!(fa0_f16 f16 "fa0" "fmv.s");

// CHECK-LABEL: fa0_f32:
// CHECK: #APP
// CHECK: fmv.s fa0, fa0
Expand Down
Loading