Skip to content

Commit 8c62fca

Browse files
authored
Merge pull request #2105 from andi4191/anurag.dixit/aten_tile
feat: Added support for aten::tile converter
2 parents a052cf0 + b7b2725 commit 8c62fca

File tree

9 files changed

+193
-0
lines changed

9 files changed

+193
-0
lines changed

core/lowering/lowering.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, std::vector<torch::jit::I
152152
passes::ReplaceScalarImplicit(g);
153153
passes::RewriteInputsWithParams(g, params);
154154
passes::ReplaceAtenPad(g);
155+
passes::ReplaceTileWithRepeat(g);
155156
LOG_GRAPH(*g);
156157
}
157158

core/lowering/passes/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ cc_library(
3131
"replace_aten_pad.cpp",
3232
"rewrite_inputs_with_params.cpp",
3333
"silu_to_sigmoid_multiplication.cpp",
34+
"tile_to_repeat.cpp",
3435
"unpack_addmm.cpp",
3536
"unpack_batch_norm.cpp",
3637
"unpack_hardsigmoid.cpp",

core/lowering/passes/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ target_sources(${lib_name}
66
"${CMAKE_CURRENT_SOURCE_DIR}/linear_to_addmm.cpp"
77
"${CMAKE_CURRENT_SOURCE_DIR}/module_fallback.cpp"
88
"${CMAKE_CURRENT_SOURCE_DIR}/op_aliasing.cpp"
9+
"${CMAKE_CURRENT_SOURCE_DIR}/tile_to_repeat.cpp"
910
"${CMAKE_CURRENT_SOURCE_DIR}/reduce_gelu.cpp"
1011
"${CMAKE_CURRENT_SOURCE_DIR}/reduce_remainder.cpp"
1112
"${CMAKE_CURRENT_SOURCE_DIR}/reduce_to.cpp"

core/lowering/passes/passes.h

+1
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ void UnpackAndCastNumToTensor(std::shared_ptr<torch::jit::Graph>& graph, std::st
5151
void UnpackAndCastFull(std::shared_ptr<torch::jit::Graph>& graph, std::string target_device_name);
5252
void ReplaceScalarImplicit(std::shared_ptr<torch::jit::Graph>& graph);
5353
void ReplaceAtenPad(std::shared_ptr<torch::jit::Graph>& graph);
54+
void ReplaceTileWithRepeat(std::shared_ptr<torch::jit::Graph>& graph);
5455

5556
// utility functions exposed for testing
5657
std::string unmangle_cls_name(const std::string& name);
+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#include "core/util/prelude.h"
2+
#include "torch/csrc/jit/passes/subgraph_rewrite.h"
3+
4+
namespace torch_tensorrt {
5+
namespace core {
6+
namespace lowering {
7+
namespace passes {
8+
void ReplaceTileWithRepeat(std::shared_ptr<torch::jit::Graph>& graph) {
9+
std::string tile_pattern = R"IR(
10+
graph(%input, %1):
11+
%2 = aten::tile(%input, %1)
12+
return (%2))IR";
13+
std::string repeat_pattern = R"IR(
14+
graph(%input, %1):
15+
%2 = aten::repeat(%input, %1)
16+
return (%2))IR";
17+
torch::jit::SubgraphRewriter tile_to_repeat;
18+
tile_to_repeat.RegisterRewritePattern(tile_pattern, repeat_pattern);
19+
tile_to_repeat.runOnGraph(graph);
20+
LOG_GRAPH("Mapping tile -> repeat: " << *graph);
21+
}
22+
} // namespace passes
23+
} // namespace lowering
24+
} // namespace core
25+
} // namespace torch_tensorrt

docsrc/contributors/lowering.rst

+7
Original file line numberDiff line numberDiff line change
@@ -205,3 +205,10 @@ Unroll Loops
205205
`torch/csrc/jit/passes/loop_unrolling.h <https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/passes/loop_unrolling.h>`_
206206

207207
Unrolls the operations of compatable loops (e.g. sufficently short) so that you only have to go through the loop once.
208+
209+
Replace Tile with Repeat
210+
***************************************
211+
212+
`Torch-TensorRT/core/lowering/passes/tile_to_repeat.cpp <https://github.com/pytorch/TensorRT/blob/master/core/lowering/passes/tile_to_repeat.cpp>`_
213+
214+
Removes dropout operators since we are doing inference.

tests/core/conversion/converters/test_expand.cpp

+126
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include <torch/torch.h>
22
#include <string>
33
#include "core/compiler.h"
4+
#include "core/lowering/passes/passes.h"
45
#include "gtest/gtest.h"
56
#include "tests/util/util.h"
67
#include "torch/csrc/jit/ir/irparser.h"
@@ -670,6 +671,131 @@ TEST(Converters, ATenRepeatInterleave3dScalarNoDimConvertsCorrectlyWithDynamicIn
670671
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
671672
}
672673

674+
TEST(Converters, ATenTileConvertsCorrectly) {
675+
const auto graph = R"IR(
676+
graph(%x.1 : Tensor):
677+
%2 : int[] = prim::Constant[value=[4, 1]]()
678+
%3 : Tensor = aten::tile(%x.1, %2)
679+
return (%3))IR";
680+
681+
auto g = std::make_shared<torch::jit::Graph>();
682+
683+
torch::jit::parseIR(graph, g.get());
684+
torch_tensorrt::core::lowering::passes::ReplaceTileWithRepeat(g);
685+
686+
auto in = at::randint(1, 10, {1, 3}, {at::kCUDA});
687+
688+
auto jit_in = at::clone(in);
689+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
690+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
691+
692+
auto trt_in = at::clone(jit_in);
693+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
694+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});
695+
696+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
697+
}
698+
699+
TEST(Converters, ATenTileRepeatRankConvertsCorrectly) {
700+
const auto graph = R"IR(
701+
graph(%x.1 : Tensor):
702+
%2 : int[] = prim::Constant[value=[4, 1, 2]]()
703+
%3 : Tensor = aten::tile(%x.1, %2)
704+
return (%3))IR";
705+
706+
auto g = std::make_shared<torch::jit::Graph>();
707+
708+
torch::jit::parseIR(graph, g.get());
709+
torch_tensorrt::core::lowering::passes::ReplaceTileWithRepeat(g);
710+
711+
auto in = at::randint(1, 10, {1, 3}, {at::kCUDA});
712+
713+
auto jit_in = at::clone(in);
714+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
715+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
716+
717+
auto trt_in = at::clone(jit_in);
718+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
719+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});
720+
721+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
722+
}
723+
724+
TEST(Converters, ATenTileConvertsCorrectlyWithDynamicInput) {
725+
const auto graph = R"IR(
726+
graph(%x.1 : Tensor):
727+
%2 : int[] = prim::Constant[value=[4, 1]]()
728+
%3 : Tensor = aten::tile(%x.1, %2)
729+
return (%3))IR";
730+
731+
auto g = std::make_shared<torch::jit::Graph>();
732+
733+
torch::jit::parseIR(graph, g.get());
734+
torch_tensorrt::core::lowering::passes::ReplaceTileWithRepeat(g);
735+
736+
auto in = at::randint(1, 10, {1, 3}, {at::kCUDA});
737+
738+
auto jit_in = at::clone(in);
739+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
740+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
741+
742+
auto trt_in = at::clone(jit_in);
743+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
744+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in});
745+
746+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
747+
}
748+
749+
TEST(Converters, ATenTile3dConvertsCorrectly) {
750+
const auto graph = R"IR(
751+
graph(%x.1 : Tensor):
752+
%2 : int[] = prim::Constant[value=[2, 2, 2]]()
753+
%3 : Tensor = aten::tile(%x.1, %2)
754+
return (%3))IR";
755+
756+
auto g = std::make_shared<torch::jit::Graph>();
757+
758+
torch::jit::parseIR(graph, g.get());
759+
torch_tensorrt::core::lowering::passes::ReplaceTileWithRepeat(g);
760+
761+
auto in = at::randint(1, 10, {2, 3, 2}, {at::kCUDA});
762+
763+
auto jit_in = at::clone(in);
764+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
765+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
766+
767+
auto trt_in = at::clone(jit_in);
768+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
769+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});
770+
771+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
772+
}
773+
774+
TEST(Converters, ATenTile3dConvertsCorrectlyWithDynamicInput) {
775+
const auto graph = R"IR(
776+
graph(%x.1 : Tensor):
777+
%2 : int[] = prim::Constant[value=[2, 2, 2]]()
778+
%3 : Tensor = aten::tile(%x.1, %2)
779+
return (%3))IR";
780+
781+
auto g = std::make_shared<torch::jit::Graph>();
782+
783+
torch::jit::parseIR(graph, g.get());
784+
torch_tensorrt::core::lowering::passes::ReplaceTileWithRepeat(g);
785+
786+
auto in = at::randint(1, 10, {2, 3, 2}, {at::kCUDA});
787+
788+
auto jit_in = at::clone(in);
789+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
790+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
791+
792+
auto trt_in = at::clone(jit_in);
793+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
794+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in});
795+
796+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
797+
}
798+
673799
TEST(Converters, ATenMeshGridConvertsCorrectly) {
674800
const auto graph = R"IR(
675801
graph(%x : Tensor, %y : Tensor, %z : Tensor):

tests/core/lowering/BUILD

+5
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,10 @@ lowering_test(
103103
name = "test_replace_aten_pad_pass",
104104
)
105105

106+
lowering_test(
107+
name = "test_tile_to_repeat_pass",
108+
)
109+
106110
test_suite(
107111
name = "lowering_tests",
108112
tests = [
@@ -122,6 +126,7 @@ test_suite(
122126
":test_remove_unnecessary_casts",
123127
":test_replace_aten_pad_pass",
124128
":test_rewrite_inputs_with_params",
129+
":test_tile_to_repeat_pass",
125130
":test_unpack_hardsigmoid",
126131
":test_unpack_hardswish",
127132
":test_unpack_reduce_ops",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#include <string>
2+
#include "core/compiler.h"
3+
#include "core/lowering/passes/passes.h"
4+
#include "gtest/gtest.h"
5+
#include "tests/util/util.h"
6+
#include "torch/csrc/jit/ir/irparser.h"
7+
#include "torch/csrc/jit/ir/subgraph_matcher.h"
8+
9+
TEST(LoweringPasses, TileToRepeatCorrectly) {
10+
std::string source_graph = R"IR(
11+
graph(%input, %dim):
12+
%o : Tensor = aten::tile(%input, %dim)
13+
return (%o))IR";
14+
std::string target_graph = R"IR(
15+
graph(%input, %dim):
16+
%o : Tensor = aten::repeat(%input, %dim)
17+
return (%o))IR";
18+
auto sg = std::make_shared<torch::jit::Graph>();
19+
torch::jit::parseIR(source_graph, sg.get());
20+
torch_tensorrt::core::lowering::passes::ReplaceTileWithRepeat(sg);
21+
22+
auto tg = std::make_shared<torch::jit::Graph>();
23+
torch::jit::parseIR(target_graph, tg.get());
24+
25+
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
26+
}

0 commit comments

Comments
 (0)