Skip to content

[mlir][vector] Remove MatrixMultiplyOp and FlatTransposeOp from Vector dialect #144307

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,6 @@
namespace mlir {
class LLVMTypeConverter;

/// Collect a set of patterns to convert from Vector contractions to LLVM Matrix
/// Intrinsics. To lower to assembly, the LLVM flag -lower-matrix-intrinsics
/// will be needed when invoking LLVM.
void populateVectorToLLVMMatrixConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns);

/// Collect a set of patterns to convert from the Vector dialect to LLVM.
void populateVectorToLLVMConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
Expand Down
118 changes: 0 additions & 118 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2710,124 +2710,6 @@ def Vector_PrintOp :
}];
}

//===----------------------------------------------------------------------===//
// Ops used for supporting progressive lowering and conversion type changes.
// The Ops are typically not used directly by higher level dialects, but are
// used by intra-dialect rewriting rules to bring vector operations closer
// to the hardware ISA.
//===----------------------------------------------------------------------===//

/// Vector dialect matrix multiplication op that operates on flattened 1-D
/// MLIR vectors. This is the counterpart of llvm.matrix.multiply in MLIR.
/// This may seem redundant with vector.contract but it serves the purposes of
/// more progressive lowering and localized type conversion on the path:
/// `vector<...x...xf32> -> vector<...xf32> -> !llvm<... x float>`.
def Vector_MatmulOp : Vector_Op<"matrix_multiply", [Pure,
PredOpTrait<"lhs operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>,
PredOpTrait<"rhs operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 1>>]>,
Arguments<(
// TODO: tighten vector element types that make sense.
ins FixedVectorOfRankAndType<[1],
[AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$lhs,
FixedVectorOfRankAndType<[1],
[AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$rhs,
I32Attr:$lhs_rows, I32Attr:$lhs_columns, I32Attr:$rhs_columns)>,
Results<(
outs FixedVectorOfRankAndType<[1],
[AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$res)>
{
let summary = "Vector matrix multiplication op that operates on flattened 1-D"
" MLIR vectors";
let description = [{
This is the counterpart of llvm.matrix.multiply in MLIR. It serves the
purposes of more progressive lowering and localized type conversion.
Higher levels typically lower matrix multiplications into 'vector.contract'
operations. Subsequent rewriting rule progressively lower these operations
into 'vector.matrix_multiply' operations to bring the operations closer
to the hardware ISA.

The ‘vector.matrix_multiply’ op treats `lhs` as matrix with <lhs_rows> rows
and <lhs_columns> columns, `rhs` as matrix with <lhs_columns> rows and
<rhs_columns> and multiplies them. The result matrix is returned embedded in
the result vector.

Note, the corresponding LLVM intrinsic, `@llvm.matrix.multiply.*`, does not
support scalable vectors. Hence, this Op is only available for fixed-width
vectors. Also see:

http://llvm.org/docs/LangRef.html#llvm-matrix-multiply-intrinsic

Example:

```mlir
%C = vector.matrix_multiply %A, %B
{ lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_columns = 3: i32 } :
(vector<64xf64>, vector<48xf64>) -> vector<12xf64>
```
}];
let builders = [
OpBuilder<(ins "Value":$lhs, "Value":$rhs, "unsigned":$lhsRows,
"unsigned":$lhsColumns, "unsigned":$rhsColumns),
[{
$_state.addOperands({lhs, rhs});
$_state.addAttribute("lhs_rows",$_builder.getI32IntegerAttr(lhsRows));
$_state.addAttribute("lhs_columns",$_builder.getI32IntegerAttr(lhsColumns));
$_state.addAttribute("rhs_columns",$_builder.getI32IntegerAttr(rhsColumns));
$_state.addTypes(VectorType::get(lhsRows * rhsColumns,
::llvm::cast<VectorType>(lhs.getType()).getElementType()));
}]>,
];
let assemblyFormat = "$lhs `,` $rhs attr-dict "
"`:` `(` type($lhs) `,` type($rhs) `)` `->` type($res)";
}

/// Vector dialect matrix transposition op that operates on flattened 1-D
/// MLIR vectors. This is the counterpart of llvm.matrix.transpose in MLIR.
/// This may seem redundant with vector.transpose but it serves the purposes of
/// more progressive lowering and localized type conversion on the path:
/// `vector<...x...xf32> -> vector<...xf32> -> !llvm<... x float>`.
def Vector_FlatTransposeOp : Vector_Op<"flat_transpose", [Pure,
PredOpTrait<"source operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>]>,
Arguments<(
// TODO: tighten vector element types that make sense.
ins FixedVectorOfRankAndType<[1],
[AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$matrix,
I32Attr:$rows, I32Attr:$columns)>,
Results<(
outs FixedVectorOfRankAndType<[1],
[AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$res)> {
let summary = "Vector matrix transposition on flattened 1-D MLIR vectors";
let description = [{
This is the counterpart of llvm.matrix.transpose in MLIR. It serves
the purposes of more progressive lowering and localized type conversion.
Higher levels typically lower matrix transpositions into 'vector.transpose'
operations. Subsequent rewriting rule progressively lower these operations
into 'vector.flat_transpose' operations to bring the operations closer
to the hardware ISA.

The `vector.flat_transpose` op treats the 1-D input `matrix` as
a 2-D matrix with <rows> rows and <columns> columns, and returns the
transposed matrix in flattened form in 'res'.

Note, the corresponding LLVM intrinsic, `@llvm.matrix.transpose.*`, does not
support scalable vectors. Hence, this Op is only available for fixed-width
vectors. Also see:

http://llvm.org/docs/LangRef.html#llvm-matrix-transpose-intrinsic

Example:

```mlir
%1 = vector.flat_transpose %0 {columns = 4 : i32, rows = 4 : i32}
: vector<16xf32> -> vector<16xf32>
```
}];
let assemblyFormat = "$matrix attr-dict `:` type($matrix) `->` type($res)";
}

//===----------------------------------------------------------------------===//
// SplatOp
//===----------------------------------------------------------------------===//
Expand Down
41 changes: 0 additions & 41 deletions mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,41 +184,6 @@ class VectorBitCastOpConversion
}
};

/// Conversion pattern for a vector.matrix_multiply.
/// This is lowered directly to the proper llvm.intr.matrix.multiply.
class VectorMatmulOpConversion
: public ConvertOpToLLVMPattern<vector::MatmulOp> {
public:
using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
matmulOp, typeConverter->convertType(matmulOp.getRes().getType()),
adaptor.getLhs(), adaptor.getRhs(), matmulOp.getLhsRows(),
matmulOp.getLhsColumns(), matmulOp.getRhsColumns());
return success();
}
};

/// Conversion pattern for a vector.flat_transpose.
/// This is lowered directly to the proper llvm.intr.matrix.transpose.
class VectorFlatTransposeOpConversion
: public ConvertOpToLLVMPattern<vector::FlatTransposeOp> {
public:
using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
transOp, typeConverter->convertType(transOp.getRes().getType()),
adaptor.getMatrix(), transOp.getRows(), transOp.getColumns());
return success();
}
};

/// Overloaded utility that replaces a vector.load, vector.store,
/// vector.maskedload and vector.maskedstore with their respective LLVM
/// couterparts.
Expand Down Expand Up @@ -2026,12 +1991,6 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorScalableStepOpLowering>(converter);
}

void mlir::populateVectorToLLVMMatrixConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
patterns.add<VectorMatmulOpConversion>(converter);
patterns.add<VectorFlatTransposeOpConversion>(converter);
}

namespace {
struct VectorToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,9 @@ void ConvertVectorToLLVMPass::runOnOperation() {
LLVMTypeConverter converter(&getContext(), options);
RewritePatternSet patterns(&getContext());
populateVectorTransferLoweringPatterns(patterns);
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
populateVectorToLLVMConversionPatterns(
converter, patterns, reassociateFPReductions, force32BitVectorIndices,
useVectorAlignment);
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);

// Architecture specific augmentations.
LLVMConversionTarget target(getContext());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,9 @@ void mlir::arith::populateEmulateUnsupportedFloatsLegality(
return converter.isLegal(op);
});
// Manually mark arithmetic-performing vector instructions.
target.addDynamicallyLegalOp<
vector::ContractionOp, vector::ReductionOp, vector::MultiDimReductionOp,
vector::FMAOp, vector::OuterProductOp, vector::MatmulOp, vector::ScanOp>(
target.addDynamicallyLegalOp<vector::ContractionOp, vector::ReductionOp,
vector::MultiDimReductionOp, vector::FMAOp,
vector::OuterProductOp, vector::ScanOp>(
[&](Operation *op) { return converter.isLegal(op); });
target.addLegalOp<arith::BitcastOp, arith::ExtFOp, arith::TruncFOp,
arith::ConstantOp, vector::SplatOp>();
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
MLIRTensorDialect
MLIRTransforms
MLIRVectorDialect
MLIRLLVMDialect
MLIRVectorInterfaces
MLIRVectorUtils
)
12 changes: 8 additions & 4 deletions mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
Expand Down Expand Up @@ -1280,12 +1281,11 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
/// %mtb = maybe_transpose
/// %flattened_a = vector.shape_cast %mta
/// %flattened_b = vector.shape_cast %mtb
/// %flattened_d = vector.matrix_multiply %flattened_a, %flattened_b
/// %flattened_d = llvm.intr.matrix.multiply %flattened_a, %flattened_b
/// %mtd = vector.shape_cast %flattened_d
/// %d = maybe_untranspose %mtd
/// %e = add %c, %d
/// ```
/// `vector.matrix_multiply` later lowers to `llvm.matrix.multiply`.
//
/// This only kicks in when vectorContractLowering is set to `Matmul`.
/// vector.transpose operations are inserted if the vector.contract op is not a
Expand Down Expand Up @@ -1362,8 +1362,12 @@ FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
rhs = rew.create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs);

Value mul = rew.create<vector::MatmulOp>(loc, lhs, rhs, lhsRows, lhsColumns,
rhsColumns);
Value mul = rew.create<LLVM::MatrixMultiplyOp>(
loc,
VectorType::get(lhsRows * rhsColumns,
cast<VectorType>(lhs.getType()).getElementType()),
lhs, rhs, lhsRows, lhsColumns, rhsColumns);

