diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h index e1d828f038f9a2..d3cd1a9b128048 100644 --- a/llvm/lib/Transforms/Vectorize/VPlan.h +++ b/llvm/lib/Transforms/Vectorize/VPlan.h @@ -1220,6 +1220,7 @@ class VPInstruction : public VPRecipeWithIRFlags, CalculateTripCountMinusVF, // Increment the canonical IV separately for each unrolled part. CanonicalIVIncrementForPart, + WideIVStep, BranchOnCount, BranchOnCond, ComputeReductionResult, diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp index 5903ad29af7602..8bf9b5194932b4 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp @@ -661,7 +661,8 @@ bool VPInstruction::isFPMathOp() const { return Opcode == Instruction::FAdd || Opcode == Instruction::FMul || Opcode == Instruction::FNeg || Opcode == Instruction::FSub || Opcode == Instruction::FDiv || Opcode == Instruction::FRem || - Opcode == Instruction::FCmp || Opcode == Instruction::Select; + Opcode == Instruction::FCmp || Opcode == Instruction::Select || + Opcode == VPInstruction::WideIVStep; } #endif diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp index 6d77173735c9b8..41ed8b65b00fd6 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp @@ -1821,20 +1821,61 @@ void VPlanTransforms::createInterleaveGroups( } void VPlanTransforms::convertToConcreteRecipes(VPlan &Plan) { + Type *CanonicalIVType = Plan.getCanonicalIV()->getScalarType(); + VPTypeAnalysis TypeInfo(CanonicalIVType); + for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly( vp_depth_first_deep(Plan.getEntry()))) { - for (VPRecipeBase &R : make_early_inc_range(VPBB->phis())) { - if (!isa(&R)) + for (VPRecipeBase &R : make_early_inc_range(*VPBB)) { + if (isa(&R)) { + auto *PhiR = cast(&R); + StringRef Name = + isa(PhiR) ? "index" : "evl.based.iv"; + auto *ScalarR = new VPScalarPHIRecipe(PhiR->getStartValue(), + PhiR->getBackedgeValue(), + PhiR->getDebugLoc(), Name); + ScalarR->insertBefore(PhiR); + PhiR->replaceAllUsesWith(ScalarR); + PhiR->eraseFromParent(); continue; - auto *PhiR = cast(&R); - StringRef Name = - isa(PhiR) ? "index" : "evl.based.iv"; - auto *ScalarR = - new VPScalarPHIRecipe(PhiR->getStartValue(), PhiR->getBackedgeValue(), - PhiR->getDebugLoc(), Name); - ScalarR->insertBefore(PhiR); - PhiR->replaceAllUsesWith(ScalarR); - PhiR->eraseFromParent(); + } + + auto *VPI = dyn_cast(&R); + if (VPI && VPI->getOpcode() == VPInstruction::WideIVStep) { + VPBuilder Builder(VPI->getParent(), VPI->getIterator()); + VPValue *VectorStep = VPI->getOperand(0); + Type *IVTy = TypeInfo.inferScalarType(VPI->getOperand(2)); + if (TypeInfo.inferScalarType(VectorStep) != IVTy) { + Instruction::CastOps CastOp = IVTy->isFloatingPointTy() + ? Instruction::UIToFP + : Instruction::Trunc; + VectorStep = Builder.createWidenCast(CastOp, VectorStep, IVTy); + } + + VPValue *ScalarStep = VPI->getOperand(1); + auto *ConstStep = + ScalarStep->isLiveIn() + ? dyn_cast(ScalarStep->getLiveInIRValue()) + : nullptr; + if (!ConstStep || ConstStep->getValue() != 1) { + if (TypeInfo.inferScalarType(ScalarStep) != IVTy) { + ScalarStep = + Builder.createWidenCast(Instruction::Trunc, ScalarStep, IVTy); + } + + std::optional FMFs; + if (IVTy->isFloatingPointTy()) + FMFs = VPI->getFastMathFlags(); + + unsigned MulOpc = + IVTy->isFloatingPointTy() ? Instruction::FMul : Instruction::Mul; + VPInstruction *Mul = Builder.createNaryOp( + MulOpc, {VectorStep, ScalarStep}, FMFs, R.getDebugLoc()); + VectorStep = Mul; + } + VPI->replaceAllUsesWith(VectorStep); + VPI->eraseFromParent(); + } } } } diff --git a/llvm/lib/Transforms/Vectorize/VPlanUnroll.cpp b/llvm/lib/Transforms/Vectorize/VPlanUnroll.cpp index ff6c9295ee2057..7c1bb98c1a021f 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanUnroll.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanUnroll.cpp @@ -155,33 +155,15 @@ void UnrollState::unrollWidenInductionByUF( if (isa_and_present(ID.getInductionBinOp())) FMFs = ID.getInductionBinOp()->getFastMathFlags(); - VPValue *VectorStep = &Plan.getVF(); - VPBuilder Builder(PH); - if (TypeInfo.inferScalarType(VectorStep) != IVTy) { - Instruction::CastOps CastOp = - IVTy->isFloatingPointTy() ? Instruction::UIToFP : Instruction::Trunc; - VectorStep = Builder.createWidenCast(CastOp, VectorStep, IVTy); - ToSkip.insert(VectorStep->getDefiningRecipe()); - } - VPValue *ScalarStep = IV->getStepValue(); - auto *ConstStep = ScalarStep->isLiveIn() - ? dyn_cast(ScalarStep->getLiveInIRValue()) - : nullptr; - if (!ConstStep || ConstStep->getValue() != 1) { - if (TypeInfo.inferScalarType(ScalarStep) != IVTy) { - ScalarStep = - Builder.createWidenCast(Instruction::Trunc, ScalarStep, IVTy); - ToSkip.insert(ScalarStep->getDefiningRecipe()); - } + VPBuilder Builder(PH); + VPInstruction *VectorStep = + Builder.createNaryOp(VPInstruction::WideIVStep, + {&Plan.getVF(), ScalarStep, + Plan.getOrAddLiveIn(Constant::getNullValue(IVTy))}, + FMFs, IV->getDebugLoc()); - unsigned MulOpc = - IVTy->isFloatingPointTy() ? Instruction::FMul : Instruction::Mul; - VPInstruction *Mul = Builder.createNaryOp(MulOpc, {VectorStep, ScalarStep}, - FMFs, IV->getDebugLoc()); - VectorStep = Mul; - ToSkip.insert(Mul); - } + ToSkip.insert(VectorStep); // Now create recipes to compute the induction steps for part 1 .. UF. Part 0 // remains the header phi. Parts > 0 are computed by adding Step to the