Skip to content

Commit b7b2725

Browse files
author
Anurag Dixit
committed
feat: Moved from converter to lowering pass
Signed-off-by: Anurag Dixit <[email protected]>
1 parent 6b51057 commit b7b2725

File tree

10 files changed

+127
-61
lines changed

10 files changed

+127
-61
lines changed

core/conversion/converters/impl/expand.cpp

+54-61
Original file line numberDiff line numberDiff line change
@@ -194,60 +194,6 @@ bool add_expand_dynamic(
194194
return true;
195195
}
196196

197-
bool add_repeat(ConversionCtx* ctx, const torch::jit::Node* n, args& args, const std::string& layer) {
198-
auto in = args[0].ITensorOrFreeze(ctx);
199-
auto input_dims = in->getDimensions();
200-
auto repeats = args[1].unwrapToIntList().vec();
201-
int repeats_rank = repeats.size();
202-
TORCHTRT_CHECK(
203-
repeats_rank >= input_dims.nbDims,
204-
"Number of repeat dimensions cannot be smaller than number of input dimensions");
205-
206-
auto num_expand_dims = repeats_rank - input_dims.nbDims;
207-
208-
if (ctx->input_is_dynamic) {
209-
int input_rank = input_dims.nbDims;
210-
int output_rank = repeats_rank;
211-
auto new_input_shape_tensor = concat(output_rank, input_rank, ctx, in);
212-
213-
auto shuffle = ctx->net->addShuffle(*in);
214-
shuffle->setInput(1, *new_input_shape_tensor);
215-
in = shuffle->getOutput(0);
216-
} else {
217-
if (num_expand_dims > 0) {
218-
nvinfer1::Dims reshape_dims;
219-
reshape_dims.nbDims = repeats.size();
220-
for (int i = 0; i < num_expand_dims; i++) {
221-
reshape_dims.d[i] = 1;
222-
}
223-
for (int i = 0; i < input_dims.nbDims; i++) {
224-
reshape_dims.d[num_expand_dims + i] = input_dims.d[i];
225-
}
226-
// Add a reshape layer to expand dims
227-
auto reshape_layer = ctx->net->addShuffle(*in);
228-
reshape_layer->setReshapeDimensions(reshape_dims);
229-
in = reshape_layer->getOutput(0);
230-
LOG_DEBUG("Input reshaped to : " << in->getDimensions() << " from " << input_dims);
231-
}
232-
LOG_DEBUG("Repeats: " << repeats);
233-
}
234-
235-
// Concat across all repeat axes.
236-
for (int i = repeats.size() - 1; i >= 0; --i) {
237-
std::vector<nvinfer1::ITensor*> tensors_vec;
238-
for (int j = 0; j < repeats[i]; j++) {
239-
tensors_vec.push_back(in);
240-
}
241-
auto concat_layer = ctx->net->addConcatenation(tensors_vec.data(), tensors_vec.size());
242-
concat_layer->setAxis(i);
243-
in = concat_layer->getOutput(0);
244-
}
245-
246-
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], in);
247-
LOG_DEBUG(layer << " layer output tensor shape: " << out->getDimensions());
248-
return true;
249-
}
250-
251197
auto expand_registrations TORCHTRT_UNUSED =
252198
RegisterNodeConversionPatterns()
253199
.pattern(
@@ -284,7 +230,59 @@ auto expand_registrations TORCHTRT_UNUSED =
284230
.pattern(
285231
{"aten::repeat(Tensor self, int[] repeats) -> (Tensor)",
286232
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
287-
return add_repeat(ctx, n, args, "Repeat");
233+
auto in = args[0].ITensorOrFreeze(ctx);
234+
auto input_dims = in->getDimensions();
235+
auto repeats = args[1].unwrapToIntList().vec();
236+
int repeats_rank = repeats.size();
237+
TORCHTRT_CHECK(
238+
repeats_rank >= input_dims.nbDims,
239+
"Number of repeat dimensions cannot be smaller than number of input dimensions");
240+
auto num_expand_dims = repeats_rank - input_dims.nbDims;
241+
242+
if (ctx->input_is_dynamic) {
243+
int input_rank = input_dims.nbDims;
244+
int output_rank = repeats_rank;
245+
auto new_input_shape_tensor = concat(output_rank, input_rank, ctx, in);
246+
247+
// Add a reshape layer to expand dims
248+
auto shuffle = ctx->net->addShuffle(*in);
249+
shuffle->setInput(1, *new_input_shape_tensor);
250+
in = shuffle->getOutput(0);
251+
} else {
252+
if (num_expand_dims > 0) {
253+
nvinfer1::Dims reshape_dims;
254+
reshape_dims.nbDims = repeats.size();
255+
for (int i = 0; i < num_expand_dims; i++) {
256+
reshape_dims.d[i] = 1;
257+
}
258+
for (int i = 0; i < input_dims.nbDims; i++) {
259+
reshape_dims.d[num_expand_dims + i] = input_dims.d[i];
260+
}
261+
// Add a reshape layer to expand dims
262+
auto reshape_layer = ctx->net->addShuffle(*in);
263+
reshape_layer->setReshapeDimensions(reshape_dims);
264+
in = reshape_layer->getOutput(0);
265+
LOG_DEBUG("Input reshaped to : " << in->getDimensions() << " from " << input_dims);
266+
}
267+
LOG_DEBUG("Repeats: " << repeats);
268+
}
269+
270+
// Concat across all repeat axes.
271+
// TODO: Implementation might not be performant. Explore other strategies to improve performance.
272+
for (int i = repeats.size() - 1; i >= 0; --i) {
273+
std::vector<nvinfer1::ITensor*> tensors_vec;
274+
for (int j = 0; j < repeats[i]; j++) {
275+
tensors_vec.push_back(in);
276+
}
277+
auto concat_layer = ctx->net->addConcatenation(tensors_vec.data(), tensors_vec.size());
278+
concat_layer->setAxis(i);
279+
in = concat_layer->getOutput(0);
280+
}
281+
282+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], in);
283+
284+
LOG_DEBUG("Repeat layer output tensor shape: " << out->getDimensions());
285+
return true;
288286
}})
289287
.pattern(
290288
{"aten::repeat_interleave.self_int(Tensor self, int repeats, int? dim=None, *, int? output_size=None) -> (Tensor)",
@@ -397,11 +395,6 @@ auto expand_registrations TORCHTRT_UNUSED =
397395

398396
return true;
399397
}})
400-
.pattern(
401-
{"aten::tile(Tensor self, int[] dims) -> (Tensor)",
402-
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
403-
return add_repeat(ctx, n, args, "Tile");
404-
}})
405398
.pattern(
406399
{"aten::meshgrid(Tensor[] tensors) -> (Tensor[])",
407400
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
@@ -491,4 +484,4 @@ auto expand_registrations TORCHTRT_UNUSED =
491484
} // namespace converters
492485
} // namespace conversion
493486
} // namespace core
494-
} // namespace torch_tensorrt
487+
} // namespace torch_tensorrt

