Skip to content

Commit e304cbe

Browse files
igorban-inteligcbot
authored andcommitted
Add new shuffle combine optimization in GenXSimplify pass
Introduces a new optimization in the GenXSimplify pass that identifies and combines consecutive or overlapping shufflevector instructions into a single shuffle operation.
1 parent 1257007 commit e304cbe

File tree

3 files changed

+313
-0
lines changed

3 files changed

+313
-0
lines changed

IGC/VectorCompiler/lib/GenXCodeGen/GenXTargetMachine.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -1061,6 +1061,8 @@ void GenXTargetMachine::adjustPassManager(PassManagerBuilder &PMBuilder) {
10611061
PM.add(createGenXSimplifyPass());
10621062
};
10631063
PMBuilder.addExtension(PassManagerBuilder::EP_Peephole, AddGenXPeephole);
1064+
PMBuilder.addExtension(PassManagerBuilder::EP_EnabledOnOptLevel0,
1065+
AddGenXPeephole);
10641066
}
10651067

10661068
#else // LLVM_VERSION_MAJOR < 16

IGC/VectorCompiler/lib/GenXOpts/CMAnalysis/InstructionSimplifyGenX.cpp

+221
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,219 @@ static Value *simplifyMulDDQ(BinaryOperator &Mul) {
183183
return Result;
184184
}
185185

