Skip to content

Commit

Permalink
Merge pull request #4034 from alyssarosenzweig/fix-tied-fma
Browse files Browse the repository at this point in the history
IR: fix scalar FMA tied sources
  • Loading branch information
Sonicadvance1 authored Sep 4, 2024
2 parents 46a2a06 + 6d4693c commit a66fac6
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 25 deletions.
3 changes: 2 additions & 1 deletion FEXCore/Source/Interface/Core/JIT/Arm64/JITClass.h
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,8 @@ class Arm64JITCore final : public CPUBackend, public Arm64Emitter {
using ScalarFMAOpCaller =
std::function<void(ARMEmitter::VRegister Dst, ARMEmitter::VRegister Src1, ARMEmitter::VRegister Src2, ARMEmitter::VRegister Src3)>;
void VFScalarFMAOperation(uint8_t OpSize, uint8_t ElementSize, ScalarFMAOpCaller ScalarEmit, ARMEmitter::VRegister Dst,
ARMEmitter::VRegister Vector1, ARMEmitter::VRegister Vector2, ARMEmitter::VRegister Addend);
ARMEmitter::VRegister Upper, ARMEmitter::VRegister Vector1, ARMEmitter::VRegister Vector2,
ARMEmitter::VRegister Addend);
using ScalarBinaryOpCaller = std::function<void(ARMEmitter::VRegister Dst, ARMEmitter::VRegister Src1, ARMEmitter::VRegister Src2)>;
void VFScalarOperation(uint8_t OpSize, uint8_t ElementSize, bool ZeroUpperBits, ScalarBinaryOpCaller ScalarEmit,
ARMEmitter::VRegister Dst, ARMEmitter::VRegister Vector1, ARMEmitter::VRegister Vector2);
Expand Down
15 changes: 8 additions & 7 deletions FEXCore/Source/Interface/Core/JIT/Arm64/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,11 +211,12 @@ namespace FEXCore::CPU {
}; \
\
const auto Dst = GetVReg(Node); \
const auto Upper = GetVReg(Op->Upper.ID()); \
const auto Vector1 = GetVReg(Op->Vector1.ID()); \
const auto Vector2 = GetVReg(Op->Vector2.ID()); \
const auto Addend = GetVReg(Op->Addend.ID()); \
\
VFScalarFMAOperation(IROp->Size, ElementSize, ScalarEmit, Dst, Vector1, Vector2, Addend); \
VFScalarFMAOperation(IROp->Size, ElementSize, ScalarEmit, Dst, Upper, Vector1, Vector2, Addend); \
}

DEF_UNOP(VAbs, abs, true)
Expand Down Expand Up @@ -260,25 +261,25 @@ DEF_FMAOP_SCALAR_INSERT(VFNMLAScalarInsert, fmsub)
DEF_FMAOP_SCALAR_INSERT(VFNMLSScalarInsert, fnmadd)

void Arm64JITCore::VFScalarFMAOperation(uint8_t OpSize, uint8_t ElementSize, ScalarFMAOpCaller ScalarEmit, ARMEmitter::VRegister Dst,
ARMEmitter::VRegister Vector1, ARMEmitter::VRegister Vector2, ARMEmitter::VRegister Addend) {
ARMEmitter::VRegister Upper, ARMEmitter::VRegister Vector1, ARMEmitter::VRegister Vector2,
ARMEmitter::VRegister Addend) {
LOGMAN_THROW_A_FMT(OpSize == Core::CPUState::XMM_SSE_REG_SIZE, "256-bit unsupported", __func__);

LOGMAN_THROW_AA_FMT(ElementSize == 2 || ElementSize == 4 || ElementSize == 8, "Invalid size");
const auto SubRegSize = ARMEmitter::ToVectorSizePair(ElementSize == 2 ? ARMEmitter::SubRegSize::i16Bit :
ElementSize == 4 ? ARMEmitter::SubRegSize::i32Bit :
ARMEmitter::SubRegSize::i64Bit);
if (Dst != Vector1 && Dst != Vector2 && Dst != Addend && HostSupportsAFP) {
// If destination doesnt overlap any incoming register then move the adder to the destination first.
mov(Dst.Q(), Addend.Q());
Dst = Addend;
if (Dst != Upper) {
// If destination is not tied, move the upper bits to the destination first.
mov(Dst.Q(), Upper.Q());
}

if (HostSupportsAFP && Dst == Addend) {
///< Exactly matches ARM scalar FMA semantics
// If the host CPU supports AFP then scalar does an insert without modifying upper bits.
ScalarEmit(Dst, Vector1, Vector2, Addend);
} else {
// No overlap between addr and destination or host doesn't support AFP, need to emit in to a temporary then insert.
// Host doesn't support AFP, need to emit in to a temporary then insert.
ScalarEmit(VTMP1, Vector1, Vector2, Addend);
ins(SubRegSize.Vector, Dst.Q(), 0, VTMP1.Q(), 0);
}
Expand Down
10 changes: 5 additions & 5 deletions FEXCore/Source/Interface/Core/OpcodeDispatcher/AVX_128.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2486,14 +2486,14 @@ void OpDispatchBuilder::AVX128_VFMAScalarImpl(OpcodeArgs, IROps IROp, uint8_t Sr

const OpSize ElementSize = Op->Flags & X86Tables::DecodeFlags::FLAG_OPTION_AVX_W ? OpSize::i64Bit : OpSize::i32Bit;

auto Dest = AVX128_LoadSource_WithOpSize(Op, Op->Dest, Op->Flags, !Is128Bit);
auto Src1 = AVX128_LoadSource_WithOpSize(Op, Op->Src[0], Op->Flags, !Is128Bit);
auto Src2 = AVX128_LoadSource_WithOpSize(Op, Op->Src[1], Op->Flags, !Is128Bit);
auto Dest = AVX128_LoadSource_WithOpSize(Op, Op->Dest, Op->Flags, !Is128Bit).Low;
auto Src1 = AVX128_LoadSource_WithOpSize(Op, Op->Src[0], Op->Flags, !Is128Bit).Low;
auto Src2 = AVX128_LoadSource_WithOpSize(Op, Op->Src[1], Op->Flags, !Is128Bit).Low;

RefPair Sources[3] = {Dest, Src1, Src2};
Ref Sources[3] = {Dest, Src1, Src2};

DeriveOp(Result_Low, IROp,
_VFMLAScalarInsert(OpSize::i128Bit, ElementSize, Sources[Src1Idx - 1].Low, Sources[Src2Idx - 1].Low, Sources[AddendIdx - 1].Low));
_VFMLAScalarInsert(OpSize::i128Bit, ElementSize, Dest, Sources[Src1Idx - 1], Sources[Src2Idx - 1], Sources[AddendIdx - 1]));
AVX128_StoreResult_WithOpSize(Op, Op->Dest, AVX128_Zext(Result_Low));
}

Expand Down
28 changes: 16 additions & 12 deletions FEXCore/Source/Interface/IR/IR.json
Original file line number Diff line number Diff line change
Expand Up @@ -1790,41 +1790,45 @@
"DestSize": "RegisterSize",
"NumElements": "RegisterSize / ElementSize"
},
"FPR = VFMLAScalarInsert u8:#RegisterSize, u8:#ElementSize, FPR:$Vector1, FPR:$Vector2, FPR:$Addend": {
"FPR = VFMLAScalarInsert u8:#RegisterSize, u8:#ElementSize, FPR:$Upper, FPR:$Vector1, FPR:$Vector2, FPR:$Addend": {
"Desc": [
"Dest = (Vector1 * Vector2) + Addend",
"This explicitly matches x86 FMA semantics because ARM semantics are mind-bending."
"This explicitly matches x86 FMA semantics because ARM semantics are mind-bending.",
"Upper elements copied from Upper"
],
"DestSize": "RegisterSize",
"NumElements": "RegisterSize / ElementSize",
"TiedSource": 2
"TiedSource": 0
},
"FPR = VFMLSScalarInsert u8:#RegisterSize, u8:#ElementSize, FPR:$Vector1, FPR:$Vector2, FPR:$Addend": {
"FPR = VFMLSScalarInsert u8:#RegisterSize, u8:#ElementSize, FPR:$Upper, FPR:$Vector1, FPR:$Vector2, FPR:$Addend": {
"Desc": [
"Dest = (Vector1 * Vector2) - Addend",
"This explicitly matches x86 FMA semantics because ARM semantics are mind-bending."
"This explicitly matches x86 FMA semantics because ARM semantics are mind-bending.",
"Upper elements copied from Upper"
],
"DestSize": "RegisterSize",
"NumElements": "RegisterSize / ElementSize",
"TiedSource": 2
"TiedSource": 0
},
"FPR = VFNMLAScalarInsert u8:#RegisterSize, u8:#ElementSize, FPR:$Vector1, FPR:$Vector2, FPR:$Addend": {
"FPR = VFNMLAScalarInsert u8:#RegisterSize, u8:#ElementSize, FPR:$Upper, FPR:$Vector1, FPR:$Vector2, FPR:$Addend": {
"Desc": [
"Dest = (-Vector1 * Vector2) + Addend",
"This explicitly matches x86 FMA semantics because ARM semantics are mind-bending."
"This explicitly matches x86 FMA semantics because ARM semantics are mind-bending.",
"Upper elements copied from Upper"
],
"DestSize": "RegisterSize",
"NumElements": "RegisterSize / ElementSize",
"TiedSource": 2
"TiedSource": 0
},
"FPR = VFNMLSScalarInsert u8:#RegisterSize, u8:#ElementSize, FPR:$Vector1, FPR:$Vector2, FPR:$Addend": {
"FPR = VFNMLSScalarInsert u8:#RegisterSize, u8:#ElementSize, FPR:$Upper, FPR:$Vector1, FPR:$Vector2, FPR:$Addend": {
"Desc": [
"Dest = (-Vector1 * Vector2) - Addend",
"This explicitly matches x86 FMA semantics because ARM semantics are mind-bending."
"This explicitly matches x86 FMA semantics because ARM semantics are mind-bending.",
"Upper elements copied from Upper"
],
"DestSize": "RegisterSize",
"NumElements": "RegisterSize / ElementSize",
"TiedSource": 2
"TiedSource": 0
}
},
"Vector": {
Expand Down

0 comments on commit a66fac6

Please sign in to comment.