Skip to content

Commit 6d579fb

Browse files
committed
[mlir] Expose linalg vectorization without replacement
Signed-off-by: Max Dawkins <[email protected]>
1 parent 329dfa1 commit 6d579fb

File tree

2 files changed

+73
-61
lines changed

2 files changed

+73
-61
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -772,17 +772,26 @@ LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value /*buffer*/);
772772
bool hasVectorizationImpl(Operation *);
773773

774774
/// Emit a suitable vector form for an operation. If provided,
775-
/// `inputVectorSizes` are used to vectorize this operation. `inputVectorSizes`
776-
/// must match the rank of the iteration space of the operation and the sizes
777-
/// must be smaller or equal than their counterpart interation space sizes, if
778-
/// static. `inputVectorShapes` also allows the vectorization of operations with
779-
/// dynamic shapes.
775+
/// `inputVectorSizes` are used to vectorize this operation.
776+
/// `inputVectorSizes` must match the rank of the iteration space of the
777+
/// operation and the input vector sizes must be greater than or equal to
778+
/// their counterpart iteration space sizes, if static. `inputVectorShapes`
779+
/// also allows the vectorization of operations with dynamic shapes.
780780
LogicalResult vectorize(RewriterBase &rewriter, Operation *op,
781781
ArrayRef<int64_t> inputVectorSizes = {},
782782
ArrayRef<bool> inputScalableVecDims = {},
783783
bool vectorizeNDExtract = false,
784784
bool flatten1DDepthwiseConv = false);
785785

786+
/// Vectorize and store new vectorized results in `newResuls`, without replacing
787+
/// the old `op`.
788+
LogicalResult vectorize(RewriterBase &rewriter, Operation *op,
789+
SmallVector<Value> &newResults,
790+
ArrayRef<int64_t> inputVectorSizes = {},
791+
ArrayRef<bool> inputScalableVecDims = {},
792+
bool vectorizeNDExtract = false,
793+
bool flatten1DDepthwiseConv = false);
794+
786795
/// Emit a suitable vector form for a Copy op with fully static shape.
787796
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp);
788797

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 59 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -2522,13 +2522,8 @@ bool mlir::linalg::hasVectorizationImpl(Operation *op) {
25222522
tensor::InsertSliceOp>(op);
25232523
}
25242524

2525-
/// Emit a suitable vector form for an operation. If provided,
2526-
/// `inputVectorSizes` are used to vectorize this operation.
2527-
/// `inputVectorSizes` must match the rank of the iteration space of the
2528-
/// operation and the input vector sizes must be greater than or equal to
2529-
/// their counterpart iteration space sizes, if static. `inputVectorShapes`
2530-
/// also allows the vectorization of operations with dynamic shapes.
25312525
LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
2526+
SmallVector<Value> &newResults,
25322527
ArrayRef<int64_t> inputVectorSizes,
25332528
ArrayRef<bool> inputScalableVecDims,
25342529
bool vectorizeNDExtract,
@@ -2558,57 +2553,65 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
25582553
}
25592554
}
25602555

