Skip to content

Commit 6b51057

Browse files
author
Anurag Dixit
committed
feat: Added support for aten::tile converter
Signed-off-by: Anurag Dixit <[email protected]>
1 parent e7f4752 commit 6b51057

File tree

2 files changed

+181
-54
lines changed

2 files changed

+181
-54
lines changed

core/conversion/converters/impl/expand.cpp

+61-54
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,60 @@ bool add_expand_dynamic(
194194
return true;
195195
}
196196

197+
bool add_repeat(ConversionCtx* ctx, const torch::jit::Node* n, args& args, const std::string& layer) {
198+
auto in = args[0].ITensorOrFreeze(ctx);
199+
auto input_dims = in->getDimensions();
200+
auto repeats = args[1].unwrapToIntList().vec();
201+
int repeats_rank = repeats.size();
202+
TORCHTRT_CHECK(
203+
repeats_rank >= input_dims.nbDims,
204+
"Number of repeat dimensions cannot be smaller than number of input dimensions");
205+
206+
auto num_expand_dims = repeats_rank - input_dims.nbDims;
207+
208+
if (ctx->input_is_dynamic) {
209+
int input_rank = input_dims.nbDims;
210+
int output_rank = repeats_rank;
211+
auto new_input_shape_tensor = concat(output_rank, input_rank, ctx, in);
212+
213+
auto shuffle = ctx->net->addShuffle(*in);
214+
shuffle->setInput(1, *new_input_shape_tensor);
215+
in = shuffle->getOutput(0);
216+
} else {
217+
if (num_expand_dims > 0) {
218+
nvinfer1::Dims reshape_dims;
219+
reshape_dims.nbDims = repeats.size();
220+
for (int i = 0; i < num_expand_dims; i++) {
221+
reshape_dims.d[i] = 1;
222+
}
223+
for (int i = 0; i < input_dims.nbDims; i++) {
224+
reshape_dims.d[num_expand_dims + i] = input_dims.d[i];
225+
}
226+
// Add a reshape layer to expand dims
227+
auto reshape_layer = ctx->net->addShuffle(*in);
228+
reshape_layer->setReshapeDimensions(reshape_dims);
229+
in = reshape_layer->getOutput(0);
230+
LOG_DEBUG("Input reshaped to : " << in->getDimensions() << " from " << input_dims);
231+
}
232+
LOG_DEBUG("Repeats: " << repeats);
233+
}
234+
235+
// Concat across all repeat axes.
236+
for (int i = repeats.size() - 1; i >= 0; --i) {
237+
std::vector<nvinfer1::ITensor*> tensors_vec;
238+
for (int j = 0; j < repeats[i]; j++) {
239+
tensors_vec.push_back(in);
240+
}
241+
auto concat_layer = ctx->net->addConcatenation(tensors_vec.data(), tensors_vec.size());
242+
concat_layer->setAxis(i);
243+
in = concat_layer->getOutput(0);
244+
}
245+
246+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], in);
247+
LOG_DEBUG(layer << " layer output tensor shape: " << out->getDimensions());
248+
return true;
249+
}
250+
197251
auto expand_registrations TORCHTRT_UNUSED =
198252
RegisterNodeConversionPatterns()
199253
.pattern(
@@ -230,59 +284,7 @@ auto expand_registrations TORCHTRT_UNUSED =
230284
.pattern(
231285
{"aten::repeat(Tensor self, int[] repeats) -> (Tensor)",
232286
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
233-
auto in = args[0].ITensorOrFreeze(ctx);
234-
auto input_dims = in->getDimensions();
235-
auto repeats = args[1].unwrapToIntList().vec();
236-
int repeats_rank = repeats.size();
237-
TORCHTRT_CHECK(
238-
repeats_rank >= input_dims.nbDims,
239-
"Number of repeat dimensions cannot be smaller than number of input dimensions");
240-
auto num_expand_dims = repeats_rank - input_dims.nbDims;
241-
242-
if (ctx->input_is_dynamic) {
243-
int input_rank = input_dims.nbDims;
244-
int output_rank = repeats_rank;
245-
auto new_input_shape_tensor = concat(output_rank, input_rank, ctx, in);
246-
247-
// Add a reshape layer to expand dims
248-
auto shuffle = ctx->net->addShuffle(*in);
249-
shuffle->setInput(1, *new_input_shape_tensor);
250-
in = shuffle->getOutput(0);
251-
} else {
252-
if (num_expand_dims > 0) {
253-
nvinfer1::Dims reshape_dims;
254-
reshape_dims.nbDims = repeats.size();
255-
for (int i = 0; i < num_expand_dims; i++) {
256-
reshape_dims.d[i] = 1;
257-
}
258-
for (int i = 0; i < input_dims.nbDims; i++) {
259-
reshape_dims.d[num_expand_dims + i] = input_dims.d[i];
260-
}
261-
// Add a reshape layer to expand dims
262-
auto reshape_layer = ctx->net->addShuffle(*in);
263-
reshape_layer->setReshapeDimensions(reshape_dims);
264-
in = reshape_layer->getOutput(0);
265-
LOG_DEBUG("Input reshaped to : " << in->getDimensions() << " from " << input_dims);
266-
}
267-
LOG_DEBUG("Repeats: " << repeats);
268-
}
269-
270-
// Concat across all repeat axes.
271-
// TODO: Implementation might not be performant. Explore other strategies to improve performance.
272-
for (int i = repeats.size() - 1; i >= 0; --i) {
273-
std::vector<nvinfer1::ITensor*> tensors_vec;
274-
for (int j = 0; j < repeats[i]; j++) {
275-
tensors_vec.push_back(in);
276-
}
277-
auto concat_layer = ctx->net->addConcatenation(tensors_vec.data(), tensors_vec.size());
278-
concat_layer->setAxis(i);
279-
in = concat_layer->getOutput(0);
280-
}
281-
282-
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], in);
283-
284-
LOG_DEBUG("Repeat layer output tensor shape: " << out->getDimensions());
285-
return true;
287+
return add_repeat(ctx, n, args, "Repeat");
286288
}})
287289
.pattern(
288290
{"aten::repeat_interleave.self_int(Tensor self, int repeats, int? dim=None, *, int? output_size=None) -> (Tensor)",
@@ -395,6 +397,11 @@ auto expand_registrations TORCHTRT_UNUSED =
395397

396398
return true;
397399
}})
400+
.pattern(
401+
{"aten::tile(Tensor self, int[] dims) -> (Tensor)",
402+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
403+
return add_repeat(ctx, n, args, "Tile");
404+
}})
398405
.pattern(
399406
{"aten::meshgrid(Tensor[] tensors) -> (Tensor[])",
400407
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
@@ -484,4 +491,4 @@ auto expand_registrations TORCHTRT_UNUSED =
484491
} // namespace converters
485492
} // namespace conversion
486493
} // namespace core
487-
} // namespace torch_tensorrt
494+
} // namespace torch_tensorrt

