Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for allowing direct VEXTRACT to 20-bit registers #233

Open
wants to merge 3 commits into
base: aie-public
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions llvm/lib/Target/AIE/AIE2InstrInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,9 @@ bool AIE2InstrInfo::verifyGenericInstruction(const MachineInstr &MI,
switch (MI.getOpcode()) {
case AIE2::G_AIE_ZEXT_EXTRACT_VECTOR_ELT:
case AIE2::G_AIE_SEXT_EXTRACT_VECTOR_ELT:
ErrInfo = "Expected 32bit scalar destination";
return MRI.getType(MI.getOperand(0).getReg()) == LLT::scalar(32);
ErrInfo = "Expected 32bit or 20bit scalar destination";
return (MRI.getType(MI.getOperand(0).getReg()) == LLT::scalar(32) ||
MRI.getType(MI.getOperand(0).getReg()) == LLT::scalar(20));
case AIE2::G_AIE_PAD_VECTOR_UNDEF:
return verifySameLaneTypes(MI, ErrInfo) &&
isLegalTypeToUnpad(MRI.getType(MI.getOperand(0).getReg()),
Expand Down
93 changes: 93 additions & 0 deletions llvm/lib/Target/AIE/AIE2PreLegalizerCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ class AIE2PreLegalizerCombinerImpl : public Combiner {

bool tryToCombineVectorInserts(MachineInstr &MI, unsigned SclSrcBits) const;

bool tryToCombineTruncExt(Register DstReg, bool SignVal,
unsigned SrcEltSize) const;

bool tryToCombineVExtractElt(MachineInstr &MI) const;

bool tryToCombineIntrinsic(MachineInstr &MI) const;

private:
Expand Down Expand Up @@ -288,6 +293,89 @@ bool AIE2PreLegalizerCombinerImpl::tryToCombineVectorInserts(
return true;
}

/// \returns true if it is possible to combine the below sequence of MIRs
/// From : %10:_(s32) = G_INTRINSIC intrinsic(@llvm.aie2.vextract.elem16.I512),
/// %2(<32 x s16>), %0(s32), %1(s32)
/// %20:_(s16) = G_TRUNC %10(s32)
/// %30:_(s20) = G_SEXT %20(s16)
/// To : %10:_(s32) = G_INTRINSIC intrinsic(@llvm.aie2.vextract.elem16.I512),
/// %2(<32 x s16>), %0(s32), %1(s32)
/// %30:_(s20) = G_TRUNC %10(s32)
/// Or even:
/// From : %10:_(s32) = G_INTRINSIC intrinsic(@llvm.aie2.vextract.elem8.I512),
/// %2(<64 x s8>), %0(s32), %1(s32)
/// %20:_(s8) = G_TRUNC %10(s32)
/// %30:_(s20) = G_SEXT %20(s8)
/// To : %10:_(s32) = G_INTRINSIC intrinsic(@llvm.aie2.vextract.elem8.I512),
/// %2(<64 x s8>), %0(s32), %1(s32)
/// %30:_(s20) = G_TRUNC %10(s32)
/// This also enables S20Narrowing for vextract
bool AIE2PreLegalizerCombinerImpl::tryToCombineTruncExt(
Copy link
Collaborator

@konstantinschwarz konstantinschwarz Nov 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of doing this combine here, why don't we just use G_ASSERT_ZEXT/G_ASSERT_SEXT?
In your previous commit 16445fb,
we should pre-select

%0:_(s32) = G_INTRINSIC intrinsic(@llvm.aie2.vextract.elem8.I512), ...

into

%new:_(s32) = G_AIE_ZEXT_EXTRACT_VECTOR_ELT
%0:_(s32) = G_ASSERT_ZEXT %new, 16

Then you should get the G_TRUNC + G_ZEXT combine for free.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should have something like:

    %0:_(s32) = G_CONSTANT i32 7
    %2:_(<32 x s16>) = COPY $x0
    %9:_(s20) = G_AIE_ZEXT_EXTRACT_VECTOR_ELT %2(<32 x s16>), %0(s32)
    %3:_(s20) = G_ASSERT_ZEXT %9, 16
    %4:_(s16) = G_TRUNC %3(s20)
    %5:_(s20) = G_ZEXT %4(s16)
    %6:_(p0) = G_CONSTANT i20 0
    %7:_(p0), %8:_(s20) = G_INTRINSIC intrinsic(@llvm.aie2.add.2d), %6(p0), %5(s20), %5(s20), %5(s20), %5(s20)
    PseudoRET implicit $lr, implicit %7(p0)

However, for SEXT case, there is no combine pattern AFAIK.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can try to match this pattern explicitly looking to ASSERT, something like:

bool CombinerHelper::matchCombineSextTrunc(MachineInstr &MI, Register &Reg) {
  assert(MI.getOpcode() == TargetOpcode::G_SEXT && "Expected a G_SEXT");
  Register DstReg = MI.getOperand(0).getReg();
  Register SrcReg = MI.getOperand(1).getReg();
  LLT DstTy = MRI.getType(DstReg);
  if (mi_match(SrcReg, MRI,
               m_GTrunc(m_all_of(m_Reg(Reg), m_SpecificType(DstTy))))) {
    unsigned DstSize = DstTy.getScalarSizeInBits();
    unsigned SrcSize = MRI.getType(SrcReg).getScalarSizeInBits();
    MachineInstr* MIDef = MRI.getVRegDef(Reg);
    if (MIDef->getOpcode() != TargetOpcode::G_ASSERT_SEXT)
      return false;
    unsigned ExtBits = MIDef->getOperand(2).getImm();
    return SrcSize == ExtBits;
  }
  return false;
}

And apply replaceSingleDefInstWithReg

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of hardcoding the G_ASSERT_SEXT, I'd suggest to use the KnownBits analysis to check if we have known sign bits: KB->computeNumSignBits(Reg)

Register DstReg, bool SignVal, unsigned SrcEltSize) const {
// Checks if a given register has non-debug user with a specific opcode and
// destination size, and return that user.
auto GetUseWithOpCode =
abhinay-anubola marked this conversation as resolved.
Show resolved Hide resolved
[&](const Register Reg, const unsigned OpcodeToCheck,
const unsigned DstSize) -> std::optional<MachineInstr *> {
for (auto &Use : MRI.use_nodbg_instructions(Reg)) {
if (Use.getOpcode() == OpcodeToCheck) {
const LLT DstRegTy = MRI.getType(Use.getOperand(0).getReg());
if (DstRegTy.getSizeInBits() == DstSize)
return &Use;
}
}
return std::nullopt;
};

if (auto Trunc =
GetUseWithOpCode(DstReg, TargetOpcode::G_TRUNC, SrcEltSize)) {
MachineInstr *TruncMI = Trunc.value();
const unsigned ExtOpcode =
SignVal ? TargetOpcode::G_SEXT : TargetOpcode::G_ZEXT;
const Register UseDstReg = TruncMI->getOperand(0).getReg();
// Ensure G_TRUNC has a single non-debug user before safely eliminating it.
if (!MRI.hasOneNonDBGUser(UseDstReg))
return false;
if (auto Ext = GetUseWithOpCode(UseDstReg, ExtOpcode, 20)) {
MachineInstr *ExtMI = Ext.value();
MachineIRBuilder MIRBuilder(*ExtMI);
MIRBuilder.buildInstr(TargetOpcode::G_TRUNC, {ExtMI->getOperand(0)},
{DstReg});
ExtMI->eraseFromParent();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest to not erase 2 instructions here, because we may corrupt the iterator. Let ExtMI to be removed by DCE.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Combiner engine may access freed memory:

    while (!WorkList.empty()) {
      MachineInstr *CurrInst = WorkList.pop_back_val();
      LLVM_DEBUG(dbgs() << "\nTry combining " << *CurrInst;);
      Changed |= tryCombineAll(*CurrInst);
      WLObserver->reportFullyCreatedInstrs();
    }

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question is, should we refactor void makeDeadMI(MachineInstr &MI, MachineRegisterInfo &MRI); as a common utility function?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We cant erase just the TRUNC as it is input for EXT.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You will also erase MI (intrinsic) in the caller function.

TruncMI->eraseFromParent();
return true;
}
}
return false;
}

bool AIE2PreLegalizerCombinerImpl::tryToCombineVExtractElt(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can have a description of the combiner here.

MachineInstr &MI) const {
abhinay-anubola marked this conversation as resolved.
Show resolved Hide resolved
const Register DstReg = MI.getOperand(0).getReg();
// In this case of G_INTRINSIC operand 1 is target intrinsic
const Register SrcReg = MI.getOperand(2).getReg();
const Register IdxReg = MI.getOperand(3).getReg();
const Register SignReg = MI.getOperand(4).getReg();

const auto SignVal = getIConstantVRegSExtVal(SignReg, MRI);
if (!SignVal)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like an assert. We always need constant signal to be able to handle this intrinsic.

return false;
const LLT SrcVecTy = MRI.getType(SrcReg);
const unsigned SrcEltSize = SrcVecTy.getScalarSizeInBits();
if (SrcEltSize == 8 || SrcEltSize == 16) {
tryToCombineTruncExt(DstReg, SignVal.value(), SrcEltSize);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can trigger the same multiple erased instructions problem here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understood that this tryToCombineTruncExt is only safe because we know the range of the output of the defining instruction. However, it would be good to have this as a separated combiner (may looking to the input reg. def.). It will help to solve the erasing problem and we also can have different tests and, as a gift, less coupled code.

}

auto *TII = static_cast<const AIE2InstrInfo *>(STI.getInstrInfo());
const unsigned Opcode =
TII->getGenericExtractVectorEltOpcode(SignVal.value());
MachineIRBuilder MIRBuilder(MI);
MIRBuilder.buildInstr(Opcode, {DstReg}, {SrcReg, IdxReg});

MI.eraseFromParent();
return true;
}

bool AIE2PreLegalizerCombinerImpl::tryToCombineIntrinsic(
MachineInstr &MI) const {
const unsigned IntrinsicID = cast<GIntrinsic>(MI).getIntrinsicID();
Expand All @@ -306,6 +394,11 @@ bool AIE2PreLegalizerCombinerImpl::tryToCombineIntrinsic(
case Intrinsic::aie2_vinsert32_I512: {
return tryToCombineVectorInserts(MI, getVInsertScalarSize(IntrinsicID));
}
case Intrinsic::aie2_vextract_elem8_I512:
case Intrinsic::aie2_vextract_elem16_I512:
case Intrinsic::aie2_vextract_elem32_I512: {
return tryToCombineVExtractElt(MI);
}
default:
break;
}
Expand Down
48 changes: 39 additions & 9 deletions llvm/lib/Target/AIE/AIECombinerHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,18 @@ void llvm::applyGlobalValOffset(MachineInstr &MI, MachineRegisterInfo &MRI,
B.buildConstant(LLT::scalar(20), -static_cast<int64_t>(Offset)));
}

/// Determine if the instruction is a generic extract vector element operation
static bool IsGenericExtractVectorElt(const MachineInstr &MI) {
const AIEBaseSubtarget &STI = AIEBaseSubtarget::get(*MI.getMF());
const AIEBaseInstrInfo *TII = STI.getInstrInfo();
const unsigned Opcode = MI.getOpcode();

if (Opcode == TII->getGenericExtractVectorEltOpcode(false) ||
Opcode == TII->getGenericExtractVectorEltOpcode(true))
return true;
return false;
}

/// Checks whether the instruction produces or can be adapted to produce
/// a single S20 output.
static bool canProduceS20(const MachineRegisterInfo &MRI,
Expand All @@ -581,9 +593,12 @@ static bool canProduceS20(const MachineRegisterInfo &MRI,
case TargetOpcode::G_CONSTANT:
case TargetOpcode::G_IMPLICIT_DEF:
return true;
default:
default: {
if (IsGenericExtractVectorElt(MI))
return true;
abhinay-anubola marked this conversation as resolved.
Show resolved Hide resolved
return false;
}
}
}

/// The function checks if the node can be adapted to produce an S20 value, and
Expand Down Expand Up @@ -901,15 +916,19 @@ bool modifyToS20(InstrNode Start, MachineRegisterInfo &MRI, MachineIRBuilder &B,
return true;
}
default: {
LLVM_DEBUG(dbgs() << "Node :" << *StartNodeMI);
llvm_unreachable("Unexpected OpCode, while modifying IR");
if (IsGenericExtractVectorElt(*StartNodeMI)) {
abhinay-anubola marked this conversation as resolved.
Show resolved Hide resolved
Observer.changingInstr(*StartNodeMI);
MRI.setType(StartNodeMI->getOperand(0).getReg(), S20);
Observer.changedInstr(*StartNodeMI);
} else {
LLVM_DEBUG(dbgs() << "Node :" << *StartNodeMI);
llvm_unreachable("Unexpected OpCode, while modifying IR");
abhinay-anubola marked this conversation as resolved.
Show resolved Hide resolved
}
}
}

switch (StartNodeMI->getOpcode()) {
case TargetOpcode::COPY:
case TargetOpcode::G_LOAD:
case TargetOpcode::G_PHI: {
// Function to handle the modification of instructions
auto ModifyInstructionUses = [&](MachineInstr *StartNodeMI) {
const auto UseInstIter =
MRI.use_nodbg_instructions(StartNodeMI->getOperand(0).getReg());
std::vector<MachineInstr *> UseInstr;
Expand All @@ -924,11 +943,22 @@ bool modifyToS20(InstrNode Start, MachineRegisterInfo &MRI, MachineIRBuilder &B,
if (!modifyToS20(NextNodeToModify, MRI, B, Observer, Helper))
llvm_unreachable("All input nodes should have updated");
}
};

switch (StartNodeMI->getOpcode()) {
case TargetOpcode::COPY:
case TargetOpcode::G_LOAD:
case TargetOpcode::G_PHI: {
ModifyInstructionUses(StartNodeMI);
break;
}
default: {
LLVM_DEBUG(dbgs() << "Node :" << *StartNodeMI);
llvm_unreachable("Unexpected OpCode, while modifying IR");
if (IsGenericExtractVectorElt(*StartNodeMI)) {
ModifyInstructionUses(StartNodeMI);
} else {
LLVM_DEBUG(dbgs() << "Node :" << *StartNodeMI);
llvm_unreachable("Unexpected OpCode, while modifying IR");
}
}
}
return true;
Expand Down
Loading
Loading