@@ -183,6 +183,219 @@ static Value *simplifyMulDDQ(BinaryOperator &Mul) {
183
183
return Result;
184
184
}
185
185
186
+ static bool usesBothOperands (const ArrayRef<int > &Mask, int InputLen) {
187
+ int OperandIdx = -1 ;
188
+ for (int Index : Mask) {
189
+ if (Index < 0 )
190
+ continue ; // skip undef
191
+
192
+ int CurrentIdx = (Index >= InputLen) ? 1 : 0 ;
193
+ if (OperandIdx == -1 ) {
194
+ OperandIdx = CurrentIdx;
195
+ } else if (OperandIdx != CurrentIdx) {
196
+ return true ; // Both operands are used
197
+ }
198
+ }
199
+ return false ;
200
+ }
201
+
202
+ static void checkInputsForMask (ArrayRef<int > &Mask1, ArrayRef<int > &Mask2,
203
+ ArrayRef<int > &CurrentMask, bool &UsesInput1,
204
+ bool &UsesInput2) {
205
+ for (auto Index : CurrentMask) {
206
+ if (Index < 0 ) {
207
+ continue ; // Undef index.
208
+ } else if (static_cast <unsigned >(Index) < Mask1.size ()) {
209
+ UsesInput1 = true ;
210
+ } else if (static_cast <unsigned >(Index) < Mask1.size () + Mask2.size ()) {
211
+ UsesInput2 = true ;
212
+ } else {
213
+ IGC_ASSERT_EXIT (false && " Unexpected index in mask" );
214
+ }
215
+ }
216
+ }
217
+
218
+ static int getInputLen (ShuffleVectorInst *CheckShuffle) {
219
+ auto *Type =
220
+ cast<IGCLLVM::FixedVectorType>(CheckShuffle->getOperand (0 )->getType ());
221
+ return static_cast <int >(Type->getNumElements ());
222
+ }
223
+
224
+ // Simplify ShuffleVector one-instruction chain
225
+ // Transform from:
226
+ // %shuffle1 = shufflevector <16 x i1> %input1, <16 x i1> poison,
227
+ // <256 x i32> <i32 undef (224 times), i32 0-15 (16), i32 undef (16 times)>
228
+ // %combinedShuffle = shufflevector <256 x i1> %shuffle1, <256 x i1> poison,
229
+ // <256 x i32> <i32 undef (224 times), i32 224-239 & 256-271 (32)>
230
+ // To:
231
+ // %finalShuffle = shufflevector <16 x i1> %input1, <16 x i1> poison,
232
+ // <32 x i32> <i32 undef, ..., i32 0-15>
233
+ static Value *propagateShuffleVector (ShuffleVectorInst *Shuffle) {
234
+ LLVM_DEBUG (dbgs () << " Simplifying shufflevector: " << *Shuffle << " \n " );
235
+
236
+ auto *Input1 = dyn_cast<ShuffleVectorInst>(Shuffle->getOperand (0 ));
237
+ auto *Input2 = dyn_cast<ShuffleVectorInst>(Shuffle->getOperand (1 ));
238
+
239
+ if (!Input1 && !Input2) {
240
+ LLVM_DEBUG (
241
+ dbgs ()
242
+ << " propagateShuffleVector: No chain detected, nothing to optimize.\n " );
243
+ return Shuffle;
244
+ }
245
+
246
+ bool UseInput1 = !Input2;
247
+ auto *CheckShuffle = Input1 ? Input1 : Input2;
248
+ ArrayRef<int > Mask = CheckShuffle->getShuffleMask ();
249
+ ArrayRef<int > CurrentMask = Shuffle->getShuffleMask ();
250
+
251
+ bool UsesInput1 = false ;
252
+ bool UsesInput2 = false ;
253
+ checkInputsForMask (Mask, Mask, CurrentMask, UsesInput1, UsesInput2);
254
+
255
+ if (UsesInput1 && UsesInput2 || (UsesInput1 && !UseInput1) ||
256
+ (UsesInput2 && UseInput1)) {
257
+ LLVM_DEBUG (dbgs () << " Expected only one use in shuffle.\n " );
258
+ return Shuffle;
259
+ }
260
+
261
+ auto InputLen = getInputLen (CheckShuffle);
262
+
263
+ SmallVector<int , 32 > CombinedMask;
264
+ // Combine the masks into a single mask.
265
+ for (auto Index : CurrentMask) {
266
+ if (Index < 0 ) {
267
+ CombinedMask.push_back (-1 ); // Undef index.
268
+ } else if (static_cast <unsigned >(Index) < Mask.size ()) {
269
+ CombinedMask.push_back (Mask[Index]);
270
+ } else {
271
+ CombinedMask.push_back (Mask[Index - Mask.size ()] + InputLen);
272
+ }
273
+ }
274
+
275
+ LLVM_DEBUG (dbgs () << " Combined mask: " ; for (int Val
276
+ : CombinedMask) dbgs ()
277
+ << Val << " " ;
278
+ dbgs () << " \n " );
279
+
280
+ // Create the final shuffle vector with the combined mask.
281
+ IRBuilder<> Builder (Shuffle);
282
+ auto *NewShuffle = Builder.CreateShuffleVector (
283
+ CheckShuffle->getOperand (0 ), CheckShuffle->getOperand (1 ), CombinedMask);
284
+
285
+ LLVM_DEBUG (dbgs () << " Created new shufflevector: " << *NewShuffle << " \n " );
286
+
287
+ return NewShuffle;
288
+ }
289
+
290
+ static bool checkInputsForMaskIndex (int InputLen, ArrayRef<int > Mask,
291
+ int &InputOperandIdx) {
292
+ for (auto Index : Mask) {
293
+ if (Index < 0 )
294
+ continue ; // skip undef
295
+ int OperandIdx = 0 ;
296
+ if (Index >= InputLen) {
297
+ OperandIdx = 1 ;
298
+ }
299
+ if (InputOperandIdx == -1 ) {
300
+ InputOperandIdx = OperandIdx;
301
+ } else if (InputOperandIdx != OperandIdx) {
302
+ // Both operands are used in the first shuffle.
303
+ return false ;
304
+ }
305
+ }
306
+ return true ;
307
+ }
308
+
309
+ // Simplify ShuffleVector multi-instructions chain
310
+ // Transform from:
311
+ // %shuffle1 = shufflevector <16 x i1> %input1, <16 x i1> poison,
312
+ // <256 x i32> <i32 undef (224 times), i32 0-15 (16), i32 undef (16 times)>
313
+ // %shuffle2 = shufflevector <16 x i1> %input2, <16 x i1> poison,
314
+ // <256 x i32> <i32 0-15 (16), i32 undef (240 times)>
315
+ // %combinedShuffle = shufflevector <256 x i1> %shuffle1, <256 x i1>
316
+ // %shuffle2,
317
+ // <256 x i32> <i32 undef (224 times), i32 224-239 & 256-271 (32)>
318
+ // To:
319
+ // %finalShuffle = shufflevector <16 x i1> %input1, <16 x i1> %input2,
320
+ // <32 x i32> <i32 undef, ..., i32 0-15, i32 16-31>
321
+ static Value *simplifyShuffleVectorChain (ShuffleVectorInst *Shuffle) {
322
+ LLVM_DEBUG (dbgs () << " Simplifying shufflevector: " << *Shuffle << " \n " );
323
+
324
+ auto *Input1 = dyn_cast<ShuffleVectorInst>(Shuffle->getOperand (0 ));
325
+ auto *Input2 = dyn_cast<ShuffleVectorInst>(Shuffle->getOperand (1 ));
326
+
327
+ if (!Input1 || !Input2) {
328
+ LLVM_DEBUG (dbgs () << " simplifyShuffleVectorChain: No chain detected, "
329
+ " nothing to optimize.\n " );
330
+ return Shuffle;
331
+ }
332
+
333
+ ArrayRef<int > Mask1 = Input1->getShuffleMask ();
334
+ ArrayRef<int > Mask2 = Input2->getShuffleMask ();
335
+ ArrayRef<int > CurrentMask = Shuffle->getShuffleMask ();
336
+
337
+ bool UsesInput1 = false ;
338
+ bool UsesInput2 = false ;
339
+
340
+ checkInputsForMask (Mask1, Mask2, CurrentMask, UsesInput1, UsesInput2);
341
+
342
+ if (!UsesInput1 || !UsesInput2) {
343
+ LLVM_DEBUG (dbgs () << " Only one input used in shuffle.\n " );
344
+ return Shuffle;
345
+ }
346
+
347
+ int Input1OperandIdx = -1 ;
348
+ auto Input1Len = getInputLen (Input1);
349
+ if (!checkInputsForMaskIndex (Input1Len, Mask1, Input1OperandIdx))
350
+ return Shuffle;
351
+
352
+ int Input2OperandIdx = -1 ;
353
+ auto Input2Len = getInputLen (Input2);
354
+ if (!checkInputsForMaskIndex (Input2Len, Mask2, Input2OperandIdx))
355
+ return Shuffle;
356
+
357
+ // If we have only one operand used in the first shuffle, we need to
358
+ // set the other operand to 0.
359
+ if (Input1OperandIdx == -1 )
360
+ Input1OperandIdx = 0 ;
361
+ if (Input2OperandIdx == -1 )
362
+ Input2OperandIdx = 0 ;
363
+
364
+ if (Input1->getOperand (Input1OperandIdx)->getType () !=
365
+ Input2->getOperand (Input2OperandIdx)->getType ()) {
366
+ // The types of the operands are different.
367
+ return Shuffle;
368
+ }
369
+
370
+ SmallVector<int , 32 > CombinedMask;
371
+ // Combine the masks into a single mask.
372
+ for (auto Index : CurrentMask) {
373
+ if (Index < 0 ) {
374
+ CombinedMask.push_back (-1 ); // Undef index.
375
+ } else if (static_cast <unsigned >(Index) < Mask1.size ()) {
376
+ CombinedMask.push_back (Mask1[Index] - Input1Len * Input1OperandIdx);
377
+ } else {
378
+ CombinedMask.push_back (Mask2[Index - Mask1.size ()] +
379
+ Input1Len * (1 - Input2OperandIdx));
380
+ }
381
+ }
382
+
383
+ LLVM_DEBUG (dbgs () << " Combined mask: " ; for (int Val
384
+ : CombinedMask) dbgs ()
385
+ << Val << " " ;
386
+ dbgs () << " \n " );
387
+
388
+ // Create the final shuffle vector with the combined mask.
389
+ IRBuilder<> Builder (Shuffle);
390
+ auto *NewShuffle = Builder.CreateShuffleVector (
391
+ Input1->getOperand (Input1OperandIdx),
392
+ Input2->getOperand (Input2OperandIdx), CombinedMask);
393
+
394
+ LLVM_DEBUG (dbgs () << " Created new shufflevector: " << *NewShuffle << " \n " );
395
+
396
+ return NewShuffle;
397
+ }
398
+
186
399
static inline bool isBitcastFits (BitCastInst *BC) {
187
400
return BC->getSrcTy ()->isVectorTy () && BC->getDestTy ()->isIntegerTy ();
188
401
}
@@ -251,6 +464,14 @@ static Value *GenXSimplifyInstruction(llvm::Instruction *Inst) {
251
464
if (Inst->getOpcode () == Instruction::Mul)
252
465
return simplifyMulDDQ (*cast<BinaryOperator>(Inst));
253
466
467
+ if (auto *Shuffle = dyn_cast<ShuffleVectorInst>(Inst)) {
468
+ auto *Chain = simplifyShuffleVectorChain (Shuffle);
469
+ if (Chain != Shuffle) {
470
+ return Chain;
471
+ }
472
+ return propagateShuffleVector (Shuffle);
473
+ }
474
+
254
475
return nullptr ;
255
476
}
256
477
0 commit comments