Skip to content

Commit 74be7a8

Browse files
PR Review round 1
Signed-off-by: Muzammiluddin Syed <[email protected]>
1 parent 63f2337 commit 74be7a8

File tree

1 file changed

+18
-19
lines changed

1 file changed

+18
-19
lines changed

mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ static Value createConst(Location loc, Type type, int value,
3434
return rewriter.create<arith::ConstantOp>(loc, attr);
3535
}
3636

37-
/// Create an float constant.
38-
static Value createFloatConst(Location loc, Type type, float value,
37+
/// Create a float constant.
38+
static Value createFloatConst(Location loc, Type type, APFloat value,
3939
PatternRewriter &rewriter) {
4040
auto attr = rewriter.getFloatAttr(getElementTypeOrSelf(type), value);
4141
if (auto shapedTy = dyn_cast<ShapedType>(type)) {
@@ -416,67 +416,66 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
416416
Type operandETy = getElementTypeOrSelf(operandTy);
417417
Type resultETy = getElementTypeOrSelf(resultTy);
418418

419-
if (!llvm::isa<Float4E2M1FNType>(resultETy)) {
420-
return rewriter.notifyMatchFailure(op, "not a trunc of F4E2M1FN");
419+
if (!isa<Float32Type>(operandETy) || !isa<Float4E2M1FNType>(resultETy)) {
420+
return rewriter.notifyMatchFailure(op, "not a trunc of F32 to F4E2M1FN");
421421
}
422422

423423
Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
424424
Type i8Ty = cloneToShapedType(operandTy, b.getI8Type());
425425
Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
426426
Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
427427

428-
// Constants
429428
Value c0x1 = createConst(op->getLoc(), i4Ty, 1, rewriter);
430-
Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter); // 23
431-
Value cF4MantissaWidth = c0x1; // 1
432-
Value cF32FirstBitMask = createConst(op.getLoc(), i32Ty, 0x400000, rewriter);
433429
Value c0x00000016 = createConst(op->getLoc(), i32Ty, 22, rewriter);
434430
Value c0x00 = createConst(op.getLoc(), i8Ty, 0x00, rewriter);
435431
Value c0xff = createConst(op.getLoc(), i8Ty, 0xff, rewriter);
436-
Value cF32MantissaMask = createConst(op->getLoc(), i32Ty, 0x7fffff, rewriter);
437432
Value c0x00000000 = createConst(op.getLoc(), i32Ty, 0, rewriter);
438-
Value cF32Last22BitMask = createConst(op->getLoc(), i32Ty, 0x3fffff, rewriter);;
439433

440434
// Step 1: Clamp to bounds.
441435
Value cHigherBound = createFloatConst(op->getLoc(), f32Ty, 6.0, rewriter);
442436
Value cLowerBound = createFloatConst(op->getLoc(), f32Ty, -6.0, rewriter);
443437
Value operandClamped = b.create<arith::MinimumFOp>(clampLow, operand);
444438
operandClamped = b.create<arith::MaximumFOp>(clampHigh, operandClamped);
445439
Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operandClamped);
446-
440+
447441
// Step 2: Convert exponent by adjusting bias.
448-
Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
449442
Value biasAdjustment = createConst(op.getLoc(), i32Ty, 0x7e, rewriter);
443+
Value cF4MantissaWidth = c0x1; // 1
444+
Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter); // 23
445+
Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
450446
Value biasAdjustedSignExp = b.create<arith::SubIOp>(f32SignExp, biasAdjustment);
451447
Value f4SignExp = b.create<arith::TruncIOp>(i4Ty, biasAdjustedSignExp);
452448
f4SignExp = b.create<arith::ShLIOp>(f4SignExp, cF4MantissaWidth);
453449

454450
// Step 3: Set mantissa to first bit.
451+
Value cF32FirstBitMask = createConst(op.getLoc(), i32Ty, 0x400000, rewriter);
455452
Value man1Bit = b.create<arith::AndIOp>(f32Bits, cF32FirstBitMask);
456453
man1Bit = b.create<arith::ShRUIOp>(man1Bit, c0x00000016);
457454
Value f4Man = b.create<arith::TruncIOp>(i4Ty, man1Bit);
458455
Value f4Bits = b.create<arith::AddIOp>(f4SignExp, f4Man);
459456

460457
// Step 4: Special consideration for conversion to 0.5.
458+
Value cF32MantissaMask = createConst(op->getLoc(), i32Ty, 0x7fffff, rewriter);
461459
Value f8Exp = b.create<arith::TruncIOp>(i8Ty, biasAdjustedSignExp);
462460
Value isSubnormal =
463-
b.create<arith::CmpIOp>(arith::CmpIPredicate::sle, f8Exp, c0x00);
461+
b.create<arith::CmpIOp>(arith::CmpIPredicate::sle, f8Exp, c0x00);
464462
Value isNegOneExp =
465-
b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0xff);
463+
b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0xff);
466464
Value man23Bits = b.create<arith::AndIOp>(f32Bits, cF32MantissaMask);
467465
Value isNonZeroMan =
468-
b.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, man23Bits, c0x00000000);
466+
b.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, man23Bits, c0x00000000);
469467
Value roundToHalf = b.create<arith::AndIOp>(isNegOneExp, isNonZeroMan);
470-
Value subnormalF4Bits = createConst(op->getLoc(), i4Ty, 0xf, rewriter);
471-
Value halfF4Bits = createConst(op->getLoc(), i4Ty, 0x0, rewriter);
472468
Value isZeroExp =
473-
b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0x00);
469+
b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0x00);
474470

471+
Value subnormalF4Bits = createConst(op->getLoc(), i4Ty, 0xf, rewriter);
472+
Value halfF4Bits = createConst(op->getLoc(), i4Ty, 0x0, rewriter);
475473
Value subResult = b.create<arith::SelectOp>(isSubnormal, subnormalF4Bits, f4Bits);
476474
subResult = b.create<arith::SelectOp>(roundToHalf, halfF4Bits, subResult);
477475
f4Bits = b.create<arith::SelectOp>(isZeroExp, f4Bits, subResult);
478-
476+
479477
// Step 5: Round up if necessary.
478+
Value cF32Last22BitMask = createConst(op->getLoc(), i32Ty, 0x3fffff, rewriter);
480479
Value cRound = createConst(op.getLoc(), i32Ty, 0x200000, rewriter); // 010 0000...
481480
Value man22Bits = b.create<arith::AndIOp>(f32Bits, cF32Last22BitMask);
482481
Value shouldRound =

0 commit comments

Comments
 (0)