Skip to content

Commit ab1d10b

Browse files
authored
[CIR][ThroughMLIR] remove nested memref wrapper for array types (#1412)
for example, lower `cir.alloca !cir.array<!s32i x N>, !cir.ptr<!cir.array<!s32i x N>>` to `memref.alloca() : memref<Nxi32>` see #1405
1 parent 8746bd4 commit ab1d10b

File tree

3 files changed

+40
-2
lines changed

3 files changed

+40
-2
lines changed

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp

+8-1
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,14 @@ class CIRAllocaOpLowering : public mlir::OpConversionPattern<cir::AllocaOp> {
173173
if (!mlirType)
174174
return mlir::LogicalResult::failure();
175175

176-
auto memreftype = mlir::MemRefType::get({}, mlirType);
176+
auto memreftype = mlir::dyn_cast<mlir::MemRefType>(mlirType);
177+
if (memreftype && mlir::isa<cir::ArrayType>(adaptor.getAllocaType())) {
178+
// if the type is an array,
179+
// we don't need to wrap with memref.
180+
} else {
181+
memreftype = mlir::MemRefType::get({}, mlirType);
182+
}
183+
177184
rewriter.replaceOpWithNewOp<mlir::memref::AllocaOp>(op, memreftype,
178185
op.getAlignmentAttr());
179186
return mlir::LogicalResult::success();
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir %s -o %t.cir
2+
// RUN: FileCheck --input-file=%t.cir %s --check-prefix=CIR
3+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -fno-clangir-direct-lowering -emit-mlir %s -o %t.mlir
4+
// RUN: FileCheck --input-file=%t.mlir %s --check-prefix=MLIR
5+
6+
int test_array1() {
7+
// CIR-LABEL: cir.func {{.*}} @test_array1
8+
// CIR: %[[ARRAY:.*]] = cir.alloca !cir.array<!s32i x 3>, !cir.ptr<!cir.array<!s32i x 3>>, ["a"] {alignment = 4 : i64}
9+
// CIR: %{{.*}} = cir.cast(array_to_ptrdecay, %[[ARRAY]] : !cir.ptr<!cir.array<!s32i x 3>>), !cir.ptr<!s32i>
10+
11+
// MLIR-LABEL: func @test_array1
12+
// MLIR: %{{.*}} = memref.alloca() {alignment = 4 : i64} : memref<i32>
13+
// MLIR: %[[ARRAY:.*]] = memref.alloca() {alignment = 4 : i64} : memref<3xi32>
14+
// MLIR: %{{.*}} = memref.load %[[ARRAY]][%{{.*}}] : memref<3xi32>
15+
int a[3];
16+
return a[1];
17+
}
18+
19+
int test_array2() {
20+
// CIR-LABEL: cir.func {{.*}} @test_array2
21+
// CIR: %[[ARRAY:.*]] = cir.alloca !cir.array<!cir.array<!s32i x 4> x 3>, !cir.ptr<!cir.array<!cir.array<!s32i x 4> x 3>>, ["a"] {alignment = 16 : i64}
22+
// CIR: %{{.*}} = cir.cast(array_to_ptrdecay, %[[ARRAY]] : !cir.ptr<!cir.array<!cir.array<!s32i x 4> x 3>>), !cir.ptr<!cir.array<!s32i x 4>>
23+
// CIR: %{{.*}} = cir.cast(array_to_ptrdecay, %{{.*}} : !cir.ptr<!cir.array<!s32i x 4>>), !cir.ptr<!s32i>
24+
25+
// MLIR-LABEL: func @test_array2
26+
// MLIR: %{{.*}} = memref.alloca() {alignment = 4 : i64} : memref<i32>
27+
// MLIR: %[[ARRAY:.*]] = memref.alloca() {alignment = 16 : i64} : memref<3x4xi32>
28+
// MLIR: %{{.*}} = memref.load %[[ARRAY]][%{{.*}}, %{{.*}}] : memref<3x4xi32>
29+
int a[3][4];
30+
return a[1][2];
31+
}

clang/test/CIR/Lowering/ThroughMLIR/array.cir

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ module {
1111

1212
// CHECK: module {
1313
// CHECK: func @foo() {
14-
// CHECK: = memref.alloca() {alignment = 16 : i64} : memref<memref<10xi32>>
14+
// CHECK: = memref.alloca() {alignment = 16 : i64} : memref<10xi32>
1515
// CHECK: return
1616
// CHECK: }
1717
// CHECK: }

0 commit comments

Comments
 (0)