core/lowering/lowering.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, std::vector<torch::jit::I
153153
passes::ReplaceScalarImplicit(g);
154154
passes::RewriteInputsWithParams(g, params);
155155
passes::ReplaceAtenPad(g);
156+
passes::ReplaceTileWithRepeat(g);
156157
LOG_GRAPH(*g);
157158
}
158159

core/lowering/passes/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ cc_library(
3131
"replace_aten_pad.cpp",
3232
"rewrite_inputs_with_params.cpp",
3333
"silu_to_sigmoid_multiplication.cpp",
34+
"tile_to_repeat.cpp",
3435
"unpack_addmm.cpp",
3536
"unpack_batch_norm.cpp",
3637
"unpack_hardsigmoid.cpp",

core/lowering/passes/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ target_sources(${lib_name}
66
"${CMAKE_CURRENT_SOURCE_DIR}/linear_to_addmm.cpp"
77
"${CMAKE_CURRENT_SOURCE_DIR}/module_fallback.cpp"
88
"${CMAKE_CURRENT_SOURCE_DIR}/op_aliasing.cpp"
9+
"${CMAKE_CURRENT_SOURCE_DIR}/tile_to_repeat.cpp"
910
"${CMAKE_CURRENT_SOURCE_DIR}/reduce_gelu.cpp"
1011
"${CMAKE_CURRENT_SOURCE_DIR}/reduce_remainder.cpp"
1112
"${CMAKE_CURRENT_SOURCE_DIR}/reduce_to.cpp"

core/lowering/passes/passes.h

+1
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ void UnpackAndCastNumToTensor(std::shared_ptr<torch::jit::Graph>& graph, std::st
5050
void UnpackAndCastFull(std::shared_ptr<torch::jit::Graph>& graph, std::string target_device_name);
5151
void ReplaceScalarImplicit(std::shared_ptr<torch::jit::Graph>& graph);
5252
void ReplaceAtenPad(std::shared_ptr<torch::jit::Graph>& graph);
53+
void ReplaceTileWithRepeat(std::shared_ptr<torch::jit::Graph>& graph);
5354

5455
// utility functions exposed for testing
5556
std::string unmangle_cls_name(const std::string& name);
+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#include "core/util/prelude.h"
2+
#include "torch/csrc/jit/passes/subgraph_rewrite.h"
3+
4+
namespace torch_tensorrt {
5+
namespace core {
6+
namespace lowering {
7+
namespace passes {
8+
void ReplaceTileWithRepeat(std::shared_ptr<torch::jit::Graph>& graph) {
9+
std::string tile_pattern = R"IR(
10+
graph(%input, %1):
11+
%2 = aten::tile(%input, %1)
12+
return (%2))IR";
13+
std::string repeat_pattern = R"IR(
14+
graph(%input, %1):
15+
%2 = aten::repeat(%input, %1)
16+
return (%2))IR";
17+
torch::jit::SubgraphRewriter tile_to_repeat;
18+
tile_to_repeat.RegisterRewritePattern(tile_pattern, repeat_pattern);
19+
tile_to_repeat.runOnGraph(graph);
20+
LOG_GRAPH("Mapping tile -> repeat: " << *graph);
21+
}
22+
} // namespace passes
23+
} // namespace lowering
24+
} // namespace core
25+
} // namespace torch_tensorrt

docsrc/contributors/lowering.rst

+7
Original file line numberDiff line numberDiff line change
@@ -205,3 +205,10 @@ Unroll Loops
205205
`torch/csrc/jit/passes/loop_unrolling.h <https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/passes/loop_unrolling.h>`_
206206

207207
Unrolls the operations of compatable loops (e.g. sufficently short) so that you only have to go through the loop once.
208+
209+
Replace Tile with Repeat
210+
***************************************
211+
212+
`Torch-TensorRT/core/lowering/passes/tile_to_repeat.cpp <https://github.com/pytorch/TensorRT/blob/master/core/lowering/passes/tile_to_repeat.cpp>`_
213+
214+
Removes dropout operators since we are doing inference.

tests/core/conversion/converters/test_expand.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include <torch/torch.h>
22
#include <string>
33
#include "core/compiler.h"
4+
#include "core/lowering/passes/passes.h"
45
#include "gtest/gtest.h"
56
#include "tests/util/util.h"
67
#include "torch/csrc/jit/ir/irparser.h"
@@ -680,6 +681,7 @@ TEST(Converters, ATenTileConvertsCorrectly) {
680681
auto g = std::make_shared<torch::jit::Graph>();
681682

682683
torch::jit::parseIR(graph, g.get());
684+
torch_tensorrt::core::lowering::passes::ReplaceTileWithRepeat(g);
683685

684686
auto in = at::randint(1, 10, {1, 3}, {at::kCUDA});
685687

@@ -704,6 +706,7 @@ TEST(Converters, ATenTileRepeatRankConvertsCorrectly) {
704706
auto g = std::make_shared<torch::jit::Graph>();
705707

706708
torch::jit::parseIR(graph, g.get());
709+
torch_tensorrt::core::lowering::passes::ReplaceTileWithRepeat(g);
707710

708711
auto in = at::randint(1, 10, {1, 3}, {at::kCUDA});
709712

@@ -728,6 +731,7 @@ TEST(Converters, ATenTileConvertsCorrectlyWithDynamicInput) {
728731
auto g = std::make_shared<torch::jit::Graph>();
729732

730733
torch::jit::parseIR(graph, g.get());
734+
torch_tensorrt::core::lowering::passes::ReplaceTileWithRepeat(g);
731735

732736
auto in = at::randint(1, 10, {1, 3}, {at::kCUDA});
733737

@@ -752,6 +756,7 @@ TEST(Converters, ATenTile3dConvertsCorrectly) {
752756
auto g = std::make_shared<torch::jit::Graph>();
753757

754758
torch::jit::parseIR(graph, g.get());
759+
torch_tensorrt::core::lowering::passes::ReplaceTileWithRepeat(g);
755760

756761
auto in = at::randint(1, 10, {2, 3, 2}, {at::kCUDA});
757762

@@ -776,6 +781,7 @@ TEST(Converters, ATenTile3dConvertsCorrectlyWithDynamicInput) {
776781
auto g = std::make_shared<torch::jit::Graph>();
777782

778783
torch::jit::parseIR(graph, g.get());
784+
torch_tensorrt::core::lowering::passes::ReplaceTileWithRepeat(g);
779785

780786
auto in = at::randint(1, 10, {2, 3, 2}, {at::kCUDA});
781787

tests/core/lowering/BUILD

+5
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,10 @@ lowering_test(
103103
name = "test_replace_aten_pad_pass",
104104
)
105105

106+
lowering_test(
107+
name = "test_tile_to_repeat_pass",
108+
)
109+
106110
test_suite(
107111
name = "lowering_tests",
108112
tests = [
@@ -122,6 +126,7 @@ test_suite(
122126
":test_remove_unnecessary_casts",
123127
":test_replace_aten_pad_pass",
124128
":test_rewrite_inputs_with_params",
129+
":test_tile_to_repeat_pass",
125130
":test_unpack_hardsigmoid",
126131
":test_unpack_hardswish",
127132
":test_unpack_reduce_ops",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#include <string>
2+
#include "core/compiler.h"
3+
#include "core/lowering/passes/passes.h"
4+
#include "gtest/gtest.h"
5+
#include "tests/util/util.h"
6+
#include "torch/csrc/jit/ir/irparser.h"
7+
#include "torch/csrc/jit/ir/subgraph_matcher.h"
8+
9+
TEST(LoweringPasses, TileToRepeatCorrectly) {
10+
std::string source_graph = R"IR(
11+
graph(%input, %dim):
12+
%o : Tensor = aten::tile(%input, %dim)
13+
return (%o))IR";
14+
std::string target_graph = R"IR(
15+
graph(%input, %dim):
16+
%o : Tensor = aten::repeat(%input, %dim)
17+
return (%o))IR";
18+
auto sg = std::make_shared<torch::jit::Graph>();
19+
torch::jit::parseIR(source_graph, sg.get());
20+
torch_tensorrt::core::lowering::passes::ReplaceTileWithRepeat(sg);
21+
22+
auto tg = std::make_shared<torch::jit::Graph>();
23+
torch::jit::parseIR(target_graph, tg.get());
24+
25+
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
26+
}

0 commit comments

Comments
 (0)