Skip to content

Commit e91ff94

Browse files
authored
Enable Pad->Conv(no pads) fusion (#22001)
### Description ### Motivation and Context For some model has pattern Pad -> Conv. If the Conv doesn't have pads attributes, the Pad can be fused into Conv.
1 parent 20d9464 commit e91ff94

File tree

3 files changed

+54
-6
lines changed

3 files changed

+54
-6
lines changed

onnxruntime/core/optimizer/pad_fusion.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,15 @@ bool VerifyNotCastChild(const Node& child_node) {
3131
return false;
3232
}
3333

34-
// This pass currently assumed that this attribute already exists on the child node
35-
if (child_node.GetAttributes().find("pads") == child_node.GetAttributes().end()) {
36-
return false;
37-
}
38-
3934
return true;
4035
}
4136

4237
void UpdatePaddingAttribute(Node& child_node, const std::vector<int64_t>& pads_values, const uint32_t pads_size) {
38+
if (child_node.GetAttributes().find("pads") == child_node.GetAttributes().end()) {
39+
std::vector<int64_t> pads(pads_size - 4, 0);
40+
child_node.AddAttribute("pads", pads);
41+
}
42+
4343
auto child_pads = child_node.GetMutableAttributes()["pads"].mutable_ints();
4444
uint32_t child_pads_size = static_cast<uint32_t>(child_pads->size());
4545

@@ -162,4 +162,4 @@ Status PadFusion::Apply(Graph& graph, Node& pad_node, RewriteRuleEffect& rule_ef
162162
rule_effect = RewriteRuleEffect::kRemovedCurrentNode;
163163
return Status::OK();
164164
}
165-
} // namespace onnxruntime
165+
} // namespace onnxruntime

onnxruntime/test/optimizer/graph_transform_test.cc

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1469,6 +1469,54 @@ TEST_F(GraphTransformationTests, FusePadWithConv) {
14691469
}
14701470
}
14711471

1472+
TEST_F(GraphTransformationTests, FusePadWithNoPadsConv) {
1473+
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-pad-nopadsconv.onnx";
1474+
1475+
std::shared_ptr<Model> p_model;
1476+
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
1477+
Graph& graph = p_model->MainGraph();
1478+
1479+
std::vector<int64_t> expected_pads;
1480+
GraphViewer graphViewer(graph);
1481+
for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) {
1482+
auto& node = *graph.GetNode(node_index);
1483+
if (node.OpType() == "Pad") {
1484+
const auto* pads_proto = graph_utils::GetConstantInitializer(graph, node.InputDefs()[1]->Name());
1485+
Initializer pads{*pads_proto, graph.ModelPath()};
1486+
gsl::span<const int64_t> pads_values = pads.DataAsSpan<int64_t>();
1487+
expected_pads.resize(pads_values.size() - 4);
1488+
1489+
for (uint32_t pads_index = 2, index = 0; pads_index < pads_values.size() / 2; pads_index++, index++) {
1490+
expected_pads[index] = pads_values[pads_index];
1491+
expected_pads[index + (expected_pads.size() / 2)] = pads_values[pads_index + (pads_values.size() / 2)];
1492+
}
1493+
}
1494+
}
1495+
1496+
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
1497+
auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformerL1");
1498+
ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique<PadFusion>()));
1499+
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1));
1500+
1501+
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
1502+
1503+
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
1504+
ASSERT_EQ(op_to_count["Pad"], 0);
1505+
ASSERT_EQ(op_to_count["Conv"], 1);
1506+
1507+
for (auto& node : graph.Nodes()) {
1508+
if (node.OpType() == "Conv") {
1509+
auto child_pads = node.GetMutableAttributes()["pads"].mutable_ints();
1510+
ASSERT_EQ(child_pads->size(), static_cast<int32_t>(expected_pads.size()))
1511+
<< "fusion should produce the same size of pads integer as the Conv node";
1512+
for (uint32_t index = 0; index < expected_pads.size(); index++) {
1513+
ASSERT_EQ(expected_pads[index], child_pads->Get(index))
1514+
<< "fusion does not produce correct padding value";
1515+
}
1516+
}
1517+
}
1518+
}
1519+
14721520
TEST_F(GraphTransformationTests, FusePadWithMaxPool) {
14731521
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-pad-maxpool.onnx";
14741522

Binary file not shown.

0 commit comments

Comments
 (0)