Skip to content

Commit

Permalink
[InstSimplify] Make simplifyWithOpReplaced() recursive (PR63104)
Browse files Browse the repository at this point in the history
Support replacement of operands not only in the immediate
instruction, but also instructions it uses.

To the most part, this extension is straightforward, but there are
two bits worth highlighting:

First, we can now no longer assume that if the Op is a vector, the
instruction also returns a vector. If Op is a vector and the
instruction returns a scalar, we should consider it as a cross-lane
operation.

Second, for the x ^ x special case, we can no longer assume that
the operand is RepOp, as we might have a replacement higher up the
instruction chain.

There is one optimization regression, but it is in a fuzzer-generated
test case.

Fixes llvm/llvm-project#63104.
  • Loading branch information
nikic authored and veselypeta committed Sep 6, 2024
2 parents 5483a8d + 3d199d0 commit d99f061
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 73 deletions.
34 changes: 23 additions & 11 deletions llvm/lib/Analysis/InstructionSimplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4258,12 +4258,15 @@ static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
if (V == Op)
return RepOp;

if (!MaxRecurse--)
return nullptr;

// We cannot replace a constant, and shouldn't even try.
if (isa<Constant>(Op))
return nullptr;

auto *I = dyn_cast<Instruction>(V);
if (!I || !is_contained(I->operands(), Op))
if (!I)
return nullptr;

// The arguments of a phi node might refer to a value from a previous
Expand All @@ -4274,15 +4277,26 @@ static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
if (Op->getType()->isVectorTy()) {
// For vector types, the simplification must hold per-lane, so forbid
// potentially cross-lane operations like shufflevector.
assert(I->getType()->isVectorTy() && "Vector type mismatch");
if (isa<ShuffleVectorInst>(I) || isa<CallBase>(I))
if (!I->getType()->isVectorTy() || isa<ShuffleVectorInst>(I) ||
isa<CallBase>(I))
return nullptr;
}

// Replace Op with RepOp in instruction operands.
SmallVector<Value *, 8> NewOps(I->getNumOperands());
transform(I->operands(), NewOps.begin(),
[&](Value *V) { return V == Op ? RepOp : V; });
SmallVector<Value *, 8> NewOps;
bool AnyReplaced = false;
for (Value *InstOp : I->operands()) {
if (Value *NewInstOp = simplifyWithOpReplaced(
InstOp, Op, RepOp, Q, AllowRefinement, MaxRecurse)) {
NewOps.push_back(NewInstOp);
AnyReplaced = InstOp != NewInstOp;
} else {
NewOps.push_back(InstOp);
}
}

if (!AnyReplaced)
return nullptr;

if (!AllowRefinement) {
// General InstSimplify functions may refine the result, e.g. by returning
Expand All @@ -4307,10 +4321,8 @@ static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
// by assumption and this case never wraps, so nowrap flags can be
// ignored.
if ((Opcode == Instruction::Sub || Opcode == Instruction::Xor) &&
NewOps[0] == NewOps[1]) {
assert(NewOps[0] == RepOp && "Precondition for non-poison assumption");
NewOps[0] == RepOp && NewOps[1] == RepOp)
return Constant::getNullValue(I->getType());
}

// If we are substituting an absorber constant into a binop and extra
// poison can't leak if we remove the select -- because both operands of
Expand All @@ -4330,7 +4342,7 @@ static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
if (NewOps.size() == 2 && match(NewOps[1], m_Zero()))
return NewOps[0];
}
} else if (MaxRecurse) {
} else {
// The simplification queries below may return the original value. Consider:
// %div = udiv i32 %arg, %arg2
// %mul = mul nsw i32 %div, %arg2
Expand All @@ -4345,7 +4357,7 @@ static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
};

return PreventSelfSimplify(
::simplifyInstructionWithOperands(I, NewOps, Q, MaxRecurse - 1));
::simplifyInstructionWithOperands(I, NewOps, Q, MaxRecurse));
}

