Skip to content

[mlir][arith][transforms] Adds Truncf f32 to f4e2m1 #144157

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

Muzammiluddin-Syed-ECE
Copy link
Contributor

See work detail: iree-org/iree#20920

Add support for FP32 -> MXFP F4 in arith.truncf op

@llvmbot
Copy link
Member

llvmbot commented Jun 13, 2025

@llvm/pr-subscribers-mlir-arith

@llvm/pr-subscribers-mlir

Author: Muzammil (Muzammiluddin-Syed-ECE)

Changes

See work detail: iree-org/iree#20920

Add support for FP32 -> MXFP F4 in arith.truncf op


Full diff: https://github.com/llvm/llvm-project/pull/144157.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp (+128)
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 534aff9562b7a..40d080d83ad38 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -34,6 +34,18 @@ static Value createConst(Location loc, Type type, int value,
   return rewriter.create<arith::ConstantOp>(loc, attr);
 }
 
+/// Create an float constant.
+static Value createFloatConst(Location loc, Type type, float value,
+                         PatternRewriter &rewriter) {
+  auto attr = rewriter.getFloatAttr(getElementTypeOrSelf(type), value);
+  if (auto shapedTy = dyn_cast<ShapedType>(type)) {
+    return rewriter.create<arith::ConstantOp>(
+        loc, DenseElementsAttr::get(shapedTy, attr));
+  }
+
+  return rewriter.create<arith::ConstantOp>(loc, attr);
+}
+
 /// Creates shapedType using shape from cloneFrom and base type from cloneTo
 static Type cloneToShapedType(Type cloneFrom, Type cloneTo) {
   if (auto shapedTy = dyn_cast<ShapedType>(cloneFrom)) {
@@ -366,6 +378,122 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
   }
 };
 