tests/core/conversion/converters/test_expand.cpp

+120
Original file line numberDiff line numberDiff line change
@@ -670,6 +670,126 @@ TEST(Converters, ATenRepeatInterleave3dScalarNoDimConvertsCorrectlyWithDynamicIn
670670
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
671671
}
672672

673+
TEST(Converters, ATenTileConvertsCorrectly) {
674+
const auto graph = R"IR(
675+
graph(%x.1 : Tensor):
676+
%2 : int[] = prim::Constant[value=[4, 1]]()
677+
%3 : Tensor = aten::tile(%x.1, %2)
678+
return (%3))IR";
679+
680+
auto g = std::make_shared<torch::jit::Graph>();
681+
682+
torch::jit::parseIR(graph, g.get());
683+
684+
auto in = at::randint(1, 10, {1, 3}, {at::kCUDA});
685+
686+
auto jit_in = at::clone(in);
687+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
688+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
689+
690+
auto trt_in = at::clone(jit_in);
691+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
692+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});
693+
694+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
695+
}
696+
697+
TEST(Converters, ATenTileRepeatRankConvertsCorrectly) {
698+
const auto graph = R"IR(
699+
graph(%x.1 : Tensor):
700+
%2 : int[] = prim::Constant[value=[4, 1, 2]]()
701+
%3 : Tensor = aten::tile(%x.1, %2)
702+
return (%3))IR";
703+
704+
auto g = std::make_shared<torch::jit::Graph>();
705+
706+
torch::jit::parseIR(graph, g.get());
707+
708+
auto in = at::randint(1, 10, {1, 3}, {at::kCUDA});
709+
710+
auto jit_in = at::clone(in);
711+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
712+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
713+
714+
auto trt_in = at::clone(jit_in);
715+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
716+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});
717+
718+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
719+
}
720+
721+
TEST(Converters, ATenTileConvertsCorrectlyWithDynamicInput) {
722+
const auto graph = R"IR(
723+
graph(%x.1 : Tensor):
724+
%2 : int[] = prim::Constant[value=[4, 1]]()
725+
%3 : Tensor = aten::tile(%x.1, %2)
726+
return (%3))IR";
727+
728+
auto g = std::make_shared<torch::jit::Graph>();
729+
730+
torch::jit::parseIR(graph, g.get());
731+
732+
auto in = at::randint(1, 10, {1, 3}, {at::kCUDA});
733+
734+
auto jit_in = at::clone(in);
735+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
736+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
737+
738+
auto trt_in = at::clone(jit_in);
739+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
740+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in});
741+
742+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
743+
}
744+
745+
TEST(Converters, ATenTile3dConvertsCorrectly) {
746+
const auto graph = R"IR(
747+
graph(%x.1 : Tensor):
748+
%2 : int[] = prim::Constant[value=[2, 2, 2]]()
749+
%3 : Tensor = aten::tile(%x.1, %2)
750+
return (%3))IR";
751+
752+
auto g = std::make_shared<torch::jit::Graph>();
753+
754+
torch::jit::parseIR(graph, g.get());
755+
756+
auto in = at::randint(1, 10, {2, 3, 2}, {at::kCUDA});
757+
758+
auto jit_in = at::clone(in);
759+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
760+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
761+
762+
auto trt_in = at::clone(jit_in);
763+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
764+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});
765+
766+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
767+
}
768+
769+
TEST(Converters, ATenTile3dConvertsCorrectlyWithDynamicInput) {
770+
const auto graph = R"IR(
771+
graph(%x.1 : Tensor):
772+
%2 : int[] = prim::Constant[value=[2, 2, 2]]()
773+
%3 : Tensor = aten::tile(%x.1, %2)
774+
return (%3))IR";
775+
776+
auto g = std::make_shared<torch::jit::Graph>();
777+
778+
torch::jit::parseIR(graph, g.get());
779+
780+
auto in = at::randint(1, 10, {2, 3, 2}, {at::kCUDA});
781+
782+
auto jit_in = at::clone(in);
783+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
784+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
785+
786+
auto trt_in = at::clone(jit_in);
787+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
788+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in});
789+
790+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
791+
}
792+
673793
TEST(Converters, ATenMeshGridConvertsCorrectly) {
674794
const auto graph = R"IR(
675795
graph(%x : Tensor, %y : Tensor, %z : Tensor):

0 commit comments

Comments
 (0)