|
1 | 1 | #include <torch/torch.h>
|
2 | 2 | #include <string>
|
3 | 3 | #include "core/compiler.h"
|
| 4 | +#include "core/lowering/passes/passes.h" |
4 | 5 | #include "gtest/gtest.h"
|
5 | 6 | #include "tests/util/util.h"
|
6 | 7 | #include "torch/csrc/jit/ir/irparser.h"
|
@@ -670,6 +671,131 @@ TEST(Converters, ATenRepeatInterleave3dScalarNoDimConvertsCorrectlyWithDynamicIn
|
670 | 671 | ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
|
671 | 672 | }
|
672 | 673 |
|
| 674 | +TEST(Converters, ATenTileConvertsCorrectly) { |
| 675 | + const auto graph = R"IR( |
| 676 | + graph(%x.1 : Tensor): |
| 677 | + %2 : int[] = prim::Constant[value=[4, 1]]() |
| 678 | + %3 : Tensor = aten::tile(%x.1, %2) |
| 679 | + return (%3))IR"; |
| 680 | + |
| 681 | + auto g = std::make_shared<torch::jit::Graph>(); |
| 682 | + |
| 683 | + torch::jit::parseIR(graph, g.get()); |
| 684 | + torch_tensorrt::core::lowering::passes::ReplaceTileWithRepeat(g); |
| 685 | + |
| 686 | + auto in = at::randint(1, 10, {1, 3}, {at::kCUDA}); |
| 687 | + |
| 688 | + auto jit_in = at::clone(in); |
| 689 | + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); |
| 690 | + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); |
| 691 | + |
| 692 | + auto trt_in = at::clone(jit_in); |
| 693 | + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); |
| 694 | + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); |
| 695 | + |
| 696 | + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); |
| 697 | +} |
| 698 | + |
| 699 | +TEST(Converters, ATenTileRepeatRankConvertsCorrectly) { |
| 700 | + const auto graph = R"IR( |
| 701 | + graph(%x.1 : Tensor): |
| 702 | + %2 : int[] = prim::Constant[value=[4, 1, 2]]() |
| 703 | + %3 : Tensor = aten::tile(%x.1, %2) |
| 704 | + return (%3))IR"; |
| 705 | + |
| 706 | + auto g = std::make_shared<torch::jit::Graph>(); |
| 707 | + |
| 708 | + torch::jit::parseIR(graph, g.get()); |
| 709 | + torch_tensorrt::core::lowering::passes::ReplaceTileWithRepeat(g); |
| 710 | + |
| 711 | + auto in = at::randint(1, 10, {1, 3}, {at::kCUDA}); |
| 712 | + |
| 713 | + auto jit_in = at::clone(in); |
| 714 | + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); |
| 715 | + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); |
| 716 | + |
| 717 | + auto trt_in = at::clone(jit_in); |
| 718 | + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); |
| 719 | + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); |
| 720 | + |
| 721 | + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); |
| 722 | +} |
| 723 | + |
| 724 | +TEST(Converters, ATenTileConvertsCorrectlyWithDynamicInput) { |
| 725 | + const auto graph = R"IR( |
| 726 | + graph(%x.1 : Tensor): |
| 727 | + %2 : int[] = prim::Constant[value=[4, 1]]() |
| 728 | + %3 : Tensor = aten::tile(%x.1, %2) |
| 729 | + return (%3))IR"; |
| 730 | + |
| 731 | + auto g = std::make_shared<torch::jit::Graph>(); |
| 732 | + |
| 733 | + torch::jit::parseIR(graph, g.get()); |
| 734 | + torch_tensorrt::core::lowering::passes::ReplaceTileWithRepeat(g); |
| 735 | + |
| 736 | + auto in = at::randint(1, 10, {1, 3}, {at::kCUDA}); |
| 737 | + |
| 738 | + auto jit_in = at::clone(in); |
| 739 | + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); |
| 740 | + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); |
| 741 | + |
| 742 | + auto trt_in = at::clone(jit_in); |
| 743 | + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); |
| 744 | + auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in}); |
| 745 | + |
| 746 | + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); |
| 747 | +} |
| 748 | + |
| 749 | +TEST(Converters, ATenTile3dConvertsCorrectly) { |
| 750 | + const auto graph = R"IR( |
| 751 | + graph(%x.1 : Tensor): |
| 752 | + %2 : int[] = prim::Constant[value=[2, 2, 2]]() |
| 753 | + %3 : Tensor = aten::tile(%x.1, %2) |
| 754 | + return (%3))IR"; |
| 755 | + |
| 756 | + auto g = std::make_shared<torch::jit::Graph>(); |
| 757 | + |
| 758 | + torch::jit::parseIR(graph, g.get()); |
| 759 | + torch_tensorrt::core::lowering::passes::ReplaceTileWithRepeat(g); |
| 760 | + |
| 761 | + auto in = at::randint(1, 10, {2, 3, 2}, {at::kCUDA}); |
| 762 | + |
| 763 | + auto jit_in = at::clone(in); |
| 764 | + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); |
| 765 | + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); |
| 766 | + |
| 767 | + auto trt_in = at::clone(jit_in); |
| 768 | + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); |
| 769 | + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in}); |
| 770 | + |
| 771 | + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); |
| 772 | +} |
| 773 | + |
| 774 | +TEST(Converters, ATenTile3dConvertsCorrectlyWithDynamicInput) { |
| 775 | + const auto graph = R"IR( |
| 776 | + graph(%x.1 : Tensor): |
| 777 | + %2 : int[] = prim::Constant[value=[2, 2, 2]]() |
| 778 | + %3 : Tensor = aten::tile(%x.1, %2) |
| 779 | + return (%3))IR"; |
| 780 | + |
| 781 | + auto g = std::make_shared<torch::jit::Graph>(); |
| 782 | + |
| 783 | + torch::jit::parseIR(graph, g.get()); |
| 784 | + torch_tensorrt::core::lowering::passes::ReplaceTileWithRepeat(g); |
| 785 | + |
| 786 | + auto in = at::randint(1, 10, {2, 3, 2}, {at::kCUDA}); |
| 787 | + |
| 788 | + auto jit_in = at::clone(in); |
| 789 | + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); |
| 790 | + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); |
| 791 | + |
| 792 | + auto trt_in = at::clone(jit_in); |
| 793 | + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); |
| 794 | + auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in}); |
| 795 | + |
| 796 | + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); |
| 797 | +} |
| 798 | + |
673 | 799 | TEST(Converters, ATenMeshGridConvertsCorrectly) {
|
674 | 800 | const auto graph = R"IR(
|
675 | 801 | graph(%x : Tensor, %y : Tensor, %z : Tensor):
|
|
0 commit comments