Skip to content

[mlir][arith][transforms] Adds Truncf f32 to f4e2m1 #144157

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 198 additions & 0 deletions mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,18 @@ static Value createConst(Location loc, Type type, int value,
return rewriter.create<arith::ConstantOp>(loc, attr);
}

/// Create a float constant.
static Value createFloatConst(Location loc, Type type, APFloat value,
PatternRewriter &rewriter) {
auto attr = rewriter.getFloatAttr(getElementTypeOrSelf(type), value);
if (auto shapedTy = dyn_cast<ShapedType>(type)) {
return rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(shapedTy, attr));
}

return rewriter.create<arith::ConstantOp>(loc, attr);
}

/// Creates shapedType using shape from cloneFrom and base type from cloneTo
static Type cloneToShapedType(Type cloneFrom, Type cloneTo) {
if (auto shapedTy = dyn_cast<ShapedType>(cloneFrom)) {
Expand Down Expand Up @@ -322,6 +334,70 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
}
};

struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(arith::ExtFOp op,
PatternRewriter &rewriter) const final {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
Value operand = op.getOperand();
Type operandTy = operand.getType();
Type resultTy = op.getType();
Type operandETy = getElementTypeOrSelf(operandTy);
Type resultETy = getElementTypeOrSelf(resultTy);

if (!llvm::isa<Float4E2M1FNType>(operandETy) ||
!llvm::isa<Float32Type>(resultETy)) {
return rewriter.notifyMatchFailure(op, "not a ext of F4E2M1FN to F32");
}

Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());

Value bitcast = b.create<arith::BitcastOp>(i4Ty, operand);

Value c0x1 = createConst(op->getLoc(), i4Ty, 1, rewriter);
Value c0x0000001c = createConst(op->getLoc(), i32Ty, 28, rewriter);
Value c0x00000014 = createConst(op->getLoc(), i32Ty, 22, rewriter);
Value cZero =
createFloatConst(op->getLoc(), f32Ty, APFloat(0.0f), rewriter);
Value cHalf =
createFloatConst(op->getLoc(), f32Ty, APFloat(0.5f), rewriter);

Value mantissaBitmask = c0x1;
Value exponentBitmask = createConst(op.getLoc(), i4Ty, 0x6, rewriter);
Value signBitmask = createConst(op.getLoc(), i4Ty, 0x8, rewriter);

Value f4SignBit = b.create<arith::AndIOp>(bitcast, signBitmask);
Value f32Bits = b.create<arith::ExtUIOp>(i32Ty, f4SignBit);
f32Bits = b.create<arith::ShRUIOp>(f32Bits, c0x0000001c);

Value biasAdjustment = createConst(op.getLoc(), i32Ty, 126, rewriter);
Value f4ExpBits = b.create<arith::AndIOp>(bitcast, exponentBitmask);
f4ExpBits = b.create<arith::ShRUIOp>(f4ExpBits, c0x1);
Value f32ExpBits = b.create<arith::ExtUIOp>(i32Ty, f4ExpBits);
f32ExpBits = b.create<arith::AddIOp>(f32ExpBits, biasAdjustment);
f32ExpBits = b.create<arith::ShLIOp>(f32ExpBits, c0x00000014);
f32Bits = b.create<arith::AddIOp>(f32Bits, f32ExpBits);

Value f4ManBit = b.create<arith::AndIOp>(bitcast, mantissaBitmask);
Value f32ManBit = b.create<arith::ExtUIOp>(i32Ty, f4ManBit);
f32Bits = b.create<arith::AddIOp>(f32Bits, f32ManBit);

// Special consideration for subnormal exp (exp == 0).
Value isSubnormal = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq,
f32ExpBits, biasAdjustment);
Value isManSet =
b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f4ManBit, c0x1);
Value subnormalVal = b.create<arith::SelectOp>(isManSet, cHalf, cZero);
f32Bits = b.create<arith::SelectOp>(isSubnormal, subnormalVal, f32Bits);

Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
rewriter.replaceOp(op, result);
return success();
}
};

struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(arith::ExtFOp op,
Expand Down Expand Up @@ -366,6 +442,128 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
}
};

/// Conversion from F32 to F4E2M1 according to the OCP Spec:
/// www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
///
/// The spec requiers us to perform Round to Nearest, Ties to Even.
///
/// This means that after rounding, we should break ties by choosing the option
/// which results in a mantissa of 0 in the least significant digit.
///
/// Table of representable values in F4E2M1:
///
/// Note: x is sign bit
/// | Binary | Value ( + / - )
/// | x000 | 0.0
/// | x001 | 0.5
/// | x010 | 1.0
/// | x011 | 1.5
/// | x100 | 2.0
/// | x101 | 3.0
/// | x110 | 4.0
/// | x111 | 6.0
///
/// Conversion procedure:
/// Step 1: Clamp to representable bounds.
/// Step 2: Convert exponent by adjusting bias.
/// Step 3: Set mantissa to first bit.
/// Step 4: Special consideration for subnormal and zero exponent.
/// Step 5: Round up if necessary, if mantissa[1:] greater than 1000000 or
/// subnormal.
struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(arith::TruncFOp op,
PatternRewriter &rewriter) const final {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
Value operand = op.getOperand();
Type operandTy = operand.getType();
Type resultTy = op.getType();
Type operandETy = getElementTypeOrSelf(operandTy);
Type resultETy = getElementTypeOrSelf(resultTy);

if (!isa<Float32Type>(operandETy) || !isa<Float4E2M1FNType>(resultETy)) {
return rewriter.notifyMatchFailure(op, "not a trunc of F32 to F4E2M1FN");
}

Type i4Ty = cloneToShapedType(operandTy, b.getI4Type());
Type i8Ty = cloneToShapedType(operandTy, b.getI8Type());
Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());

