Skip to content

[mlir] Expose linalg vectorization without replacement #144158

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

Max191
Copy link
Contributor

@Max191 Max191 commented Jun 13, 2025

Exposes a new linalg::vectorize function, which populates the replacement values for the vectorized op, instead of immediately replacing the results of the original operation. This allows for more control over vectorization transformations. The old vectorize function remains the same, and simply calls the new function and replaces the op results in its implementation.

@llvmbot
Copy link
Member

llvmbot commented Jun 13, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: None (Max191)

Changes

Exposes a new linalg::vectorize function, which populates the replacement values for the vectorized op, instead of immediately replacing the results of the original operation. This allows for more control over vectorization transformations. The old vectorize function remains the same, and simply calls the new function and replaces the op results in its implementation.


Full diff: https://github.com/llvm/llvm-project/pull/144158.diff

2 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+14-5)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+59-56)
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 2eef0a06d0eb4..de09dae24eccf 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -772,17 +772,26 @@ LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value /*buffer*/);
 bool hasVectorizationImpl(Operation *);
 
 /// Emit a suitable vector form for an operation. If provided,
-/// `inputVectorSizes` are used to vectorize this operation. `inputVectorSizes`
-/// must match the rank of the iteration space of the operation and the sizes
-/// must be smaller or equal than their counterpart interation space sizes, if
-/// static. `inputVectorShapes` also allows the vectorization of operations with
-/// dynamic shapes.
+/// `inputVectorSizes` are used to vectorize this operation.
+/// `inputVectorSizes` must match the rank of the iteration space of the
+/// operation and the input vector sizes must be greater than or equal to
+/// their counterpart iteration space sizes, if static. `inputVectorShapes`
+/// also allows the vectorization of operations with dynamic shapes.
 LogicalResult vectorize(RewriterBase &rewriter, Operation *op,
                         ArrayRef<int64_t> inputVectorSizes = {},
                         ArrayRef<bool> inputScalableVecDims = {},
                         bool vectorizeNDExtract = false,
                         bool flatten1DDepthwiseConv = false);
 
+/// Vectorize and store new vectorized results in `newResuls`, without replacing
+/// the old `op`.
+LogicalResult vectorize(RewriterBase &rewriter, Operation *op,
+                        SmallVector<Value> &newResults,
+                        ArrayRef<int64_t> inputVectorSizes = {},
+                        ArrayRef<bool> inputScalableVecDims = {},
+                        bool vectorizeNDExtract = false,
+                        bool flatten1DDepthwiseConv = false);
+
 /// Emit a suitable vector form for a Copy op with fully static shape.
 LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp);
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index ff28bd7c48342..3efef3af93fa3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2522,13 +2522,8 @@ bool mlir::linalg::hasVectorizationImpl(Operation *op) {
              tensor::InsertSliceOp>(op);
 }
 
-/// Emit a suitable vector form for an operation. If provided,
-/// `inputVectorSizes` are used to vectorize this operation.
-/// `inputVectorSizes` must match the rank of the iteration space of the
-/// operation and the input vector sizes must be greater than or equal to
-/// their counterpart iteration space sizes, if static. `inputVectorShapes`
-/// also allows the vectorization of operations with dynamic shapes.
 LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
+                                      SmallVector<Value> &newResults,
                                       ArrayRef<int64_t> inputVectorSizes,
                                       ArrayRef<bool> inputScalableVecDims,
                                       bool vectorizeNDExtract,
@@ -2558,57 +2553,65 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
     }
   }
 
