diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 2eef0a06d0eb4..de09dae24eccf 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -772,17 +772,26 @@ LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value /*buffer*/); bool hasVectorizationImpl(Operation *); /// Emit a suitable vector form for an operation. If provided, -/// `inputVectorSizes` are used to vectorize this operation. `inputVectorSizes` -/// must match the rank of the iteration space of the operation and the sizes -/// must be smaller or equal than their counterpart interation space sizes, if -/// static. `inputVectorShapes` also allows the vectorization of operations with -/// dynamic shapes. +/// `inputVectorSizes` are used to vectorize this operation. +/// `inputVectorSizes` must match the rank of the iteration space of the +/// operation and the input vector sizes must be greater than or equal to +/// their counterpart iteration space sizes, if static. `inputVectorShapes` +/// also allows the vectorization of operations with dynamic shapes. LogicalResult vectorize(RewriterBase &rewriter, Operation *op, ArrayRef inputVectorSizes = {}, ArrayRef inputScalableVecDims = {}, bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false); +/// Vectorize and store new vectorized results in `newResuls`, without replacing +/// the old `op`. +LogicalResult vectorize(RewriterBase &rewriter, Operation *op, + SmallVector &newResults, + ArrayRef inputVectorSizes = {}, + ArrayRef inputScalableVecDims = {}, + bool vectorizeNDExtract = false, + bool flatten1DDepthwiseConv = false); + /// Emit a suitable vector form for a Copy op with fully static shape. LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index ff28bd7c48342..3efef3af93fa3 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -2522,13 +2522,8 @@ bool mlir::linalg::hasVectorizationImpl(Operation *op) { tensor::InsertSliceOp>(op); } -/// Emit a suitable vector form for an operation. If provided, -/// `inputVectorSizes` are used to vectorize this operation. -/// `inputVectorSizes` must match the rank of the iteration space of the -/// operation and the input vector sizes must be greater than or equal to -/// their counterpart iteration space sizes, if static. `inputVectorShapes` -/// also allows the vectorization of operations with dynamic shapes. LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op, + SmallVector &newResults, ArrayRef inputVectorSizes, ArrayRef inputScalableVecDims, bool vectorizeNDExtract, @@ -2558,57 +2553,65 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op, } } - SmallVector results; - auto vectorizeResult = - TypeSwitch(op) - .Case([&](auto linalgOp) { - // TODO: isaConvolutionOpInterface that can also infer from - // generic features. Will require stride/dilation attributes - // inference. - if (isa(linalgOp.getOperation())) { - FailureOr convOr = vectorizeConvolution( - rewriter, linalgOp, inputVectorSizes, inputScalableVecDims, - flatten1DDepthwiseConv); - if (succeeded(convOr)) { - llvm::append_range(results, (*convOr)->getResults()); - return success(); - } - - LDBG("Unsupported convolution can't be vectorized.\n"); - return failure(); - } - - LDBG("Vectorize generic by broadcasting to the canonical vector " - "shape\n"); - - // Pre-process before proceeding. - convertAffineApply(rewriter, linalgOp); - - // TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted - // to 'OpBuilder' when it is passed over to some methods like - // 'vectorizeAsLinalgGeneric'. This is highly problematic: if we - // erase an op within these methods, the actual rewriter won't be - // notified and we will end up with read-after-free issues! - return vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results); - }) - .Case([&](auto padOp) { - return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes, - results); - }) - .Case([&](auto packOp) { - return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes, - results); - }) - .Case([&](auto unpackOp) { - return vectorizeAsTensorUnpackOp(rewriter, unpackOp, - inputVectorSizes, results); - }) - .Case([&](auto sliceOp) { - return vectorizeAsInsertSliceOp(rewriter, sliceOp, inputVectorSizes, - results); - }) - .Default([](auto) { return failure(); }); + return TypeSwitch(op) + .Case([&](auto linalgOp) { + // TODO: isaConvolutionOpInterface that can also infer from + // generic features. Will require stride/dilation attributes + // inference. + if (isa(linalgOp.getOperation())) { + FailureOr convOr = vectorizeConvolution( + rewriter, linalgOp, inputVectorSizes, inputScalableVecDims, + flatten1DDepthwiseConv); + if (succeeded(convOr)) { + llvm::append_range(newResults, (*convOr)->getResults()); + return success(); + } + + LDBG("Unsupported convolution can't be vectorized.\n"); + return failure(); + } + + LDBG("Vectorize generic by broadcasting to the canonical vector " + "shape\n"); + + // Pre-process before proceeding. + convertAffineApply(rewriter, linalgOp); + + // TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted + // to 'OpBuilder' when it is passed over to some methods like + // 'vectorizeAsLinalgGeneric'. This is highly problematic: if we + // erase an op within these methods, the actual rewriter won't be + // notified and we will end up with read-after-free issues! + return vectorizeAsLinalgGeneric(rewriter, state, linalgOp, newResults); + }) + .Case([&](auto padOp) { + return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes, + newResults); + }) + .Case([&](auto packOp) { + return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes, + newResults); + }) + .Case([&](auto unpackOp) { + return vectorizeAsTensorUnpackOp(rewriter, unpackOp, inputVectorSizes, + newResults); + }) + .Case([&](auto sliceOp) { + return vectorizeAsInsertSliceOp(rewriter, sliceOp, inputVectorSizes, + newResults); + }) + .Default([](auto) { return failure(); }); +} +LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op, + ArrayRef inputVectorSizes, + ArrayRef inputScalableVecDims, + bool vectorizeNDExtract, + bool flatten1DDepthwiseConv) { + SmallVector results; + LogicalResult vectorizeResult = mlir::linalg::vectorize( + rewriter, op, results, inputVectorSizes, inputScalableVecDims, + vectorizeNDExtract, flatten1DDepthwiseConv); if (failed(vectorizeResult)) { LDBG("Vectorization failed\n"); return failure();