-
Notifications
You must be signed in to change notification settings - Fork 13.8k
[mlir] Expose linearize/delinearize lowering transforms #144156
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir] Expose linearize/delinearize lowering transforms #144156
Conversation
@llvm/pr-subscribers-mlir Author: None (Max191) ChangesMoves the transformation logic from the AffineLinearizeOp and AffineDelinearizeOp lowerings into separate transform functions that can now be called separately. This provides a more controlled way to apply the op lowerings. Full diff: https://github.com/llvm/llvm-project/pull/144156.diff 2 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
index bf830a29613fd..779571e911e1d 100644
--- a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
@@ -14,6 +14,7 @@
#ifndef MLIR_DIALECT_AFFINE_TRANSFORMS_TRANSFORMS_H
#define MLIR_DIALECT_AFFINE_TRANSFORMS_TRANSFORMS_H
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Support/LLVM.h"
@@ -33,6 +34,18 @@ enum class BoundType;
namespace affine {
class AffineApplyOp;
+/// Lowers `affine.delinearize_index` into a sequence of division and remainder
+/// operations.
+LogicalResult lowerAffineDelinearizeIndexOp(RewriterBase &rewriter,
+ AffineDelinearizeIndexOp op);
+
+/// Lowers `affine.linearize_index` into a sequence of multiplications and
+/// additions. Make a best effort to sort the input indices so that
+/// the most loop-invariant terms are at the left of the additions
+/// to enable loop-invariant code motion.
+LogicalResult lowerAffineLinearizeIndexOp(RewriterBase &rewriter,
+ AffineLinearizeIndexOp op);
+
/// Populate patterns that expand affine index operations into more fundamental
/// operations (not necessarily restricted to Affine dialect).
void populateAffineExpandIndexOpsPatterns(RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
index 35205a6ca2eee..c0ef28c648ac5 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
@@ -84,126 +84,130 @@ static SmallVector<Value> computeStrides(Location loc, RewriterBase &rewriter,
return result;
}
+LogicalResult
+affine::lowerAffineDelinearizeIndexOp(RewriterBase &rewriter,
+ AffineDelinearizeIndexOp op) {
+ Location loc = op.getLoc();
+ Value linearIdx = op.getLinearIndex();
+ unsigned numResults = op.getNumResults();
+ ArrayRef<int64_t> staticBasis = op.getStaticBasis();
+ if (numResults == staticBasis.size())
+ staticBasis = staticBasis.drop_front();
+
+ if (numResults == 1) {
+ rewriter.replaceOp(op, linearIdx);
+ return success();
+ }
+
+ SmallVector<Value> results;
+ results.reserve(numResults);
+ SmallVector<Value> strides =
+ computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis,
+ /*knownNonNegative=*/true);
+
+ Value zero = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
+
+ Value initialPart =
+ rewriter.create<arith::FloorDivSIOp>(loc, linearIdx, strides.front());
+ results.push_back(initialPart);
+
+ auto emitModTerm = [&](Value stride) -> Value {
+ Value remainder = rewriter.create<arith::RemSIOp>(loc, linearIdx, stride);
+ Value remainderNegative = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::slt, remainder, zero);
+ // If the correction is relevant, this term is <= stride, which is known
+ // to be positive in `index`. Otherwise, while 2 * stride might overflow,
+ // this branch won't be taken, so the risk of `poison` is fine.
+ Value corrected = rewriter.create<arith::AddIOp>(
+ loc, remainder, stride, arith::IntegerOverflowFlags::nsw);
+ Value mod = rewriter.create<arith::SelectOp>(loc, remainderNegative,
+ corrected, remainder);
+ return mod;
+ };
+
+ // Generate all the intermediate parts
+ for (size_t i = 0, e = strides.size() - 1; i < e; ++i) {
+ Value thisStride = strides[i];
+ Value nextStride = strides[i + 1];
+ Value modulus = emitModTerm(thisStride);
+ // We know both inputs are positive, so floorDiv == div.
+ // This could potentially be a divui, but it's not clear if that would
+ // cause issues.
+ Value divided = rewriter.create<arith::DivSIOp>(loc, modulus, nextStride);
+ results.push_back(divided);
+ }
+
+ results.push_back(emitModTerm(strides.back()));
+
+ rewriter.replaceOp(op, results);
+ return success();
+}
+
+LogicalResult affine::lowerAffineLinearizeIndexOp(RewriterBase &rewriter,
+ AffineLinearizeIndexOp op) {
+ // Should be folded away, included here for safety.
+ if (op.getMultiIndex().empty()) {
+ rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0);
+ return success();
+ }
+
+ Location loc = op.getLoc();
+ ValueRange multiIndex = op.getMultiIndex();
+ size_t numIndexes = multiIndex.size();
+ ArrayRef<int64_t> staticBasis = op.getStaticBasis();
+ if (numIndexes == staticBasis.size())
+ staticBasis = staticBasis.drop_front();
+
+ SmallVector<Value> strides =
+ computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis,
+ /*knownNonNegative=*/op.getDisjoint());
+ SmallVector<std::pair<Value, int64_t>> scaledValues;
+ scaledValues.reserve(numIndexes);
+
+ // Note: strides doesn't contain a value for the final element (stride 1)
+ // and everything else lines up. We use the "mutable" accessor so we can get
+ // our hands on an `OpOperand&` for the loop invariant counting function.
+ for (auto [stride, idxOp] :
+ llvm::zip_equal(strides, llvm::drop_end(op.getMultiIndexMutable()))) {
+ Value scaledIdx = rewriter.create<arith::MulIOp>(
+ loc, idxOp.get(), stride, arith::IntegerOverflowFlags::nsw);
+ int64_t numHoistableLoops = numEnclosingInvariantLoops(idxOp);
+ scaledValues.emplace_back(scaledIdx, numHoistableLoops);
+ }
+ scaledValues.emplace_back(
+ multiIndex.back(),
+ numEnclosingInvariantLoops(op.getMultiIndexMutable()[numIndexes - 1]));
+
+ // Sort by how many enclosing loops there are, ties implicitly broken by
+ // size of the stride.
+ llvm::stable_sort(scaledValues,
+ [&](auto l, auto r) { return l.second > r.second; });
+
+ Value result = scaledValues.front().first;
+ for (auto [scaledValue, numHoistableLoops] : llvm::drop_begin(scaledValues)) {
+ std::ignore = numHoistableLoops;
+ result = rewriter.create<arith::AddIOp>(loc, result, scaledValue,
+ arith::IntegerOverflowFlags::nsw);
+ }
+ rewriter.replaceOp(op, result);
+ return success();
+}
+
namespace {
-/// Lowers `affine.delinearize_index` into a sequence of division and remainder
-/// operations.
struct LowerDelinearizeIndexOps
: public OpRewritePattern<AffineDelinearizeIndexOp> {
using OpRewritePattern<AffineDelinearizeIndexOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AffineDelinearizeIndexOp op,
PatternRewriter &rewriter) const override {
- Location loc = op.getLoc();
- Value linearIdx = op.getLinearIndex();
- unsigned numResults = op.getNumResults();
- ArrayRef<int64_t> staticBasis = op.getStaticBasis();
- if (numResults == staticBasis.size())
- staticBasis = staticBasis.drop_front();
-
- if (numResults == 1) {
- rewriter.replaceOp(op, linearIdx);
- return success();
- }
-
- SmallVector<Value> results;
- results.reserve(numResults);
- SmallVector<Value> strides =
- computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis,
- /*knownNonNegative=*/true);
-
- Value zero = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
-
- Value initialPart =
- rewriter.create<arith::FloorDivSIOp>(loc, linearIdx, strides.front());
- results.push_back(initialPart);
-
- auto emitModTerm = [&](Value stride) -> Value {
- Value remainder = rewriter.create<arith::RemSIOp>(loc, linearIdx, stride);
- Value remainderNegative = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::slt, remainder, zero);
- // If the correction is relevant, this term is <= stride, which is known
- // to be positive in `index`. Otherwise, while 2 * stride might overflow,
- // this branch won't be taken, so the risk of `poison` is fine.
- Value corrected = rewriter.create<arith::AddIOp>(
- loc, remainder, stride, arith::IntegerOverflowFlags::nsw);
- Value mod = rewriter.create<arith::SelectOp>(loc, remainderNegative,
- corrected, remainder);
- return mod;
- };
-
- // Generate all the intermediate parts
- for (size_t i = 0, e = strides.size() - 1; i < e; ++i) {
- Value thisStride = strides[i];
- Value nextStride = strides[i + 1];
- Value modulus = emitModTerm(thisStride);
- // We know both inputs are positive, so floorDiv == div.
- // This could potentially be a divui, but it's not clear if that would
- // cause issues.
- Value divided = rewriter.create<arith::DivSIOp>(loc, modulus, nextStride);
- results.push_back(divided);
- }
-
- results.push_back(emitModTerm(strides.back()));
-
- rewriter.replaceOp(op, results);
- return success();
+ return affine::lowerAffineDelinearizeIndexOp(rewriter, op);
}
};
-/// Lowers `affine.linearize_index` into a sequence of multiplications and
-/// additions. Make a best effort to sort the input indices so that
-/// the most loop-invariant terms are at the left of the additions
-/// to enable loop-invariant code motion.
struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AffineLinearizeIndexOp op,
PatternRewriter &rewriter) const override {
- // Should be folded away, included here for safety.
- if (op.getMultiIndex().empty()) {
- rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0);
- return success();
- }
-
- Location loc = op.getLoc();
- ValueRange multiIndex = op.getMultiIndex();
- size_t numIndexes = multiIndex.size();
- ArrayRef<int64_t> staticBasis = op.getStaticBasis();
- if (numIndexes == staticBasis.size())
- staticBasis = staticBasis.drop_front();
-
- SmallVector<Value> strides =
- computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis,
- /*knownNonNegative=*/op.getDisjoint());
- SmallVector<std::pair<Value, int64_t>> scaledValues;
- scaledValues.reserve(numIndexes);
-
- // Note: strides doesn't contain a value for the final element (stride 1)
- // and everything else lines up. We use the "mutable" accessor so we can get
- // our hands on an `OpOperand&` for the loop invariant counting function.
- for (auto [stride, idxOp] :
- llvm::zip_equal(strides, llvm::drop_end(op.getMultiIndexMutable()))) {
- Value scaledIdx = rewriter.create<arith::MulIOp>(
- loc, idxOp.get(), stride, arith::IntegerOverflowFlags::nsw);
- int64_t numHoistableLoops = numEnclosingInvariantLoops(idxOp);
- scaledValues.emplace_back(scaledIdx, numHoistableLoops);
- }
- scaledValues.emplace_back(
- multiIndex.back(),
- numEnclosingInvariantLoops(op.getMultiIndexMutable()[numIndexes - 1]));
-
- // Sort by how many enclosing loops there are, ties implicitly broken by
- // size of the stride.
- llvm::stable_sort(scaledValues,
- [&](auto l, auto r) { return l.second > r.second; });
-
- Value result = scaledValues.front().first;
- for (auto [scaledValue, numHoistableLoops] :
- llvm::drop_begin(scaledValues)) {
- std::ignore = numHoistableLoops;
- result = rewriter.create<arith::AddIOp>(loc, result, scaledValue,
- arith::IntegerOverflowFlags::nsw);
- }
- rewriter.replaceOp(op, result);
- return success();
+ return affine::lowerAffineLinearizeIndexOp(rewriter, op);
}
};
|
@llvm/pr-subscribers-mlir-affine Author: None (Max191) ChangesMoves the transformation logic from the AffineLinearizeOp and AffineDelinearizeOp lowerings into separate transform functions that can now be called separately. This provides a more controlled way to apply the op lowerings. Full diff: https://github.com/llvm/llvm-project/pull/144156.diff 2 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
index bf830a29613fd..779571e911e1d 100644
--- a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
@@ -14,6 +14,7 @@
#ifndef MLIR_DIALECT_AFFINE_TRANSFORMS_TRANSFORMS_H
#define MLIR_DIALECT_AFFINE_TRANSFORMS_TRANSFORMS_H
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Support/LLVM.h"
@@ -33,6 +34,18 @@ enum class BoundType;
namespace affine {
class AffineApplyOp;
+/// Lowers `affine.delinearize_index` into a sequence of division and remainder
+/// operations.
+LogicalResult lowerAffineDelinearizeIndexOp(RewriterBase &rewriter,
+ AffineDelinearizeIndexOp op);
+
+/// Lowers `affine.linearize_index` into a sequence of multiplications and
+/// additions. Make a best effort to sort the input indices so that
+/// the most loop-invariant terms are at the left of the additions
+/// to enable loop-invariant code motion.
+LogicalResult lowerAffineLinearizeIndexOp(RewriterBase &rewriter,
+ AffineLinearizeIndexOp op);
+
/// Populate patterns that expand affine index operations into more fundamental
/// operations (not necessarily restricted to Affine dialect).
void populateAffineExpandIndexOpsPatterns(RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
index 35205a6ca2eee..c0ef28c648ac5 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
@@ -84,126 +84,130 @@ static SmallVector<Value> computeStrides(Location loc, RewriterBase &rewriter,
return result;
}
+LogicalResult
+affine::lowerAffineDelinearizeIndexOp(RewriterBase &rewriter,
+ AffineDelinearizeIndexOp op) {
+ Location loc = op.getLoc();
+ Value linearIdx = op.getLinearIndex();
+ unsigned numResults = op.getNumResults();
+ ArrayRef<int64_t> staticBasis = op.getStaticBasis();
+ if (numResults == staticBasis.size())
+ staticBasis = staticBasis.drop_front();
+
+ if (numResults == 1) {
+ rewriter.replaceOp(op, linearIdx);
+ return success();
+ }
+
+ SmallVector<Value> results;
+ results.reserve(numResults);
+ SmallVector<Value> strides =
+ computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis,
+ /*knownNonNegative=*/true);
+
+ Value zero = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
+
+ Value initialPart =
+ rewriter.create<arith::FloorDivSIOp>(loc, linearIdx, strides.front());
+ results.push_back(initialPart);
+
+ auto emitModTerm = [&](Value stride) -> Value {
+ Value remainder = rewriter.create<arith::RemSIOp>(loc, linearIdx, stride);
+ Value remainderNegative = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::slt, remainder, zero);
+ // If the correction is relevant, this term is <= stride, which is known
+ // to be positive in `index`. Otherwise, while 2 * stride might overflow,
+ // this branch won't be taken, so the risk of `poison` is fine.
+ Value corrected = rewriter.create<arith::AddIOp>(
+ loc, remainder, stride, arith::IntegerOverflowFlags::nsw);
+ Value mod = rewriter.create<arith::SelectOp>(loc, remainderNegative,
+ corrected, remainder);
+ return mod;
+ };
+
+ // Generate all the intermediate parts
+ for (size_t i = 0, e = strides.size() - 1; i < e; ++i) {
+ Value thisStride = strides[i];
+ Value nextStride = strides[i + 1];
+ Value modulus = emitModTerm(thisStride);
+ // We know both inputs are positive, so floorDiv == div.
+ // This could potentially be a divui, but it's not clear if that would
+ // cause issues.
+ Value divided = rewriter.create<arith::DivSIOp>(loc, modulus, nextStride);
+ results.push_back(divided);
+ }
+
+ results.push_back(emitModTerm(strides.back()));
+
+ rewriter.replaceOp(op, results);
+ return success();
+}
+
+LogicalResult affine::lowerAffineLinearizeIndexOp(RewriterBase &rewriter,
+ AffineLinearizeIndexOp op) {
+ // Should be folded away, included here for safety.
+ if (op.getMultiIndex().empty()) {
+ rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0);
+ return success();
+ }
+
+ Location loc = op.getLoc();
+ ValueRange multiIndex = op.getMultiIndex();
+ size_t numIndexes = multiIndex.size();
+ ArrayRef<int64_t> staticBasis = op.getStaticBasis();
+ if (numIndexes == staticBasis.size())
+ staticBasis = staticBasis.drop_front();
+
+ SmallVector<Value> strides =
+ computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis,
+ /*knownNonNegative=*/op.getDisjoint());
+ SmallVector<std::pair<Value, int64_t>> scaledValues;
+ scaledValues.reserve(numIndexes);
+
+ // Note: strides doesn't contain a value for the final element (stride 1)
+ // and everything else lines up. We use the "mutable" accessor so we can get
+ // our hands on an `OpOperand&` for the loop invariant counting function.
+ for (auto [stride, idxOp] :
+ llvm::zip_equal(strides, llvm::drop_end(op.getMultiIndexMutable()))) {
+ Value scaledIdx = rewriter.create<arith::MulIOp>(
+ loc, idxOp.get(), stride, arith::IntegerOverflowFlags::nsw);
+ int64_t numHoistableLoops = numEnclosingInvariantLoops(idxOp);
+ scaledValues.emplace_back(scaledIdx, numHoistableLoops);
+ }
+ scaledValues.emplace_back(
+ multiIndex.back(),
+ numEnclosingInvariantLoops(op.getMultiIndexMutable()[numIndexes - 1]));
+
+ // Sort by how many enclosing loops there are, ties implicitly broken by
+ // size of the stride.
+ llvm::stable_sort(scaledValues,
+ [&](auto l, auto r) { return l.second > r.second; });
+
+ Value result = scaledValues.front().first;
+ for (auto [scaledValue, numHoistableLoops] : llvm::drop_begin(scaledValues)) {
+ std::ignore = numHoistableLoops;
+ result = rewriter.create<arith::AddIOp>(loc, result, scaledValue,
+ arith::IntegerOverflowFlags::nsw);
+ }
+ rewriter.replaceOp(op, result);
+ return success();
+}
+
namespace {
-/// Lowers `affine.delinearize_index` into a sequence of division and remainder
-/// operations.
struct LowerDelinearizeIndexOps
: public OpRewritePattern<AffineDelinearizeIndexOp> {
using OpRewritePattern<AffineDelinearizeIndexOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AffineDelinearizeIndexOp op,
PatternRewriter &rewriter) const override {
- Location loc = op.getLoc();
- Value linearIdx = op.getLinearIndex();
- unsigned numResults = op.getNumResults();
- ArrayRef<int64_t> staticBasis = op.getStaticBasis();
- if (numResults == staticBasis.size())
- staticBasis = staticBasis.drop_front();
-
- if (numResults == 1) {
- rewriter.replaceOp(op, linearIdx);
- return success();
- }
-
- SmallVector<Value> results;
- results.reserve(numResults);
- SmallVector<Value> strides =
- computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis,
- /*knownNonNegative=*/true);
-
- Value zero = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
-
- Value initialPart =
- rewriter.create<arith::FloorDivSIOp>(loc, linearIdx, strides.front());
- results.push_back(initialPart);
-
- auto emitModTerm = [&](Value stride) -> Value {
- Value remainder = rewriter.create<arith::RemSIOp>(loc, linearIdx, stride);
- Value remainderNegative = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::slt, remainder, zero);
- // If the correction is relevant, this term is <= stride, which is known
- // to be positive in `index`. Otherwise, while 2 * stride might overflow,
- // this branch won't be taken, so the risk of `poison` is fine.
- Value corrected = rewriter.create<arith::AddIOp>(
- loc, remainder, stride, arith::IntegerOverflowFlags::nsw);
- Value mod = rewriter.create<arith::SelectOp>(loc, remainderNegative,
- corrected, remainder);
- return mod;
- };
-
- // Generate all the intermediate parts
- for (size_t i = 0, e = strides.size() - 1; i < e; ++i) {
- Value thisStride = strides[i];
- Value nextStride = strides[i + 1];
- Value modulus = emitModTerm(thisStride);
- // We know both inputs are positive, so floorDiv == div.
- // This could potentially be a divui, but it's not clear if that would
- // cause issues.
- Value divided = rewriter.create<arith::DivSIOp>(loc, modulus, nextStride);
- results.push_back(divided);
- }
-
- results.push_back(emitModTerm(strides.back()));
-
- rewriter.replaceOp(op, results);
- return success();
+ return affine::lowerAffineDelinearizeIndexOp(rewriter, op);
}
};
-/// Lowers `affine.linearize_index` into a sequence of multiplications and
-/// additions. Make a best effort to sort the input indices so that
-/// the most loop-invariant terms are at the left of the additions
-/// to enable loop-invariant code motion.
struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AffineLinearizeIndexOp op,
PatternRewriter &rewriter) const override {
- // Should be folded away, included here for safety.
- if (op.getMultiIndex().empty()) {
- rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0);
- return success();
- }
-
- Location loc = op.getLoc();
- ValueRange multiIndex = op.getMultiIndex();
- size_t numIndexes = multiIndex.size();
- ArrayRef<int64_t> staticBasis = op.getStaticBasis();
- if (numIndexes == staticBasis.size())
- staticBasis = staticBasis.drop_front();
-
- SmallVector<Value> strides =
- computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis,
- /*knownNonNegative=*/op.getDisjoint());
- SmallVector<std::pair<Value, int64_t>> scaledValues;
- scaledValues.reserve(numIndexes);
-
- // Note: strides doesn't contain a value for the final element (stride 1)
- // and everything else lines up. We use the "mutable" accessor so we can get
- // our hands on an `OpOperand&` for the loop invariant counting function.
- for (auto [stride, idxOp] :
- llvm::zip_equal(strides, llvm::drop_end(op.getMultiIndexMutable()))) {
- Value scaledIdx = rewriter.create<arith::MulIOp>(
- loc, idxOp.get(), stride, arith::IntegerOverflowFlags::nsw);
- int64_t numHoistableLoops = numEnclosingInvariantLoops(idxOp);
- scaledValues.emplace_back(scaledIdx, numHoistableLoops);
- }
- scaledValues.emplace_back(
- multiIndex.back(),
- numEnclosingInvariantLoops(op.getMultiIndexMutable()[numIndexes - 1]));
-
- // Sort by how many enclosing loops there are, ties implicitly broken by
- // size of the stride.
- llvm::stable_sort(scaledValues,
- [&](auto l, auto r) { return l.second > r.second; });
-
- Value result = scaledValues.front().first;
- for (auto [scaledValue, numHoistableLoops] :
- llvm::drop_begin(scaledValues)) {
- std::ignore = numHoistableLoops;
- result = rewriter.create<arith::AddIOp>(loc, result, scaledValue,
- arith::IntegerOverflowFlags::nsw);
- }
- rewriter.replaceOp(op, result);
- return success();
+ return affine::lowerAffineLinearizeIndexOp(rewriter, op);
}
};
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Am curious about the motivation but see no issue here, approved
@@ -14,6 +14,7 @@ | |||
#ifndef MLIR_DIALECT_AFFINE_TRANSFORMS_TRANSFORMS_H | |||
#define MLIR_DIALECT_AFFINE_TRANSFORMS_TRANSFORMS_H | |||
|
|||
#include "mlir/Dialect/Affine/IR/AffineOps.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Prefer forward-declaring classes that are only used in function declarations instead of including their full definitions. This tends to improve compile time significantly.
Signed-off-by: Max Dawkins <[email protected]>
Downstream in IREE, I have an op (map_scatter) that I want to vectorize by reusing the linalg.generic vectorization, but it can have some linearize/delinearize ops in the body, which aren't supported in the linalg.generic vectorizer. |
Signed-off-by: Max Dawkins <[email protected]>
c961bef
to
1a4ee9e
Compare
@Max191 Makes sense, and seems reasonable for now Though I'll note that allowing affine.linearize_index / affine.delinearize_index to take |
Yeah, I think that could be a better end state, although for now it's low priority for me. Perhaps as a future improvement! |
Moves the transformation logic from the AffineLinearizeOp and AffineDelinearizeOp lowerings into separate transform functions that can now be called separately. This provides a more controlled way to apply the op lowerings. --------- Signed-off-by: Max Dawkins <[email protected]>
Moves the transformation logic from the AffineLinearizeOp and AffineDelinearizeOp lowerings into separate transform functions that can now be called separately. This provides a more controlled way to apply the op lowerings.