// If all operands are constant after substituting Op for RepOp then we can
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,7 @@ define i1 @n2_wrong_size(i4 %size0, i4 %size1, i4 %nmemb) {

define i1 @n3_wrong_pred(i4 %size, i4 %nmemb) {
; CHECK-LABEL: @n3_wrong_pred(
; CHECK-NEXT: [[CMP:%.*]] = icmp ne i4 [[SIZE:%.*]], 0
; CHECK-NEXT: [[SMUL:%.*]] = tail call { i4, i1 } @llvm.smul.with.overflow.i4(i4 [[SIZE]], i4 [[NMEMB:%.*]])
; CHECK-NEXT: [[SMUL_OV:%.*]] = extractvalue { i4, i1 } [[SMUL]], 1
; CHECK-NEXT: [[PHITMP:%.*]] = xor i1 [[SMUL_OV]], true
; CHECK-NEXT: [[OR:%.*]] = select i1 [[CMP]], i1 true, i1 [[PHITMP]]
; CHECK-NEXT: ret i1 [[OR]]
; CHECK-NEXT: ret i1 true
;
%cmp = icmp ne i4 %size, 0 ; not 'eq'
%smul = tail call { i4, i1 } @llvm.smul.with.overflow.i4(i4 %size, i4 %nmemb)
Expand All @@ -71,11 +66,7 @@ define i1 @n3_wrong_pred(i4 %size, i4 %nmemb) {
define i1 @n4_not_and(i4 %size, i4 %nmemb) {
; CHECK-LABEL: @n4_not_and(
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i4 [[SIZE:%.*]], 0
; CHECK-NEXT: [[SMUL:%.*]] = tail call { i4, i1 } @llvm.smul.with.overflow.i4(i4 [[SIZE]], i4 [[NMEMB:%.*]])
; CHECK-NEXT: [[SMUL_OV:%.*]] = extractvalue { i4, i1 } [[SMUL]], 1
; CHECK-NEXT: [[PHITMP:%.*]] = xor i1 [[SMUL_OV]], true
; CHECK-NEXT: [[OR:%.*]] = select i1 [[CMP]], i1 [[PHITMP]], i1 false
; CHECK-NEXT: ret i1 [[OR]]
; CHECK-NEXT: ret i1 [[CMP]]
;
%cmp = icmp eq i4 %size, 0
%smul = tail call { i4, i1 } @llvm.smul.with.overflow.i4(i4 %size, i4 %nmemb)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,7 @@ define i1 @n2_wrong_size(i4 %size0, i4 %size1, i4 %nmemb) {

define i1 @n3_wrong_pred(i4 %size, i4 %nmemb) {
; CHECK-LABEL: @n3_wrong_pred(
; CHECK-NEXT: [[CMP:%.*]] = icmp ne i4 [[SIZE:%.*]], 0
; CHECK-NEXT: [[UMUL:%.*]] = tail call { i4, i1 } @llvm.umul.with.overflow.i4(i4 [[SIZE]], i4 [[NMEMB:%.*]])
; CHECK-NEXT: [[UMUL_OV:%.*]] = extractvalue { i4, i1 } [[UMUL]], 1
; CHECK-NEXT: [[PHITMP:%.*]] = xor i1 [[UMUL_OV]], true
; CHECK-NEXT: [[OR:%.*]] = select i1 [[CMP]], i1 true, i1 [[PHITMP]]
; CHECK-NEXT: ret i1 [[OR]]
; CHECK-NEXT: ret i1 true
;
%cmp = icmp ne i4 %size, 0 ; not 'eq'
%umul = tail call { i4, i1 } @llvm.umul.with.overflow.i4(i4 %size, i4 %nmemb)
Expand All @@ -71,11 +66,7 @@ define i1 @n3_wrong_pred(i4 %size, i4 %nmemb) {
define i1 @n4_not_and(i4 %size, i4 %nmemb) {
; CHECK-LABEL: @n4_not_and(
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i4 [[SIZE:%.*]], 0
; CHECK-NEXT: [[UMUL:%.*]] = tail call { i4, i1 } @llvm.umul.with.overflow.i4(i4 [[SIZE]], i4 [[NMEMB:%.*]])
; CHECK-NEXT: [[UMUL_OV:%.*]] = extractvalue { i4, i1 } [[UMUL]], 1
; CHECK-NEXT: [[PHITMP:%.*]] = xor i1 [[UMUL_OV]], true
; CHECK-NEXT: [[OR:%.*]] = select i1 [[CMP]], i1 [[PHITMP]], i1 false
; CHECK-NEXT: ret i1 [[OR]]
; CHECK-NEXT: ret i1 [[CMP]]
;
%cmp = icmp eq i4 %size, 0
%umul = tail call { i4, i1 } @llvm.umul.with.overflow.i4(i4 %size, i4 %nmemb)
Expand Down
44 changes: 18 additions & 26 deletions llvm/test/Transforms/InstCombine/select-ctlz-to-cttz.ll
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ declare void @use2(i1)

define i32 @select_clz_to_ctz(i32 %a) {
; CHECK-LABEL: @select_clz_to_ctz(
; CHECK-NEXT: [[COND:%.*]] = call i32 @llvm.cttz.i32(i32 [[A:%.*]], i1 true), !range [[RNG0:![0-9]+]]
; CHECK-NEXT: ret i32 [[COND]]
; CHECK-NEXT: [[SUB1:%.*]] = call i32 @llvm.cttz.i32(i32 [[A:%.*]], i1 true), !range [[RNG0:![0-9]+]]
; CHECK-NEXT: ret i32 [[SUB1]]
;
%sub = sub i32 0, %a
%and = and i32 %sub, %a
Expand Down Expand Up @@ -74,8 +74,7 @@ define i32 @select_clz_to_ctz_extra_use(i32 %a) {
; CHECK-LABEL: @select_clz_to_ctz_extra_use(
; CHECK-NEXT: [[SUB1:%.*]] = call i32 @llvm.cttz.i32(i32 [[A:%.*]], i1 true), !range [[RNG0]]
; CHECK-NEXT: call void @use(i32 [[SUB1]])
; CHECK-NEXT: [[COND:%.*]] = call i32 @llvm.cttz.i32(i32 [[A]], i1 true), !range [[RNG0]]
; CHECK-NEXT: ret i32 [[COND]]
; CHECK-NEXT: ret i32 [[SUB1]]
;
%sub = sub i32 0, %a
%and = and i32 %sub, %a
Expand All @@ -89,8 +88,8 @@ define i32 @select_clz_to_ctz_extra_use(i32 %a) {

define i32 @select_clz_to_ctz_and_commuted(i32 %a) {
; CHECK-LABEL: @select_clz_to_ctz_and_commuted(
; CHECK-NEXT: [[COND:%.*]] = call i32 @llvm.cttz.i32(i32 [[A:%.*]], i1 true), !range [[RNG0]]
; CHECK-NEXT: ret i32 [[COND]]
; CHECK-NEXT: [[SUB1:%.*]] = call i32 @llvm.cttz.i32(i32 [[A:%.*]], i1 true), !range [[RNG0]]
; CHECK-NEXT: ret i32 [[SUB1]]
;
%sub = sub i32 0, %a
%and = and i32 %a, %sub
Expand All @@ -105,8 +104,8 @@ define i32 @select_clz_to_ctz_icmp_ne(i32 %a) {
; CHECK-LABEL: @select_clz_to_ctz_icmp_ne(
; CHECK-NEXT: [[TOBOOL:%.*]] = icmp ne i32 [[A:%.*]], 0
; CHECK-NEXT: call void @use2(i1 [[TOBOOL]])
; CHECK-NEXT: [[COND:%.*]] = call i32 @llvm.cttz.i32(i32 [[A]], i1 true), !range [[RNG0]]
; CHECK-NEXT: ret i32 [[COND]]
; CHECK-NEXT: [[SUB1:%.*]] = call i32 @llvm.cttz.i32(i32 [[A]], i1 true), !range [[RNG0]]
; CHECK-NEXT: ret i32 [[SUB1]]
;
%sub = sub i32 0, %a
%and = and i32 %sub, %a
Expand All @@ -120,8 +119,8 @@ define i32 @select_clz_to_ctz_icmp_ne(i32 %a) {

define i64 @select_clz_to_ctz_i64(i64 %a) {
; CHECK-LABEL: @select_clz_to_ctz_i64(
; CHECK-NEXT: [[COND:%.*]] = call i64 @llvm.cttz.i64(i64 [[A:%.*]], i1 true), !range [[RNG1:![0-9]+]]
; CHECK-NEXT: ret i64 [[COND]]
; CHECK-NEXT: [[SUB1:%.*]] = call i64 @llvm.cttz.i64(i64 [[A:%.*]], i1 true), !range [[RNG1:![0-9]+]]
; CHECK-NEXT: ret i64 [[SUB1]]
;
%sub = sub i64 0, %a
%and = and i64 %sub, %a
Expand All @@ -139,10 +138,8 @@ define i32 @select_clz_to_ctz_wrong_sub(i32 %a) {
; CHECK-NEXT: [[SUB:%.*]] = sub i32 1, [[A:%.*]]
; CHECK-NEXT: [[AND:%.*]] = and i32 [[SUB]], [[A]]
; CHECK-NEXT: [[LZ:%.*]] = tail call i32 @llvm.ctlz.i32(i32 [[AND]], i1 true), !range [[RNG0]]
; CHECK-NEXT: [[TOBOOL:%.*]] = icmp eq i32 [[A]], 0
; CHECK-NEXT: [[SUB1:%.*]] = xor i32 [[LZ]], 31
; CHECK-NEXT: [[COND:%.*]] = select i1 [[TOBOOL]], i32 [[LZ]], i32 [[SUB1]]
; CHECK-NEXT: ret i32 [[COND]]
; CHECK-NEXT: ret i32 [[SUB1]]
;
%sub = sub i32 1, %a
%and = and i32 %sub, %a
Expand All @@ -159,10 +156,8 @@ define i64 @select_clz_to_ctz_i64_wrong_xor(i64 %a) {
; CHECK-NEXT: [[SUB:%.*]] = sub i64 0, [[A:%.*]]
; CHECK-NEXT: [[AND:%.*]] = and i64 [[SUB]], [[A]]
; CHECK-NEXT: [[LZ:%.*]] = tail call i64 @llvm.ctlz.i64(i64 [[AND]], i1 true), !range [[RNG1]]
; CHECK-NEXT: [[TOBOOL:%.*]] = icmp eq i64 [[A]], 0
; CHECK-NEXT: [[SUB11:%.*]] = or i64 [[LZ]], 64
; CHECK-NEXT: [[COND:%.*]] = select i1 [[TOBOOL]], i64 [[LZ]], i64 [[SUB11]]
; CHECK-NEXT: ret i64 [[COND]]
; CHECK-NEXT: ret i64 [[SUB11]]
;
%sub = sub i64 0, %a
%and = and i64 %sub, %a
Expand All @@ -175,12 +170,9 @@ define i64 @select_clz_to_ctz_i64_wrong_xor(i64 %a) {

define i64 @select_clz_to_ctz_i64_wrong_icmp_cst(i64 %a) {
; CHECK-LABEL: @select_clz_to_ctz_i64_wrong_icmp_cst(
; CHECK-NEXT: [[SUB:%.*]] = sub i64 0, [[A:%.*]]
; CHECK-NEXT: [[AND:%.*]] = and i64 [[SUB]], [[A]]
; CHECK-NEXT: [[LZ:%.*]] = tail call i64 @llvm.ctlz.i64(i64 [[AND]], i1 true), !range [[RNG1]]
; CHECK-NEXT: [[TOBOOL:%.*]] = icmp eq i64 [[A]], 1
; CHECK-NEXT: [[SUB1:%.*]] = xor i64 [[LZ]], 63
; CHECK-NEXT: [[COND:%.*]] = select i1 [[TOBOOL]], i64 [[LZ]], i64 [[SUB1]]
; CHECK-NEXT: [[TOBOOL:%.*]] = icmp eq i64 [[A:%.*]], 1
; CHECK-NEXT: [[SUB1:%.*]] = call i64 @llvm.cttz.i64(i64 [[A]], i1 true), !range [[RNG1]]
; CHECK-NEXT: [[COND:%.*]] = select i1 [[TOBOOL]], i64 63, i64 [[SUB1]]
; CHECK-NEXT: ret i64 [[COND]]
;
%sub = sub i64 0, %a
Expand Down Expand Up @@ -255,8 +247,8 @@ define i4 @PR45762(i3 %x4) {
; CHECK-NEXT: [[T7:%.*]] = zext i3 [[T4]] to i4
; CHECK-NEXT: [[ONE_HOT_16:%.*]] = shl nuw i4 1, [[T7]]
; CHECK-NEXT: [[OR_69_NOT:%.*]] = icmp eq i3 [[X4]], 0
; CHECK-NEXT: [[UMUL_231:%.*]] = select i1 [[OR_69_NOT]], i4 0, i4 [[T7]]
; CHECK-NEXT: [[SEL_71:%.*]] = shl i4 [[ONE_HOT_16]], [[UMUL_231]]
; CHECK-NEXT: [[UMUL_231:%.*]] = shl i4 [[ONE_HOT_16]], [[T7]]
; CHECK-NEXT: [[SEL_71:%.*]] = select i1 [[OR_69_NOT]], i4 -8, i4 [[UMUL_231]]
; CHECK-NEXT: ret i4 [[SEL_71]]
;
%t4 = call i3 @llvm.cttz.i3(i3 %x4, i1 false)
Expand Down Expand Up @@ -284,8 +276,8 @@ define i4 @PR45762_logical(i3 %x4) {
; CHECK-NEXT: [[T7:%.*]] = zext i3 [[T4]] to i4
; CHECK-NEXT: [[ONE_HOT_16:%.*]] = shl nuw i4 1, [[T7]]
; CHECK-NEXT: [[OR_69_NOT:%.*]] = icmp eq i3 [[X4]], 0
; CHECK-NEXT: [[UMUL_231:%.*]] = select i1 [[OR_69_NOT]], i4 0, i4 [[T7]]
; CHECK-NEXT: [[SEL_71:%.*]] = shl i4 [[ONE_HOT_16]], [[UMUL_231]]
; CHECK-NEXT: [[UMUL_231:%.*]] = shl i4 [[ONE_HOT_16]], [[T7]]
; CHECK-NEXT: [[SEL_71:%.*]] = select i1 [[OR_69_NOT]], i4 -8, i4 [[UMUL_231]]
; CHECK-NEXT: ret i4 [[SEL_71]]
;
%t4 = call i3 @llvm.cttz.i3(i3 %x4, i1 false)
Expand Down
13 changes: 13 additions & 0 deletions llvm/test/Transforms/InstCombine/shift.ll
Original file line number Diff line number Diff line change
Expand Up @@ -1747,6 +1747,19 @@ define void @ashr_out_of_range(ptr %A) {
; https://bugs.chromium.org/p/oss-fuzz/issues/detail?id=26135
define void @ashr_out_of_range_1(ptr %A) {
; CHECK-LABEL: @ashr_out_of_range_1(
; CHECK-NEXT: [[L:%.*]] = load i177, ptr [[A:%.*]], align 4
; CHECK-NEXT: [[L_FROZEN:%.*]] = freeze i177 [[L]]
; CHECK-NEXT: [[TMP1:%.*]] = icmp eq i177 [[L_FROZEN]], -1
; CHECK-NEXT: [[B:%.*]] = select i1 [[TMP1]], i177 0, i177 [[L_FROZEN]]
; CHECK-NEXT: [[TMP2:%.*]] = trunc i177 [[B]] to i64
; CHECK-NEXT: [[TMP3:%.*]] = add i64 [[TMP2]], -1
; CHECK-NEXT: [[G11:%.*]] = getelementptr i177, ptr [[A]], i64 [[TMP3]]
; CHECK-NEXT: [[C17:%.*]] = icmp sgt i177 [[B]], [[L_FROZEN]]
; CHECK-NEXT: [[TMP4:%.*]] = sext i1 [[C17]] to i64
; CHECK-NEXT: [[G62:%.*]] = getelementptr i177, ptr [[G11]], i64 [[TMP4]]
; CHECK-NEXT: [[TMP5:%.*]] = icmp eq i177 [[L_FROZEN]], -1
; CHECK-NEXT: [[B28:%.*]] = select i1 [[TMP5]], i177 0, i177 [[L_FROZEN]]
; CHECK-NEXT: store i177 [[B28]], ptr [[G62]], align 4
; CHECK-NEXT: ret void
;
%L = load i177, ptr %A, align 4
Expand Down
18 changes: 4 additions & 14 deletions llvm/test/Transforms/InstSimplify/select.ll
Original file line number Diff line number Diff line change
Expand Up @@ -1094,9 +1094,7 @@ define i8 @select_eq_xor_recursive(i8 %a, i8 %b) {
; CHECK-LABEL: @select_eq_xor_recursive(
; CHECK-NEXT: [[XOR:%.*]] = xor i8 [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: [[INV:%.*]] = xor i8 [[XOR]], -1
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[A]], [[B]]
; CHECK-NEXT: [[SEL:%.*]] = select i1 [[CMP]], i8 -1, i8 [[INV]]
; CHECK-NEXT: ret i8 [[SEL]]
; CHECK-NEXT: ret i8 [[INV]]
;
%xor = xor i8 %a, %b
%inv = xor i8 %xor, -1
Expand All @@ -1110,9 +1108,7 @@ define i8 @select_eq_xor_recursive2(i8 %a, i8 %b) {
; CHECK-NEXT: [[XOR:%.*]] = xor i8 [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: [[INV:%.*]] = xor i8 [[XOR]], -1
; CHECK-NEXT: [[ADD:%.*]] = add i8 [[INV]], 10
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[A]], [[B]]
; CHECK-NEXT: [[SEL:%.*]] = select i1 [[CMP]], i8 9, i8 [[ADD]]
; CHECK-NEXT: ret i8 [[SEL]]
; CHECK-NEXT: ret i8 [[ADD]]
;
%xor = xor i8 %a, %b
%inv = xor i8 %xor, -1
Expand Down Expand Up @@ -1162,9 +1158,7 @@ define i8 @select_eq_and_recursive(i8 %a) {
; CHECK-NEXT: [[NEG:%.*]] = sub i8 0, [[A:%.*]]
; CHECK-NEXT: [[AND:%.*]] = and i8 [[NEG]], [[A]]
; CHECK-NEXT: [[ADD:%.*]] = add i8 [[AND]], 1
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[A]], 0
; CHECK-NEXT: [[SEL:%.*]] = select i1 [[CMP]], i8 1, i8 [[ADD]]
; CHECK-NEXT: ret i8 [[SEL]]
; CHECK-NEXT: ret i8 [[ADD]]
;
%neg = sub i8 0, %a
%and = and i8 %neg, %a
Expand Down Expand Up @@ -1194,11 +1188,7 @@ define i8 @select_eq_and_recursive_propagates_poison(i8 %a, i8 %b) {

define i8 @select_eq_xor_recursive_allow_refinement(i8 %a, i8 %b) {
; CHECK-LABEL: @select_eq_xor_recursive_allow_refinement(
; CHECK-NEXT: [[XOR1:%.*]] = add i8 [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: [[XOR2:%.*]] = xor i8 [[A]], [[XOR1]]
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[B]], 0
; CHECK-NEXT: [[SEL:%.*]] = select i1 [[CMP]], i8 [[XOR2]], i8 0
; CHECK-NEXT: ret i8 [[SEL]]
; CHECK-NEXT: ret i8 0
;
%xor1 = add i8 %a, %b
%xor2 = xor i8 %a, %xor1
Expand Down

0 comments on commit d99f061

Please sign in to comment.