Skip to content

Commit 6d45284

Browse files
author
Stephan Herhut
committed
[mlir][memref] Add better support for identity layouts in memref.collapse_shape canonicalizer
When computing the new type of a collapse_shape operation, we need to at least take into account whether the type has an identity layout, in which case we can easily support dynamic strides. Otherwise, the canonicalizer creates invalid IR. Longer term, both the verifier and the canoncializer need to be extended to support the general case. Differential Revision: https://reviews.llvm.org/D117772
1 parent c95cb4d commit 6d45284

File tree

2 files changed

+29
-4
lines changed

2 files changed

+29
-4
lines changed

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1334,6 +1334,7 @@ computeReshapeCollapsedType(MemRefType type,
13341334
AffineExpr offset;
13351335
SmallVector<AffineExpr, 4> strides;
13361336
auto status = getStridesAndOffset(type, strides, offset);
1337+
auto isIdentityLayout = type.getLayout().isIdentity();
13371338
(void)status;
13381339
assert(succeeded(status) && "expected strided memref");
13391340

@@ -1350,12 +1351,19 @@ computeReshapeCollapsedType(MemRefType type,
13501351
unsigned dim = m.getNumResults();
13511352
int64_t size = 1;
13521353
AffineExpr stride = strides[currentDim + dim - 1];
1353-
if (!isReshapableDimBand(currentDim, dim, sizes, strides)) {
1354+
if (isIdentityLayout ||
1355+
isReshapableDimBand(currentDim, dim, sizes, strides)) {
1356+
for (unsigned d = 0; d < dim; ++d) {
1357+
int64_t currentSize = sizes[currentDim + d];
1358+
if (ShapedType::isDynamic(currentSize)) {
1359+
size = ShapedType::kDynamicSize;
1360+
break;
1361+
}
1362+
size *= currentSize;
1363+
}
1364+
} else {
13541365
size = ShapedType::kDynamicSize;
13551366
stride = AffineExpr();
1356-
} else {
1357-
for (unsigned d = 0; d < dim; ++d)
1358-
size *= sizes[currentDim + d];
13591367
}
13601368
newSizes.push_back(size);
13611369
newStrides.push_back(stride);

mlir/test/Dialect/MemRef/canonicalize.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,8 @@ func @collapse_after_memref_cast_type_change(%arg0 : memref<?x512x1x1xf32>) -> m
406406
return %collapsed : memref<?x?xf32>
407407
}
408408

409+
// -----
410+
409411
// CHECK-LABEL: func @collapse_after_memref_cast(
410412
// CHECK-SAME: %[[INPUT:.*]]: memref<?x512x1x?xf32>) -> memref<?x?xf32> {
411413
// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[INPUT]]
@@ -419,6 +421,21 @@ func @collapse_after_memref_cast(%arg0 : memref<?x512x1x?xf32>) -> memref<?x?xf3
419421

420422
// -----
421423

424+
// CHECK-LABEL: func @collapse_after_memref_cast_type_change_dynamic(
425+
// CHECK-SAME: %[[INPUT:.*]]: memref<1x1x1x?xi64>) -> memref<?x?xi64> {
426+
// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[INPUT]]
427+
// CHECK_SAME: {{\[\[}}0, 1, 2], [3]] : memref<1x1x1x?xi64> into memref<1x?xi64>
428+
// CHECK: %[[DYNAMIC:.*]] = memref.cast %[[COLLAPSED]] :
429+
// CHECK-SAME: memref<1x?xi64> to memref<?x?xi64>
430+
// CHECK: return %[[DYNAMIC]] : memref<?x?xi64>
431+
func @collapse_after_memref_cast_type_change_dynamic(%arg0: memref<1x1x1x?xi64>) -> memref<?x?xi64> {
432+
%casted = memref.cast %arg0 : memref<1x1x1x?xi64> to memref<1x1x?x?xi64>
433+
%collapsed = memref.collapse_shape %casted [[0, 1, 2], [3]] : memref<1x1x?x?xi64> into memref<?x?xi64>
434+
return %collapsed : memref<?x?xi64>
435+
}
436+
437+
// -----
438+
422439
func @reduced_memref(%arg0: memref<2x5x7x1xf32>, %arg1 :index)
423440
-> memref<1x4x1xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 35 + s0 + d1 * 7 + d2)>> {
424441
%c0 = arith.constant 0 : index

0 commit comments

Comments
 (0)