Skip to content

Commit

Permalink
x64: Implement SIMD fma
Browse files Browse the repository at this point in the history
  • Loading branch information
afonso360 committed Jul 20, 2022
1 parent f79671e commit c4cf522
Show file tree
Hide file tree
Showing 10 changed files with 251 additions and 41 deletions.
2 changes: 1 addition & 1 deletion cranelift/codegen/meta/src/isa/x86.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ fn define_settings(shared: &SettingGroup) -> SettingGroup {
settings.add_predicate("use_ssse3", predicate!(has_ssse3));
settings.add_predicate("use_sse41", predicate!(has_sse41));
settings.add_predicate("use_sse42", predicate!(has_sse41 && has_sse42));
settings.add_predicate("use_fma", predicate!(has_avx && has_fma));

settings.add_predicate(
"use_ssse3_simd",
Expand All @@ -138,7 +139,6 @@ fn define_settings(shared: &SettingGroup) -> SettingGroup {

settings.add_predicate("use_avx_simd", predicate!(shared_enable_simd && has_avx));
settings.add_predicate("use_avx2_simd", predicate!(shared_enable_simd && has_avx2));
settings.add_predicate("use_fma_simd", predicate!(shared_enable_simd && has_fma));
settings.add_predicate(
"use_avx512bitalg_simd",
predicate!(shared_enable_simd && has_avx512bitalg),
Expand Down
44 changes: 5 additions & 39 deletions cranelift/codegen/src/isa/x64/encoding/vex.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
//! Encodes VEX instructions. These instructions are those added by the Advanced Vector Extensions
//! (AVX).
// section 3.1.1.2

use super::evex::Register;
use super::rex::{LegacyPrefixes, OpcodeMap};
use super::ByteSink;
Expand All @@ -17,7 +15,6 @@ pub struct VexInstruction {
map: OpcodeMap,
opcode: u8,
w: bool,
wig: bool,
reg: u8,
rm: Register,
vvvv: Option<Register>,
Expand All @@ -32,7 +29,6 @@ impl Default for VexInstruction {
map: OpcodeMap::None,
opcode: 0x00,
w: false,
wig: false,
reg: 0x00,
rm: Register::default(),
vvvv: None,
Expand Down Expand Up @@ -87,13 +83,6 @@ impl VexInstruction {
self
}

/// Set the WIG bit, denoted by `.WIG` in the instruction string.
#[inline(always)]
pub fn wig(mut self, wig: bool) -> Self {
self.wig = wig;
self
}

/// Set the instruction opcode byte.
#[inline(always)]
pub fn opcode(mut self, opcode: u8) -> Self {
Expand All @@ -118,7 +107,7 @@ impl VexInstruction {

/// Set the register to use for the `rm` bits; many instructions use this as the "read from
/// register/memory" operand. Currently this does not support memory addressing (TODO).Setting
/// this affects both the ModRM byte (`rm` section) and the EVEX prefix (the extension bits for
/// this affects both the ModRM byte (`rm` section) and the VEX prefix (the extension bits for
/// register encodings > 8).
#[inline(always)]
pub fn rm(mut self, reg: impl Into<Register>) -> Self {
Expand Down Expand Up @@ -282,19 +271,18 @@ mod tests {
// VEX.128.66.0F 73 /7 ib
// VPSLLDQ xmm1, xmm2, imm8

let dst = regs::xmm1();
let src = regs::xmm2();
let dst = regs::xmm1().to_real_reg().unwrap().hw_enc();
let src = regs::xmm2().to_real_reg().unwrap().hw_enc();
let mut sink0 = Vec::new();

VexInstruction::new()
.length(VexVectorLength::V128)
.prefix(LegacyPrefixes::_66)
.map(OpcodeMap::_0F)
.wig(true)
.opcode(0x73)
.opcode_ext(7)
.vvvv(dst.to_real_reg().unwrap().hw_enc())
.rm(src.to_real_reg().unwrap().hw_enc())
.vvvv(dst)
.rm(src)
.imm(0x17)
.encode(&mut sink0);

Expand Down Expand Up @@ -327,26 +315,4 @@ mod tests {

assert_eq!(sink0, vec![0xc4, 0xe3, 0x69, 0x4b, 0xcb, 0x40]);
}

// #[test]
// fn vmovaps_mem_access() {
// // VEX.256.0F.WIG 29 /r
// // vmovaps [2 * edx + 4],ymm2
//
// let dst = regs::rdx().to_real_reg().unwrap().hw_enc();
// let src = regs::xmm2().to_real_reg().unwrap().hw_enc();
// let mut sink0 = Vec::new();
//
// VexInstruction::new()
// .length(VexVectorLength::V256)
// .map(OpcodeMap::_0F)
// .wig(true)
// .opcode(0x4B)
// .reg(src)
// .rm(dst)
// .imm(4)
// .encode(&mut sink0);
//
// assert_eq!(sink0, vec![0xc5, 0xfc, 0x29, 0x54, 0x12, 0x04]);
// }
}
26 changes: 26 additions & 0 deletions cranelift/codegen/src/isa/x64/inst.isle
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,12 @@
(src2 XmmMem)
(dst WritableXmm))

;; XMM (scalar or vector) binary op that relies on the VEX prefix.
(XmmRmRVex (op AvxOpcode)
(src1 XmmMem)
(src2 Xmm)
(dst WritableXmm))

;; XMM (scalar or vector) binary op that relies on the EVEX prefix.
(XmmRmREvex (op Avx512Opcode)
(src1 XmmMem)
Expand Down Expand Up @@ -1042,6 +1048,10 @@
(decl intcc_to_cc (IntCC) CC)
(extern constructor intcc_to_cc intcc_to_cc)

(type AvxOpcode extern
(enum Vfmadd213ps
Vfmadd213pd))

(type Avx512Opcode extern
(enum Vcvtudq2ps
Vpabsq
Expand Down Expand Up @@ -2805,6 +2815,22 @@
(_ Unit (emit (MInst.XmmRmR (SseOpcode.Maxpd) x y dst))))
dst))

;; Helper for creating `vfmadd213ps` instructions.
(decl x64_vfmadd213ps (Xmm Xmm XmmMem) Xmm)
(rule (x64_vfmadd213ps x y z)
(let ((dst WritableXmm (temp_writable_xmm))
(_1 Unit (emit (MInst.XmmUnaryRmR (SseOpcode.Movups) x dst)))
(_2 Unit (emit (MInst.XmmRmRVex (AvxOpcode.Vfmadd213ps) z y dst))))
dst))

;; Helper for creating `vfmadd213pd` instructions.
(decl x64_vfmadd213pd (Xmm Xmm XmmMem) Xmm)
(rule (x64_vfmadd213pd x y z)
(let ((dst WritableXmm (temp_writable_xmm))
(_1 Unit (emit (MInst.XmmUnaryRmR (SseOpcode.Movups) x dst)))
(_2 Unit (emit (MInst.XmmRmRVex (AvxOpcode.Vfmadd213pd) z y dst))))
dst))


;; Helper for creating `sqrtss` instructions.
(decl x64_sqrtss (Xmm) Xmm)
Expand Down
33 changes: 33 additions & 0 deletions cranelift/codegen/src/isa/x64/inst/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -794,6 +794,7 @@ pub(crate) enum InstructionSet {
BMI1,
#[allow(dead_code)] // never constructed (yet).
BMI2,
FMA,
AVX512BITALG,
AVX512DQ,
AVX512F,
Expand Down Expand Up @@ -1386,6 +1387,38 @@ impl fmt::Display for SseOpcode {
}
}

#[derive(Clone, PartialEq)]
pub enum AvxOpcode {
Vfmadd213ps,
Vfmadd213pd,
}

impl AvxOpcode {
/// Which `InstructionSet`s support the opcode?
pub(crate) fn available_from(&self) -> SmallVec<[InstructionSet; 2]> {
match self {
AvxOpcode::Vfmadd213ps => smallvec![InstructionSet::FMA],
AvxOpcode::Vfmadd213pd => smallvec![InstructionSet::FMA],
}
}
}

impl fmt::Debug for AvxOpcode {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
let name = match self {
AvxOpcode::Vfmadd213ps => "vfmadd213ps",
AvxOpcode::Vfmadd213pd => "vfmadd213pd",
};
write!(fmt, "{}", name)
}
}

impl fmt::Display for AvxOpcode {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Debug::fmt(self, f)
}
}

#[derive(Clone, PartialEq)]
pub enum Avx512Opcode {
Vcvtudq2ps,
Expand Down
32 changes: 32 additions & 0 deletions cranelift/codegen/src/isa/x64/inst/emit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use crate::isa::x64::encoding::rex::{
low8_will_sign_extend_to_32, low8_will_sign_extend_to_64, reg_enc, LegacyPrefixes, OpcodeMap,
RexFlags,
};
use crate::isa::x64::encoding::vex::{VexInstruction, VexVectorLength};
use crate::isa::x64::inst::args::*;
use crate::isa::x64::inst::*;
use crate::machinst::{inst_common, MachBuffer, MachInstEmit, MachLabel, Reg, Writable};
Expand Down Expand Up @@ -119,6 +120,7 @@ pub(crate) fn emit(
InstructionSet::Lzcnt => info.isa_flags.use_lzcnt(),
InstructionSet::BMI1 => info.isa_flags.use_bmi1(),
InstructionSet::BMI2 => info.isa_flags.has_bmi2(),
InstructionSet::FMA => info.isa_flags.has_fma(),
InstructionSet::AVX512BITALG => info.isa_flags.has_avx512bitalg(),
InstructionSet::AVX512DQ => info.isa_flags.has_avx512dq(),
InstructionSet::AVX512F => info.isa_flags.has_avx512f(),
Expand Down Expand Up @@ -1689,6 +1691,36 @@ pub(crate) fn emit(
}
}

Inst::XmmRmRVex {
op,
src1,
src2,
dst,
} => {
let dst = allocs.next(dst.to_reg().to_reg());
let src2 = allocs.next(src2.to_reg());
let src1 = src1.clone().to_reg_mem().with_allocs(allocs);

let (w, opcode) = match op {
AvxOpcode::Vfmadd213ps => (false, 0xA8),
AvxOpcode::Vfmadd213pd => (true, 0xA8),
};

match src1 {
RegMem::Reg { reg: src } => VexInstruction::new()
.length(VexVectorLength::V128)
.prefix(LegacyPrefixes::_66)
.map(OpcodeMap::_0F38)
.w(w)
.opcode(opcode)
.reg(dst.to_real_reg().unwrap().hw_enc())
.rm(src.to_real_reg().unwrap().hw_enc())
.vvvv(src2.to_real_reg().unwrap().hw_enc())
.encode(sink),
_ => todo!(),
};
}

Inst::XmmRmREvex {
op,
src1,
Expand Down
16 changes: 16 additions & 0 deletions cranelift/codegen/src/isa/x64/inst/emit_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3701,6 +3701,21 @@ fn test_x64_emit() {
"jmp *321(%r10,%rdx,4)",
));

// ========================================================
// XMM FMA

insns.push((
Inst::xmm_rm_r_fma(AvxOpcode::Vfmadd213ps, RegMem::reg(xmm2), xmm1, w_xmm0),
"C4E271A8C2",
"vfmadd213ps %xmm2, %xmm1, %xmm0",
));

insns.push((
Inst::xmm_rm_r_fma(AvxOpcode::Vfmadd213pd, RegMem::reg(xmm5), xmm4, w_xmm3),
"C4E2D9A8DD",
"vfmadd213pd %xmm5, %xmm4, %xmm3",
));

// ========================================================
// XMM_CMP_RM_R

Expand Down Expand Up @@ -4866,6 +4881,7 @@ fn test_x64_emit() {
let mut isa_flag_builder = x64::settings::builder();
isa_flag_builder.enable("has_ssse3").unwrap();
isa_flag_builder.enable("has_sse41").unwrap();
isa_flag_builder.enable("has_fma").unwrap();
isa_flag_builder.enable("has_avx512bitalg").unwrap();
isa_flag_builder.enable("has_avx512dq").unwrap();
isa_flag_builder.enable("has_avx512f").unwrap();
Expand Down
45 changes: 45 additions & 0 deletions cranelift/codegen/src/isa/x64/inst/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ impl Inst {
| Inst::XmmUnaryRmR { op, .. } => smallvec![op.available_from()],

Inst::XmmUnaryRmREvex { op, .. } | Inst::XmmRmREvex { op, .. } => op.available_from(),

Inst::XmmRmRVex { op, .. } => op.available_from(),
}
}
}
Expand Down Expand Up @@ -316,6 +318,19 @@ impl Inst {
}
}

#[cfg(test)]
pub(crate) fn xmm_rm_r_vex(op: AvxOpcode, src1: RegMem, src2: Reg, dst: Writable<Reg>) -> Self {
src1.assert_regclass_is(RegClass::Float);
debug_assert!(src2.class() == RegClass::Float);
debug_assert!(dst.to_reg().class() == RegClass::Float);
Inst::XmmRmRVex {
op,
src1: XmmMem::new(src1).unwrap(),
src2: Xmm::new(src2).unwrap(),
dst: WritableXmm::from_writable_reg(dst).unwrap(),
}
}

pub(crate) fn xmm_rm_r_evex(
op: Avx512Opcode,
src1: RegMem,
Expand Down Expand Up @@ -1128,6 +1143,19 @@ impl PrettyPrint for Inst {
format!("{} {}, {}, {}", ljustify(op.to_string()), src1, src2, dst)
}

Inst::XmmRmRVex {
op,
src1,
src2,
dst,
..
} => {
let src2 = pretty_print_reg(src2.to_reg(), 8, allocs);
let dst = pretty_print_reg(dst.to_reg().to_reg(), 8, allocs);
let src1 = src1.pretty_print(8, allocs);
format!("{} {}, {}, {}", ljustify(op.to_string()), src1, src2, dst)
}

Inst::XmmRmREvex {
op,
src1,
Expand Down Expand Up @@ -1832,6 +1860,23 @@ fn x64_get_operands<F: Fn(VReg) -> VReg>(inst: &Inst, collector: &mut OperandCol
}
}
}
Inst::XmmRmRVex {
op,
src1,
src2,
dst,
..
} => {
// Vfmadd uses and defs the dst reg, that is not the case with all
// AVX's ops, if you're adding a new op, make sure to correctly define
// register uses.
assert!(*op == AvxOpcode::Vfmadd213ps || *op == AvxOpcode::Vfmadd213pd);

// We both use and def dst
collector.reg_mod(dst.to_writable_reg());
collector.reg_use(src2.to_reg());
src1.get_operands(collector);
}
Inst::XmmRmREvex {
op,
src1,
Expand Down
7 changes: 7 additions & 0 deletions cranelift/codegen/src/isa/x64/lower.isle
Original file line number Diff line number Diff line change
Expand Up @@ -2542,6 +2542,13 @@
(rule (lower (has_type $F64X2 (fmax_pseudo x y)))
(x64_maxpd y x))

;; Rules for `fma` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

(rule (lower (has_type $F32X4 (fma x y z)))
(x64_vfmadd213ps x y z))
(rule (lower (has_type $F64X2 (fma x y z)))
(x64_vfmadd213pd x y z))

;; Rules for `load*` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

;; In order to load a value from memory to a GPR register, we may need to extend
Expand Down
2 changes: 1 addition & 1 deletion cranelift/codegen/src/isa/x64/lower.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2917,7 +2917,7 @@ fn lower_insn_to_regs<C: LowerCtx<I = Inst>>(

Opcode::Cls => unimplemented!("Cls not supported"),

Opcode::Fma => unimplemented!("Fma not supported"),
Opcode::Fma => implemented_in_isle(ctx),

Opcode::BorNot | Opcode::BxorNot => {
unimplemented!("or-not / xor-not opcodes not implemented");
Expand Down
Loading

0 comments on commit c4cf522

Please sign in to comment.