Skip to content

Commit 6ce8653

Browse files
authored
[mlir][cf] Preserve branch weights during cf.cond_br canonicalization. (#144822)
1 parent 0816bb3 commit 6ce8653

File tree

3 files changed

+42
-9
lines changed

3 files changed

+42
-9
lines changed

mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -153,17 +153,25 @@ def CondBranchOp
153153
let builders = [OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
154154
"ValueRange":$trueOperands,
155155
"Block *":$falseDest,
156-
"ValueRange":$falseOperands),
156+
"ValueRange":$falseOperands,
157+
CArg<"ArrayRef<int32_t>", "{}">:$branchWeights),
157158
[{
158-
build($_builder, $_state, condition, trueOperands, falseOperands, /*branch_weights=*/{}, trueDest,
159-
falseDest);
159+
DenseI32ArrayAttr weights;
160+
if (!branchWeights.empty())
161+
weights = $_builder.getDenseI32ArrayAttr(branchWeights);
162+
build($_builder, $_state, condition, trueOperands, falseOperands,
163+
weights, trueDest, falseDest);
160164
}]>,
161165
OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
162166
"Block *":$falseDest,
163-
CArg<"ValueRange", "{}">:$falseOperands),
167+
CArg<"ValueRange", "{}">:$falseOperands,
168+
CArg<"ArrayRef<int32_t>", "{}">:$branchWeights),
164169
[{
165-
build($_builder, $_state, condition, trueDest, ValueRange(), falseDest,
166-
falseOperands);
170+
DenseI32ArrayAttr weights;
171+
if (!branchWeights.empty())
172+
weights = $_builder.getDenseI32ArrayAttr(branchWeights);
173+
build($_builder, $_state, condition, ValueRange(), falseOperands,
174+
weights, trueDest, falseDest);
167175
}]>];
168176

169177
let extraClassDeclaration = [{

mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -265,9 +265,9 @@ struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> {
265265
return failure();
266266

267267
// Create a new branch with the collapsed successors.
268-
rewriter.replaceOpWithNewOp<CondBranchOp>(condbr, condbr.getCondition(),
269-
trueDest, trueDestOperands,
270-
falseDest, falseDestOperands);
268+
rewriter.replaceOpWithNewOp<CondBranchOp>(
269+
condbr, condbr.getCondition(), trueDest, trueDestOperands, falseDest,
270+
falseDestOperands, condbr.getWeights());
271271
return success();
272272
}
273273
};

mlir/test/Dialect/ControlFlow/canonicalize.mlir

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,31 @@ func.func @cond_br_and_br_folding(%a : i32) {
102102

103103
/// Test that pass-through successors of CondBranchOp get folded.
104104

105+
// Test that the weights are preserved:
106+
// CHECK-LABEL: func.func @cond_br_passthrough_weights(
107+
// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i1) -> i32 {
108+
func.func @cond_br_passthrough_weights(%arg0 : i32, %arg1 : i32, %cond : i1) -> i32 {
109+
// CHECK: cf.cond_br %[[ARG2]] weights([30, 70]), ^bb1, ^bb2
110+
// CHECK: ^bb1:
111+
// CHECK: return %[[ARG0]] : i32
112+
// CHECK: ^bb2:
113+
// CHECK: return %[[ARG1]] : i32
114+
// CHECK: }
115+
cf.cond_br %cond weights([30,70]), ^bb1, ^bb3
116+
117+
^bb1:
118+
cf.br ^bb2
119+
120+
^bb3:
121+
cf.br ^bb4
122+
123+
^bb2:
124+
return %arg0 : i32
125+
126+
^bb4:
127+
return %arg1 : i32
128+
}
129+
105130
// CHECK-LABEL: func @cond_br_passthrough(
106131
// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[COND:.*]]: i1
107132
func.func @cond_br_passthrough(%arg0 : i32, %arg1 : i32, %arg2 : i32, %cond : i1) -> (i32, i32) {

0 commit comments

Comments
 (0)