Skip to content

Commit c0c7146

Browse files
authored
[InstCombine] Optimize sub(sext(add(x,y)),sext(add(x,z))). (#144174)
This pattern can be often met in Flang generated LLVM IR, for example, for the counts of the loops generated for array expressions like: `a(x:x+y)` or `a(x+z:x+z)` or their variations. In order to compute the loop count, Flang needs to subtract the lower bound of the array slice from the upper bound of the array slice. To avoid the sign wraps, it sign extends the original values (that may be of any user data type) to `i64`. This peephole is really helpful in CPU2017/548.exchange2, where we have multiple following statements like this: ``` block(row+1:row+2, 7:9, i7) = block(row+1:row+2, 7:9, i7) - 10 ``` While this is just a 2x3 iterations loop nest, LLVM cannot figure it out, ending up vectorizing the inner loop really hard (with a vector epilog and scalar remainder). This, in turn, causes problems for LSR that ends up creating too many loop-carried values in the loop containing the above statement, which are then causing too many spills/reloads. Alive2: https://alive2.llvm.org/ce/z/gLgfYX Related to #143219.
1 parent 6ce8653 commit c0c7146

File tree

5 files changed

+413
-10
lines changed

5 files changed

+413
-10
lines changed

llvm/include/llvm/IR/PatternMatch.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1323,6 +1323,14 @@ m_NSWAdd(const LHS &L, const RHS &R) {
13231323
R);
13241324
}
13251325
template <typename LHS, typename RHS>
1326+
inline OverflowingBinaryOp_match<LHS, RHS, Instruction::Add,
1327+
OverflowingBinaryOperator::NoSignedWrap, true>
1328+
m_c_NSWAdd(const LHS &L, const RHS &R) {
1329+
return OverflowingBinaryOp_match<LHS, RHS, Instruction::Add,
1330+
OverflowingBinaryOperator::NoSignedWrap,
1331+
true>(L, R);
1332+
}
1333+
template <typename LHS, typename RHS>
13261334
inline OverflowingBinaryOp_match<LHS, RHS, Instruction::Sub,
13271335
OverflowingBinaryOperator::NoSignedWrap>
13281336
m_NSWSub(const LHS &L, const RHS &R) {

llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1896,7 +1896,7 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
18961896
{Sub, Builder.getFalse()});
18971897
Value *Ret = Builder.CreateSub(
18981898
ConstantInt::get(A->getType(), A->getType()->getScalarSizeInBits()),
1899-
Ctlz, "", /*HasNUW*/ true, /*HasNSW*/ true);
1899+
Ctlz, "", /*HasNUW=*/true, /*HasNSW=*/true);
19001900
return replaceInstUsesWith(I, Builder.CreateZExtOrTrunc(Ret, I.getType()));
19011901
}
19021902

@@ -2363,8 +2363,8 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
23632363
OverflowingBinaryOperator *LHSSub = cast<OverflowingBinaryOperator>(Op0);
23642364
bool HasNUW = I.hasNoUnsignedWrap() && LHSSub->hasNoUnsignedWrap();
23652365
bool HasNSW = HasNUW && I.hasNoSignedWrap() && LHSSub->hasNoSignedWrap();
2366-
Value *Add = Builder.CreateAdd(Y, Op1, "", /* HasNUW */ HasNUW,
2367-
/* HasNSW */ HasNSW);
2366+
Value *Add = Builder.CreateAdd(Y, Op1, "", /*HasNUW=*/HasNUW,
2367+
/*HasNSW=*/HasNSW);
23682368
BinaryOperator *Sub = BinaryOperator::CreateSub(X, Add);
23692369
Sub->setHasNoUnsignedWrap(HasNUW);
23702370
Sub->setHasNoSignedWrap(HasNSW);
@@ -2835,6 +2835,51 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
28352835
if (Instruction *Res = foldBinOpOfSelectAndCastOfSelectCondition(I))
28362836
return Res;
28372837

