Skip to content

Commit

Permalink
Arm64: Implement support for emulated masked vector loadstores
Browse files Browse the repository at this point in the history
In order to support `vmaskmov{ps,pd}` without SVE128 this is required.
It's pretty gnarly but they aren't often used so that's fine from a
compatibility perspective.

Example SVE128 implementation:
```json
    "vmaskmovps ymm0, ymm1, [rax]": {
      "ExpectedInstructionCount": 9,
      "Comment": [
        "Map 2 0b01 0x2c 256-bit"
      ],
      "ExpectedArm64ASM": [
        "ldr q2, [x28, #32]",
        "mrs x20, nzcv",
        "cmplt p0.s, p6/z, z17.s, #0",
        "ld1w {z16.s}, p0/z, [x4]",
        "add x21, x4, #0x10 (16)",
        "cmplt p0.s, p6/z, z2.s, #0",
        "ld1w {z2.s}, p0/z, [x21]",
        "str q2, [x28, #16]",
        "msr nzcv, x20"
      ]
    },
```

Example ASIMD implementation
```json
    "vmaskmovps ymm0, ymm1, [rax]": {
      "ExpectedInstructionCount": 41,
      "Comment": [
        "Map 2 0b01 0x2c 256-bit"
      ],
      "ExpectedArm64ASM": [
        "ldr q2, [x28, #32]",
        "mrs x20, nzcv",
        "movi v0.2d, #0x0",
        "mov x1, x4",
        "mov w0, v17.s[0]",
        "tbz w0, #31, #+0x8",
        "ld1 {v0.s}[0], [x1]",
        "add x1, x1, #0x4 (4)",
        "mov w0, v17.s[1]",
        "tbz w0, #31, #+0x8",
        "ld1 {v0.s}[1], [x1]",
        "add x1, x1, #0x4 (4)",
        "mov w0, v17.s[2]",
        "tbz w0, #31, #+0x8",
        "ld1 {v0.s}[2], [x1]",
        "add x1, x1, #0x4 (4)",
        "mov w0, v17.s[3]",
        "tbz w0, #31, #+0x8",
        "ld1 {v0.s}[3], [x1]",
        "mov v16.16b, v0.16b",
        "add x21, x4, #0x10 (16)",
        "movi v0.2d, #0x0",
        "mov x1, x21",
        "mov w0, v2.s[0]",
        "tbz w0, #31, #+0x8",
        "ld1 {v0.s}[0], [x1]",
        "add x1, x1, #0x4 (4)",
        "mov w0, v2.s[1]",
        "tbz w0, #31, #+0x8",
        "ld1 {v0.s}[1], [x1]",
        "add x1, x1, #0x4 (4)",
        "mov w0, v2.s[2]",
        "tbz w0, #31, #+0x8",
        "ld1 {v0.s}[2], [x1]",
        "add x1, x1, #0x4 (4)",
        "mov w0, v2.s[3]",
        "tbz w0, #31, #+0x8",
        "ld1 {v0.s}[3], [x1]",
        "mov v2.16b, v0.16b",
        "str q2, [x28, #16]",
        "msr nzcv, x20"
      ]
    },
```

There's a little bit of an improvement where nzcv isn't needed to get
touched on the ASIMD implementation, but I'll leave that for a future
improvement.
  • Loading branch information