186+
static bool usesBothOperands(const ArrayRef<int> &Mask, int InputLen) {
187+
int OperandIdx = -1;
188+
for (int Index : Mask) {
189+
if (Index < 0)
190+
continue; // skip undef
191+
192+
int CurrentIdx = (Index >= InputLen) ? 1 : 0;
193+
if (OperandIdx == -1) {
194+
OperandIdx = CurrentIdx;
195+
} else if (OperandIdx != CurrentIdx) {
196+
return true; // Both operands are used
197+
}
198+
}
199+
return false;
200+
}
201+
202+
static void checkInputsForMask(ArrayRef<int> &Mask1, ArrayRef<int> &Mask2,
203+
ArrayRef<int> &CurrentMask, bool &UsesInput1,
204+
bool &UsesInput2) {
205+
for (auto Index : CurrentMask) {
206+
if (Index < 0) {
207+
continue; // Undef index.
208+
} else if (static_cast<unsigned>(Index) < Mask1.size()) {
209+
UsesInput1 = true;
210+
} else if (static_cast<unsigned>(Index) < Mask1.size() + Mask2.size()) {
211+
UsesInput2 = true;
212+
} else {
213+
IGC_ASSERT_EXIT(false && "Unexpected index in mask");
214+
}
215+
}
216+
}
217+
218+
static int getInputLen(ShuffleVectorInst *CheckShuffle) {
219+
auto *Type =
220+
cast<IGCLLVM::FixedVectorType>(CheckShuffle->getOperand(0)->getType());
221+
return static_cast<int>(Type->getNumElements());
222+
}
223+
224+
// Simplify ShuffleVector one-instruction chain
225+
// Transform from:
226+
// %shuffle1 = shufflevector <16 x i1> %input1, <16 x i1> poison,
227+
// <256 x i32> <i32 undef (224 times), i32 0-15 (16), i32 undef (16 times)>
228+
// %combinedShuffle = shufflevector <256 x i1> %shuffle1, <256 x i1> poison,
229+
// <256 x i32> <i32 undef (224 times), i32 224-239 & 256-271 (32)>
230+
// To:
231+
// %finalShuffle = shufflevector <16 x i1> %input1, <16 x i1> poison,
232+
// <32 x i32> <i32 undef, ..., i32 0-15>
233+
static Value *propagateShuffleVector(ShuffleVectorInst *Shuffle) {
234+
LLVM_DEBUG(dbgs() << "Simplifying shufflevector: " << *Shuffle << "\n");
235+
236+
auto *Input1 = dyn_cast<ShuffleVectorInst>(Shuffle->getOperand(0));
237+
auto *Input2 = dyn_cast<ShuffleVectorInst>(Shuffle->getOperand(1));
238+
239+
if (!Input1 && !Input2) {
240+
LLVM_DEBUG(
241+
dbgs()
242+
<< "propagateShuffleVector: No chain detected, nothing to optimize.\n");
243+
return Shuffle;
244+
}
245+
246+
bool UseInput1 = !Input2;
247+
auto *CheckShuffle = Input1 ? Input1 : Input2;
248+
ArrayRef<int> Mask = CheckShuffle->getShuffleMask();
249+
ArrayRef<int> CurrentMask = Shuffle->getShuffleMask();
250+
251+
bool UsesInput1 = false;
252+
bool UsesInput2 = false;
253+
checkInputsForMask(Mask, Mask, CurrentMask, UsesInput1, UsesInput2);
254+
255+
if (UsesInput1 && UsesInput2 || (UsesInput1 && !UseInput1) ||
256+
(UsesInput2 && UseInput1)) {
257+
LLVM_DEBUG(dbgs() << "Expected only one use in shuffle.\n");
258+
return Shuffle;
259+
}
260+
261+
auto InputLen = getInputLen(CheckShuffle);
262+
263+
SmallVector<int, 32> CombinedMask;
264+
// Combine the masks into a single mask.
265+
for (auto Index : CurrentMask) {
266+
if (Index < 0) {
267+
CombinedMask.push_back(-1); // Undef index.
268+
} else if (static_cast<unsigned>(Index) < Mask.size()) {
269+
CombinedMask.push_back(Mask[Index]);
270+
} else {
271+
CombinedMask.push_back(Mask[Index - Mask.size()] + InputLen);
272+
}
273+
}
274+
275+
LLVM_DEBUG(dbgs() << "Combined mask: "; for (int Val
276+
: CombinedMask) dbgs()
277+
<< Val << " ";
278+
dbgs() << "\n");
279+
280+
// Create the final shuffle vector with the combined mask.
281+
IRBuilder<> Builder(Shuffle);
282+
auto *NewShuffle = Builder.CreateShuffleVector(
283+
CheckShuffle->getOperand(0), CheckShuffle->getOperand(1), CombinedMask);
284+
285+
LLVM_DEBUG(dbgs() << "Created new shufflevector: " << *NewShuffle << "\n");
286+
287+
return NewShuffle;
288+
}
289+
290+
static bool checkInputsForMaskIndex(int InputLen, ArrayRef<int> Mask,
291+
int &InputOperandIdx) {
292+
for (auto Index : Mask) {
293+
if (Index < 0)
294+
continue; // skip undef
295+
int OperandIdx = 0;
296+
if (Index >= InputLen) {
297+
OperandIdx = 1;
298+
}
299+
if (InputOperandIdx == -1) {
300+
InputOperandIdx = OperandIdx;
301+
} else if (InputOperandIdx != OperandIdx) {
302+
// Both operands are used in the first shuffle.
303+
return false;
304+
}
305+
}
306+
return true;
307+
}
308+
309+
// Simplify ShuffleVector multi-instructions chain
310+
// Transform from:
311+
// %shuffle1 = shufflevector <16 x i1> %input1, <16 x i1> poison,
312+
// <256 x i32> <i32 undef (224 times), i32 0-15 (16), i32 undef (16 times)>
313+
// %shuffle2 = shufflevector <16 x i1> %input2, <16 x i1> poison,
314+
// <256 x i32> <i32 0-15 (16), i32 undef (240 times)>
315+
// %combinedShuffle = shufflevector <256 x i1> %shuffle1, <256 x i1>
316+
// %shuffle2,
317+
// <256 x i32> <i32 undef (224 times), i32 224-239 & 256-271 (32)>
318+
// To:
319+
// %finalShuffle = shufflevector <16 x i1> %input1, <16 x i1> %input2,
320+
// <32 x i32> <i32 undef, ..., i32 0-15, i32 16-31>
321+
static Value *simplifyShuffleVectorChain(ShuffleVectorInst *Shuffle) {
322+
LLVM_DEBUG(dbgs() << "Simplifying shufflevector: " << *Shuffle << "\n");
323+
324+
auto *Input1 = dyn_cast<ShuffleVectorInst>(Shuffle->getOperand(0));
325+
auto *Input2 = dyn_cast<ShuffleVectorInst>(Shuffle->getOperand(1));
326+
327+
if (!Input1 || !Input2) {
328+
LLVM_DEBUG(dbgs() << "simplifyShuffleVectorChain: No chain detected, "
329+
"nothing to optimize.\n");
330+
return Shuffle;
331+
}
332+
333+
ArrayRef<int> Mask1 = Input1->getShuffleMask();
334+
ArrayRef<int> Mask2 = Input2->getShuffleMask();
335+
ArrayRef<int> CurrentMask = Shuffle->getShuffleMask();
336+
337+
bool UsesInput1 = false;
338+
bool UsesInput2 = false;
339+
340+
checkInputsForMask(Mask1, Mask2, CurrentMask, UsesInput1, UsesInput2);
341+
342+
if (!UsesInput1 || !UsesInput2) {
343+
LLVM_DEBUG(dbgs() << "Only one input used in shuffle.\n");
344+
return Shuffle;
345+
}
346+
347+
int Input1OperandIdx = -1;
348+
auto Input1Len = getInputLen(Input1);
349+
if (!checkInputsForMaskIndex(Input1Len, Mask1, Input1OperandIdx))
350+
return Shuffle;
351+
352+
int Input2OperandIdx = -1;
353+
auto Input2Len = getInputLen(Input2);
354+
if (!checkInputsForMaskIndex(Input2Len, Mask2, Input2OperandIdx))
355+
return Shuffle;
356+
357+
// If we have only one operand used in the first shuffle, we need to
358+
// set the other operand to 0.
359+
if (Input1OperandIdx == -1)
360+
Input1OperandIdx = 0;
361+
if (Input2OperandIdx == -1)
362+
Input2OperandIdx = 0;
363+
364+
if (Input1->getOperand(Input1OperandIdx)->getType() !=
365+
Input2->getOperand(Input2OperandIdx)->getType()) {
366+
// The types of the operands are different.
367+
return Shuffle;
368+
}
369+
370+
SmallVector<int, 32> CombinedMask;
371+
// Combine the masks into a single mask.
372+
for (auto Index : CurrentMask) {
373+
if (Index < 0) {
374+
CombinedMask.push_back(-1); // Undef index.
375+
} else if (static_cast<unsigned>(Index) < Mask1.size()) {
376+
CombinedMask.push_back(Mask1[Index] - Input1Len * Input1OperandIdx);
377+
} else {
378+
CombinedMask.push_back(Mask2[Index - Mask1.size()] +
379+
Input1Len * (1 - Input2OperandIdx));
380+
}
381+
}
382+
383+
LLVM_DEBUG(dbgs() << "Combined mask: "; for (int Val
384+
: CombinedMask) dbgs()
385+
<< Val << " ";
386+
dbgs() << "\n");
387+
388+
// Create the final shuffle vector with the combined mask.
389+
IRBuilder<> Builder(Shuffle);
390+
auto *NewShuffle = Builder.CreateShuffleVector(
391+
Input1->getOperand(Input1OperandIdx),
392+
Input2->getOperand(Input2OperandIdx), CombinedMask);
393+
394+
LLVM_DEBUG(dbgs() << "Created new shufflevector: " << *NewShuffle << "\n");
395+
396+
return NewShuffle;
397+
}
398+
186399
static inline bool isBitcastFits(BitCastInst *BC) {
187400
return BC->getSrcTy()->isVectorTy() && BC->getDestTy()->isIntegerTy();
188401
}
@@ -251,6 +464,14 @@ static Value *GenXSimplifyInstruction(llvm::Instruction *Inst) {
251464
if (Inst->getOpcode() == Instruction::Mul)
252465
return simplifyMulDDQ(*cast<BinaryOperator>(Inst));
253466

467+
if (auto *Shuffle = dyn_cast<ShuffleVectorInst>(Inst)) {
468+
auto *Chain = simplifyShuffleVectorChain(Shuffle);
469+
if (Chain != Shuffle) {
470+
return Chain;
471+
}
472+
return propagateShuffleVector(Shuffle);
473+
}
474+
254475
return nullptr;
255476
}
256477

0 commit comments

Comments
 (0)