Skip to content

[mlir] Expose linearize/delinearize lowering transforms #144156

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

Max191
Copy link
Contributor

@Max191 Max191 commented Jun 13, 2025

Moves the transformation logic from the AffineLinearizeOp and AffineDelinearizeOp lowerings into separate transform functions that can now be called separately. This provides a more controlled way to apply the op lowerings.

@llvmbot
Copy link
Member

llvmbot commented Jun 13, 2025

@llvm/pr-subscribers-mlir

Author: None (Max191)

Changes

Moves the transformation logic from the AffineLinearizeOp and AffineDelinearizeOp lowerings into separate transform functions that can now be called separately. This provides a more controlled way to apply the op lowerings.


Full diff: https://github.com/llvm/llvm-project/pull/144156.diff

2 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h (+13)
  • (modified) mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp (+111-107)
diff --git a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
index bf830a29613fd..779571e911e1d 100644
--- a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
@@ -14,6 +14,7 @@
 #ifndef MLIR_DIALECT_AFFINE_TRANSFORMS_TRANSFORMS_H
 #define MLIR_DIALECT_AFFINE_TRANSFORMS_TRANSFORMS_H
 
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
 #include "mlir/Support/LLVM.h"
 
@@ -33,6 +34,18 @@ enum class BoundType;
 namespace affine {
 class AffineApplyOp;
 
+/// Lowers `affine.delinearize_index` into a sequence of division and remainder
+/// operations.
+LogicalResult lowerAffineDelinearizeIndexOp(RewriterBase &rewriter,
+                                            AffineDelinearizeIndexOp op);
+
+/// Lowers `affine.linearize_index` into a sequence of multiplications and
+/// additions. Make a best effort to sort the input indices so that
+/// the most loop-invariant terms are at the left of the additions
+/// to enable loop-invariant code motion.
+LogicalResult lowerAffineLinearizeIndexOp(RewriterBase &rewriter,
+                                          AffineLinearizeIndexOp op);
+
 /// Populate patterns that expand affine index operations into more fundamental
 /// operations (not necessarily restricted to Affine dialect).
 void populateAffineExpandIndexOpsPatterns(RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
index 35205a6ca2eee..c0ef28c648ac5 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
@@ -84,126 +84,130 @@ static SmallVector<Value> computeStrides(Location loc, RewriterBase &rewriter,
   return result;
 }
 
+LogicalResult
+affine::lowerAffineDelinearizeIndexOp(RewriterBase &rewriter,
+                                      AffineDelinearizeIndexOp op) {
+  Location loc = op.getLoc();
+  Value linearIdx = op.getLinearIndex();
+  unsigned numResults = op.getNumResults();
+  ArrayRef<int64_t> staticBasis = op.getStaticBasis();
+  if (numResults == staticBasis.size())
+    staticBasis = staticBasis.drop_front();
+
+  if (numResults == 1) {
+    rewriter.replaceOp(op, linearIdx);
+    return success();
+  }
+
+  SmallVector<Value> results;
+  results.reserve(numResults);
+  SmallVector<Value> strides =
+      computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis,
+                     /*knownNonNegative=*/true);
+
+  Value zero = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
+
+  Value initialPart =
+      rewriter.create<arith::FloorDivSIOp>(loc, linearIdx, strides.front());
+  results.push_back(initialPart);
+
+  auto emitModTerm = [&](Value stride) -> Value {
+    Value remainder = rewriter.create<arith::RemSIOp>(loc, linearIdx, stride);
+    Value remainderNegative = rewriter.create<arith::CmpIOp>(
+        loc, arith::CmpIPredicate::slt, remainder, zero);
+    // If the correction is relevant, this term is <= stride, which is known
+    // to be positive in `index`. Otherwise, while 2 * stride might overflow,
+    // this branch won't be taken, so the risk of `poison` is fine.
+    Value corrected = rewriter.create<arith::AddIOp>(
+        loc, remainder, stride, arith::IntegerOverflowFlags::nsw);
+    Value mod = rewriter.create<arith::SelectOp>(loc, remainderNegative,
+                                                 corrected, remainder);
+    return mod;
+  };
+
+  // Generate all the intermediate parts
+  for (size_t i = 0, e = strides.size() - 1; i < e; ++i) {
+    Value thisStride = strides[i];
+    Value nextStride = strides[i + 1];
+    Value modulus = emitModTerm(thisStride);
+    // We know both inputs are positive, so floorDiv == div.
+    // This could potentially be a divui, but it's not clear if that would
+    // cause issues.
+    Value divided = rewriter.create<arith::DivSIOp>(loc, modulus, nextStride);
+    results.push_back(divided);
+  }
+
+  results.push_back(emitModTerm(strides.back()));
+
+  rewriter.replaceOp(op, results);
+  return success();
+}
+
+LogicalResult affine::lowerAffineLinearizeIndexOp(RewriterBase &rewriter,
+                                                  AffineLinearizeIndexOp op) {
+  // Should be folded away, included here for safety.
+  if (op.getMultiIndex().empty()) {
+    rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0);
+    return success();
+  }
+
+  Location loc = op.getLoc();
+  ValueRange multiIndex = op.getMultiIndex();
+  size_t numIndexes = multiIndex.size();
+  ArrayRef<int64_t> staticBasis = op.getStaticBasis();
+  if (numIndexes == staticBasis.size())
+    staticBasis = staticBasis.drop_front();
+
+  SmallVector<Value> strides =
+      computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis,
+                     /*knownNonNegative=*/op.getDisjoint());
+  SmallVector<std::pair<Value, int64_t>> scaledValues;
+  scaledValues.reserve(numIndexes);
+
+  // Note: strides doesn't contain a value for the final element (stride 1)
+  // and everything else lines up. We use the "mutable" accessor so we can get
+  // our hands on an `OpOperand&` for the loop invariant counting function.
+  for (auto [stride, idxOp] :
+       llvm::zip_equal(strides, llvm::drop_end(op.getMultiIndexMutable()))) {
+    Value scaledIdx = rewriter.create<arith::MulIOp>(
+        loc, idxOp.get(), stride, arith::IntegerOverflowFlags::nsw);
+    int64_t numHoistableLoops = numEnclosingInvariantLoops(idxOp);
+    scaledValues.emplace_back(scaledIdx, numHoistableLoops);
+  }
+  scaledValues.emplace_back(
+      multiIndex.back(),
+      numEnclosingInvariantLoops(op.getMultiIndexMutable()[numIndexes - 1]));
+
+  // Sort by how many enclosing loops there are, ties implicitly broken by
+  // size of the stride.
+  llvm::stable_sort(scaledValues,
+                    [&](auto l, auto r) { return l.second > r.second; });
+
+  Value result = scaledValues.front().first;
+  for (auto [scaledValue, numHoistableLoops] : llvm::drop_begin(scaledValues)) {
+    std::ignore = numHoistableLoops;
+    result = rewriter.create<arith::AddIOp>(loc, result, scaledValue,
+                                            arith::IntegerOverflowFlags::nsw);
+  }
+  rewriter.replaceOp(op, result);
+  return success();
+}
+
 namespace {
-/// Lowers `affine.delinearize_index` into a sequence of division and remainder
-/// operations.
 struct LowerDelinearizeIndexOps
     : public OpRewritePattern<AffineDelinearizeIndexOp> {
   using OpRewritePattern<AffineDelinearizeIndexOp>::OpRewritePattern;
   LogicalResult matchAndRewrite(AffineDelinearizeIndexOp op,
                                 PatternRewriter &rewriter) const override {
-    Location loc = op.getLoc();
-    Value linearIdx = op.getLinearIndex();
-    unsigned numResults = op.getNumResults();
-    ArrayRef<int64_t> staticBasis = op.getStaticBasis();
-    if (numResults == staticBasis.size())
-      staticBasis = staticBasis.drop_front();
-
-    if (numResults == 1) {
-      rewriter.replaceOp(op, linearIdx);
-      return success();
-    }
-
-    SmallVector<Value> results;
-    results.reserve(numResults);
-    SmallVector<Value> strides =
-        computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis,
-                       /*knownNonNegative=*/true);
-
-    Value zero = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
-
-    Value initialPart =
-        rewriter.create<arith::FloorDivSIOp>(loc, linearIdx, strides.front());
-    results.push_back(initialPart);
-
-    auto emitModTerm = [&](Value stride) -> Value {
-      Value remainder = rewriter.create<arith::RemSIOp>(loc, linearIdx, stride);
-      Value remainderNegative = rewriter.create<arith::CmpIOp>(
-          loc, arith::CmpIPredicate::slt, remainder, zero);
-      // If the correction is relevant, this term is <= stride, which is known
-      // to be positive in `index`. Otherwise, while 2 * stride might overflow,
-      // this branch won't be taken, so the risk of `poison` is fine.
-      Value corrected = rewriter.create<arith::AddIOp>(
-          loc, remainder, stride, arith::IntegerOverflowFlags::nsw);
-      Value mod = rewriter.create<arith::SelectOp>(loc, remainderNegative,
-                                                   corrected, remainder);
-      return mod;
-    };
-
-    // Generate all the intermediate parts
-    for (size_t i = 0, e = strides.size() - 1; i < e; ++i) {
-      Value thisStride = strides[i];
-      Value nextStride = strides[i + 1];
-      Value modulus = emitModTerm(thisStride);
-      // We know both inputs are positive, so floorDiv == div.
-      // This could potentially be a divui, but it's not clear if that would
-      // cause issues.
-      Value divided = rewriter.create<arith::DivSIOp>(loc, modulus, nextStride);
-      results.push_back(divided);
-    }
-
-    results.push_back(emitModTerm(strides.back()));
-
-    rewriter.replaceOp(op, results);
-    return success();
+    return affine::lowerAffineDelinearizeIndexOp(rewriter, op);
   }
 };
 