2838+
// (sub (sext (add nsw (X, Y)), sext (X))) --> (sext (Y))
2839+
if (match(Op1, m_SExtLike(m_Value(X))) &&
2840+
match(Op0, m_SExtLike(m_c_NSWAdd(m_Specific(X), m_Value(Y))))) {
2841+
Value *SExtY = Builder.CreateSExt(Y, I.getType());
2842+
return replaceInstUsesWith(I, SExtY);
2843+
}
2844+
2845+
// (sub[ nsw] (sext (add nsw (X, Y)), sext (add nsw (X, Z)))) -->
2846+
// --> (sub[ nsw] (sext (Y), sext (Z)))
2847+
{
2848+
Value *Z, *Add0, *Add1;
2849+
if (match(Op0, m_SExtLike(m_Value(Add0))) &&
2850+
match(Op1, m_SExtLike(m_Value(Add1))) &&
2851+
((match(Add0, m_NSWAdd(m_Value(X), m_Value(Y))) &&
2852+
match(Add1, m_c_NSWAdd(m_Specific(X), m_Value(Z)))) ||
2853+
(match(Add0, m_NSWAdd(m_Value(Y), m_Value(X))) &&
2854+
match(Add1, m_c_NSWAdd(m_Specific(X), m_Value(Z)))))) {
2855+
unsigned NumOfNewInstrs = 0;
2856+
// Non-constant Y, Z require new SExt.
2857+
NumOfNewInstrs += !isa<Constant>(Y) ? 1 : 0;
2858+
NumOfNewInstrs += !isa<Constant>(Z) ? 1 : 0;
2859+
// Check if we can trade some of the old instructions for the new ones.
2860+
unsigned NumOfDeadInstrs = 0;
2861+
if (Op0->hasOneUse()) {
2862+
// If Op0 (sext) has multiple uses, then we keep it
2863+
// and the add that it uses, otherwise, we can remove
2864+
// the sext and probably the add (depending on the number of its uses).
2865+
++NumOfDeadInstrs;
2866+
NumOfDeadInstrs += Add0->hasOneUse() ? 1 : 0;
2867+
}
2868+
if (Op1->hasOneUse()) {
2869+
++NumOfDeadInstrs;
2870+
NumOfDeadInstrs += Add1->hasOneUse() ? 1 : 0;
2871+
}
2872+
if (NumOfDeadInstrs >= NumOfNewInstrs) {
2873+
Value *SExtY = Builder.CreateSExt(Y, I.getType());
2874+
Value *SExtZ = Builder.CreateSExt(Z, I.getType());
2875+
Value *Sub = Builder.CreateSub(SExtY, SExtZ, "",
2876+
/*HasNUW=*/false,
2877+
/*HasNSW=*/I.hasNoSignedWrap());
2878+
return replaceInstUsesWith(I, Sub);
2879+
}
2880+
}
2881+
}
2882+
28382883
return TryToNarrowDeduceFlags();
28392884
}
28402885

llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) {
290290
auto *Op1C = cast<Constant>(Op1);
291291
return replaceInstUsesWith(
292292
I, Builder.CreateMul(NegOp0, ConstantExpr::getNeg(Op1C), "",
293-
/* HasNUW */ false,
293+
/*HasNUW=*/false,
294294
HasNSW && Op1C->isNotMinSignedValue()));
295295
}
296296

@@ -1255,8 +1255,8 @@ static Value *foldIDivShl(BinaryOperator &I, InstCombiner::BuilderTy &Builder) {
12551255
// or divisor has nsw and operator is sdiv.
12561256
Value *Dividend = Builder.CreateShl(
12571257
One, Y, "shl.dividend",
1258-
/*HasNUW*/ true,
1259-
/*HasNSW*/
1258+
/*HasNUW=*/true,
1259+
/*HasNSW=*/
12601260
IsSigned ? (Shl0->hasNoUnsignedWrap() || Shl1->hasNoUnsignedWrap())
12611261
: Shl0->hasNoSignedWrap());
12621262
return Builder.CreateLShr(Dividend, Z, "", I.isExact());

llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) {
233233
// However, only do this either if the old `sub` doesn't stick around, or
234234
// it was subtracting from a constant. Otherwise, this isn't profitable.
235235
return Builder.CreateSub(I->getOperand(1), I->getOperand(0),
236-
I->getName() + ".neg", /* HasNUW */ false,
236+
I->getName() + ".neg", /*HasNUW=*/false,
237237
IsNSW && I->hasNoSignedWrap());
238238
}
239239

@@ -404,15 +404,15 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) {
404404
IsNSW &= I->hasNoSignedWrap();
405405
if (Value *NegOp0 = negate(I->getOperand(0), IsNSW, Depth + 1))
406406
return Builder.CreateShl(NegOp0, I->getOperand(1), I->getName() + ".neg",
407-
/* HasNUW */ false, IsNSW);
407+
/*HasNUW=*/false, IsNSW);
408408
// Otherwise, `shl %x, C` can be interpreted as `mul %x, 1<<C`.
409409
Constant *Op1C;
410410
if (!match(I->getOperand(1), m_ImmConstant(Op1C)) || !IsTrulyNegation)
411411
return nullptr;
412412
return Builder.CreateMul(
413413
I->getOperand(0),
414414
Builder.CreateShl(Constant::getAllOnesValue(Op1C->getType()), Op1C),
415-
I->getName() + ".neg", /* HasNUW */ false, IsNSW);
415+
I->getName() + ".neg", /*HasNUW=*/false, IsNSW);
416416
}
417417
case Instruction::Or: {
418418
if (!cast<PossiblyDisjointInst>(I)->isDisjoint())
@@ -483,7 +483,7 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) {
483483
// Can't negate either of them.
484484
return nullptr;
485485
return Builder.CreateMul(NegatedOp, OtherOp, I->getName() + ".neg",
486-
/* HasNUW */ false, IsNSW && I->hasNoSignedWrap());
486+
/*HasNUW=*/false, IsNSW && I->hasNoSignedWrap());
487487
}
488488
default:
489489
return nullptr; // Don't know, likely not negatible for free.

0 commit comments

Comments
 (0)