Sonicadvance1 committed Jun 21, 2024
1 parent 7f74c83 commit ad18bff
Show file tree
Hide file tree
Showing 2 changed files with 172 additions and 45 deletions.
3 changes: 3 additions & 0 deletions FEXCore/Source/Interface/Core/JIT/Arm64/JITClass.h
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,9 @@ class Arm64JITCore final : public CPUBackend, public Arm64Emitter {
ARMEmitter::ExtendedMemOperand GenerateMemOperand(uint8_t AccessSize, ARMEmitter::Register Base, IR::OrderedNodeWrapper Offset,
IR::MemOffsetType OffsetType, uint8_t OffsetScale);

void CalculateMemOperand(ARMEmitter::Register Destination, uint8_t AccessSize, ARMEmitter::Register Base, IR::OrderedNodeWrapper Offset,
IR::MemOffsetType OffsetType, uint8_t OffsetScale);

// NOTE: Will use TMP1 as a way to encode immediates that happen to fall outside
// the limits of the scalar plus immediate variant of SVE load/stores.
//
Expand Down
214 changes: 169 additions & 45 deletions FEXCore/Source/Interface/Core/JIT/Arm64/MemoryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,37 @@ ARMEmitter::ExtendedMemOperand Arm64JITCore::GenerateMemOperand(
FEX_UNREACHABLE;
}

void Arm64JITCore::CalculateMemOperand(ARMEmitter::Register Destination, uint8_t AccessSize, ARMEmitter::Register Base,
IR::OrderedNodeWrapper Offset, IR::MemOffsetType OffsetType, uint8_t OffsetScale) {
if (Offset.IsInvalid()) {
mov(Destination.X(), Base.X());
} else {
if (OffsetScale != 1 && OffsetScale != AccessSize) {
LOGMAN_MSG_A_FMT("Unhandled GenerateMemOperand OffsetScale: {}", OffsetScale);
}
uint64_t Const;
if (IsInlineConstant(Offset, &Const)) {
add(ARMEmitter::Size::i64Bit, Destination.X(), Base.X(), Const);
} else {
auto RegOffset = GetReg(Offset.ID());
switch (OffsetType.Val) {
case IR::MEM_OFFSET_SXTX.Val:
add(ARMEmitter::Size::i64Bit, Destination, Base.X(), RegOffset.X(), ARMEmitter::ShiftType::LSL, FEXCore::ilog2(OffsetScale));
break;
case IR::MEM_OFFSET_UXTW.Val:
mov(ARMEmitter::Size::i32Bit, Destination, RegOffset.W());
add(ARMEmitter::Size::i64Bit, Destination, Base.X(), Destination.X(), ARMEmitter::ShiftType::LSL, FEXCore::ilog2(OffsetScale));
break;
case IR::MEM_OFFSET_SXTW.Val:
sxtw(Destination.X(), RegOffset.W());
add(ARMEmitter::Size::i64Bit, Destination, Base.X(), Destination.X(), ARMEmitter::ShiftType::LSL, FEXCore::ilog2(OffsetScale));
break;
default: LOGMAN_MSG_A_FMT("Unhandled GenerateMemOperand OffsetType: {}", OffsetType.Val); break;
}
}
}
}

ARMEmitter::SVEMemOperand Arm64JITCore::GenerateSVEMemOperand(uint8_t AccessSize, ARMEmitter::Register Base, IR::OrderedNodeWrapper Offset,
IR::MemOffsetType OffsetType, [[maybe_unused]] uint8_t OffsetScale) {
if (Offset.IsInvalid()) {
Expand Down Expand Up @@ -752,7 +783,6 @@ DEF_OP(LoadMemTSO) {
}

DEF_OP(VLoadVectorMasked) {
LOGMAN_THROW_A_FMT(HostSupportsSVE128 || HostSupportsSVE256, "Need SVE support in order to use VLoadVectorMasked");

const auto Op = IROp->C<IR::IROp_VLoadVectorMasked>();
const auto OpSize = IROp->Size;
Expand All @@ -769,35 +799,84 @@ DEF_OP(VLoadVectorMasked) {
const auto Dst = GetVReg(Node);
const auto MaskReg = GetVReg(Op->Mask.ID());
const auto MemReg = GetReg(Op->Addr.ID());
const auto MemSrc = GenerateSVEMemOperand(OpSize, MemReg, Op->Offset, Op->OffsetType, Op->OffsetScale);

// Check if the sign bit is set for the given element size.
cmplt(SubRegSize, CMPPredicate, GoverningPredicate.Zeroing(), MaskReg.Z(), 0);
if (HostSupportsSVE128 || HostSupportsSVE256) {
const auto MemSrc = GenerateSVEMemOperand(OpSize, MemReg, Op->Offset, Op->OffsetType, Op->OffsetScale);

switch (IROp->ElementSize) {
case 1: {
ld1b<ARMEmitter::SubRegSize::i8Bit>(Dst.Z(), CMPPredicate.Zeroing(), MemSrc);
break;
}
case 2: {
ld1h<ARMEmitter::SubRegSize::i16Bit>(Dst.Z(), CMPPredicate.Zeroing(), MemSrc);
break;
}
case 4: {
ld1w<ARMEmitter::SubRegSize::i32Bit>(Dst.Z(), CMPPredicate.Zeroing(), MemSrc);
break;
}
case 8: {
ld1d(Dst.Z(), CMPPredicate.Zeroing(), MemSrc);
break;
}
default: break;
// Check if the sign bit is set for the given element size.
cmplt(SubRegSize, CMPPredicate, GoverningPredicate.Zeroing(), MaskReg.Z(), 0);

switch (IROp->ElementSize) {
case 1: {
ld1b<ARMEmitter::SubRegSize::i8Bit>(Dst.Z(), CMPPredicate.Zeroing(), MemSrc);
break;
}
case 2: {
ld1h<ARMEmitter::SubRegSize::i16Bit>(Dst.Z(), CMPPredicate.Zeroing(), MemSrc);
break;
}
case 4: {
ld1w<ARMEmitter::SubRegSize::i32Bit>(Dst.Z(), CMPPredicate.Zeroing(), MemSrc);
break;
}
case 8: {
ld1d(Dst.Z(), CMPPredicate.Zeroing(), MemSrc);
break;
}
default: break;
}
} else {
// Prepare yourself adventurer. For a masked load without instructions that implement it.
LOGMAN_THROW_A_FMT(OpSize == Core::CPUState::XMM_SSE_REG_SIZE, "Only supports 128-bit without SVE256");
size_t NumElements = IROp->Size / IROp->ElementSize;

const auto PerformMove = [this](size_t ElementSize, const ARMEmitter::Register Dst, const ARMEmitter::VRegister Vector, int index) {
switch (ElementSize) {
case 1: umov<ARMEmitter::SubRegSize::i8Bit>(Dst, Vector, index); break;
case 2: umov<ARMEmitter::SubRegSize::i16Bit>(Dst, Vector, index); break;
case 4: umov<ARMEmitter::SubRegSize::i32Bit>(Dst, Vector, index); break;
case 8: umov<ARMEmitter::SubRegSize::i64Bit>(Dst, Vector, index); break;
default: LOGMAN_MSG_A_FMT("Unhandled ExtractElementSize: {}", ElementSize); break;
}
};

// Use VTMP1 as the temporary destination
auto TempDst = VTMP1;
auto WorkingReg = TMP1;
auto TempMemReg = TMP2;
movi(ARMEmitter::SubRegSize::i64Bit, TempDst.Q(), 0);
CalculateMemOperand(TempMemReg, OpSize, MemReg, Op->Offset, Op->OffsetType, Op->OffsetScale);

for (size_t i = 0; i < NumElements; ++i) {
// Extract the mask element.
PerformMove(IROp->ElementSize, WorkingReg, MaskReg, i);

// If the sign bit is zero then skip the load
ARMEmitter::SingleUseForwardLabel Skip {};
tbz(WorkingReg, (IROp->ElementSize * 8) - 1, &Skip);
// Do the gather load for this element into the destination
switch (IROp->ElementSize) {
case 1: ld1<ARMEmitter::SubRegSize::i8Bit>(TempDst.Q(), i, TempMemReg); break;
case 2: ld1<ARMEmitter::SubRegSize::i16Bit>(TempDst.Q(), i, TempMemReg); break;
case 4: ld1<ARMEmitter::SubRegSize::i32Bit>(TempDst.Q(), i, TempMemReg); break;
case 8: ld1<ARMEmitter::SubRegSize::i64Bit>(TempDst.Q(), i, TempMemReg); break;
case 16: ldr(TempDst.Q(), TempMemReg, 0); break;
default: LOGMAN_MSG_A_FMT("Unhandled {} size: {}", __func__, IROp->ElementSize); return;
}

Bind(&Skip);

if ((i + 1) != NumElements) {
add(ARMEmitter::Size::i64Bit, TempMemReg, TempMemReg, IROp->ElementSize);
}
}

// Move result.
mov(Dst.Q(), TempDst.Q());
}
}

DEF_OP(VStoreVectorMasked) {
LOGMAN_THROW_A_FMT(HostSupportsSVE128 || HostSupportsSVE256, "Need SVE support in order to use VStoreVectorMasked");

const auto Op = IROp->C<IR::IROp_VStoreVectorMasked>();
const auto OpSize = IROp->Size;

Expand All @@ -813,29 +892,74 @@ DEF_OP(VStoreVectorMasked) {
const auto RegData = GetVReg(Op->Data.ID());
const auto MaskReg = GetVReg(Op->Mask.ID());
const auto MemReg = GetReg(Op->Addr.ID());
const auto MemDst = GenerateSVEMemOperand(OpSize, MemReg, Op->Offset, Op->OffsetType, Op->OffsetScale);
if (HostSupportsSVE128 || HostSupportsSVE256) {
const auto MemDst = GenerateSVEMemOperand(OpSize, MemReg, Op->Offset, Op->OffsetType, Op->OffsetScale);

// Check if the sign bit is set for the given element size.
cmplt(SubRegSize, CMPPredicate, GoverningPredicate.Zeroing(), MaskReg.Z(), 0);
// Check if the sign bit is set for the given element size.
cmplt(SubRegSize, CMPPredicate, GoverningPredicate.Zeroing(), MaskReg.Z(), 0);

switch (IROp->ElementSize) {
case 1: {
st1b<ARMEmitter::SubRegSize::i8Bit>(RegData.Z(), CMPPredicate.Zeroing(), MemDst);
break;
}
case 2: {
st1h<ARMEmitter::SubRegSize::i16Bit>(RegData.Z(), CMPPredicate.Zeroing(), MemDst);
break;
}
case 4: {
st1w<ARMEmitter::SubRegSize::i32Bit>(RegData.Z(), CMPPredicate.Zeroing(), MemDst);
break;
}
case 8: {
st1d(RegData.Z(), CMPPredicate.Zeroing(), MemDst);
break;
}
default: break;
switch (IROp->ElementSize) {
case 1: {
st1b<ARMEmitter::SubRegSize::i8Bit>(RegData.Z(), CMPPredicate.Zeroing(), MemDst);
break;
}
case 2: {
st1h<ARMEmitter::SubRegSize::i16Bit>(RegData.Z(), CMPPredicate.Zeroing(), MemDst);
break;
}
case 4: {
st1w<ARMEmitter::SubRegSize::i32Bit>(RegData.Z(), CMPPredicate.Zeroing(), MemDst);
break;
}
case 8: {
st1d(RegData.Z(), CMPPredicate.Zeroing(), MemDst);
break;
}
default: break;
}
} else {
// Prepare yourself adventurer. For a masked store without instructions that implement it.
LOGMAN_THROW_A_FMT(OpSize == Core::CPUState::XMM_SSE_REG_SIZE, "Only supports 128-bit without SVE256");
size_t NumElements = IROp->Size / IROp->ElementSize;

const auto PerformMove = [this](size_t ElementSize, const ARMEmitter::Register Dst, const ARMEmitter::VRegister Vector, int index) {
switch (ElementSize) {
case 1: umov<ARMEmitter::SubRegSize::i8Bit>(Dst, Vector, index); break;
case 2: umov<ARMEmitter::SubRegSize::i16Bit>(Dst, Vector, index); break;
case 4: umov<ARMEmitter::SubRegSize::i32Bit>(Dst, Vector, index); break;
case 8: umov<ARMEmitter::SubRegSize::i64Bit>(Dst, Vector, index); break;
default: LOGMAN_MSG_A_FMT("Unhandled ExtractElementSize: {}", ElementSize); break;
}
};

// Use VTMP1 as the temporary destination
auto WorkingReg = TMP1;
auto TempMemReg = TMP2;
CalculateMemOperand(TempMemReg, OpSize, MemReg, Op->Offset, Op->OffsetType, Op->OffsetScale);

for (size_t i = 0; i < NumElements; ++i) {
// Extract the mask element.
PerformMove(IROp->ElementSize, WorkingReg, MaskReg, i);

// If the sign bit is zero then skip the load
ARMEmitter::SingleUseForwardLabel Skip {};
tbz(WorkingReg, (IROp->ElementSize * 8) - 1, &Skip);
// Do the gather load for this element into the destination
switch (IROp->ElementSize) {
case 1: st1<ARMEmitter::SubRegSize::i8Bit>(RegData.Q(), i, TempMemReg); break;
case 2: st1<ARMEmitter::SubRegSize::i16Bit>(RegData.Q(), i, TempMemReg); break;
case 4: st1<ARMEmitter::SubRegSize::i32Bit>(RegData.Q(), i, TempMemReg); break;
case 8: st1<ARMEmitter::SubRegSize::i64Bit>(RegData.Q(), i, TempMemReg); break;
case 16: str(RegData.Q(), TempMemReg, 0); break;
default: LOGMAN_MSG_A_FMT("Unhandled {} size: {}", __func__, IROp->ElementSize); return;
}

Bind(&Skip);

if ((i + 1) != NumElements) {
add(ARMEmitter::Size::i64Bit, TempMemReg, TempMemReg, IROp->ElementSize);
}
}
}
}

Expand Down

0 comments on commit ad18bff

Please sign in to comment.