-/// Lowers `affine.linearize_index` into a sequence of multiplications and
-/// additions. Make a best effort to sort the input indices so that
-/// the most loop-invariant terms are at the left of the additions
-/// to enable loop-invariant code motion.
 struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
   using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(AffineLinearizeIndexOp op,
                                 PatternRewriter &rewriter) const override {
-    // Should be folded away, included here for safety.
-    if (op.getMultiIndex().empty()) {
-      rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0);
-      return success();
-    }
-
-    Location loc = op.getLoc();
-    ValueRange multiIndex = op.getMultiIndex();
-    size_t numIndexes = multiIndex.size();
-    ArrayRef<int64_t> staticBasis = op.getStaticBasis();
-    if (numIndexes == staticBasis.size())
-      staticBasis = staticBasis.drop_front();
-
-    SmallVector<Value> strides =
-        computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis,
-                       /*knownNonNegative=*/op.getDisjoint());
-    SmallVector<std::pair<Value, int64_t>> scaledValues;
-    scaledValues.reserve(numIndexes);
-
-    // Note: strides doesn't contain a value for the final element (stride 1)
-    // and everything else lines up. We use the "mutable" accessor so we can get
-    // our hands on an `OpOperand&` for the loop invariant counting function.
-    for (auto [stride, idxOp] :
-         llvm::zip_equal(strides, llvm::drop_end(op.getMultiIndexMutable()))) {
-      Value scaledIdx = rewriter.create<arith::MulIOp>(
-          loc, idxOp.get(), stride, arith::IntegerOverflowFlags::nsw);
-      int64_t numHoistableLoops = numEnclosingInvariantLoops(idxOp);
-      scaledValues.emplace_back(scaledIdx, numHoistableLoops);
-    }
-    scaledValues.emplace_back(
-        multiIndex.back(),
-        numEnclosingInvariantLoops(op.getMultiIndexMutable()[numIndexes - 1]));
-
-    // Sort by how many enclosing loops there are, ties implicitly broken by
-    // size of the stride.
-    llvm::stable_sort(scaledValues,
-                      [&](auto l, auto r) { return l.second > r.second; });
-
-    Value result = scaledValues.front().first;
-    for (auto [scaledValue, numHoistableLoops] :
-         llvm::drop_begin(scaledValues)) {
-      std::ignore = numHoistableLoops;
-      result = rewriter.create<arith::AddIOp>(loc, result, scaledValue,
-                                              arith::IntegerOverflowFlags::nsw);
-    }
-    rewriter.replaceOp(op, result);
-    return success();
+    return affine::lowerAffineLinearizeIndexOp(rewriter, op);
   }
 };
 

