Skip to content

[mlir][vector] Remove MatrixMultiplyOp and FlatTransposeOp from Vector dialect #144307

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

banach-space
Copy link
Contributor

This patch deletes vector.matrix_multiply and vector.flat_transpose,
which are thin wrappers around the corresponding LLVM intrinsics:

  • llvm.intr.matrix.multiply
  • llvm.intr.matrix.transpose

These Vector dialect ops did not provide additional semantics or
abstraction beyond the LLVM intrinsics. Their removal simplifies the
lowering pipeline without losing any functionality.

The lowering chains:

  • vector.contractvector.matrix_multiplyllvm.intr.matrix.multiply
  • vector.transposevector.flat_transposellvm.intr.matrix.transpose

are now replaced with:

  • vector.contractllvm.intr.matrix.multiply
  • vector.transposellvm.intr.matrix.transpose

This was accomplished by directly replacing:

  • vector::MatrixMultiplyOp with LLVM::MatrixMultiplyOp
  • vector::FlatTransposeOp with LLVM::MatrixTransposeOp

Note: This change introduces a build-time dependency from Vector to
LLVM. Ideally, such dependencies should be confined to dialect
conversion (ConvertVectorToLLVM). However, moving the lowering code
there would introduce notable churn, so this patch leaves the new
dependency in place for now.

…r dialect

This patch deletes `vector.matrix_multiply` and `vector.flat_transpose`,
which are thin wrappers around the corresponding LLVM intrinsics:
  - `llvm.intr.matrix.multiply`
  - `llvm.intr.matrix.transpose`

These Vector dialect ops did not provide additional semantics or
abstraction beyond the LLVM intrinsics. Their removal simplifies the
lowering pipeline without losing any functionality.

The lowering chains:
  - `vector.contract` → `vector.matrix_multiply` → `llvm.intr.matrix.multiply`
  - `vector.transpose` → `vector.flat_transpose` → `llvm.intr.matrix.transpose`

are now replaced with:
  - `vector.contract` → `llvm.intr.matrix.multiply`
  - `vector.transpose` → `llvm.intr.matrix.transpose`

This was accomplished by directly replacing:
  - `vector::MatrixMultiplyOp` with `LLVM::MatrixMultiplyOp`
  - `vector::FlatTransposeOp` with `LLVM::MatrixTransposeOp`

Note: This change introduces a build-time dependency from `Vector` to
`LLVM`. Ideally, such dependencies should be confined to dialect
conversion (`ConvertVectorToLLVM`). However, moving the lowering code
there would introduce notable churn, so this patch leaves the new
dependency in place for now.
@llvmbot
Copy link
Member

llvmbot commented Jun 16, 2025

@llvm/pr-subscribers-mlir-vector
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-arith

Author: Andrzej Warzyński (banach-space)

Changes

This patch deletes vector.matrix_multiply and vector.flat_transpose,
which are thin wrappers around the corresponding LLVM intrinsics:

  • llvm.intr.matrix.multiply
  • llvm.intr.matrix.transpose

These Vector dialect ops did not provide additional semantics or
abstraction beyond the LLVM intrinsics. Their removal simplifies the
lowering pipeline without losing any functionality.

The lowering chains:

  • vector.contractvector.matrix_multiplyllvm.intr.matrix.multiply
  • vector.transposevector.flat_transposellvm.intr.matrix.transpose

are now replaced with:

  • vector.contractllvm.intr.matrix.multiply
  • vector.transposellvm.intr.matrix.transpose

This was accomplished by directly replacing:

  • vector::MatrixMultiplyOp with LLVM::MatrixMultiplyOp
  • vector::FlatTransposeOp with LLVM::MatrixTransposeOp

Note: This change introduces a build-time dependency from Vector to
LLVM. Ideally, such dependencies should be confined to dialect
conversion (ConvertVectorToLLVM). However, moving the lowering code
there would introduce notable churn, so this patch leaves the new
dependency in place for now.


Patch is 29.94 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/144307.diff

17 Files Affected:

  • (modified) mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h (-6)
  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (-118)
  • (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (-41)
  • (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp (-2)
  • (modified) mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp (+3-3)
  • (modified) mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt (+1)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp (+8-4)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp (+2-1)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (-80)
  • (modified) mlir/test/Dialect/Vector/invalid.mlir (-29)
  • (modified) mlir/test/Dialect/Vector/ops.mlir (-16)
  • (modified) mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir (+2-2)
  • (modified) mlir/test/Dialect/Vector/vector-transpose-lowering.mlir (+2-2)
  • (modified) mlir/test/Integration/Dialect/Vector/CPU/flat-transpose-col.mlir (+4-4)
  • (modified) mlir/test/Integration/Dialect/Vector/CPU/flat-transpose-row.mlir (+4-4)
  • (modified) mlir/test/Integration/Dialect/Vector/CPU/matrix-multiply-col.mlir (+1-1)
  • (modified) mlir/test/Integration/Dialect/Vector/CPU/matrix-multiply-row.mlir (+1-1)
diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
index f6b09deb4e44c..cfb6cc313bc63 100644
--- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
+++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h
@@ -13,12 +13,6 @@
 namespace mlir {
 class LLVMTypeConverter;
 
-/// Collect a set of patterns to convert from Vector contractions to LLVM Matrix
-/// Intrinsics. To lower to assembly, the LLVM flag -lower-matrix-intrinsics
-/// will be needed when invoking LLVM.
-void populateVectorToLLVMMatrixConversionPatterns(
-    const LLVMTypeConverter &converter, RewritePatternSet &patterns);
-
 /// Collect a set of patterns to convert from the Vector dialect to LLVM.
 void populateVectorToLLVMConversionPatterns(
     const LLVMTypeConverter &converter, RewritePatternSet &patterns,
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 8353314ed958b..9c95677ee50da 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2710,124 +2710,6 @@ def Vector_PrintOp :
     }];
 }
 
-//===----------------------------------------------------------------------===//
-// Ops used for supporting progressive lowering and conversion type changes.
-// The Ops are typically not used directly by higher level dialects, but are
-// used by intra-dialect rewriting rules to bring vector operations closer
-// to the hardware ISA.
-//===----------------------------------------------------------------------===//
-
-/// Vector dialect matrix multiplication op that operates on flattened 1-D
-/// MLIR vectors. This is the counterpart of llvm.matrix.multiply in MLIR.
-/// This may seem redundant with vector.contract but it serves the purposes of
-/// more progressive lowering and localized type conversion on the path:
-///   `vector<...x...xf32> -> vector<...xf32> -> !llvm<... x float>`.
-def Vector_MatmulOp : Vector_Op<"matrix_multiply", [Pure,
-        PredOpTrait<"lhs operand and result have same element type",
-                    TCresVTEtIsSameAsOpBase<0, 0>>,
-        PredOpTrait<"rhs operand and result have same element type",
-                    TCresVTEtIsSameAsOpBase<0, 1>>]>,
-      Arguments<(
-        // TODO: tighten vector element types that make sense.
-        ins FixedVectorOfRankAndType<[1],
-              [AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$lhs,
-            FixedVectorOfRankAndType<[1],
-              [AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$rhs,
-            I32Attr:$lhs_rows, I32Attr:$lhs_columns, I32Attr:$rhs_columns)>,
-      Results<(
-        outs FixedVectorOfRankAndType<[1],
-               [AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$res)>
-{
-  let summary = "Vector matrix multiplication op that operates on flattened 1-D"
-    " MLIR vectors";
-  let description = [{
-    This is the counterpart of llvm.matrix.multiply in MLIR. It serves the
-    purposes of more progressive lowering and localized type conversion.
-    Higher levels typically lower matrix multiplications into 'vector.contract'
-    operations. Subsequent rewriting rule progressively lower these operations
-    into 'vector.matrix_multiply' operations to bring the operations closer
-    to the hardware ISA.
-
-    The ‘vector.matrix_multiply’ op treats `lhs` as matrix with <lhs_rows> rows
-    and <lhs_columns> columns, `rhs` as matrix with <lhs_columns> rows and
-    <rhs_columns> and multiplies them. The result matrix is returned embedded in
-    the result vector.
-
-    Note, the corresponding LLVM intrinsic, `@llvm.matrix.multiply.*`, does not
-    support scalable vectors. Hence, this Op is only available for fixed-width
-    vectors. Also see:
-
-    http://llvm.org/docs/LangRef.html#llvm-matrix-multiply-intrinsic
-
-    Example:
-
-    ```mlir
-    %C = vector.matrix_multiply %A, %B
-      { lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_columns = 3: i32 } :
-      (vector<64xf64>, vector<48xf64>) -> vector<12xf64>
-    ```
-  }];
-  let builders = [
-   OpBuilder<(ins "Value":$lhs, "Value":$rhs, "unsigned":$lhsRows,
-     "unsigned":$lhsColumns, "unsigned":$rhsColumns),
-   [{
-     $_state.addOperands({lhs, rhs});
-     $_state.addAttribute("lhs_rows",$_builder.getI32IntegerAttr(lhsRows));
-     $_state.addAttribute("lhs_columns",$_builder.getI32IntegerAttr(lhsColumns));
-     $_state.addAttribute("rhs_columns",$_builder.getI32IntegerAttr(rhsColumns));
-     $_state.addTypes(VectorType::get(lhsRows * rhsColumns,
-       ::llvm::cast<VectorType>(lhs.getType()).getElementType()));
-   }]>,
-  ];
-  let assemblyFormat = "$lhs `,` $rhs attr-dict "
-    "`:` `(` type($lhs) `,` type($rhs) `)` `->` type($res)";
-}
-
-/// Vector dialect matrix transposition op that operates on flattened 1-D
-/// MLIR vectors. This is the counterpart of llvm.matrix.transpose in MLIR.
-/// This may seem redundant with vector.transpose but it serves the purposes of
-/// more progressive lowering and localized type conversion on the path:
-///   `vector<...x...xf32> -> vector<...xf32> -> !llvm<... x float>`.
-def Vector_FlatTransposeOp : Vector_Op<"flat_transpose", [Pure,
-  PredOpTrait<"source operand and result have same element type",
-                 TCresVTEtIsSameAsOpBase<0, 0>>]>,
-    Arguments<(
-      // TODO: tighten vector element types that make sense.
-      ins FixedVectorOfRankAndType<[1],
-            [AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$matrix,
-          I32Attr:$rows, I32Attr:$columns)>,
-    Results<(
-      outs FixedVectorOfRankAndType<[1],
-             [AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$res)> {
-  let summary = "Vector matrix transposition on flattened 1-D MLIR vectors";
-  let description = [{
-    This is the counterpart of llvm.matrix.transpose in MLIR. It serves
-    the purposes of more progressive lowering and localized type conversion.
-    Higher levels typically lower matrix transpositions into 'vector.transpose'
-    operations. Subsequent rewriting rule progressively lower these operations
-    into 'vector.flat_transpose' operations to bring the operations closer
-    to the hardware ISA.
-
-    The `vector.flat_transpose` op treats the 1-D input `matrix` as
-    a 2-D matrix with <rows> rows and <columns> columns, and returns the
-    transposed matrix in flattened form in 'res'.
-
-    Note, the corresponding LLVM intrinsic, `@llvm.matrix.transpose.*`, does not
-    support scalable vectors. Hence, this Op is only available for fixed-width
-    vectors. Also see:
-
-    http://llvm.org/docs/LangRef.html#llvm-matrix-transpose-intrinsic
-
-    Example:
-
-    ```mlir
-    %1 = vector.flat_transpose %0 {columns = 4 : i32, rows = 4 : i32}
-       : vector<16xf32> -> vector<16xf32>
-    ```
-  }];
-  let assemblyFormat = "$matrix attr-dict `:` type($matrix) `->` type($res)";
-}
-
 //===----------------------------------------------------------------------===//
 // SplatOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index f725993635672..bcdeccc54cf17 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -184,41 +184,6 @@ class VectorBitCastOpConversion
   }
 };
 
-/// Conversion pattern for a vector.matrix_multiply.
-/// This is lowered directly to the proper llvm.intr.matrix.multiply.
-class VectorMatmulOpConversion
-    : public ConvertOpToLLVMPattern<vector::MatmulOp> {
-public:
-  using ConvertOpToLLVMPattern<vector::MatmulOp>::ConvertOpToLLVMPattern;
-
-  LogicalResult
-  matchAndRewrite(vector::MatmulOp matmulOp, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
-        matmulOp, typeConverter->convertType(matmulOp.getRes().getType()),
-        adaptor.getLhs(), adaptor.getRhs(), matmulOp.getLhsRows(),
-        matmulOp.getLhsColumns(), matmulOp.getRhsColumns());
-    return success();
-  }
-};
-
-/// Conversion pattern for a vector.flat_transpose.
-/// This is lowered directly to the proper llvm.intr.matrix.transpose.
-class VectorFlatTransposeOpConversion
-    : public ConvertOpToLLVMPattern<vector::FlatTransposeOp> {
-public:
-  using ConvertOpToLLVMPattern<vector::FlatTransposeOp>::ConvertOpToLLVMPattern;
-
-  LogicalResult
-  matchAndRewrite(vector::FlatTransposeOp transOp, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    rewriter.replaceOpWithNewOp<LLVM::MatrixTransposeOp>(
-        transOp, typeConverter->convertType(transOp.getRes().getType()),
-        adaptor.getMatrix(), transOp.getRows(), transOp.getColumns());
-    return success();
-  }
-};
-
 /// Overloaded utility that replaces a vector.load, vector.store,
 /// vector.maskedload and vector.maskedstore with their respective LLVM
 /// couterparts.
@@ -2026,12 +1991,6 @@ void mlir::populateVectorToLLVMConversionPatterns(
                VectorScalableStepOpLowering>(converter);
 }
 
-void mlir::populateVectorToLLVMMatrixConversionPatterns(
-    const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
-  patterns.add<VectorMatmulOpConversion>(converter);
-  patterns.add<VectorFlatTransposeOpConversion>(converter);
-}
-
 namespace {
 struct VectorToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
   using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 293e01a5bf4d4..dcc5ded02341f 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -97,11 +97,9 @@ void ConvertVectorToLLVMPass::runOnOperation() {
   LLVMTypeConverter converter(&getContext(), options);
   RewritePatternSet patterns(&getContext());
   populateVectorTransferLoweringPatterns(patterns);
-  populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
   populateVectorToLLVMConversionPatterns(
       converter, patterns, reassociateFPReductions, force32BitVectorIndices,
       useVectorAlignment);
-  populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
 
   // Architecture specific augmentations.
   LLVMConversionTarget target(getContext());
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
index 62022bfb7df1e..f14264e2f55f3 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -118,9 +118,9 @@ void mlir::arith::populateEmulateUnsupportedFloatsLegality(
         return converter.isLegal(op);
       });
   // Manually mark arithmetic-performing vector instructions.
-  target.addDynamicallyLegalOp<
-      vector::ContractionOp, vector::ReductionOp, vector::MultiDimReductionOp,
-      vector::FMAOp, vector::OuterProductOp, vector::MatmulOp, vector::ScanOp>(
+  target.addDynamicallyLegalOp<vector::ContractionOp, vector::ReductionOp,
+                               vector::MultiDimReductionOp, vector::FMAOp,
+                               vector::OuterProductOp, vector::ScanOp>(
       [&](Operation *op) { return converter.isLegal(op); });
   target.addLegalOp<arith::BitcastOp, arith::ExtFOp, arith::TruncFOp,
                     arith::ConstantOp, vector::SplatOp>();
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index 8ca5cb6c6dfab..08aef70fc4d8a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -49,6 +49,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
   MLIRTensorDialect
   MLIRTransforms
   MLIRVectorDialect
+  MLIRLLVMDialect
   MLIRVectorInterfaces
   MLIRVectorUtils
   )
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index c6627b5ec0d77..0e8c60f6a9a6e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
@@ -1280,12 +1281,11 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
 ///    %mtb = maybe_transpose
 ///    %flattened_a = vector.shape_cast %mta
 ///    %flattened_b = vector.shape_cast %mtb
-///    %flattened_d = vector.matrix_multiply %flattened_a, %flattened_b
+///    %flattened_d = llvm.intr.matrix.multiply %flattened_a, %flattened_b
 ///    %mtd = vector.shape_cast %flattened_d
 ///    %d = maybe_untranspose %mtd
 ///    %e = add %c, %d
 /// ```
-/// `vector.matrix_multiply` later lowers to `llvm.matrix.multiply`.
 //
 /// This only kicks in when vectorContractLowering is set to `Matmul`.
 /// vector.transpose operations are inserted if the vector.contract op is not a
@@ -1362,8 +1362,12 @@ FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
       VectorType::get(rhsType.getNumElements(), rhsType.getElementType());
   rhs = rew.create<vector::ShapeCastOp>(loc, flattenedRHSType, rhs);
 
-  Value mul = rew.create<vector::MatmulOp>(loc, lhs, rhs, lhsRows, lhsColumns,
-                                           rhsColumns);
+  Value mul = rew.create<LLVM::MatrixMultiplyOp>(
+      loc,
+      VectorType::get(lhsRows * rhsColumns,
+                      cast<VectorType>(lhs.getType()).getElementType()),
+      lhs, rhs, lhsRows, lhsColumns, rhsColumns);
+
   mul = rew.create<vector::ShapeCastOp>(
       loc,
       VectorType::get({lhsRows, rhsColumns},
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index 732e316c93381..05fb613393584 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -12,6 +12,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
@@ -338,7 +339,7 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
           rewriter.create<vector::ShapeCastOp>(loc, flattenedType, input);
       auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]);
       auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]);
-      Value trans = rewriter.create<vector::FlatTransposeOp>(
+      Value trans = rewriter.create<LLVM::MatrixTransposeOp>(
           loc, flattenedType, matrix, rows, columns);
       rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, trans);
       return success();
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 64e51f5554628..72810b5dddaa3 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1424,36 +1424,6 @@ func.func @fma_scalable(%vec_1d: vector<[8]xf32>, %vec_2d: vector<2x[4]xf32>, %v
 
   return %0, %1, %2: vector<[8]xf32>, vector<2x[4]xf32>, vector<1x1x[1]xf32>
 }
-// -----
-
-//===----------------------------------------------------------------------===//
-// vector.matrix_multiply
-//===----------------------------------------------------------------------===//
-
-//                          4x16                16x3               4x3
-func.func @matrix_ops(%A: vector<64xf64>, %B: vector<48xf64>) -> vector<12xf64> {
-  %C = vector.matrix_multiply %A, %B
-    { lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_columns = 3: i32 } :
-    (vector<64xf64>, vector<48xf64>) -> vector<12xf64>
-  return %C: vector<12xf64>
-}
-// CHECK-LABEL: @matrix_ops
-//       CHECK:   llvm.intr.matrix.multiply %{{.*}}, %{{.*}} {
-//  CHECK-SAME: lhs_columns = 16 : i32, lhs_rows = 4 : i32, rhs_columns = 3 : i32
-//  CHECK-SAME: } : (vector<64xf64>, vector<48xf64>) -> vector<12xf64>
-
-// -----
-
-func.func @matrix_ops_index(%A: vector<64xindex>, %B: vector<48xindex>) -> vector<12xindex> {
-  %C = vector.matrix_multiply %A, %B
-    { lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_columns = 3: i32 } :
-    (vector<64xindex>, vector<48xindex>) -> vector<12xindex>
-  return %C: vector<12xindex>
-}
-// CHECK-LABEL: @matrix_ops_index
-//       CHECK:   llvm.intr.matrix.multiply %{{.*}}, %{{.*}} {
-//  CHECK-SAME: lhs_columns = 16 : i32, lhs_rows = 4 : i32, rhs_columns = 3 : i32
-//  CHECK-SAME: } : (vector<64xi64>, vector<48xi64>) -> vector<12xi64>
 
 // -----
 
@@ -1602,56 +1572,6 @@ func.func @create_mask_1d_scalable(%num_elems : index) -> vector<[4]xi1> {
 
 // -----
 
-//===----------------------------------------------------------------------===//
-// vector.flat_transpose
-//===----------------------------------------------------------------------===//
-
-func.func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> {
-  %0 = vector.flat_transpose %arg0 { rows = 4: i32, columns = 4: i32 }
-     : vector<16xf32> -> vector<16xf32>
-  return %0 : vector<16xf32>
-}
-
-// CHECK-LABEL: func @flat_transpose
-// CHECK-SAME:  %[[A:.*]]: vector<16xf32>
-// CHECK:       %[[T:.*]] = llvm.intr.matrix.transpose %[[A]]
-// CHECK-SAME:      {columns = 4 : i32, rows = 4 : i32} :
-// CHECK-SAME:      vector<16xf32> into vector<16xf32>
-// CHECK:       return %[[T]] : vector<16xf32>
-
-// -----
-
-func.func @flat_transpose_index(%arg0: vector<16xindex>) -> vector<16xindex> {
-  %0 = vector.flat_transpose %arg0 { rows = 4: i32, columns = 4: i32 }
-     : vector<16xindex> -> vector<16xindex>
-  return %0 : vector<16xindex>
-}
-// CHECK-LABEL: func @flat_transpose_index
-// CHECK-SAME:  %[[A:.*]]: vector<16xindex>
-// CHECK:       %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<16xindex> to vector<16xi64>
-// CHECK:       %[[T1:.*]] = llvm.intr.matrix.transpose %[[T0]]
-// CHECK-SAME:      {columns = 4 : i32, rows = 4 : i32} :
-// CHECK-SAME:      vector<16xi64> into vector<16xi64>
-// CHECK:       %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<16xi64> to vector<16xindex>
-// CHECK:       return %[[T2]] : vector<16xindex>
-
-// -----
-
-func.func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> {
-  %0 = vector.flat_transpose %arg0 { rows = 4: i32, columns = 4: i32 }
-     : vector<16xf32> -> vector<16xf32>
-  return %0 : vector<16xf32>
-}
-
-// CHECK-LABEL: func @flat_transpose
-// CHECK-SAME:  %[[A:.*]]: vector<16xf32>
-// CHECK:       %[[T:.*]] = llvm.intr.matrix.transpose %[[A]]
-// CHECK-SAME:      {columns = 4 : i32, rows = 4 : i32} :
-// CHECK-SAME:      vector<16xf32> into vector<16xf32>
-// CHECK:       return %[[T]] : vector<16xf32>
-
-// -----
-
 //===----------------------------------------------------------------------===//
 // vector.gather
 //
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 04810ed52584f..a2ede475a1478 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1328,13 +1328,6 @@ func.func @transpose_dim_size_mismatch(%arg0: vector<11x7x3x2xi32>) {
 
 // -----
 
-func.func @flat_transpose_type_mismatch(%arg0: vector<16xf32>) {
-  // expected-error@+1 {{'vector.flat_transpose' op failed to verify that source operand and result have same element type}}
-  %0 = vector.flat_transpose %arg0 { rows = 4: i32, columns = 4: i32 } : vector<16xf32> -> vector<16xf64>
-}
-
-// -----
-
 func.func @type_cast_layout(%arg0: memref<4x3xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s0 + d1 * s1 + s2)>>) {
   // expected-error@+1 {{expects opera...
[truncated]

@Groverkss
Copy link
Member

LGTM on removing the ops, but I don't know if we can introduce a llvm dialect dependency like this. Someone going down the Vector -> SPIRV path might get suprised by this. I don't like churn too, but i think we need to move this down to conversions.

Maybe we can split this into two patches, one that changes to conversion directly to llvm intrinsic and removes the old transformation and then the op removal?

I'm ok with any way, as long as we also ensure we don't suprise anyone trying to use a backend other than LLVM.

@Groverkss Groverkss requested a review from antiagainst June 16, 2025 11:09
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm +1 for the removal, but emitting llvm dialect directly in vector transforms seems like a layering violation to me

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants