Skip to content

Commit 8e333e3

Browse files
authored
[mlir] Expose linearize/delinearize lowering transforms (#144156)
Moves the transformation logic from the AffineLinearizeOp and AffineDelinearizeOp lowerings into separate transform functions that can now be called separately. This provides a more controlled way to apply the op lowerings. --------- Signed-off-by: Max Dawkins <[email protected]>
1 parent 595a273 commit 8e333e3

File tree

2 files changed

+125
-107
lines changed

2 files changed

+125
-107
lines changed

mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,20 @@ enum class BoundType;
3232

3333
namespace affine {
3434
class AffineApplyOp;
35+
class AffineDelinearizeIndexOp;
36+
class AffineLinearizeIndexOp;
37+
38+
/// Lowers `affine.delinearize_index` into a sequence of division and remainder
39+
/// operations.
40+
LogicalResult lowerAffineDelinearizeIndexOp(RewriterBase &rewriter,
41+
AffineDelinearizeIndexOp op);
42+
43+
/// Lowers `affine.linearize_index` into a sequence of multiplications and
44+
/// additions. Make a best effort to sort the input indices so that
45+
/// the most loop-invariant terms are at the left of the additions
46+
/// to enable loop-invariant code motion.
47+
LogicalResult lowerAffineLinearizeIndexOp(RewriterBase &rewriter,
48+
AffineLinearizeIndexOp op);
3549

3650
/// Populate patterns that expand affine index operations into more fundamental
3751
/// operations (not necessarily restricted to Affine dialect).

mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp

Lines changed: 111 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -84,126 +84,130 @@ static SmallVector<Value> computeStrides(Location loc, RewriterBase &rewriter,
8484
return result;
8585
}
8686

87+
LogicalResult
88+
affine::lowerAffineDelinearizeIndexOp(RewriterBase &rewriter,
89+
AffineDelinearizeIndexOp op) {
90+
Location loc = op.getLoc();
91+
Value linearIdx = op.getLinearIndex();
92+
unsigned numResults = op.getNumResults();
93+
ArrayRef<int64_t> staticBasis = op.getStaticBasis();
94+
if (numResults == staticBasis.size())
95+
staticBasis = staticBasis.drop_front();
96+
97+
if (numResults == 1) {
98+
rewriter.replaceOp(op, linearIdx);
99+
return success();
100+
}
101+
102+
SmallVector<Value> results;
103+
results.reserve(numResults);
104+
SmallVector<Value> strides =
105+
computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis,
106+
/*knownNonNegative=*/true);
107+
108+
Value zero = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
109+
110+
Value initialPart =
111+
rewriter.create<arith::FloorDivSIOp>(loc, linearIdx, strides.front());
112+
results.push_back(initialPart);
113+
114+
auto emitModTerm = [&](Value stride) -> Value {
115+
Value remainder = rewriter.create<arith::RemSIOp>(loc, linearIdx, stride);
116+
Value remainderNegative = rewriter.create<arith::CmpIOp>(
117+
loc, arith::CmpIPredicate::slt, remainder, zero);
118+
// If the correction is relevant, this term is <= stride, which is known
119+
// to be positive in `index`. Otherwise, while 2 * stride might overflow,
120+
// this branch won't be taken, so the risk of `poison` is fine.
121+
Value corrected = rewriter.create<arith::AddIOp>(
122+
loc, remainder, stride, arith::IntegerOverflowFlags::nsw);
123+
Value mod = rewriter.create<arith::SelectOp>(loc, remainderNegative,
124+
corrected, remainder);
125+
return mod;
126+
};
127+
128+
// Generate all the intermediate parts
129+
for (size_t i = 0, e = strides.size() - 1; i < e; ++i) {
130+
Value thisStride = strides[i];
131+
Value nextStride = strides[i + 1];
132+
Value modulus = emitModTerm(thisStride);
133+
// We know both inputs are positive, so floorDiv == div.
134+
// This could potentially be a divui, but it's not clear if that would
135+
// cause issues.
136+
Value divided = rewriter.create<arith::DivSIOp>(loc, modulus, nextStride);
137+
results.push_back(divided);
138+
}
139+
140+
results.push_back(emitModTerm(strides.back()));
141+
142+
rewriter.replaceOp(op, results);
143+
return success();
144+
}
145+
146+
LogicalResult affine::lowerAffineLinearizeIndexOp(RewriterBase &rewriter,
147+
AffineLinearizeIndexOp op) {
148+
// Should be folded away, included here for safety.
149+
if (op.getMultiIndex().empty()) {
150+
rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0);
151+
return success();
152+
}
153+
154+
Location loc = op.getLoc();
155+
ValueRange multiIndex = op.getMultiIndex();
156+
size_t numIndexes = multiIndex.size();
157+
ArrayRef<int64_t> staticBasis = op.getStaticBasis();
158+
if (numIndexes == staticBasis.size())
159+
staticBasis = staticBasis.drop_front();
160+
161+
SmallVector<Value> strides =
162+
computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis,
163+
/*knownNonNegative=*/op.getDisjoint());
164+
SmallVector<std::pair<Value, int64_t>> scaledValues;
165+
scaledValues.reserve(numIndexes);
166+
167+
// Note: strides doesn't contain a value for the final element (stride 1)
168+
// and everything else lines up. We use the "mutable" accessor so we can get
169+
// our hands on an `OpOperand&` for the loop invariant counting function.
170+
for (auto [stride, idxOp] :
171+
llvm::zip_equal(strides, llvm::drop_end(op.getMultiIndexMutable()))) {
172+
Value scaledIdx = rewriter.create<arith::MulIOp>(
173+
loc, idxOp.get(), stride, arith::IntegerOverflowFlags::nsw);
174+
int64_t numHoistableLoops = numEnclosingInvariantLoops(idxOp);
175+
scaledValues.emplace_back(scaledIdx, numHoistableLoops);
176+
}
177+
scaledValues.emplace_back(
178+
multiIndex.back(),
179+
numEnclosingInvariantLoops(op.getMultiIndexMutable()[numIndexes - 1]));
180+
181+
// Sort by how many enclosing loops there are, ties implicitly broken by
182+
// size of the stride.
183+
llvm::stable_sort(scaledValues,
184+
[&](auto l, auto r) { return l.second > r.second; });
185+
186+
Value result = scaledValues.front().first;
187+
for (auto [scaledValue, numHoistableLoops] : llvm::drop_begin(scaledValues)) {
188+
std::ignore = numHoistableLoops;
189+
result = rewriter.create<arith::AddIOp>(loc, result, scaledValue,
190+
arith::IntegerOverflowFlags::nsw);
191+
}
192+
rewriter.replaceOp(op, result);
193+
return success();
194+
}
195+
87196
namespace {
88-
/// Lowers `affine.delinearize_index` into a sequence of division and remainder
89-
/// operations.
90197
struct LowerDelinearizeIndexOps
91198
: public OpRewritePattern<AffineDelinearizeIndexOp> {
92199
using OpRewritePattern<AffineDelinearizeIndexOp>::OpRewritePattern;
93200
LogicalResult matchAndRewrite(AffineDelinearizeIndexOp op,
94201
PatternRewriter &rewriter) const override {
95-
Location loc = op.getLoc();
96-
Value linearIdx = op.getLinearIndex();
97-
unsigned numResults = op.getNumResults();
98-
ArrayRef<int64_t> staticBasis = op.getStaticBasis();
99-
if (numResults == staticBasis.size())
100-
staticBasis = staticBasis.drop_front();
101-
102-
if (numResults == 1) {
103-
rewriter.replaceOp(op, linearIdx);
104-
return success();
105-
}
106-
107-
SmallVector<Value> results;
108-
results.reserve(numResults);
109-
SmallVector<Value> strides =
110-
computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis,
111-
/*knownNonNegative=*/true);
112-
113-
Value zero = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
114-
115-
Value initialPart =
116-
rewriter.create<arith::FloorDivSIOp>(loc, linearIdx, strides.front());
117-
results.push_back(initialPart);
118-
119-
auto emitModTerm = [&](Value stride) -> Value {
120-
Value remainder = rewriter.create<arith::RemSIOp>(loc, linearIdx, stride);
121-
Value remainderNegative = rewriter.create<arith::CmpIOp>(
122-
loc, arith::CmpIPredicate::slt, remainder, zero);
123-
// If the correction is relevant, this term is <= stride, which is known
124-
// to be positive in `index`. Otherwise, while 2 * stride might overflow,
125-
// this branch won't be taken, so the risk of `poison` is fine.
126-
Value corrected = rewriter.create<arith::AddIOp>(
127-
loc, remainder, stride, arith::IntegerOverflowFlags::nsw);
128-
Value mod = rewriter.create<arith::SelectOp>(loc, remainderNegative,
129-
corrected, remainder);
130-
return mod;
131-
};
132-
133-
// Generate all the intermediate parts
134-
for (size_t i = 0, e = strides.size() - 1; i < e; ++i) {
135-
Value thisStride = strides[i];
136-
Value nextStride = strides[i + 1];
137-
Value modulus = emitModTerm(thisStride);
138-
// We know both inputs are positive, so floorDiv == div.
139-
// This could potentially be a divui, but it's not clear if that would
140-
// cause issues.
141-
Value divided = rewriter.create<arith::DivSIOp>(loc, modulus, nextStride);
142-
results.push_back(divided);
143-
}
144-
145-
results.push_back(emitModTerm(strides.back()));
146-
147-
rewriter.replaceOp(op, results);
148-
return success();
202+
return affine::lowerAffineDelinearizeIndexOp(rewriter, op);
149203
}
150204
};
151205

152-
/// Lowers `affine.linearize_index` into a sequence of multiplications and
153-
/// additions. Make a best effort to sort the input indices so that
154-
/// the most loop-invariant terms are at the left of the additions
155-
/// to enable loop-invariant code motion.
156206
struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
157207
using OpRewritePattern::OpRewritePattern;
158208
LogicalResult matchAndRewrite(AffineLinearizeIndexOp op,
159209
PatternRewriter &rewriter) const override {
160-
// Should be folded away, included here for safety.
161-
if (op.getMultiIndex().empty()) {
162-
rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0);
163-
return success();
164-
}
165-
166-
Location loc = op.getLoc();
167-
ValueRange multiIndex = op.getMultiIndex();
168-
size_t numIndexes = multiIndex.size();
169-
ArrayRef<int64_t> staticBasis = op.getStaticBasis();
170-
if (numIndexes == staticBasis.size())
171-
staticBasis = staticBasis.drop_front();
172-
173-
SmallVector<Value> strides =
174-
computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis,
175-
/*knownNonNegative=*/op.getDisjoint());
176-
SmallVector<std::pair<Value, int64_t>> scaledValues;
177-
scaledValues.reserve(numIndexes);
178-
179-
// Note: strides doesn't contain a value for the final element (stride 1)
180-
// and everything else lines up. We use the "mutable" accessor so we can get
181-
// our hands on an `OpOperand&` for the loop invariant counting function.
182-
for (auto [stride, idxOp] :
183-
llvm::zip_equal(strides, llvm::drop_end(op.getMultiIndexMutable()))) {
184-
Value scaledIdx = rewriter.create<arith::MulIOp>(
185-
loc, idxOp.get(), stride, arith::IntegerOverflowFlags::nsw);
186-
int64_t numHoistableLoops = numEnclosingInvariantLoops(idxOp);
187-
scaledValues.emplace_back(scaledIdx, numHoistableLoops);
188-
}
189-
scaledValues.emplace_back(
190-
multiIndex.back(),
191-
numEnclosingInvariantLoops(op.getMultiIndexMutable()[numIndexes - 1]));
192-
193-
// Sort by how many enclosing loops there are, ties implicitly broken by
194-
// size of the stride.
195-
llvm::stable_sort(scaledValues,
196-
[&](auto l, auto r) { return l.second > r.second; });
197-
198-
Value result = scaledValues.front().first;
199-
for (auto [scaledValue, numHoistableLoops] :
200-
llvm::drop_begin(scaledValues)) {
201-
std::ignore = numHoistableLoops;
202-
result = rewriter.create<arith::AddIOp>(loc, result, scaledValue,
203-
arith::IntegerOverflowFlags::nsw);
204-
}
205-
rewriter.replaceOp(op, result);
206-
return success();
210+
return affine::lowerAffineLinearizeIndexOp(rewriter, op);
207211
}
208212
};
209213

0 commit comments

Comments
 (0)