@llvmbot
Copy link
Member

llvmbot commented Jun 13, 2025

@llvm/pr-subscribers-mlir-affine

Author: None (Max191)

Changes

Moves the transformation logic from the AffineLinearizeOp and AffineDelinearizeOp lowerings into separate transform functions that can now be called separately. This provides a more controlled way to apply the op lowerings.


Full diff: https://github.com/llvm/llvm-project/pull/144156.diff

2 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h (+13)
  • (modified) mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp (+111-107)
diff --git a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
index bf830a29613fd..779571e911e1d 100644
--- a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
@@ -14,6 +14,7 @@
 #ifndef MLIR_DIALECT_AFFINE_TRANSFORMS_TRANSFORMS_H
 #define MLIR_DIALECT_AFFINE_TRANSFORMS_TRANSFORMS_H
 
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
 #include "mlir/Support/LLVM.h"
 
@@ -33,6 +34,18 @@ enum class BoundType;
 namespace affine {
 class AffineApplyOp;
 
+/// Lowers `affine.delinearize_index` into a sequence of division and remainder
+/// operations.
+LogicalResult lowerAffineDelinearizeIndexOp(RewriterBase &rewriter,
+                                            AffineDelinearizeIndexOp op);
+
+/// Lowers `affine.linearize_index` into a sequence of multiplications and
+/// additions. Make a best effort to sort the input indices so that
+/// the most loop-invariant terms are at the left of the additions
+/// to enable loop-invariant code motion.
+LogicalResult lowerAffineLinearizeIndexOp(RewriterBase &rewriter,
+                                          AffineLinearizeIndexOp op);
+
 /// Populate patterns that expand affine index operations into more fundamental
 /// operations (not necessarily restricted to Affine dialect).
 void populateAffineExpandIndexOpsPatterns(RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
index 35205a6ca2eee..c0ef28c648ac5 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
@@ -84,126 +84,130 @@ static SmallVector<Value> computeStrides(Location loc, RewriterBase &rewriter,
   return result;
 }
 
+LogicalResult
+affine::lowerAffineDelinearizeIndexOp(RewriterBase &rewriter,
+                                      AffineDelinearizeIndexOp op) {
+  Location loc = op.getLoc();
+  Value linearIdx = op.getLinearIndex();
+  unsigned numResults = op.getNumResults();
+  ArrayRef<int64_t> staticBasis = op.getStaticBasis();
+  if (numResults == staticBasis.size())
+    staticBasis = staticBasis.drop_front();
+
+  if (numResults == 1) {
+    rewriter.replaceOp(op, linearIdx);
+    return success();
+  }
+
+  SmallVector<Value> results;
+  results.reserve(numResults);
+  SmallVector<Value> strides =
+      computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis,
+                     /*knownNonNegative=*/true);
+
+  Value zero = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
+
+  Value initialPart =
+      rewriter.create<arith::FloorDivSIOp>(loc, linearIdx, strides.front());
+  results.push_back(initialPart);
+
+  auto emitModTerm = [&](Value stride) -> Value {
+    Value remainder = rewriter.create<arith::RemSIOp>(loc, linearIdx, stride);
+    Value remainderNegative = rewriter.create<arith::CmpIOp>(
+        loc, arith::CmpIPredicate::slt, remainder, zero);
+    // If the correction is relevant, this term is <= stride, which is known
+    // to be positive in `index`. Otherwise, while 2 * stride might overflow,
+    // this branch won't be taken, so the risk of `poison` is fine.
+    Value corrected = rewriter.create<arith::AddIOp>(
+        loc, remainder, stride, arith::IntegerOverflowFlags::nsw);
+    Value mod = rewriter.create<arith::SelectOp>(loc, remainderNegative,
+                                                 corrected, remainder);
+    return mod;
+  };
+
+  // Generate all the intermediate parts
+  for (size_t i = 0, e = strides.size() - 1; i < e; ++i) {
+    Value thisStride = strides[i];
+    Value nextStride = strides[i + 1];
+    Value modulus = emitModTerm(thisStride);
+    // We know both inputs are positive, so floorDiv == div.
+    // This could potentially be a divui, but it's not clear if that would
+    // cause issues.
+    Value divided = rewriter.create<arith::DivSIOp>(loc, modulus, nextStride);
+    results.push_back(divided);
+  }
+
+  results.push_back(emitModTerm(strides.back()));
+
+  rewriter.replaceOp(op, results);
+  return success();
+}
+
+LogicalResult affine::lowerAffineLinearizeIndexOp(RewriterBase &rewriter,
+                                                  AffineLinearizeIndexOp op) {
+  // Should be folded away, included here for safety.
+  if (op.getMultiIndex().empty()) {
+    rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0);
+    return success();
+  }
+
+  Location loc = op.getLoc();
+  ValueRange multiIndex = op.getMultiIndex();
+  size_t numIndexes = multiIndex.size();
+  ArrayRef<int64_t> staticBasis = op.getStaticBasis();
+  if (numIndexes == staticBasis.size())
+    staticBasis = staticBasis.drop_front();
+
+  SmallVector<Value> strides =
+      computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis,
+                     /*knownNonNegative=*/op.getDisjoint());
+  SmallVector<std::pair<Value, int64_t>> scaledValues;
+  scaledValues.reserve(numIndexes);
+
+  // Note: strides doesn't contain a value for the final element (stride 1)
+  // and everything else lines up. We use the "mutable" accessor so we can get
+  // our hands on an `OpOperand&` for the loop invariant counting function.
+  for (auto [stride, idxOp] :
+       llvm::zip_equal(strides, llvm::drop_end(op.getMultiIndexMutable()))) {
+    Value scaledIdx = rewriter.create<arith::MulIOp>(
+        loc, idxOp.get(), stride, arith::IntegerOverflowFlags::nsw);
+    int64_t numHoistableLoops = numEnclosingInvariantLoops(idxOp);
+    scaledValues.emplace_back(scaledIdx, numHoistableLoops);
+  }
+  scaledValues.emplace_back(
+      multiIndex.back(),
+      numEnclosingInvariantLoops(op.getMultiIndexMutable()[numIndexes - 1]));
+
+  // Sort by how many enclosing loops there are, ties implicitly broken by
+  // size of the stride.
+  llvm::stable_sort(scaledValues,
+                    [&](auto l, auto r) { return l.second > r.second; });
+
+  Value result = scaledValues.front().first;
+  for (auto [scaledValue, numHoistableLoops] : llvm::drop_begin(scaledValues)) {
+    std::ignore = numHoistableLoops;
+    result = rewriter.create<arith::AddIOp>(loc, result, scaledValue,
+                                            arith::IntegerOverflowFlags::nsw);
+  }
+  rewriter.replaceOp(op, result);
+  return success();
+}
+
 namespace {
-/// Lowers `affine.delinearize_index` into a sequence of division and remainder
-/// operations.
 struct LowerDelinearizeIndexOps
     : public OpRewritePattern<AffineDelinearizeIndexOp> {
   using OpRewritePattern<AffineDelinearizeIndexOp>::OpRewritePattern;
   LogicalResult matchAndRewrite(AffineDelinearizeIndexOp op,
                                 PatternRewriter &rewriter) const override {
-    Location loc = op.getLoc();
-    Value linearIdx = op.getLinearIndex();
-    unsigned numResults = op.getNumResults();
-    ArrayRef<int64_t> staticBasis = op.getStaticBasis();
-    if (numResults == staticBasis.size())
-      staticBasis = staticBasis.drop_front();
-
-    if (numResults == 1) {
-      rewriter.replaceOp(op, linearIdx);
-      return success();
-    }
-
-    SmallVector<Value> results;
-    results.reserve(numResults);
-    SmallVector<Value> strides =
-        computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis,
-                       /*knownNonNegative=*/true);
-
-    Value zero = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
-
-    Value initialPart =
-        rewriter.create<arith::FloorDivSIOp>(loc, linearIdx, strides.front());
-    results.push_back(initialPart);
-
-    auto emitModTerm = [&](Value stride) -> Value {
-      Value remainder = rewriter.create<arith::RemSIOp>(loc, linearIdx, stride);
-      Value remainderNegative = rewriter.create<arith::CmpIOp>(
-          loc, arith::CmpIPredicate::slt, remainder, zero);
-      // If the correction is relevant, this term is <= stride, which is known
-      // to be positive in `index`. Otherwise, while 2 * stride might overflow,
-      // this branch won't be taken, so the risk of `poison` is fine.
-      Value corrected = rewriter.create<arith::AddIOp>(
-          loc, remainder, stride, arith::IntegerOverflowFlags::nsw);
-      Value mod = rewriter.create<arith::SelectOp>(loc, remainderNegative,
-                                                   corrected, remainder);
-      return mod;
-    };
-
-    // Generate all the intermediate parts
-    for (size_t i = 0, e = strides.size() - 1; i < e; ++i) {
-      Value thisStride = strides[i];
-      Value nextStride = strides[i + 1];
-      Value modulus = emitModTerm(thisStride);
-      // We know both inputs are positive, so floorDiv == div.
-      // This could potentially be a divui, but it's not clear if that would
-      // cause issues.
-      Value divided = rewriter.create<arith::DivSIOp>(loc, modulus, nextStride);
-      results.push_back(divided);
-    }
-
-    results.push_back(emitModTerm(strides.back()));
-
-    rewriter.replaceOp(op, results);
-    return success();
+    return affine::lowerAffineDelinearizeIndexOp(rewriter, op);
   }
 };
 