mul = rew.create<vector::ShapeCastOp>(
loc,
VectorType::get({lhsRows, rhsColumns},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
Expand Down Expand Up @@ -338,7 +339,7 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
rewriter.create<vector::ShapeCastOp>(loc, flattenedType, input);
auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]);
auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]);
Value trans = rewriter.create<vector::FlatTransposeOp>(
Value trans = rewriter.create<LLVM::MatrixTransposeOp>(
loc, flattenedType, matrix, rows, columns);
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans);
return success();
Expand Down
80 changes: 0 additions & 80 deletions mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1424,36 +1424,6 @@ func.func @fma_scalable(%vec_1d: vector<[8]xf32>, %vec_2d: vector<2x[4]xf32>, %v

return %0, %1, %2: vector<[8]xf32>, vector<2x[4]xf32>, vector<1x1x[1]xf32>
}
// -----

//===----------------------------------------------------------------------===//
// vector.matrix_multiply
//===----------------------------------------------------------------------===//

// 4x16 16x3 4x3
func.func @matrix_ops(%A: vector<64xf64>, %B: vector<48xf64>) -> vector<12xf64> {
%C = vector.matrix_multiply %A, %B
{ lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_columns = 3: i32 } :
(vector<64xf64>, vector<48xf64>) -> vector<12xf64>
return %C: vector<12xf64>
}
// CHECK-LABEL: @matrix_ops
// CHECK: llvm.intr.matrix.multiply %{{.*}}, %{{.*}} {
// CHECK-SAME: lhs_columns = 16 : i32, lhs_rows = 4 : i32, rhs_columns = 3 : i32
// CHECK-SAME: } : (vector<64xf64>, vector<48xf64>) -> vector<12xf64>

// -----

func.func @matrix_ops_index(%A: vector<64xindex>, %B: vector<48xindex>) -> vector<12xindex> {
%C = vector.matrix_multiply %A, %B
{ lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_columns = 3: i32 } :
(vector<64xindex>, vector<48xindex>) -> vector<12xindex>
return %C: vector<12xindex>
}
// CHECK-LABEL: @matrix_ops_index
// CHECK: llvm.intr.matrix.multiply %{{.*}}, %{{.*}} {
// CHECK-SAME: lhs_columns = 16 : i32, lhs_rows = 4 : i32, rhs_columns = 3 : i32
// CHECK-SAME: } : (vector<64xi64>, vector<48xi64>) -> vector<12xi64>

// -----

Expand Down Expand Up @@ -1602,56 +1572,6 @@ func.func @create_mask_1d_scalable(%num_elems : index) -> vector<[4]xi1> {

// -----

//===----------------------------------------------------------------------===//
// vector.flat_transpose
//===----------------------------------------------------------------------===//

func.func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> {
%0 = vector.flat_transpose %arg0 { rows = 4: i32, columns = 4: i32 }
: vector<16xf32> -> vector<16xf32>
return %0 : vector<16xf32>
}

// CHECK-LABEL: func @flat_transpose
// CHECK-SAME: %[[A:.*]]: vector<16xf32>
// CHECK: %[[T:.*]] = llvm.intr.matrix.transpose %[[A]]
// CHECK-SAME: {columns = 4 : i32, rows = 4 : i32} :
// CHECK-SAME: vector<16xf32> into vector<16xf32>
// CHECK: return %[[T]] : vector<16xf32>

// -----

func.func @flat_transpose_index(%arg0: vector<16xindex>) -> vector<16xindex> {
%0 = vector.flat_transpose %arg0 { rows = 4: i32, columns = 4: i32 }
: vector<16xindex> -> vector<16xindex>
return %0 : vector<16xindex>
}
// CHECK-LABEL: func @flat_transpose_index
// CHECK-SAME: %[[A:.*]]: vector<16xindex>
// CHECK: %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<16xindex> to vector<16xi64>
// CHECK: %[[T1:.*]] = llvm.intr.matrix.transpose %[[T0]]
// CHECK-SAME: {columns = 4 : i32, rows = 4 : i32} :
// CHECK-SAME: vector<16xi64> into vector<16xi64>
// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<16xi64> to vector<16xindex>
// CHECK: return %[[T2]] : vector<16xindex>

// -----

func.func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> {
%0 = vector.flat_transpose %arg0 { rows = 4: i32, columns = 4: i32 }
: vector<16xf32> -> vector<16xf32>
return %0 : vector<16xf32>
}

// CHECK-LABEL: func @flat_transpose
// CHECK-SAME: %[[A:.*]]: vector<16xf32>
// CHECK: %[[T:.*]] = llvm.intr.matrix.transpose %[[A]]
// CHECK-SAME: {columns = 4 : i32, rows = 4 : i32} :
// CHECK-SAME: vector<16xf32> into vector<16xf32>
// CHECK: return %[[T]] : vector<16xf32>

// -----

//===----------------------------------------------------------------------===//
// vector.gather
//
Expand Down
Loading
Loading