Skip to content

Commit 5b156dc

Browse files
authored
Adding converter aten::chunk in torchscript (#1802)
1 parent dbc3172 commit 5b156dc

File tree

5 files changed

+125
-0
lines changed

5 files changed

+125
-0
lines changed

core/conversion/converters/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ cc_library(
5757
"impl/batch_norm.cpp",
5858
"impl/bitwise.cpp",
5959
"impl/cast.cpp",
60+
"impl/chunk.cpp",
6061
"impl/concat.cpp",
6162
"impl/constant.cpp",
6263
"impl/constant_pad.cpp",

core/conversion/converters/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ target_sources(${lib_name}
77
"${CMAKE_CURRENT_SOURCE_DIR}/impl/activation.cpp"
88
"${CMAKE_CURRENT_SOURCE_DIR}/impl/batch_norm.cpp"
99
"${CMAKE_CURRENT_SOURCE_DIR}/impl/cast.cpp"
10+
"${CMAKE_CURRENT_SOURCE_DIR}/impl/chunk.cpp"
1011
"${CMAKE_CURRENT_SOURCE_DIR}/impl/concat.cpp"
1112
"${CMAKE_CURRENT_SOURCE_DIR}/impl/constant.cpp"
1213
"${CMAKE_CURRENT_SOURCE_DIR}/impl/constant_pad.cpp"
+84
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
#include "core/conversion/converters/converters.h"
2+
#include "core/conversion/tensorcontainer/TensorContainer.h"
3+
#include "core/util/prelude.h"
4+
5+
namespace torch_tensorrt {
6+
namespace core {
7+
namespace conversion {
8+
namespace converters {
9+
namespace impl {
10+
namespace {
11+
12+
// clang-format off
13+
auto cat_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns()
14+
.pattern({"aten::chunk(Tensor(a -> *) self, int chunks, int dim=0) -> Tensor(a)[]",
15+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
16+
auto in = args[0].ITensorOrFreeze(ctx);
17+
auto chunks = args[1].unwrapToInt();
18+
auto dim = args[2].unwrapToInt();
19+
bool dynamic_shape = ctx->input_is_dynamic;
20+
int size = in->getDimensions().nbDims;
21+
int maxDim = static_cast<int32_t>(in->getDimensions().d[dim]);
22+
23+
c10::ListTypePtr lt = n->output()->type()->expect<c10::ListType>();
24+
c10::TypePtr elementType = lt->getElementType();
25+
26+
int offset = 0;
27+
if(dim < 0) {
28+
dim = in->getDimensions().nbDims + dim;
29+
}
30+
if (dynamic_shape) {
31+
TORCHTRT_ASSERT(in->getDimensions().d[dim] != -1, "Can't chunk on dynamic shape dimension!");
32+
}
33+
if (chunks > in->getDimensions().d[dim]) {
34+
LOG_WARNING("The chunks size" << chunks << "along dimension" << dim << "is greater than tensor with size" << in->getDimensions().d[dim]
35+
<< "it will default to dimension" << in->getDimensions().d[dim]);
36+
}
37+
int step = (maxDim + chunks - 1) / chunks;
38+
nvinfer1::Dims start_, size_, stride_;
39+
int nbdims = in->getDimensions().nbDims;
40+
start_.nbDims = nbdims;
41+
size_.nbDims = nbdims;
42+
stride_.nbDims = nbdims;
43+
44+
int startIdx = 0;
45+
int endIdx = maxDim;
46+
47+
for (int i = 0; i < nbdims; i++) {
48+
start_.d[i] = 0;
49+
size_.d[i] = 0;
50+
stride_.d[i] = 1;
51+
}
52+
// update slice layer
53+
auto list = c10::impl::GenericList(elementType);
54+
list.reserve(chunks);
55+
if(!dynamic_shape) {
56+
for (int chunk = 0; chunk < chunks; chunk++) {
57+
for (int i = 0; i < nbdims; i++) {
58+
if (i == dim) {
59+
start_.d[i] = offset;
60+
size_.d[i] = std::min(step, maxDim - offset);
61+
}
62+
}
63+
LOG_DEBUG("start_:" << start_);
64+
LOG_DEBUG("size_:" << size_);
65+
LOG_DEBUG("stride_:" << stride_);
66+
auto slice_layer = ctx->net->addSlice(*in, start_, size_, stride_);
67+
auto tensor_holder = TensorContainer();
68+
tensor_holder.hold_tensor(slice_layer->getOutput(0));
69+
auto ival = c10::IValue(std::move(c10::make_intrusive<TensorContainer>(tensor_holder)));
70+
list.emplace_back(ival);
71+
offset = offset + step;
72+
}
73+
}
74+
auto split_output_ivalue = std::move(torch::jit::IValue(list));
75+
ctx->AssociateValueAndIValue(n->outputs()[0], split_output_ivalue);
76+
return true;
77+
}});
78+
// clang-format on
79+
} // namespace
80+
} // namespace impl
81+
} // namespace converters
82+
} // namespace conversion
83+
} // namespace core
84+
} // namespace torch_tensorrt

tests/core/conversion/converters/BUILD

+5
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ converter_test(
3535
name = "test_cast",
3636
)
3737

38+
converter_test(
39+
name = "test_chunk",
40+
)
41+
3842
converter_test(
3943
name = "test_clone",
4044
)
@@ -208,6 +212,7 @@ test_suite(
208212
":test_batch_norm",
209213
":test_bitwise",
210214
":test_cast",
215+
":test_chunk",
211216
":test_clamp",
212217
":test_clone",
213218
":test_comparators",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
#include <torch/torch.h>
2+
#include <string>
3+
#include "core/compiler.h"
4+
#include "core/lowering/passes/passes.h"
5+
#include "gtest/gtest.h"
6+
#include "tests/util/util.h"
7+
#include "torch/csrc/jit/ir/irparser.h"
8+
9+
TEST(Converters, ATenChunkConvertsCorrectly) {
10+
const auto graph = R"IR(
11+
graph(%0 : Tensor):
12+
%2 : int = prim::Constant[value=6]()
13+
%3 : int = prim::Constant[value=0]()
14+
%4 : Tensor[] = aten::chunk(%0, %2, %3)
15+
%5 : Tensor, %6 : Tensor, %7 : Tensor, %8 : Tensor, %9 : Tensor, %10 : Tensor = prim::ListUnpack(%4)
16+
return (%5, %6, %7, %8, %9, %10))IR";
17+
18+
auto g = std::make_shared<torch::jit::Graph>();
19+
20+
torch::jit::parseIR(graph, g.get());
21+
auto in = at::randint(1, 10, {12}, {at::kCUDA});
22+
23+
auto jit_in = at::clone(in);
24+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
25+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
26+
27+
auto trt_in = at::clone(in);
28+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
29+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});
30+
31+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
32+
33+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
34+
}

0 commit comments

Comments
 (0)