2561-
SmallVector<Value> results;
2562-
auto vectorizeResult =
2563-
TypeSwitch<Operation *, LogicalResult>(op)
2564-
.Case<linalg::LinalgOp>([&](auto linalgOp) {
2565-
// TODO: isaConvolutionOpInterface that can also infer from
2566-
// generic features. Will require stride/dilation attributes
2567-
// inference.
2568-
if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
2569-
FailureOr<Operation *> convOr = vectorizeConvolution(
2570-
rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
2571-
flatten1DDepthwiseConv);
2572-
if (succeeded(convOr)) {
2573-
llvm::append_range(results, (*convOr)->getResults());
2574-
return success();
2575-
}
2576-
2577-
LDBG("Unsupported convolution can't be vectorized.\n");
2578-
return failure();
2579-
}
2580-
2581-
LDBG("Vectorize generic by broadcasting to the canonical vector "
2582-
"shape\n");
2583-
2584-
// Pre-process before proceeding.
2585-
convertAffineApply(rewriter, linalgOp);
2586-
2587-
// TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted
2588-
// to 'OpBuilder' when it is passed over to some methods like
2589-
// 'vectorizeAsLinalgGeneric'. This is highly problematic: if we
2590-
// erase an op within these methods, the actual rewriter won't be
2591-
// notified and we will end up with read-after-free issues!
2592-
return vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results);
2593-
})
2594-
.Case<tensor::PadOp>([&](auto padOp) {
2595-
return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes,
2596-
results);
2597-
})
2598-
.Case<linalg::PackOp>([&](auto packOp) {
2599-
return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
2600-
results);
2601-
})
2602-
.Case<linalg::UnPackOp>([&](auto unpackOp) {
2603-
return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
2604-
inputVectorSizes, results);
2605-
})
2606-
.Case<tensor::InsertSliceOp>([&](auto sliceOp) {
2607-
return vectorizeAsInsertSliceOp(rewriter, sliceOp, inputVectorSizes,
2608-
results);
2609-
})
2610-
.Default([](auto) { return failure(); });
2556+
return TypeSwitch<Operation *, LogicalResult>(op)
2557+
.Case<linalg::LinalgOp>([&](auto linalgOp) {
2558+
// TODO: isaConvolutionOpInterface that can also infer from
2559+
// generic features. Will require stride/dilation attributes
2560+
// inference.
2561+
if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
2562+
FailureOr<Operation *> convOr = vectorizeConvolution(
2563+
rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
2564+
flatten1DDepthwiseConv);
2565+
if (succeeded(convOr)) {
2566+
llvm::append_range(newResults, (*convOr)->getResults());
2567+
return success();
2568+
}
2569+
2570+
LDBG("Unsupported convolution can't be vectorized.\n");
2571+
return failure();
2572+
}
2573+
2574+
LDBG("Vectorize generic by broadcasting to the canonical vector "
2575+
"shape\n");
2576+
2577+
// Pre-process before proceeding.
2578+
convertAffineApply(rewriter, linalgOp);
2579+
2580+
// TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted
2581+
// to 'OpBuilder' when it is passed over to some methods like
2582+
// 'vectorizeAsLinalgGeneric'. This is highly problematic: if we
2583+
// erase an op within these methods, the actual rewriter won't be
2584+
// notified and we will end up with read-after-free issues!
2585+
return vectorizeAsLinalgGeneric(rewriter, state, linalgOp, newResults);
2586+
})
2587+
.Case<tensor::PadOp>([&](auto padOp) {
2588+
return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes,
2589+
newResults);
2590+
})
2591+
.Case<linalg::PackOp>([&](auto packOp) {
2592+
return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
2593+
newResults);
2594+
})
2595+
.Case<linalg::UnPackOp>([&](auto unpackOp) {
2596+
return vectorizeAsTensorUnpackOp(rewriter, unpackOp, inputVectorSizes,
2597+
newResults);
2598+
})
2599+
.Case<tensor::InsertSliceOp>([&](auto sliceOp) {
2600+
return vectorizeAsInsertSliceOp(rewriter, sliceOp, inputVectorSizes,
2601+
newResults);
2602+
})
2603+
.Default([](auto) { return failure(); });
2604+
}
26112605

2606+
LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
2607+
ArrayRef<int64_t> inputVectorSizes,
2608+
ArrayRef<bool> inputScalableVecDims,
2609+
bool vectorizeNDExtract,
2610+
bool flatten1DDepthwiseConv) {
2611+
SmallVector<Value> results;
2612+
LogicalResult vectorizeResult = mlir::linalg::vectorize(
2613+
rewriter, op, results, inputVectorSizes, inputScalableVecDims,
2614+
vectorizeNDExtract, flatten1DDepthwiseConv);
26122615
if (failed(vectorizeResult)) {
26132616
LDBG("Vectorization failed\n");
26142617
return failure();

0 commit comments

Comments
 (0)