Skip to content

Commit 7421040

Browse files
authored
[mlir] Move supplemental patterns before op replacement (#66959)
This moves the C++ code generated from supplemental patterns before op replacement. It is necessary if the supllemental patterns need to access the source op.
1 parent a1584dd commit 7421040

File tree

1 file changed

+14
-11
lines changed

1 file changed

+14
-11
lines changed

mlir/tools/mlir-tblgen/RewriterGen.cpp

+14-11
Original file line numberDiff line numberDiff line change
@@ -1173,9 +1173,22 @@ void PatternEmitter::emitRewriteLogic() {
11731173
os << val << ";\n";
11741174
}
11751175

1176+
auto processSupplementalPatterns = [&]() {
1177+
int numSupplementalPatterns = pattern.getNumSupplementalPatterns();
1178+
for (int i = 0, offset = -numSupplementalPatterns;
1179+
i < numSupplementalPatterns; ++i) {
1180+
DagNode resultTree = pattern.getSupplementalPattern(i);
1181+
auto val = handleResultPattern(resultTree, offset++, 0);
1182+
if (resultTree.isNativeCodeCall() &&
1183+
resultTree.getNumReturnsOfNativeCode() == 0)
1184+
os << val << ";\n";
1185+
}
1186+
};
1187+
11761188
if (numExpectedResults == 0) {
11771189
assert(replStartIndex >= numResultPatterns &&
11781190
"invalid auxiliary vs. replacement pattern division!");
1191+
processSupplementalPatterns();
11791192
// No result to replace. Just erase the op.
11801193
os << "rewriter.eraseOp(op0);\n";
11811194
} else {
@@ -1197,20 +1210,10 @@ void PatternEmitter::emitRewriteLogic() {
11971210
" tblgen_repl_values.push_back(v);\n}\n",
11981211
"\n");
11991212
}
1213+
processSupplementalPatterns();
12001214
os << "\nrewriter.replaceOp(op0, tblgen_repl_values);\n";
12011215
}
12021216

1203-
// Process supplemtal patterns.
1204-
int numSupplementalPatterns = pattern.getNumSupplementalPatterns();
1205-
for (int i = 0, offset = -numSupplementalPatterns;
1206-
i < numSupplementalPatterns; ++i) {
1207-
DagNode resultTree = pattern.getSupplementalPattern(i);
1208-
auto val = handleResultPattern(resultTree, offset++, 0);
1209-
if (resultTree.isNativeCodeCall() &&
1210-
resultTree.getNumReturnsOfNativeCode() == 0)
1211-
os << val << ";\n";
1212-
}
1213-
12141217
LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n");
12151218
}
12161219

0 commit comments

Comments
 (0)