@@ -34,8 +34,8 @@ static Value createConst(Location loc, Type type, int value,
34
34
return rewriter.create <arith::ConstantOp>(loc, attr);
35
35
}
36
36
37
- // / Create an float constant.
38
- static Value createFloatConst (Location loc, Type type, float value,
37
+ // / Create a float constant.
38
+ static Value createFloatConst (Location loc, Type type, APFloat value,
39
39
PatternRewriter &rewriter) {
40
40
auto attr = rewriter.getFloatAttr (getElementTypeOrSelf (type), value);
41
41
if (auto shapedTy = dyn_cast<ShapedType>(type)) {
@@ -416,67 +416,66 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
416
416
Type operandETy = getElementTypeOrSelf (operandTy);
417
417
Type resultETy = getElementTypeOrSelf (resultTy);
418
418
419
- if (!llvm:: isa<Float4E2M1FNType>(resultETy)) {
420
- return rewriter.notifyMatchFailure (op, " not a trunc of F4E2M1FN" );
419
+ if (!isa<Float32Type>(operandETy) || ! isa<Float4E2M1FNType>(resultETy)) {
420
+ return rewriter.notifyMatchFailure (op, " not a trunc of F32 to F4E2M1FN" );
421
421
}
422
422
423
423
Type i4Ty = cloneToShapedType (operandTy, b.getI4Type ());
424
424
Type i8Ty = cloneToShapedType (operandTy, b.getI8Type ());
425
425
Type i32Ty = cloneToShapedType (operandTy, b.getI32Type ());
426
426
Type f32Ty = cloneToShapedType (operandTy, b.getF32Type ());
427
427
428
- // Constants
429
428
Value c0x1 = createConst (op->getLoc (), i4Ty, 1 , rewriter);
430
- Value cF32MantissaWidth = createConst (op->getLoc (), i32Ty, 23 , rewriter); // 23
431
- Value cF4MantissaWidth = c0x1; // 1
432
- Value cF32FirstBitMask = createConst (op.getLoc (), i32Ty, 0x400000 , rewriter);
433
429
Value c0x00000016 = createConst (op->getLoc (), i32Ty, 22 , rewriter);
434
430
Value c0x00 = createConst (op.getLoc (), i8Ty, 0x00 , rewriter);
435
431
Value c0xff = createConst (op.getLoc (), i8Ty, 0xff , rewriter);
436
- Value cF32MantissaMask = createConst (op->getLoc (), i32Ty, 0x7fffff , rewriter);
437
432
Value c0x00000000 = createConst (op.getLoc (), i32Ty, 0 , rewriter);
438
- Value cF32Last22BitMask = createConst (op->getLoc (), i32Ty, 0x3fffff , rewriter);;
439
433
440
434
// Step 1: Clamp to bounds.
441
435
Value cHigherBound = createFloatConst (op->getLoc (), f32Ty, 6.0 , rewriter);
442
436
Value cLowerBound = createFloatConst (op->getLoc (), f32Ty, -6.0 , rewriter);
443
437
Value operandClamped = b.create <arith::MinimumFOp>(clampLow, operand);
444
438
operandClamped = b.create <arith::MaximumFOp>(clampHigh, operandClamped);
445
439
Value f32Bits = b.create <arith::BitcastOp>(i32Ty, operandClamped);
446
-
440
+
447
441
// Step 2: Convert exponent by adjusting bias.
448
- Value f32SignExp = b.create <arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
449
442
Value biasAdjustment = createConst (op.getLoc (), i32Ty, 0x7e , rewriter);
443
+ Value cF4MantissaWidth = c0x1; // 1
444
+ Value cF32MantissaWidth = createConst (op->getLoc (), i32Ty, 23 , rewriter); // 23
445
+ Value f32SignExp = b.create <arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
450
446
Value biasAdjustedSignExp = b.create <arith::SubIOp>(f32SignExp, biasAdjustment);
451
447
Value f4SignExp = b.create <arith::TruncIOp>(i4Ty, biasAdjustedSignExp);
452
448
f4SignExp = b.create <arith::ShLIOp>(f4SignExp, cF4MantissaWidth);
453
449
454
450
// Step 3: Set mantissa to first bit.
451
+ Value cF32FirstBitMask = createConst (op.getLoc (), i32Ty, 0x400000 , rewriter);
455
452
Value man1Bit = b.create <arith::AndIOp>(f32Bits, cF32FirstBitMask);
456
453
man1Bit = b.create <arith::ShRUIOp>(man1Bit, c0x00000016);
457
454
Value f4Man = b.create <arith::TruncIOp>(i4Ty, man1Bit);
458
455
Value f4Bits = b.create <arith::AddIOp>(f4SignExp, f4Man);
459
456
460
457
// Step 4: Special consideration for conversion to 0.5.
458
+ Value cF32MantissaMask = createConst (op->getLoc (), i32Ty, 0x7fffff , rewriter);
461
459
Value f8Exp = b.create <arith::TruncIOp>(i8Ty, biasAdjustedSignExp);
462
460
Value isSubnormal =
463
- b.create <arith::CmpIOp>(arith::CmpIPredicate::sle, f8Exp, c0x00);
461
+ b.create <arith::CmpIOp>(arith::CmpIPredicate::sle, f8Exp, c0x00);
464
462
Value isNegOneExp =
465
- b.create <arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0xff);
463
+ b.create <arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0xff);
466
464
Value man23Bits = b.create <arith::AndIOp>(f32Bits, cF32MantissaMask);
467
465
Value isNonZeroMan =
468
- b.create <arith::CmpIOp>(arith::CmpIPredicate::ugt, man23Bits, c0x00000000);
466
+ b.create <arith::CmpIOp>(arith::CmpIPredicate::ugt, man23Bits, c0x00000000);
469
467
Value roundToHalf = b.create <arith::AndIOp>(isNegOneExp, isNonZeroMan);
470
- Value subnormalF4Bits = createConst (op->getLoc (), i4Ty, 0xf , rewriter);
471
- Value halfF4Bits = createConst (op->getLoc (), i4Ty, 0x0 , rewriter);
472
468
Value isZeroExp =
473
- b.create <arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0x00);
469
+ b.create <arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0x00);
474
470
471
+ Value subnormalF4Bits = createConst (op->getLoc (), i4Ty, 0xf , rewriter);
472
+ Value halfF4Bits = createConst (op->getLoc (), i4Ty, 0x0 , rewriter);
475
473
Value subResult = b.create <arith::SelectOp>(isSubnormal, subnormalF4Bits, f4Bits);
476
474
subResult = b.create <arith::SelectOp>(roundToHalf, halfF4Bits, subResult);
477
475
f4Bits = b.create <arith::SelectOp>(isZeroExp, f4Bits, subResult);
478
-
476
+
479
477
// Step 5: Round up if necessary.
478
+ Value cF32Last22BitMask = createConst (op->getLoc (), i32Ty, 0x3fffff , rewriter);
480
479
Value cRound = createConst (op.getLoc (), i32Ty, 0x200000 , rewriter); // 010 0000...
481
480
Value man22Bits = b.create <arith::AndIOp>(f32Bits, cF32Last22BitMask);
482
481
Value shouldRound =
0 commit comments