Skip to content

Commit

Permalink
[VPlan] Dispatch to multiple exit blocks via middle blocks.
Browse files Browse the repository at this point in the history
A more lightweight variant of #109193,
which dispatches to multiple exit blocks via the middle blocks.
  • Loading branch information
fhahn committed Oct 22, 2024
1 parent 245b56a commit 47258de
Show file tree
Hide file tree
Showing 12 changed files with 614 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,9 @@ class LoopVectorizationLegality {
/// we can use in-order reductions.
bool canVectorizeFPMath(bool EnableStrictReductions);

/// Returns true if the loop has an early exit that we can vectorize.
bool canVectorizeEarlyExit() const;

/// Return true if we can vectorize this loop while folding its tail by
/// masking.
bool canFoldTailByMasking() const;
Expand Down
29 changes: 29 additions & 0 deletions llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ AllowStridedPointerIVs("lv-strided-pointer-ivs", cl::init(false), cl::Hidden,
cl::desc("Enable recognition of non-constant strided "
"pointer induction variables."));

static cl::opt<bool>
EnableEarlyExitVectorization("enable-early-exit-vectorization",
cl::init(false), cl::Hidden, cl::desc(""));

namespace llvm {
cl::opt<bool>
HintsAllowReordering("hints-allow-reordering", cl::init(true), cl::Hidden,
Expand Down Expand Up @@ -1378,6 +1382,10 @@ bool LoopVectorizationLegality::isFixedOrderRecurrence(
}

bool LoopVectorizationLegality::blockNeedsPredication(BasicBlock *BB) const {
// When vectorizing early exits, create predicates for all blocks, except the
// header.
if (canVectorizeEarlyExit() && BB != TheLoop->getHeader())
return true;
return LoopAccessInfo::blockNeedsPredication(BB, TheLoop, DT);
}

Expand Down Expand Up @@ -1514,6 +1522,27 @@ bool LoopVectorizationLegality::canVectorizeWithIfConvert() {
return true;
}

bool LoopVectorizationLegality::canVectorizeEarlyExit() const {
// Currently only allow vectorizing loops with early exits, if early-exit
// vectorization is explicitly enabled and the loop has metadata to force
// vectorization.
if (!EnableEarlyExitVectorization)
return false;

SmallVector<BasicBlock *> Exiting;
TheLoop->getExitingBlocks(Exiting);
if (Exiting.size() == 1)
return false;

LoopVectorizeHints Hints(TheLoop, true, *ORE);
if (Hints.getForce() == LoopVectorizeHints::FK_Undefined)
return false;

Function *Fn = TheLoop->getHeader()->getParent();
return Hints.allowVectorization(Fn, TheLoop,
true /*VectorizeOnlyWhenForced*/);
}

// Helper function to canVectorizeLoopNestCFG.
bool LoopVectorizationLegality::canVectorizeLoopCFG(Loop *Lp,
bool UseVPlanNativePath) {
Expand Down
82 changes: 51 additions & 31 deletions llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1363,9 +1363,11 @@ class LoopVectorizationCostModel {
// If we might exit from anywhere but the latch, must run the exiting
// iteration in scalar form.
if (TheLoop->getExitingBlock() != TheLoop->getLoopLatch()) {
LLVM_DEBUG(
dbgs() << "LV: Loop requires scalar epilogue: multiple exits\n");
return true;
if (!Legal->canVectorizeEarlyExit()) {
LLVM_DEBUG(
dbgs() << "LV: Loop requires scalar epilogue: multiple exits\n");
return true;
}
}
if (IsVectorizing && InterleaveInfo.requiresScalarEpilogue()) {
LLVM_DEBUG(dbgs() << "LV: Loop requires scalar epilogue: "
Expand Down Expand Up @@ -2575,7 +2577,8 @@ void InnerLoopVectorizer::createVectorLoopSkeleton(StringRef Prefix) {
LoopVectorPreHeader = OrigLoop->getLoopPreheader();
assert(LoopVectorPreHeader && "Invalid loop structure");
LoopExitBlock = OrigLoop->getUniqueExitBlock(); // may be nullptr
assert((LoopExitBlock || Cost->requiresScalarEpilogue(VF.isVector())) &&
assert((LoopExitBlock || Cost->requiresScalarEpilogue(VF.isVector()) ||
Legal->canVectorizeEarlyExit()) &&
"multiple exit loop without required epilogue?");

LoopMiddleBlock =
Expand Down Expand Up @@ -2758,8 +2761,6 @@ void InnerLoopVectorizer::fixupIVUsers(PHINode *OrigPhi,
// value (the value that feeds into the phi from the loop latch).
// We allow both, but they, obviously, have different values.

assert(OrigLoop->getUniqueExitBlock() && "Expected a single exit block");

DenseMap<Value *, Value *> MissingVals;

// An external user of the last iteration's value should see the value that
Expand Down Expand Up @@ -2819,6 +2820,9 @@ void InnerLoopVectorizer::fixupIVUsers(PHINode *OrigPhi,
if (PHI->getBasicBlockIndex(MiddleBlock) == -1)
PHI->addIncoming(I.second, MiddleBlock);
}

assert((MissingVals.empty() || OrigLoop->getUniqueExitBlock()) &&
"Expected a single exit block");
}

namespace {
Expand Down Expand Up @@ -3599,7 +3603,8 @@ void LoopVectorizationCostModel::collectLoopUniforms(ElementCount VF) {
TheLoop->getExitingBlocks(Exiting);
for (BasicBlock *E : Exiting) {
auto *Cmp = dyn_cast<Instruction>(E->getTerminator()->getOperand(0));
if (Cmp && TheLoop->contains(Cmp) && Cmp->hasOneUse())
if (Cmp && TheLoop->contains(Cmp) && Cmp->hasOneUse() &&
(TheLoop->getLoopLatch() == E || !Legal->canVectorizeEarlyExit()))
AddToWorklistIfAllowed(Cmp);
}

Expand Down Expand Up @@ -7692,12 +7697,15 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(
BestVPlan.execute(&State);

// 2.5 Collect reduction resume values.
auto *ExitVPBB =
cast<VPBasicBlock>(BestVPlan.getVectorLoopRegion()->getSingleSuccessor());
for (VPRecipeBase &R : *ExitVPBB) {
createAndCollectMergePhiForReduction(
dyn_cast<VPInstruction>(&R), State, OrigLoop,
State.CFG.VPBB2IRBB[ExitVPBB], ExpandedSCEVs);
VPBasicBlock *ExitVPBB = nullptr;
if (BestVPlan.getVectorLoopRegion()->getSingleSuccessor()) {
ExitVPBB = cast<VPBasicBlock>(
BestVPlan.getVectorLoopRegion()->getSingleSuccessor());
for (VPRecipeBase &R : *ExitVPBB) {
createAndCollectMergePhiForReduction(
dyn_cast<VPInstruction>(&R), State, OrigLoop,
State.CFG.VPBB2IRBB[ExitVPBB], ExpandedSCEVs);
}
}

// 2.6. Maintain Loop Hints
Expand All @@ -7723,6 +7731,7 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(
LoopVectorizeHints Hints(L, true, *ORE);
Hints.setAlreadyVectorized();
}

TargetTransformInfo::UnrollingPreferences UP;
TTI.getUnrollingPreferences(L, *PSE.getSE(), UP, ORE);
if (!UP.UnrollVectorizedLoop || CanonicalIVStartValue)
Expand All @@ -7735,15 +7744,17 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(
ILV.printDebugTracesAtEnd();

// 4. Adjust branch weight of the branch in the middle block.
auto *MiddleTerm =
cast<BranchInst>(State.CFG.VPBB2IRBB[ExitVPBB]->getTerminator());
if (MiddleTerm->isConditional() &&
hasBranchWeightMD(*OrigLoop->getLoopLatch()->getTerminator())) {
// Assume that `Count % VectorTripCount` is equally distributed.
unsigned TripCount = BestVPlan.getUF() * State.VF.getKnownMinValue();
assert(TripCount > 0 && "trip count should not be zero");
const uint32_t Weights[] = {1, TripCount - 1};
setBranchWeights(*MiddleTerm, Weights, /*IsExpected=*/false);
if (ExitVPBB) {
auto *MiddleTerm =
cast<BranchInst>(State.CFG.VPBB2IRBB[ExitVPBB]->getTerminator());
if (MiddleTerm->isConditional() &&
hasBranchWeightMD(*OrigLoop->getLoopLatch()->getTerminator())) {
// Assume that `Count % VectorTripCount` is equally distributed.
unsigned TripCount = BestVPlan.getUF() * State.VF.getKnownMinValue();
assert(TripCount > 0 && "trip count should not be zero");
const uint32_t Weights[] = {1, TripCount - 1};
setBranchWeights(*MiddleTerm, Weights, /*IsExpected=*/false);
}
}

return State.ExpandedSCEVs;
Expand Down Expand Up @@ -8128,7 +8139,7 @@ VPValue *VPRecipeBuilder::createEdgeMask(BasicBlock *Src, BasicBlock *Dst) {
// If source is an exiting block, we know the exit edge is dynamically dead
// in the vector loop, and thus we don't need to restrict the mask. Avoid
// adding uses of an otherwise potentially dead instruction.
if (OrigLoop->isLoopExiting(Src))
if (!Legal->canVectorizeEarlyExit() && OrigLoop->isLoopExiting(Src))
return EdgeMaskCache[Edge] = SrcMask;

VPValue *EdgeMask = getVPValueOrAddLiveIn(BI->getCondition());
Expand Down Expand Up @@ -8778,6 +8789,8 @@ static void addCanonicalIVRecipes(VPlan &Plan, Type *IdxTy, bool HasNUW,
static SetVector<VPIRInstruction *> collectUsersInExitBlock(
Loop *OrigLoop, VPRecipeBuilder &Builder, VPlan &Plan,
const MapVector<PHINode *, InductionDescriptor> &Inductions) {
if (!Plan.getVectorLoopRegion()->getSingleSuccessor())
return {};
auto *MiddleVPBB =
cast<VPBasicBlock>(Plan.getVectorLoopRegion()->getSingleSuccessor());
// No edge from the middle block to the unique exit block has been inserted
Expand Down Expand Up @@ -8863,6 +8876,8 @@ static void addLiveOutsForFirstOrderRecurrences(
// TODO: Should be replaced by
// Plan->getScalarLoopRegion()->getSinglePredecessor() in the future once the
// scalar region is modeled as well.
if (!VectorRegion->getSingleSuccessor())
return;
auto *MiddleVPBB = cast<VPBasicBlock>(VectorRegion->getSingleSuccessor());
VPBasicBlock *ScalarPHVPBB = nullptr;
if (MiddleVPBB->getNumSuccessors() == 2) {
Expand Down Expand Up @@ -9146,10 +9161,15 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
"VPBasicBlock");
RecipeBuilder.fixHeaderPhis();

SetVector<VPIRInstruction *> ExitUsersToFix = collectUsersInExitBlock(
OrigLoop, RecipeBuilder, *Plan, Legal->getInductionVars());
addLiveOutsForFirstOrderRecurrences(*Plan, ExitUsersToFix);
addUsersInExitBlock(*Plan, ExitUsersToFix);
if (Legal->canVectorizeEarlyExit()) {
VPlanTransforms::convertToMultiCond(*Plan, *PSE.getSE(), OrigLoop,
RecipeBuilder);
} else {
SetVector<VPIRInstruction *> ExitUsersToFix = collectUsersInExitBlock(
OrigLoop, RecipeBuilder, *Plan, Legal->getInductionVars());
addLiveOutsForFirstOrderRecurrences(*Plan, ExitUsersToFix);
addUsersInExitBlock(*Plan, ExitUsersToFix);
}

// ---------------------------------------------------------------------------
// Transform initial VPlan: Apply previously taken decisions, in order, to
Expand Down Expand Up @@ -9277,8 +9297,6 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
using namespace VPlanPatternMatch;
VPRegionBlock *VectorLoopRegion = Plan->getVectorLoopRegion();
VPBasicBlock *Header = VectorLoopRegion->getEntryBasicBlock();
VPBasicBlock *MiddleVPBB =
cast<VPBasicBlock>(VectorLoopRegion->getSingleSuccessor());
for (VPRecipeBase &R : Header->phis()) {
auto *PhiR = dyn_cast<VPReductionPHIRecipe>(&R);
if (!PhiR || !PhiR->isInLoop() || (MinVF.isScalar() && !PhiR->isOrdered()))
Expand All @@ -9297,8 +9315,6 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
for (VPUser *U : Cur->users()) {
auto *UserRecipe = cast<VPSingleDefRecipe>(U);
if (!UserRecipe->getParent()->getEnclosingLoopRegion()) {
assert(UserRecipe->getParent() == MiddleVPBB &&
"U must be either in the loop region or the middle block.");
continue;
}
Worklist.insert(UserRecipe);
Expand Down Expand Up @@ -9403,6 +9419,10 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
}
VPBasicBlock *LatchVPBB = VectorLoopRegion->getExitingBasicBlock();
Builder.setInsertPoint(&*LatchVPBB->begin());
if (!VectorLoopRegion->getSingleSuccessor())
return;
VPBasicBlock *MiddleVPBB =
cast<VPBasicBlock>(VectorLoopRegion->getSingleSuccessor());
VPBasicBlock::iterator IP = MiddleVPBB->getFirstNonPhi();
for (VPRecipeBase &R :
Plan->getVectorLoopRegion()->getEntryBasicBlock()->phis()) {
Expand Down
39 changes: 32 additions & 7 deletions llvm/lib/Transforms/Vectorize/VPlan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,14 @@ void VPIRBasicBlock::execute(VPTransformState *State) {
// backedges. A backward successor is set when the branch is created.
const auto &PredVPSuccessors = PredVPBB->getHierarchicalSuccessors();
unsigned idx = PredVPSuccessors.front() == this ? 0 : 1;
if (TermBr->getSuccessor(idx) &&
PredVPBlock == getPlan()->getVectorLoopRegion() &&
PredVPBlock->getNumSuccessors()) {
// Update PRedBB and TermBr for BranchOnMultiCond in predecessor.
PredBB = TermBr->getSuccessor(1);
TermBr = cast<BranchInst>(PredBB->getTerminator());
idx = 0;
}
assert(!TermBr->getSuccessor(idx) &&
"Trying to reset an existing successor block.");
TermBr->setSuccessor(idx, IRBB);
Expand Down Expand Up @@ -908,8 +916,8 @@ VPlanPtr VPlan::createInitialVPlan(Type *InductionTy,
VPBasicBlock *MiddleVPBB = new VPBasicBlock("middle.block");
VPBlockUtils::insertBlockAfter(MiddleVPBB, TopRegion);

VPBasicBlock *ScalarPH = new VPBasicBlock("scalar.ph");
if (!RequiresScalarEpilogueCheck) {
VPBasicBlock *ScalarPH = new VPBasicBlock("scalar.ph");
VPBlockUtils::connectBlocks(MiddleVPBB, ScalarPH);
return Plan;
}
Expand All @@ -923,10 +931,14 @@ VPlanPtr VPlan::createInitialVPlan(Type *InductionTy,
// we unconditionally branch to the scalar preheader. Do nothing.
// 3) Otherwise, construct a runtime check.
BasicBlock *IRExitBlock = TheLoop->getUniqueExitBlock();
auto *VPExitBlock = VPIRBasicBlock::fromBasicBlock(IRExitBlock);
// The connection order corresponds to the operands of the conditional branch.
VPBlockUtils::insertBlockAfter(VPExitBlock, MiddleVPBB);
VPBlockUtils::connectBlocks(MiddleVPBB, ScalarPH);
if (IRExitBlock) {
auto *VPExitBlock = VPIRBasicBlock::fromBasicBlock(IRExitBlock);
// The connection order corresponds to the operands of the conditional
// branch.
VPBlockUtils::insertBlockAfter(VPExitBlock, MiddleVPBB);
VPBasicBlock *ScalarPH = new VPBasicBlock("scalar.ph");
VPBlockUtils::connectBlocks(MiddleVPBB, ScalarPH);
}

auto *ScalarLatchTerm = TheLoop->getLoopLatch()->getTerminator();
// Here we use the same DebugLoc as the scalar loop latch terminator instead
Expand Down Expand Up @@ -1031,7 +1043,9 @@ void VPlan::execute(VPTransformState *State) {
// VPlan execution rather than earlier during VPlan construction.
BasicBlock *MiddleBB = State->CFG.ExitBB;
VPBasicBlock *MiddleVPBB =
cast<VPBasicBlock>(getVectorLoopRegion()->getSingleSuccessor());
getVectorLoopRegion()->getNumSuccessors() == 1
? cast<VPBasicBlock>(getVectorLoopRegion()->getSuccessors()[0])
: cast<VPBasicBlock>(getVectorLoopRegion()->getSuccessors()[1]);
// Find the VPBB for the scalar preheader, relying on the current structure
// when creating the middle block and its successrs: if there's a single
// predecessor, it must be the scalar preheader. Otherwise, the second
Expand All @@ -1044,6 +1058,10 @@ void VPlan::execute(VPTransformState *State) {
MiddleSuccs.size() == 1 ? MiddleSuccs[0] : MiddleSuccs[1]);
assert(!isa<VPIRBasicBlock>(ScalarPhVPBB) &&
"scalar preheader cannot be wrapped already");
if (ScalarPhVPBB->getNumSuccessors() != 0) {
ScalarPhVPBB = cast<VPBasicBlock>(ScalarPhVPBB->getSuccessors()[1]);
MiddleVPBB = cast<VPBasicBlock>(MiddleVPBB->getSuccessors()[1]);
}
replaceVPBBWithIRVPBB(ScalarPhVPBB, ScalarPh);
replaceVPBBWithIRVPBB(MiddleVPBB, MiddleBB);

Expand All @@ -1065,6 +1083,10 @@ void VPlan::execute(VPTransformState *State) {
VPBasicBlock *LatchVPBB = getVectorLoopRegion()->getExitingBasicBlock();
BasicBlock *VectorLatchBB = State->CFG.VPBB2IRBB[LatchVPBB];

if (!getVectorLoopRegion()->getSingleSuccessor())
VectorLatchBB =
cast<BranchInst>(VectorLatchBB->getTerminator())->getSuccessor(1);

// Fix the latch value of canonical, reduction and first-order recurrences
// phis in the vector loop.
VPBasicBlock *Header = getVectorLoopRegion()->getEntryBasicBlock();
Expand All @@ -1091,7 +1113,10 @@ void VPlan::execute(VPTransformState *State) {
// Move the last step to the end of the latch block. This ensures
// consistent placement of all induction updates.
Instruction *Inc = cast<Instruction>(Phi->getIncomingValue(1));
Inc->moveBefore(VectorLatchBB->getTerminator()->getPrevNode());
if (VectorLatchBB->getTerminator() == &*VectorLatchBB->getFirstNonPHI())
Inc->moveBefore(VectorLatchBB->getTerminator());
else
Inc->moveBefore(VectorLatchBB->getTerminator()->getPrevNode());

// Use the steps for the last part as backedge value for the induction.
if (auto *IV = dyn_cast<VPWidenIntOrFpInductionRecipe>(&R))
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Transforms/Vectorize/VPlan.h
Original file line number Diff line number Diff line change
Expand Up @@ -1274,6 +1274,7 @@ class VPInstruction : public VPRecipeWithIRFlags,
// operand). Only generates scalar values (either for the first lane only or
// for all lanes, depending on its uses).
PtrAdd,
AnyOf,
};

private:
Expand Down
Loading

0 comments on commit 47258de

Please sign in to comment.