+/// Conversion from F32 to F4E2M1 according to the OCP Spec:
+/// www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
+///
+/// The spec requiers us to perform Round to Nearest, Ties to Even.
+///
+/// This means that after rounding, we should break ties by choosing the option
+/// which results in a mantissa of 0 in the least significant digit.
+///
+/// Table of representable values in F4E2M1:
+///
+/// Note: x is sign bit
+/// | Binary | Value ( + / - ) 
+/// | x000   | 0.0
+/// | x001   | 0.5
+/// | x010   | 1.0
+/// | x011   | 1.5
+/// | x100   | 2.0
+/// | x101   | 3.0
+/// | x110   | 4.0
+/// | x111   | 6.0
+///
+/// Conversion procedure: 
+///   Step 1: Clamp to representable bounds.
+///   Step 2: Convert exponent by adjusting bias.
+///   Step 3: Set mantissa to first bit.
+///   Step 4: Special consideration for subnormal and zero exponent.
+///   Step 5: Round up if necessary, if mantissa[1:] greater than 1000000 or subnormal.
+struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(arith::TruncFOp op,
+                                PatternRewriter &rewriter) const final {
+    ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    Value operand = op.getOperand();
+    Type operandTy = operand.getType();
+    Type resultTy = op.getType();
+    Type operandETy = getElementTypeOrSelf(operandTy);
+    Type resultETy = getElementTypeOrSelf(resultTy);
+
+    if (!llvm::isa<Float4E2M1FNType>(resultETy)) {
+      return rewriter.notifyMatchFailure(op, "not a trunc of F4E2M1FN");
+    }
+
+    Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
+    Type i8Ty = cloneToShapedType(operandTy, b.getI8Type());
+    Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
+    Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
+
+    // Constants
+    Value c0x1 = createConst(op->getLoc(), i4Ty, 1, rewriter);
+    Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter); // 23
+    Value cF4MantissaWidth = c0x1;         // 1
+    Value cF32FirstBitMask = createConst(op.getLoc(), i32Ty, 0x400000, rewriter);
+    Value c0x00000016 = createConst(op->getLoc(), i32Ty, 22, rewriter);
+    Value c0x00 = createConst(op.getLoc(), i8Ty, 0x00, rewriter);
+    Value c0xff = createConst(op.getLoc(), i8Ty, 0xff, rewriter);
+    Value cF32MantissaMask = createConst(op->getLoc(), i32Ty, 0x7fffff, rewriter);
+    Value c0x00000000 = createConst(op.getLoc(), i32Ty, 0, rewriter);
+    Value cF32Last22BitMask = createConst(op->getLoc(), i32Ty, 0x3fffff, rewriter);;
+    
+    // Step 1: Clamp to bounds.
+    Value cHigherBound = createFloatConst(op->getLoc(), f32Ty, 6.0, rewriter);
+    Value cLowerBound = createFloatConst(op->getLoc(), f32Ty, -6.0, rewriter);
+    Value clampHigh = b.create<arith::CmpFOp>(arith::CmpFPredicate::UGT, operand, cHigherBound);
+    Value clampLow = b.create<arith::CmpFOp>(arith::CmpFPredicate::ULT, operand, cLowerBound);
+    Value operandClamped = b.create<arith::SelectOp>(clampHigh, cHigherBound, operand);
+    operandClamped = b.create<arith::SelectOp>(clampLow, cLowerBound, operandClamped);
+    Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operandClamped);
+  
+    // Step 2: Convert exponent by adjusting bias.
+    Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
+    Value biasAdjustment = createConst(op.getLoc(), i32Ty, 0x7e, rewriter);
+    Value biasAdjustedSignExp = b.create<arith::SubIOp>(f32SignExp, biasAdjustment);
+    Value f4SignExp = b.create<arith::TruncIOp>(i4Ty, biasAdjustedSignExp);
+    f4SignExp = b.create<arith::ShLIOp>(f4SignExp, cF4MantissaWidth);
+    
+    // Step 3: Set mantissa to first bit.
+    Value man1Bit = b.create<arith::AndIOp>(f32Bits, cF32FirstBitMask);
+    man1Bit = b.create<arith::ShRUIOp>(man1Bit, c0x00000016);
+    Value f4Man = b.create<arith::TruncIOp>(i4Ty, man1Bit);
+    Value f4Bits = b.create<arith::AddIOp>(f4SignExp, f4Man);
+    
+    // Step 4: Special consideration for conversion to 0.5.
+    Value f8Exp = b.create<arith::TruncIOp>(i8Ty, biasAdjustedSignExp);
+    Value isSubnormal =
+        b.create<arith::CmpIOp>(arith::CmpIPredicate::sle, f8Exp, c0x00);
+    Value isNegOneExp =
+        b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0xff);
+    Value man23Bits = b.create<arith::AndIOp>(f32Bits, cF32MantissaMask);
+    Value isNonZeroMan =
+        b.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, man23Bits, c0x00000000);
+    Value roundToHalf = b.create<arith::AndIOp>(isNegOneExp, isNonZeroMan);
+    Value subnormalF4Bits = createConst(op->getLoc(), i4Ty, 0xf, rewriter);
+    Value halfF4Bits = createConst(op->getLoc(), i4Ty, 0x0, rewriter);
+    Value isZeroExp = 
+        b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0x00);
+    
+    Value subResult = b.create<arith::SelectOp>(isSubnormal, subnormalF4Bits, f4Bits);
+    subResult = b.create<arith::SelectOp>(roundToHalf, halfF4Bits, subResult);
+    f4Bits = b.create<arith::SelectOp>(isZeroExp, f4Bits, subResult);
+
+    // Step 5: Round up if necessary.
+    Value cRound = createConst(op.getLoc(), i32Ty, 0x200000, rewriter); // 010 0000...
+    Value man22Bits = b.create<arith::AndIOp>(f32Bits, cF32Last22BitMask);
+    Value shouldRound =
+        b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, man22Bits, cRound);
+    shouldRound =
+        b.create<arith::OrIOp>(shouldRound, isSubnormal);
+    Value roundedF4Bits = b.create<arith::AddIOp>(f4Bits, c0x1);
+    f4Bits = b.create<arith::SelectOp>(shouldRound, roundedF4Bits, f4Bits);
+
+    Value result = b.create<arith::BitcastOp>(resultTy, f4Bits);
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
 /*
 TruncF to F8E8M0 is expected to extract exponent bits out of F32 type
 Since All kinds of Infs and NaNs are mapped to same exponent bits in F32 type,

Copy link

github-actions bot commented Jun 13, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems overall reasonable for a first crack at this sort of fallback lowering.

I think once we have an extf fallback we might want an integration test for this - and can get scaling_truncf/scaling_extf in there too.

@Muzammiluddin-Syed-ECE Muzammiluddin-Syed-ECE force-pushed the muzasyed/scale branch 2 times, most recently from 70dfe0f to 63f2337 Compare June 16, 2025 16:35
Signed-off-by: Muzammiluddin Syed <[email protected]>
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This need some tests. I think that lit tests are not super useful because they will check that the lowering is what the lowering is, but maybe a few integration tests would make sense here.

@kuhar kuhar changed the title [Arith][Transforms] Adds Truncf f32 to f4e2m1 [mlir][arith][transforms] Adds Truncf f32 to f4e2m1 Jun 16, 2025
@Muzammiluddin-Syed-ECE Muzammiluddin-Syed-ECE force-pushed the muzasyed/scale branch 2 times, most recently from 74be7a8 to 78d9e7a Compare June 16, 2025 17:12
Signed-off-by: Muzammiluddin Syed <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants