From 629385ab331c183772ff16c47931a587e9332498 Mon Sep 17 00:00:00 2001 From: PikachuHy Date: Wed, 26 Feb 2025 15:42:22 +0800 Subject: [PATCH] [CIR][LowerToLLVM] optimize MemRefType handling for array types for example, `cir.alloca !cir.array, !cir.ptr>` to `memref.alloca() : memref` see https://github.com/llvm/clangir/issues/1405 --- .../Lowering/ThroughMLIR/LowerCIRToMLIR.cpp | 9 +++++- clang/test/CIR/Lowering/ThroughMLIR/array.c | 31 +++++++++++++++++++ clang/test/CIR/Lowering/ThroughMLIR/array.cir | 2 +- 3 files changed, 40 insertions(+), 2 deletions(-) create mode 100644 clang/test/CIR/Lowering/ThroughMLIR/array.c diff --git a/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp b/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp index 3d88820f5033..0a5b4712df6b 100644 --- a/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp +++ b/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp @@ -173,7 +173,14 @@ class CIRAllocaOpLowering : public mlir::OpConversionPattern { if (!mlirType) return mlir::LogicalResult::failure(); - auto memreftype = mlir::MemRefType::get({}, mlirType); + auto memreftype = mlir::dyn_cast(mlirType); + if (memreftype && mlir::isa(adaptor.getAllocaType())) { + // if the type is an array, + // we don't need to wrap with memref. + } else { + memreftype = mlir::MemRefType::get({}, mlirType); + } + rewriter.replaceOpWithNewOp(op, memreftype, op.getAlignmentAttr()); return mlir::LogicalResult::success(); diff --git a/clang/test/CIR/Lowering/ThroughMLIR/array.c b/clang/test/CIR/Lowering/ThroughMLIR/array.c new file mode 100644 index 000000000000..9f57fb9e9f52 --- /dev/null +++ b/clang/test/CIR/Lowering/ThroughMLIR/array.c @@ -0,0 +1,31 @@ +// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir %s -o %t.cir +// RUN: FileCheck --input-file=%t.cir %s --check-prefix=CIR +// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -fno-clangir-direct-lowering -emit-mlir %s -o %t.mlir +// RUN: FileCheck --input-file=%t.mlir %s --check-prefix=MLIR + +int test_array1() { + // CIR-LABEL: cir.func {{.*}} @test_array1 + // CIR: %[[ARRAY:.*]] = cir.alloca !cir.array, !cir.ptr>, ["a"] {alignment = 4 : i64} + // CIR: %{{.*}} = cir.cast(array_to_ptrdecay, %[[ARRAY]] : !cir.ptr>), !cir.ptr + + // MLIR-LABEL: func @test_array1 + // MLIR: %{{.*}} = memref.alloca() {alignment = 4 : i64} : memref + // MLIR: %[[ARRAY:.*]] = memref.alloca() {alignment = 4 : i64} : memref<3xi32> + // MLIR: %{{.*}} = memref.load %[[ARRAY]][%{{.*}}] : memref<3xi32> + int a[3]; + return a[1]; +} + +int test_array2() { + // CIR-LABEL: cir.func {{.*}} @test_array2 + // CIR: %[[ARRAY:.*]] = cir.alloca !cir.array x 3>, !cir.ptr x 3>>, ["a"] {alignment = 16 : i64} + // CIR: %{{.*}} = cir.cast(array_to_ptrdecay, %[[ARRAY]] : !cir.ptr x 3>>), !cir.ptr> + // CIR: %{{.*}} = cir.cast(array_to_ptrdecay, %{{.*}} : !cir.ptr>), !cir.ptr + + // MLIR-LABEL: func @test_array2 + // MLIR: %{{.*}} = memref.alloca() {alignment = 4 : i64} : memref + // MLIR: %[[ARRAY:.*]] = memref.alloca() {alignment = 16 : i64} : memref<3x4xi32> + // MLIR: %{{.*}} = memref.load %[[ARRAY]][%{{.*}}, %{{.*}}] : memref<3x4xi32> + int a[3][4]; + return a[1][2]; +} diff --git a/clang/test/CIR/Lowering/ThroughMLIR/array.cir b/clang/test/CIR/Lowering/ThroughMLIR/array.cir index dc1eb97c80b3..cf22a2a41579 100644 --- a/clang/test/CIR/Lowering/ThroughMLIR/array.cir +++ b/clang/test/CIR/Lowering/ThroughMLIR/array.cir @@ -11,7 +11,7 @@ module { // CHECK: module { // CHECK: func @foo() { -// CHECK: = memref.alloca() {alignment = 16 : i64} : memref> +// CHECK: = memref.alloca() {alignment = 16 : i64} : memref<10xi32> // CHECK: return // CHECK: } // CHECK: }