Skip to content

Commit 0068078

Browse files
authored
[NVPTX] Remove NVPTX::IMAD opcode, and rely on intruction selection only (#121724)
I noticed that NVPTX will sometimes emit `mad.lo` to multiply by 1, e.g. in https://gcc.godbolt.org/z/4j47Y9W4c. This happens when DAGCombiner operates on the add before the mul, so the imad contraction happens regardless of whether the mul could have been simplified. To fix this, I remove `NVPTXISD::IMAD` and only combine to mad during selection. This allows the default DAGCombiner patterns to simplify the graph without any NVPTX-specific intervention.
1 parent 5a90168 commit 0068078

File tree

6 files changed

+194
-180
lines changed

6 files changed

+194
-180
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1046,7 +1046,6 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
10461046
MAKE_CASE(NVPTXISD::StoreV4)
10471047
MAKE_CASE(NVPTXISD::FSHL_CLAMP)
10481048
MAKE_CASE(NVPTXISD::FSHR_CLAMP)
1049-
MAKE_CASE(NVPTXISD::IMAD)
10501049
MAKE_CASE(NVPTXISD::BFE)
10511050
MAKE_CASE(NVPTXISD::BFI)
10521051
MAKE_CASE(NVPTXISD::PRMT)
@@ -4451,14 +4450,8 @@ PerformADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
44514450
if (!N0.getNode()->hasOneUse())
44524451
return SDValue();
44534452

4454-
// fold (add (mul a, b), c) -> (mad a, b, c)
4455-
//
4456-
if (N0.getOpcode() == ISD::MUL)
4457-
return DCI.DAG.getNode(NVPTXISD::IMAD, SDLoc(N), VT, N0.getOperand(0),
4458-
N0.getOperand(1), N1);
4459-
44604453
// fold (add (select cond, 0, (mul a, b)), c)
4461-
// -> (select cond, c, (mad a, b, c))
4454+
// -> (select cond, c, (add (mul a, b), c))
44624455
//
44634456
if (N0.getOpcode() == ISD::SELECT) {
44644457
unsigned ZeroOpNum;
@@ -4473,8 +4466,10 @@ PerformADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
44734466
if (M->getOpcode() != ISD::MUL || !M.getNode()->hasOneUse())
44744467
return SDValue();
44754468

4476-
SDValue MAD = DCI.DAG.getNode(NVPTXISD::IMAD, SDLoc(N), VT,
4477-
M->getOperand(0), M->getOperand(1), N1);
4469+
SDLoc DL(N);
4470+
SDValue Mul =
4471+
DCI.DAG.getNode(ISD::MUL, DL, VT, M->getOperand(0), M->getOperand(1));
4472+
SDValue MAD = DCI.DAG.getNode(ISD::ADD, DL, VT, Mul, N1);
44784473
return DCI.DAG.getSelect(SDLoc(N), VT, N0->getOperand(0),
44794474
((ZeroOpNum == 1) ? N1 : MAD),
44804475
((ZeroOpNum == 1) ? MAD : N1));
@@ -4911,8 +4906,10 @@ static SDValue matchMADConstOnePattern(SDValue Add) {
49114906
static SDValue combineMADConstOne(SDValue X, SDValue Add, EVT VT, SDLoc DL,
49124907
TargetLowering::DAGCombinerInfo &DCI) {
49134908

4914-
if (SDValue Y = matchMADConstOnePattern(Add))
4915-
return DCI.DAG.getNode(NVPTXISD::IMAD, DL, VT, X, Y, X);
4909+
if (SDValue Y = matchMADConstOnePattern(Add)) {
4910+
SDValue Mul = DCI.DAG.getNode(ISD::MUL, DL, VT, X, Y);
4911+
return DCI.DAG.getNode(ISD::ADD, DL, VT, Mul, X);
4912+
}
49164913

49174914
return SDValue();
49184915
}
@@ -4959,7 +4956,7 @@ PerformMULCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
49594956

49604957
SDLoc DL(N);
49614958

4962-
// (mul x, (add y, 1)) -> (mad x, y, x)
4959+
// (mul x, (add y, 1)) -> (add (mul x, y), x)
49634960
if (SDValue Res = combineMADConstOne(N0, N1, VT, DL, DCI))
49644961
return Res;
49654962
if (SDValue Res = combineMADConstOne(N1, N0, VT, DL, DCI))

llvm/lib/Target/NVPTX/NVPTXISelLowering.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@ enum NodeType : unsigned {
5555
FSHR_CLAMP,
5656
MUL_WIDE_SIGNED,
5757
MUL_WIDE_UNSIGNED,
58-
IMAD,
5958
SETP_F16X2,
6059
SETP_BF16X2,
6160
BFE,

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 34 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ def hasLDG : Predicate<"Subtarget->hasLDG()">;
141141
def hasLDU : Predicate<"Subtarget->hasLDU()">;
142142
def hasPTXASUnreachableBug : Predicate<"Subtarget->hasPTXASUnreachableBug()">;
143143
def noPTXASUnreachableBug : Predicate<"!Subtarget->hasPTXASUnreachableBug()">;
144+
def hasOptEnabled : Predicate<"TM.getOptLevel() != CodeGenOptLevel::None">;
144145

145146
def doF32FTZ : Predicate<"useF32FTZ()">;
146147
def doNoF32FTZ : Predicate<"!useF32FTZ()">;
@@ -1092,73 +1093,39 @@ def : Pat<(mul (zext i16:$a), (i32 UInt16Const:$b)),
10921093
//
10931094
// Integer multiply-add
10941095
//
1095-
def SDTIMAD :
1096-
SDTypeProfile<1, 3, [SDTCisSameAs<0, 1>, SDTCisInt<0>, SDTCisInt<2>,
1097-
SDTCisSameAs<0, 2>, SDTCisSameAs<0, 3>]>;
1098-
def imad : SDNode<"NVPTXISD::IMAD", SDTIMAD>;
1099-
1100-
def MAD16rrr :
1101-
NVPTXInst<(outs Int16Regs:$dst),
1102-
(ins Int16Regs:$a, Int16Regs:$b, Int16Regs:$c),
1103-
"mad.lo.s16 \t$dst, $a, $b, $c;",
1104-
[(set i16:$dst, (imad i16:$a, i16:$b, i16:$c))]>;
1105-
def MAD16rri :
1106-
NVPTXInst<(outs Int16Regs:$dst),
1107-
(ins Int16Regs:$a, Int16Regs:$b, i16imm:$c),
1108-
"mad.lo.s16 \t$dst, $a, $b, $c;",
1109-
[(set i16:$dst, (imad i16:$a, i16:$b, imm:$c))]>;
1110-
def MAD16rir :
1111-
NVPTXInst<(outs Int16Regs:$dst),
1112-
(ins Int16Regs:$a, i16imm:$b, Int16Regs:$c),
1113-
"mad.lo.s16 \t$dst, $a, $b, $c;",
1114-
[(set i16:$dst, (imad i16:$a, imm:$b, i16:$c))]>;
1115-
def MAD16rii :
1116-
NVPTXInst<(outs Int16Regs:$dst),
1117-
(ins Int16Regs:$a, i16imm:$b, i16imm:$c),
1118-
"mad.lo.s16 \t$dst, $a, $b, $c;",
1119-
[(set i16:$dst, (imad i16:$a, imm:$b, imm:$c))]>;
1120-
1121-
def MAD32rrr :
1122-
NVPTXInst<(outs Int32Regs:$dst),
1123-
(ins Int32Regs:$a, Int32Regs:$b, Int32Regs:$c),
1124-
"mad.lo.s32 \t$dst, $a, $b, $c;",
1125-
[(set i32:$dst, (imad i32:$a, i32:$b, i32:$c))]>;
1126-
def MAD32rri :
1127-
NVPTXInst<(outs Int32Regs:$dst),
1128-
(ins Int32Regs:$a, Int32Regs:$b, i32imm:$c),
1129-
"mad.lo.s32 \t$dst, $a, $b, $c;",
1130-
[(set i32:$dst, (imad i32:$a, i32:$b, imm:$c))]>;
1131-
def MAD32rir :
1132-
NVPTXInst<(outs Int32Regs:$dst),
1133-
(ins Int32Regs:$a, i32imm:$b, Int32Regs:$c),
1134-
"mad.lo.s32 \t$dst, $a, $b, $c;",
1135-
[(set i32:$dst, (imad i32:$a, imm:$b, i32:$c))]>;
1136-
def MAD32rii :
1137-
NVPTXInst<(outs Int32Regs:$dst),
1138-
(ins Int32Regs:$a, i32imm:$b, i32imm:$c),
1139-
"mad.lo.s32 \t$dst, $a, $b, $c;",
1140-
[(set i32:$dst, (imad i32:$a, imm:$b, imm:$c))]>;
1141-
1142-
def MAD64rrr :
1143-
NVPTXInst<(outs Int64Regs:$dst),
1144-
(ins Int64Regs:$a, Int64Regs:$b, Int64Regs:$c),
1145-
"mad.lo.s64 \t$dst, $a, $b, $c;",
1146-
[(set i64:$dst, (imad i64:$a, i64:$b, i64:$c))]>;
1147-
def MAD64rri :
1148-
NVPTXInst<(outs Int64Regs:$dst),
1149-
(ins Int64Regs:$a, Int64Regs:$b, i64imm:$c),
1150-
"mad.lo.s64 \t$dst, $a, $b, $c;",
1151-
[(set i64:$dst, (imad i64:$a, i64:$b, imm:$c))]>;
1152-
def MAD64rir :
1153-
NVPTXInst<(outs Int64Regs:$dst),
1154-
(ins Int64Regs:$a, i64imm:$b, Int64Regs:$c),
1155-
"mad.lo.s64 \t$dst, $a, $b, $c;",
1156-
[(set i64:$dst, (imad i64:$a, imm:$b, i64:$c))]>;
1157-
def MAD64rii :
1158-
NVPTXInst<(outs Int64Regs:$dst),
1159-
(ins Int64Regs:$a, i64imm:$b, i64imm:$c),
1160-
"mad.lo.s64 \t$dst, $a, $b, $c;",
1161-
[(set i64:$dst, (imad i64:$a, imm:$b, imm:$c))]>;
1096+
def mul_oneuse : PatFrag<(ops node:$a, node:$b), (mul node:$a, node:$b), [{
1097+
return N->hasOneUse();
1098+
}]>;
1099+
1100+
multiclass MAD<string Ptx, ValueType VT, NVPTXRegClass Reg, Operand Imm> {
1101+
def rrr:
1102+
NVPTXInst<(outs Reg:$dst),
1103+
(ins Reg:$a, Reg:$b, Reg:$c),
1104+
Ptx # " \t$dst, $a, $b, $c;",
1105+
[(set VT:$dst, (add (mul_oneuse VT:$a, VT:$b), VT:$c))]>;
1106+
1107+
def rir:
1108+
NVPTXInst<(outs Reg:$dst),
1109+
(ins Reg:$a, Imm:$b, Reg:$c),
1110+
Ptx # " \t$dst, $a, $b, $c;",
1111+
[(set VT:$dst, (add (mul_oneuse VT:$a, imm:$b), VT:$c))]>;
1112+
def rri:
1113+
NVPTXInst<(outs Reg:$dst),
1114+
(ins Reg:$a, Reg:$b, Imm:$c),
1115+
Ptx # " \t$dst, $a, $b, $c;",
1116+
[(set VT:$dst, (add (mul_oneuse VT:$a, VT:$b), imm:$c))]>;
1117+
def rii:
1118+
NVPTXInst<(outs Reg:$dst),
1119+
(ins Reg:$a, Imm:$b, Imm:$c),
1120+
Ptx # " \t$dst, $a, $b, $c;",
1121+
[(set VT:$dst, (add (mul_oneuse VT:$a, imm:$b), imm:$c))]>;
1122+
}
1123+
1124+
let Predicates = [hasOptEnabled] in {
1125+
defm MAD16 : MAD<"mad.lo.s16", i16, Int16Regs, i16imm>;
1126+
defm MAD32 : MAD<"mad.lo.s32", i32, Int32Regs, i32imm>;
1127+
defm MAD64 : MAD<"mad.lo.s64", i64, Int64Regs, i64imm>;
1128+
}
11621129

11631130
def INEG16 :
11641131
NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$src),

llvm/test/CodeGen/NVPTX/combine-mad.ll

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,3 +183,58 @@ define i32 @test4_rev(i32 %a, i32 %b, i32 %c, i1 %p) {
183183
%add = add i32 %c, %sel
184184
ret i32 %add
185185
}
186+
187+
declare i32 @use(i32 %0, i32 %1)
188+
189+
define i32 @test_mad_multi_use(i32 %a, i32 %b, i32 %c) {
190+
; CHECK-LABEL: test_mad_multi_use(
191+
; CHECK: {
192+
; CHECK-NEXT: .reg .b32 %r<8>;
193+
; CHECK-EMPTY:
194+
; CHECK-NEXT: // %bb.0:
195+
; CHECK-NEXT: ld.param.u32 %r1, [test_mad_multi_use_param_0];
196+
; CHECK-NEXT: ld.param.u32 %r2, [test_mad_multi_use_param_1];
197+
; CHECK-NEXT: mul.lo.s32 %r3, %r1, %r2;
198+
; CHECK-NEXT: ld.param.u32 %r4, [test_mad_multi_use_param_2];
199+
; CHECK-NEXT: add.s32 %r5, %r3, %r4;
200+
; CHECK-NEXT: { // callseq 0, 0
201+
; CHECK-NEXT: .param .b32 param0;
202+
; CHECK-NEXT: st.param.b32 [param0], %r3;
203+
; CHECK-NEXT: .param .b32 param1;
204+
; CHECK-NEXT: st.param.b32 [param1], %r5;
205+
; CHECK-NEXT: .param .b32 retval0;
206+
; CHECK-NEXT: call.uni (retval0),
207+
; CHECK-NEXT: use,
208+
; CHECK-NEXT: (
209+
; CHECK-NEXT: param0,
210+
; CHECK-NEXT: param1
211+
; CHECK-NEXT: );
212+
; CHECK-NEXT: ld.param.b32 %r6, [retval0];
213+
; CHECK-NEXT: } // callseq 0
214+
; CHECK-NEXT: st.param.b32 [func_retval0], %r6;
215+
; CHECK-NEXT: ret;
216+
%mul = mul i32 %a, %b
217+
%add = add i32 %mul, %c
218+
%res = call i32 @use(i32 %mul, i32 %add)
219+
ret i32 %res
220+
}
221+
222+
;; This case relies on mad x 1 y => add x y, previously we emit:
223+
;; mad.lo.s32 %r3, %r1, 1, %r2;
224+
define i32 @test_mad_fold(i32 %x) {
225+
; CHECK-LABEL: test_mad_fold(
226+
; CHECK: {
227+
; CHECK-NEXT: .reg .b32 %r<7>;
228+
; CHECK-EMPTY:
229+
; CHECK-NEXT: // %bb.0:
230+
; CHECK-NEXT: ld.param.u32 %r1, [test_mad_fold_param_0];
231+
; CHECK-NEXT: mul.hi.s32 %r2, %r1, -2147221471;
232+
; CHECK-NEXT: add.s32 %r3, %r2, %r1;
233+
; CHECK-NEXT: shr.u32 %r4, %r3, 31;
234+
; CHECK-NEXT: shr.s32 %r5, %r3, 12;
235+
; CHECK-NEXT: add.s32 %r6, %r5, %r4;
236+
; CHECK-NEXT: st.param.b32 [func_retval0], %r6;
237+
; CHECK-NEXT: ret;
238+
%div = sdiv i32 %x, 8191
239+
ret i32 %div
240+
}

llvm/test/CodeGen/NVPTX/dynamic_stackalloc.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
; CHECK-NOT: __local_depot
1313

1414
; CHECK-32: ld.param.u32 %r[[SIZE:[0-9]]], [test_dynamic_stackalloc_param_0];
15-
; CHECK-32-NEXT: mad.lo.s32 %r[[SIZE2:[0-9]]], %r[[SIZE]], 1, 7;
15+
; CHECK-32-NEXT: add.s32 %r[[SIZE2:[0-9]]], %r[[SIZE]], 7;
1616
; CHECK-32-NEXT: and.b32 %r[[SIZE3:[0-9]]], %r[[SIZE2]], -8;
1717
; CHECK-32-NEXT: alloca.u32 %r[[ALLOCA:[0-9]]], %r[[SIZE3]], 16;
1818
; CHECK-32-NEXT: cvta.local.u32 %r[[ALLOCA]], %r[[ALLOCA]];

0 commit comments

Comments
 (0)