-  SmallVector<Value> results;
-  auto vectorizeResult =
-      TypeSwitch<Operation *, LogicalResult>(op)
-          .Case<linalg::LinalgOp>([&](auto linalgOp) {
-            // TODO: isaConvolutionOpInterface that can also infer from
-            // generic features. Will require stride/dilation attributes
-            // inference.
-            if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
-              FailureOr<Operation *> convOr = vectorizeConvolution(
-                  rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
-                  flatten1DDepthwiseConv);
-              if (succeeded(convOr)) {
-                llvm::append_range(results, (*convOr)->getResults());
-                return success();
-              }
-
-              LDBG("Unsupported convolution can't be vectorized.\n");
-              return failure();
-            }
-
-            LDBG("Vectorize generic by broadcasting to the canonical vector "
-                 "shape\n");
-
-            // Pre-process before proceeding.
-            convertAffineApply(rewriter, linalgOp);
-
-            // TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted
-            // to 'OpBuilder' when it is passed over to some methods like
-            // 'vectorizeAsLinalgGeneric'. This is highly problematic: if we
-            // erase an op within these methods, the actual rewriter won't be
-            // notified and we will end up with read-after-free issues!
-            return vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results);
-          })
-          .Case<tensor::PadOp>([&](auto padOp) {
-            return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes,
-                                          results);
-          })
-          .Case<linalg::PackOp>([&](auto packOp) {
-            return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
-                                           results);
-          })
-          .Case<linalg::UnPackOp>([&](auto unpackOp) {
-            return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
-                                             inputVectorSizes, results);
-          })
-          .Case<tensor::InsertSliceOp>([&](auto sliceOp) {
-            return vectorizeAsInsertSliceOp(rewriter, sliceOp, inputVectorSizes,
-                                            results);
-          })
-          .Default([](auto) { return failure(); });
+  return TypeSwitch<Operation *, LogicalResult>(op)
+      .Case<linalg::LinalgOp>([&](auto linalgOp) {
+        // TODO: isaConvolutionOpInterface that can also infer from
+        // generic features. Will require stride/dilation attributes
+        // inference.
+        if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
+          FailureOr<Operation *> convOr = vectorizeConvolution(
+              rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
+              flatten1DDepthwiseConv);
+          if (succeeded(convOr)) {
+            llvm::append_range(newResults, (*convOr)->getResults());
+            return success();
+          }
+
+          LDBG("Unsupported convolution can't be vectorized.\n");
+          return failure();
+        }
+
+        LDBG("Vectorize generic by broadcasting to the canonical vector "
+             "shape\n");
+
+        // Pre-process before proceeding.
+        convertAffineApply(rewriter, linalgOp);
+
+        // TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted
+        // to 'OpBuilder' when it is passed over to some methods like
+        // 'vectorizeAsLinalgGeneric'. This is highly problematic: if we
+        // erase an op within these methods, the actual rewriter won't be
+        // notified and we will end up with read-after-free issues!
+        return vectorizeAsLinalgGeneric(rewriter, state, linalgOp, newResults);
+      })
+      .Case<tensor::PadOp>([&](auto padOp) {
+        return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes,
+                                      newResults);
+      })
+      .Case<linalg::PackOp>([&](auto packOp) {
+        return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
+                                       newResults);
+      })
+      .Case<linalg::UnPackOp>([&](auto unpackOp) {
+        return vectorizeAsTensorUnpackOp(rewriter, unpackOp, inputVectorSizes,
+                                         newResults);
+      })
+      .Case<tensor::InsertSliceOp>([&](auto sliceOp) {
+        return vectorizeAsInsertSliceOp(rewriter, sliceOp, inputVectorSizes,
+                                        newResults);
+      })
+      .Default([](auto) { return failure(); });
+}
 
+LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
+                                      ArrayRef<int64_t> inputVectorSizes,
+                                      ArrayRef<bool> inputScalableVecDims,
+                                      bool vectorizeNDExtract,
+                                      bool flatten1DDepthwiseConv) {
+  SmallVector<Value> results;
+  LogicalResult vectorizeResult = mlir::linalg::vectorize(
+      rewriter, op, results, inputVectorSizes, inputScalableVecDims,
+      vectorizeNDExtract, flatten1DDepthwiseConv);
   if (failed(vectorizeResult)) {
     LDBG("Vectorization failed\n");
     return failure();

@rengolin rengolin requested a review from shahidact June 13, 2025 20:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants