@@ -84,126 +84,130 @@ static SmallVector<Value> computeStrides(Location loc, RewriterBase &rewriter,
84
84
return result;
85
85
}
86
86
87
+ LogicalResult
88
+ affine::lowerAffineDelinearizeIndexOp (RewriterBase &rewriter,
89
+ AffineDelinearizeIndexOp op) {
90
+ Location loc = op.getLoc ();
91
+ Value linearIdx = op.getLinearIndex ();
92
+ unsigned numResults = op.getNumResults ();
93
+ ArrayRef<int64_t > staticBasis = op.getStaticBasis ();
94
+ if (numResults == staticBasis.size ())
95
+ staticBasis = staticBasis.drop_front ();
96
+
97
+ if (numResults == 1 ) {
98
+ rewriter.replaceOp (op, linearIdx);
99
+ return success ();
100
+ }
101
+
102
+ SmallVector<Value> results;
103
+ results.reserve (numResults);
104
+ SmallVector<Value> strides =
105
+ computeStrides (loc, rewriter, op.getDynamicBasis (), staticBasis,
106
+ /* knownNonNegative=*/ true );
107
+
108
+ Value zero = rewriter.createOrFold <arith::ConstantIndexOp>(loc, 0 );
109
+
110
+ Value initialPart =
111
+ rewriter.create <arith::FloorDivSIOp>(loc, linearIdx, strides.front ());
112
+ results.push_back (initialPart);
113
+
114
+ auto emitModTerm = [&](Value stride) -> Value {
115
+ Value remainder = rewriter.create <arith::RemSIOp>(loc, linearIdx, stride);
116
+ Value remainderNegative = rewriter.create <arith::CmpIOp>(
117
+ loc, arith::CmpIPredicate::slt, remainder, zero);
118
+ // If the correction is relevant, this term is <= stride, which is known
119
+ // to be positive in `index`. Otherwise, while 2 * stride might overflow,
120
+ // this branch won't be taken, so the risk of `poison` is fine.
121
+ Value corrected = rewriter.create <arith::AddIOp>(
122
+ loc, remainder, stride, arith::IntegerOverflowFlags::nsw);
123
+ Value mod = rewriter.create <arith::SelectOp>(loc, remainderNegative,
124
+ corrected, remainder);
125
+ return mod;
126
+ };
127
+
128
+ // Generate all the intermediate parts
129
+ for (size_t i = 0 , e = strides.size () - 1 ; i < e; ++i) {
130
+ Value thisStride = strides[i];
131
+ Value nextStride = strides[i + 1 ];
132
+ Value modulus = emitModTerm (thisStride);
133
+ // We know both inputs are positive, so floorDiv == div.
134
+ // This could potentially be a divui, but it's not clear if that would
135
+ // cause issues.
136
+ Value divided = rewriter.create <arith::DivSIOp>(loc, modulus, nextStride);
137
+ results.push_back (divided);
138
+ }
139
+
140
+ results.push_back (emitModTerm (strides.back ()));
141
+
142
+ rewriter.replaceOp (op, results);
143
+ return success ();
144
+ }
145
+
146
+ LogicalResult affine::lowerAffineLinearizeIndexOp (RewriterBase &rewriter,
147
+ AffineLinearizeIndexOp op) {
148
+ // Should be folded away, included here for safety.
149
+ if (op.getMultiIndex ().empty ()) {
150
+ rewriter.replaceOpWithNewOp <arith::ConstantIndexOp>(op, 0 );
151
+ return success ();
152
+ }
153
+
154
+ Location loc = op.getLoc ();
155
+ ValueRange multiIndex = op.getMultiIndex ();
156
+ size_t numIndexes = multiIndex.size ();
157
+ ArrayRef<int64_t > staticBasis = op.getStaticBasis ();
158
+ if (numIndexes == staticBasis.size ())
159
+ staticBasis = staticBasis.drop_front ();
160
+
161
+ SmallVector<Value> strides =
162
+ computeStrides (loc, rewriter, op.getDynamicBasis (), staticBasis,
163
+ /* knownNonNegative=*/ op.getDisjoint ());
164
+ SmallVector<std::pair<Value, int64_t >> scaledValues;
165
+ scaledValues.reserve (numIndexes);
166
+
167
+ // Note: strides doesn't contain a value for the final element (stride 1)
168
+ // and everything else lines up. We use the "mutable" accessor so we can get
169
+ // our hands on an `OpOperand&` for the loop invariant counting function.
170
+ for (auto [stride, idxOp] :
171
+ llvm::zip_equal (strides, llvm::drop_end (op.getMultiIndexMutable ()))) {
172
+ Value scaledIdx = rewriter.create <arith::MulIOp>(
173
+ loc, idxOp.get (), stride, arith::IntegerOverflowFlags::nsw);
174
+ int64_t numHoistableLoops = numEnclosingInvariantLoops (idxOp);
175
+ scaledValues.emplace_back (scaledIdx, numHoistableLoops);
176
+ }
177
+ scaledValues.emplace_back (
178
+ multiIndex.back (),
179
+ numEnclosingInvariantLoops (op.getMultiIndexMutable ()[numIndexes - 1 ]));
180
+
181
+ // Sort by how many enclosing loops there are, ties implicitly broken by
182
+ // size of the stride.
183
+ llvm::stable_sort (scaledValues,
184
+ [&](auto l, auto r) { return l.second > r.second ; });
185
+
186
+ Value result = scaledValues.front ().first ;
187
+ for (auto [scaledValue, numHoistableLoops] : llvm::drop_begin (scaledValues)) {
188
+ std::ignore = numHoistableLoops;
189
+ result = rewriter.create <arith::AddIOp>(loc, result, scaledValue,
190
+ arith::IntegerOverflowFlags::nsw);
191
+ }
192
+ rewriter.replaceOp (op, result);
193
+ return success ();
194
+ }
195
+
87
196
namespace {
88
- // / Lowers `affine.delinearize_index` into a sequence of division and remainder
89
- // / operations.
90
197
struct LowerDelinearizeIndexOps
91
198
: public OpRewritePattern<AffineDelinearizeIndexOp> {
92
199
using OpRewritePattern<AffineDelinearizeIndexOp>::OpRewritePattern;
93
200
LogicalResult matchAndRewrite (AffineDelinearizeIndexOp op,
94
201
PatternRewriter &rewriter) const override {
95
- Location loc = op.getLoc ();
96
- Value linearIdx = op.getLinearIndex ();
97
- unsigned numResults = op.getNumResults ();
98
- ArrayRef<int64_t > staticBasis = op.getStaticBasis ();
99
- if (numResults == staticBasis.size ())
100
- staticBasis = staticBasis.drop_front ();
101
-
102
- if (numResults == 1 ) {
103
- rewriter.replaceOp (op, linearIdx);
104
- return success ();
105
- }
106
-
107
- SmallVector<Value> results;
108
- results.reserve (numResults);
109
- SmallVector<Value> strides =
110
- computeStrides (loc, rewriter, op.getDynamicBasis (), staticBasis,
111
- /* knownNonNegative=*/ true );
112
-
113
- Value zero = rewriter.createOrFold <arith::ConstantIndexOp>(loc, 0 );
114
-
115
- Value initialPart =
116
- rewriter.create <arith::FloorDivSIOp>(loc, linearIdx, strides.front ());
117
- results.push_back (initialPart);
118
-
119
- auto emitModTerm = [&](Value stride) -> Value {
120
- Value remainder = rewriter.create <arith::RemSIOp>(loc, linearIdx, stride);
121
- Value remainderNegative = rewriter.create <arith::CmpIOp>(
122
- loc, arith::CmpIPredicate::slt, remainder, zero);
123
- // If the correction is relevant, this term is <= stride, which is known
124
- // to be positive in `index`. Otherwise, while 2 * stride might overflow,
125
- // this branch won't be taken, so the risk of `poison` is fine.
126
- Value corrected = rewriter.create <arith::AddIOp>(
127
- loc, remainder, stride, arith::IntegerOverflowFlags::nsw);
128
- Value mod = rewriter.create <arith::SelectOp>(loc, remainderNegative,
129
- corrected, remainder);
130
- return mod;
131
- };
132
-
133
- // Generate all the intermediate parts
134
- for (size_t i = 0 , e = strides.size () - 1 ; i < e; ++i) {
135
- Value thisStride = strides[i];
136
- Value nextStride = strides[i + 1 ];
137
- Value modulus = emitModTerm (thisStride);
138
- // We know both inputs are positive, so floorDiv == div.
139
- // This could potentially be a divui, but it's not clear if that would
140
- // cause issues.
141
- Value divided = rewriter.create <arith::DivSIOp>(loc, modulus, nextStride);
142
- results.push_back (divided);
143
- }
144
-
145
- results.push_back (emitModTerm (strides.back ()));
146
-
147
- rewriter.replaceOp (op, results);
148
- return success ();
202
+ return affine::lowerAffineDelinearizeIndexOp (rewriter, op);
149
203
}
150
204
};
151
205
152
- // / Lowers `affine.linearize_index` into a sequence of multiplications and
153
- // / additions. Make a best effort to sort the input indices so that
154
- // / the most loop-invariant terms are at the left of the additions
155
- // / to enable loop-invariant code motion.
156
206
struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
157
207
using OpRewritePattern::OpRewritePattern;
158
208
LogicalResult matchAndRewrite (AffineLinearizeIndexOp op,
159
209
PatternRewriter &rewriter) const override {
160
- // Should be folded away, included here for safety.
161
- if (op.getMultiIndex ().empty ()) {
162
- rewriter.replaceOpWithNewOp <arith::ConstantIndexOp>(op, 0 );
163
- return success ();
164
- }
165
-
166
- Location loc = op.getLoc ();
167
- ValueRange multiIndex = op.getMultiIndex ();
168
- size_t numIndexes = multiIndex.size ();
169
- ArrayRef<int64_t > staticBasis = op.getStaticBasis ();
170
- if (numIndexes == staticBasis.size ())
171
- staticBasis = staticBasis.drop_front ();
172
-
173
- SmallVector<Value> strides =
174
- computeStrides (loc, rewriter, op.getDynamicBasis (), staticBasis,
175
- /* knownNonNegative=*/ op.getDisjoint ());
176
- SmallVector<std::pair<Value, int64_t >> scaledValues;
177
- scaledValues.reserve (numIndexes);
178
-
179
- // Note: strides doesn't contain a value for the final element (stride 1)
180
- // and everything else lines up. We use the "mutable" accessor so we can get
181
- // our hands on an `OpOperand&` for the loop invariant counting function.
182
- for (auto [stride, idxOp] :
183
- llvm::zip_equal (strides, llvm::drop_end (op.getMultiIndexMutable ()))) {
184
- Value scaledIdx = rewriter.create <arith::MulIOp>(
185
- loc, idxOp.get (), stride, arith::IntegerOverflowFlags::nsw);
186
- int64_t numHoistableLoops = numEnclosingInvariantLoops (idxOp);
187
- scaledValues.emplace_back (scaledIdx, numHoistableLoops);
188
- }
189
- scaledValues.emplace_back (
190
- multiIndex.back (),
191
- numEnclosingInvariantLoops (op.getMultiIndexMutable ()[numIndexes - 1 ]));
192
-
193
- // Sort by how many enclosing loops there are, ties implicitly broken by
194
- // size of the stride.
195
- llvm::stable_sort (scaledValues,
196
- [&](auto l, auto r) { return l.second > r.second ; });
197
-
198
- Value result = scaledValues.front ().first ;
199
- for (auto [scaledValue, numHoistableLoops] :
200
- llvm::drop_begin (scaledValues)) {
201
- std::ignore = numHoistableLoops;
202
- result = rewriter.create <arith::AddIOp>(loc, result, scaledValue,
203
- arith::IntegerOverflowFlags::nsw);
204
- }
205
- rewriter.replaceOp (op, result);
206
- return success ();
210
+ return affine::lowerAffineLinearizeIndexOp (rewriter, op);
207
211
}
208
212
};
209
213
0 commit comments