Value c0x1 = createConst(op->getLoc(), i4Ty, 1, rewriter);
Value c0x00000016 = createConst(op->getLoc(), i32Ty, 22, rewriter);
Value c0x00 = createConst(op.getLoc(), i8Ty, 0x00, rewriter);
Value c0xff = createConst(op.getLoc(), i8Ty, 0xff, rewriter);
Value c0x00000000 = createConst(op.getLoc(), i32Ty, 0, rewriter);

// Step 1: Clamp to bounds.
Value cHigherBound =
createFloatConst(op->getLoc(), f32Ty, APFloat(6.0f), rewriter);
Value cLowerBound =
createFloatConst(op->getLoc(), f32Ty, APFloat(-6.0f), rewriter);
Value operandClamped = b.create<arith::MinimumFOp>(cLowerBound, operand);
operandClamped = b.create<arith::MaximumFOp>(cHigherBound, operandClamped);
Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operandClamped);

// Step 2: Convert exponent by adjusting bias.
Value biasAdjustment = createConst(op.getLoc(), i32Ty, 0x7e, rewriter);
Value cF4MantissaWidth = c0x1; // 1
Value cF32MantissaWidth =
createConst(op->getLoc(), i32Ty, 23, rewriter); // 23
Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
Value biasAdjustedSignExp =
b.create<arith::SubIOp>(f32SignExp, biasAdjustment);
Value f4SignExp = b.create<arith::TruncIOp>(i4Ty, biasAdjustedSignExp);
f4SignExp = b.create<arith::ShLIOp>(f4SignExp, cF4MantissaWidth);

// Step 3: Set mantissa to first bit.
Value cF32FirstBitMask =
createConst(op.getLoc(), i32Ty, 0x400000, rewriter);
Value man1Bit = b.create<arith::AndIOp>(f32Bits, cF32FirstBitMask);
man1Bit = b.create<arith::ShRUIOp>(man1Bit, c0x00000016);
Value f4Man = b.create<arith::TruncIOp>(i4Ty, man1Bit);
Value f4Bits = b.create<arith::AddIOp>(f4SignExp, f4Man);

// Step 4: Special consideration for conversion to 0.5.
Value cF32MantissaMask =
createConst(op->getLoc(), i32Ty, 0x7fffff, rewriter);
Value f8Exp = b.create<arith::TruncIOp>(i8Ty, biasAdjustedSignExp);
Value isSubnormal =
b.create<arith::CmpIOp>(arith::CmpIPredicate::sle, f8Exp, c0x00);
Value isNegOneExp =
b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0xff);
Value man23Bits = b.create<arith::AndIOp>(f32Bits, cF32MantissaMask);
Value isNonZeroMan = b.create<arith::CmpIOp>(arith::CmpIPredicate::ugt,
man23Bits, c0x00000000);
Value roundToHalf = b.create<arith::AndIOp>(isNegOneExp, isNonZeroMan);
Value isZeroExp =
b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0x00);

Value subnormalF4Bits = createConst(op->getLoc(), i4Ty, 0xf, rewriter);
Value halfF4Bits = createConst(op->getLoc(), i4Ty, 0x0, rewriter);
Value subResult =
b.create<arith::SelectOp>(isSubnormal, subnormalF4Bits, f4Bits);
subResult = b.create<arith::SelectOp>(roundToHalf, halfF4Bits, subResult);
f4Bits = b.create<arith::SelectOp>(isZeroExp, f4Bits, subResult);

// Step 5: Round up if necessary.
Value cF32Last22BitMask =
createConst(op->getLoc(), i32Ty, 0x3fffff, rewriter);
Value cRound =
createConst(op.getLoc(), i32Ty, 0x200000, rewriter); // 010 0000...
Value man22Bits = b.create<arith::AndIOp>(f32Bits, cF32Last22BitMask);
Value shouldRound =
b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, man22Bits, cRound);
shouldRound = b.create<arith::OrIOp>(shouldRound, isSubnormal);
Value roundedF4Bits = b.create<arith::AddIOp>(f4Bits, c0x1);
f4Bits = b.create<arith::SelectOp>(shouldRound, roundedF4Bits, f4Bits);

Value result = b.create<arith::BitcastOp>(resultTy, f4Bits);
rewriter.replaceOp(op, result);
return success();
}
};

/*
TruncF to F8E8M0 is expected to extract exponent bits out of F32 type
Since All kinds of Infs and NaNs are mapped to same exponent bits in F32 type,
Expand Down