-/// Lowers `affine.linearize_index` into a sequence of multiplications and
-/// additions. Make a best effort to sort the input indices so that
-/// the most loop-invariant terms are at the left of the additions
-/// to enable loop-invariant code motion.
 struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
   using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(AffineLinearizeIndexOp op,
                                 PatternRewriter &rewriter) const override {
-    // Should be folded away, included here for safety.
-    if (op.getMultiIndex().empty()) {
-      rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0);
-      return success();
-    }
-
-    Location loc = op.getLoc();
-    ValueRange multiIndex = op.getMultiIndex();
-    size_t numIndexes = multiIndex.size();
-    ArrayRef<int64_t> staticBasis = op.getStaticBasis();
-    if (numIndexes == staticBasis.size())
-      staticBasis = staticBasis.drop_front();
-
-    SmallVector<Value> strides =
-        computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis,
-                       /*knownNonNegative=*/op.getDisjoint());
-    SmallVector<std::pair<Value, int64_t>> scaledValues;
-    scaledValues.reserve(numIndexes);
-
-    // Note: strides doesn't contain a value for the final element (stride 1)
-    // and everything else lines up. We use the "mutable" accessor so we can get
-    // our hands on an `OpOperand&` for the loop invariant counting function.
-    for (auto [stride, idxOp] :
-         llvm::zip_equal(strides, llvm::drop_end(op.getMultiIndexMutable()))) {
-      Value scaledIdx = rewriter.create<arith::MulIOp>(
-          loc, idxOp.get(), stride, arith::IntegerOverflowFlags::nsw);
-      int64_t numHoistableLoops = numEnclosingInvariantLoops(idxOp);
-      scaledValues.emplace_back(scaledIdx, numHoistableLoops);
-    }
-    scaledValues.emplace_back(
-        multiIndex.back(),
-        numEnclosingInvariantLoops(op.getMultiIndexMutable()[numIndexes - 1]));
-
-    // Sort by how many enclosing loops there are, ties implicitly broken by
-    // size of the stride.
-    llvm::stable_sort(scaledValues,
-                      [&](auto l, auto r) { return l.second > r.second; });
-
-    Value result = scaledValues.front().first;
-    for (auto [scaledValue, numHoistableLoops] :
-         llvm::drop_begin(scaledValues)) {
-      std::ignore = numHoistableLoops;
-      result = rewriter.create<arith::AddIOp>(loc, result, scaledValue,
-                                              arith::IntegerOverflowFlags::nsw);
-    }
-    rewriter.replaceOp(op, result);
-    return success();
+    return affine::lowerAffineLinearizeIndexOp(rewriter, op);
   }
 };
 

Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Am curious about the motivation but see no issue here, approved

@@ -14,6 +14,7 @@
#ifndef MLIR_DIALECT_AFFINE_TRANSFORMS_TRANSFORMS_H
#define MLIR_DIALECT_AFFINE_TRANSFORMS_TRANSFORMS_H

#include "mlir/Dialect/Affine/IR/AffineOps.h"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prefer forward-declaring classes that are only used in function declarations instead of including their full definitions. This tends to improve compile time significantly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants