-
Notifications
You must be signed in to change notification settings - Fork 13.8k
[mlir][arith][transforms] Adds Truncf f32 to f4e2m1 #144157
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[mlir][arith][transforms] Adds Truncf f32 to f4e2m1 #144157
Conversation
Signed-off-by: Muzammiluddin Syed <[email protected]>
@llvm/pr-subscribers-mlir-arith @llvm/pr-subscribers-mlir Author: Muzammil (Muzammiluddin-Syed-ECE) ChangesSee work detail: iree-org/iree#20920 Add support for FP32 -> MXFP F4 in Full diff: https://github.com/llvm/llvm-project/pull/144157.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 534aff9562b7a..40d080d83ad38 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -34,6 +34,18 @@ static Value createConst(Location loc, Type type, int value,
return rewriter.create<arith::ConstantOp>(loc, attr);
}
+/// Create an float constant.
+static Value createFloatConst(Location loc, Type type, float value,
+ PatternRewriter &rewriter) {
+ auto attr = rewriter.getFloatAttr(getElementTypeOrSelf(type), value);
+ if (auto shapedTy = dyn_cast<ShapedType>(type)) {
+ return rewriter.create<arith::ConstantOp>(
+ loc, DenseElementsAttr::get(shapedTy, attr));
+ }
+
+ return rewriter.create<arith::ConstantOp>(loc, attr);
+}
+
/// Creates shapedType using shape from cloneFrom and base type from cloneTo
static Type cloneToShapedType(Type cloneFrom, Type cloneTo) {
if (auto shapedTy = dyn_cast<ShapedType>(cloneFrom)) {
@@ -366,6 +378,122 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
}
};
+/// Conversion from F32 to F4E2M1 according to the OCP Spec:
+/// www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
+///
+/// The spec requiers us to perform Round to Nearest, Ties to Even.
+///
+/// This means that after rounding, we should break ties by choosing the option
+/// which results in a mantissa of 0 in the least significant digit.
+///
+/// Table of representable values in F4E2M1:
+///
+/// Note: x is sign bit
+/// | Binary | Value ( + / - )
+/// | x000 | 0.0
+/// | x001 | 0.5
+/// | x010 | 1.0
+/// | x011 | 1.5
+/// | x100 | 2.0
+/// | x101 | 3.0
+/// | x110 | 4.0
+/// | x111 | 6.0
+///
+/// Conversion procedure:
+/// Step 1: Clamp to representable bounds.
+/// Step 2: Convert exponent by adjusting bias.
+/// Step 3: Set mantissa to first bit.
+/// Step 4: Special consideration for subnormal and zero exponent.
+/// Step 5: Round up if necessary, if mantissa[1:] greater than 1000000 or subnormal.
+struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(arith::TruncFOp op,
+ PatternRewriter &rewriter) const final {
+ ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ Value operand = op.getOperand();
+ Type operandTy = operand.getType();
+ Type resultTy = op.getType();
+ Type operandETy = getElementTypeOrSelf(operandTy);
+ Type resultETy = getElementTypeOrSelf(resultTy);
+
+ if (!llvm::isa<Float4E2M1FNType>(resultETy)) {
+ return rewriter.notifyMatchFailure(op, "not a trunc of F4E2M1FN");
+ }
+
+ Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
+ Type i8Ty = cloneToShapedType(operandTy, b.getI8Type());
+ Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
+ Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
+
+ // Constants
+ Value c0x1 = createConst(op->getLoc(), i4Ty, 1, rewriter);
+ Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter); // 23
+ Value cF4MantissaWidth = c0x1; // 1
+ Value cF32FirstBitMask = createConst(op.getLoc(), i32Ty, 0x400000, rewriter);
+ Value c0x00000016 = createConst(op->getLoc(), i32Ty, 22, rewriter);
+ Value c0x00 = createConst(op.getLoc(), i8Ty, 0x00, rewriter);
+ Value c0xff = createConst(op.getLoc(), i8Ty, 0xff, rewriter);
+ Value cF32MantissaMask = createConst(op->getLoc(), i32Ty, 0x7fffff, rewriter);
+ Value c0x00000000 = createConst(op.getLoc(), i32Ty, 0, rewriter);
+ Value cF32Last22BitMask = createConst(op->getLoc(), i32Ty, 0x3fffff, rewriter);;
+
+ // Step 1: Clamp to bounds.
+ Value cHigherBound = createFloatConst(op->getLoc(), f32Ty, 6.0, rewriter);
+ Value cLowerBound = createFloatConst(op->getLoc(), f32Ty, -6.0, rewriter);
+ Value clampHigh = b.create<arith::CmpFOp>(arith::CmpFPredicate::UGT, operand, cHigherBound);
+ Value clampLow = b.create<arith::CmpFOp>(arith::CmpFPredicate::ULT, operand, cLowerBound);
+ Value operandClamped = b.create<arith::SelectOp>(clampHigh, cHigherBound, operand);
+ operandClamped = b.create<arith::SelectOp>(clampLow, cLowerBound, operandClamped);
+ Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operandClamped);
+
+ // Step 2: Convert exponent by adjusting bias.
+ Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
+ Value biasAdjustment = createConst(op.getLoc(), i32Ty, 0x7e, rewriter);
+ Value biasAdjustedSignExp = b.create<arith::SubIOp>(f32SignExp, biasAdjustment);
+ Value f4SignExp = b.create<arith::TruncIOp>(i4Ty, biasAdjustedSignExp);
+ f4SignExp = b.create<arith::ShLIOp>(f4SignExp, cF4MantissaWidth);
+
+ // Step 3: Set mantissa to first bit.
+ Value man1Bit = b.create<arith::AndIOp>(f32Bits, cF32FirstBitMask);
+ man1Bit = b.create<arith::ShRUIOp>(man1Bit, c0x00000016);
+ Value f4Man = b.create<arith::TruncIOp>(i4Ty, man1Bit);
+ Value f4Bits = b.create<arith::AddIOp>(f4SignExp, f4Man);
+
+ // Step 4: Special consideration for conversion to 0.5.
+ Value f8Exp = b.create<arith::TruncIOp>(i8Ty, biasAdjustedSignExp);
+ Value isSubnormal =
+ b.create<arith::CmpIOp>(arith::CmpIPredicate::sle, f8Exp, c0x00);
+ Value isNegOneExp =
+ b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0xff);
+ Value man23Bits = b.create<arith::AndIOp>(f32Bits, cF32MantissaMask);
+ Value isNonZeroMan =
+ b.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, man23Bits, c0x00000000);
+ Value roundToHalf = b.create<arith::AndIOp>(isNegOneExp, isNonZeroMan);
+ Value subnormalF4Bits = createConst(op->getLoc(), i4Ty, 0xf, rewriter);
+ Value halfF4Bits = createConst(op->getLoc(), i4Ty, 0x0, rewriter);
+ Value isZeroExp =
+ b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0x00);
+
+ Value subResult = b.create<arith::SelectOp>(isSubnormal, subnormalF4Bits, f4Bits);
+ subResult = b.create<arith::SelectOp>(roundToHalf, halfF4Bits, subResult);
+ f4Bits = b.create<arith::SelectOp>(isZeroExp, f4Bits, subResult);
+
+ // Step 5: Round up if necessary.
+ Value cRound = createConst(op.getLoc(), i32Ty, 0x200000, rewriter); // 010 0000...
+ Value man22Bits = b.create<arith::AndIOp>(f32Bits, cF32Last22BitMask);
+ Value shouldRound =
+ b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, man22Bits, cRound);
+ shouldRound =
+ b.create<arith::OrIOp>(shouldRound, isSubnormal);
+ Value roundedF4Bits = b.create<arith::AddIOp>(f4Bits, c0x1);
+ f4Bits = b.create<arith::SelectOp>(shouldRound, roundedF4Bits, f4Bits);
+
+ Value result = b.create<arith::BitcastOp>(resultTy, f4Bits);
+ rewriter.replaceOp(op, result);
+ return success();
+ }
+};
+
/*
TruncF to F8E8M0 is expected to extract exponent bits out of F32 type
Since All kinds of Infs and NaNs are mapped to same exponent bits in F32 type,
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
188079d
to
d9abbe1
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems overall reasonable for a first crack at this sort of fallback lowering.
I think once we have an extf fallback we might want an integration test for this - and can get scaling_truncf/scaling_extf in there too.
70dfe0f
to
63f2337
Compare
Signed-off-by: Muzammiluddin Syed <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This need some tests. I think that lit tests are not super useful because they will check that the lowering is what the lowering is, but maybe a few integration tests would make sense here.
74be7a8
to
78d9e7a
Compare
Signed-off-by: Muzammiluddin Syed <[email protected]>
See work detail: iree-org/iree#20920
Add support for FP32 -> MXFP F4 in
arith.truncf
op