Skip to content

Commit dbc3172

Browse files
fix: Rewrite constant_pad_nd to use a single slice layer for performance (#1970)
1 parent 4494699 commit dbc3172

File tree

1 file changed

+49
-113
lines changed

1 file changed

+49
-113
lines changed

core/conversion/converters/impl/constant_pad.cpp

+49-113
Original file line numberDiff line numberDiff line change
@@ -16,127 +16,63 @@ auto constant_pad_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns
1616
{"aten::constant_pad_nd(Tensor self, int[] pad, Scalar value=0) -> (Tensor)",
1717
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
1818
auto in = args[0].ITensor();
19-
auto inDims = in->getDimensions();
20-
int64_t inRank = inDims.nbDims;
19+
auto in_dims = in->getDimensions();
20+
int64_t in_rank = in_dims.nbDims;
2121
auto padding = args[1].unwrapToIntList().vec();
22-
int64_t padSize = padding.size();
22+
int64_t pad_size = padding.size();
2323
auto value = args[2].unwrapToScalar().to<float>();
2424
at::Tensor value_tensor = torch::tensor(value, util::TRTDataTypeToScalarType(in->getType()));
25-
auto valueTensor = tensor_to_const(ctx, value_tensor);
26-
TORCHTRT_CHECK(padSize % 2 == 0, "Length of pad must be even but instead it equals " << padSize);
27-
28-
int64_t l_pad = padSize / 2;
29-
TORCHTRT_CHECK(
30-
inRank >= (int64_t)l_pad,
31-
"Length of pad should be no more than twice the number of "
32-
"dimensions of the input. Pad length is "
33-
<< padSize << "while the input has " << inRank << "dimensions.");
34-
35-
// TODO negative padding. When the pad is negative, we need to crop the image.
36-
37-
std::vector<nvinfer1::ITensor*> tensors_vec;
38-
// input: (N, C, D_in, H_in, W_in).
39-
// padding: (padding_left, padding_right, padding_top, padding_bottom, padding_front, padding_back)
40-
// When axis is inRank - 1, making W_out = W_in + padding_left + padding_right.
41-
// When axis is inRank - 2, making H_out = H_in + padding_top + padding_bottom.
42-
// When axis is inRank - 3, making D_out = D_in + padding_front + padding_back.
43-
for (int64_t i = 0; i < l_pad; i++) {
44-
int64_t axis = inRank - (i + 1); // axis = {inRank - 1, inRank - 2, inRank - 3}
45-
int64_t padding_index = i * 2;
46-
47-
if (padding[padding_index] > 0) { // left/top/front padding value
48-
tensors_vec.clear();
49-
if (ctx->input_is_dynamic) {
50-
at::Tensor left_indices = torch::tensor({0}, torch::kInt32);
51-
auto indicesTensor = tensor_to_const(ctx, left_indices);
52-
auto left_gather_layer = ctx->net->addGather(*in, *indicesTensor, axis);
53-
auto left_gather_out = left_gather_layer->getOutput(0);
54-
55-
// fill the left_gather_out with value
56-
auto fill_layer = ctx->net->addFill(nvinfer1::Dims{1, {1}}, nvinfer1::FillOperation::kLINSPACE);
57-
auto shape_gather_out = ctx->net->addShape(*left_gather_out)->getOutput(0);
58-
fill_layer->setInput(0, *shape_gather_out);
59-
fill_layer->setInput(1, *valueTensor);
60-
at::Tensor delta_tensor = torch::zeros(inRank, util::TRTDataTypeToScalarType(in->getType()));
61-
auto deltaTensor = tensor_to_const(ctx, delta_tensor);
62-
fill_layer->setInput(2, *deltaTensor);
63-
auto padTensor = fill_layer->getOutput(0);
64-
65-
for (int i = 0; i < padding[padding_index]; i++) {
66-
tensors_vec.push_back(padTensor);
67-
}
68-
} else {
69-
inDims.d[axis] = padding[padding_index];
70-
auto fill_layer = ctx->net->addFill(inDims, nvinfer1::FillOperation::kLINSPACE);
71-
fill_layer->setInput(1, *valueTensor);
72-
at::Tensor delta_tensor = torch::zeros(inRank, util::TRTDataTypeToScalarType(in->getType()));
73-
auto deltaTensor = tensor_to_const(ctx, delta_tensor);
74-
fill_layer->setInput(2, *deltaTensor);
75-
auto padTensor = fill_layer->getOutput(0);
76-
77-
tensors_vec.push_back(padTensor);
78-
}
25+
auto value_itensor = tensor_to_const(ctx, value_tensor);
26+
TORCHTRT_CHECK(pad_size % 2 == 0, "Length of pad must be even but instead it equals " << pad_size);
27+
28+
std::vector<int64_t> start(in_rank, 0);
29+
std::vector<int64_t> total_padding(in_rank, 0);
30+
std::vector<int64_t> stride(in_rank, 1);
31+
32+
// Padding is stored (left, right) starting from the last dim and working backwards
33+
for (size_t i = 0UL; i < padding.size(); i += 2) {
34+
auto left = padding[i];
35+
TORCHTRT_CHECK(left >= 0, "Unsupported negative pad at index " << i);
36+
auto right = padding[i + 1];
37+
TORCHTRT_CHECK(right >= 0, "Unsupported negative pad at index " << i + 1);
38+
auto idx = in_rank - ((i / 2) + 1);
39+
start[idx] = -left;
40+
total_padding[idx] = left + right;
41+
}
7942

80-
tensors_vec.push_back(in);
81-
auto concat_layer = ctx->net->addConcatenation(tensors_vec.data(), tensors_vec.size());
82-
concat_layer->setAxis(axis);
83-
in = concat_layer->getOutput(0);
84-
inDims = in->getDimensions();
43+
auto size = stride; // placeholder for the dynamic case
44+
if (!ctx->input_is_dynamic) {
45+
size = total_padding;
46+
for (size_t i = 0UL; i < total_padding.size(); ++i) {
47+
size[i] += in_dims.d[i];
8548
}
49+
}
8650

87-
if (padding[padding_index + 1] > 0) { // right/bottom/back padding value
88-
tensors_vec.clear();
89-
tensors_vec.push_back(in);
90-
91-
nvinfer1::ITensor* indicesTensor = NULL;
92-
if (inDims.d[axis] == -1) {
93-
auto shapeTensor = ctx->net->addShape(*in)->getOutput(0);
94-
at::Tensor dimValue = torch::tensor({axis}, torch::kInt32);
95-
auto dimTensor = tensor_to_const(ctx, dimValue);
96-
indicesTensor = ctx->net->addGather(*shapeTensor, *dimTensor, 0)->getOutput(0);
97-
auto oneTensor = tensor_to_const(ctx, torch::tensor({1}, torch::kInt32));
98-
indicesTensor = ctx->net->addElementWise(*indicesTensor, *oneTensor, nvinfer1::ElementWiseOperation::kSUB)
99-
->getOutput(0);
100-
} else {
101-
auto indices = torch::tensor({inDims.d[axis] - 1}, torch::kInt32);
102-
indicesTensor = tensor_to_const(ctx, indices);
103-
}
104-
auto right_gather_layer = ctx->net->addGather(*in, *indicesTensor, axis);
105-
auto right_gather_out = right_gather_layer->getOutput(0);
106-
107-
if (ctx->input_is_dynamic) {
108-
// fill the right_gather_out with value
109-
auto fill_layer = ctx->net->addFill(nvinfer1::Dims{1, {1}}, nvinfer1::FillOperation::kLINSPACE);
110-
auto shape_gather_out = ctx->net->addShape(*right_gather_out)->getOutput(0);
111-
fill_layer->setInput(0, *shape_gather_out);
112-
fill_layer->setInput(1, *valueTensor);
113-
at::Tensor delta_tensor = torch::zeros(inRank, util::TRTDataTypeToScalarType(in->getType()));
114-
auto deltaTensor = tensor_to_const(ctx, delta_tensor);
115-
fill_layer->setInput(2, *deltaTensor);
116-
auto padTensor = fill_layer->getOutput(0);
117-
118-
for (int i = 0; i < padding[padding_index + 1]; i++) {
119-
tensors_vec.push_back(padTensor);
120-
}
121-
} else {
122-
inDims.d[axis] = padding[padding_index + 1];
123-
auto fill_layer = ctx->net->addFill(inDims, nvinfer1::FillOperation::kLINSPACE);
124-
fill_layer->setInput(1, *valueTensor);
125-
at::Tensor delta_tensor = torch::zeros(inRank, util::TRTDataTypeToScalarType(in->getType()));
126-
auto deltaTensor = tensor_to_const(ctx, delta_tensor);
127-
fill_layer->setInput(2, *deltaTensor);
128-
auto padTensor = fill_layer->getOutput(0);
129-
130-
tensors_vec.push_back(padTensor);
131-
}
132-
auto concat_layer = ctx->net->addConcatenation(tensors_vec.data(), tensors_vec.size());
133-
concat_layer->setAxis(axis);
134-
in = concat_layer->getOutput(0);
135-
inDims = in->getDimensions();
136-
}
51+
auto slice_layer = ctx->net->addSlice(
52+
*in,
53+
util::toDims(c10::IntArrayRef(start)),
54+
util::toDims(c10::IntArrayRef(size)),
55+
util::toDims(c10::IntArrayRef(stride)));
56+
TORCHTRT_CHECK(slice_layer, "Unable to create slice layer from node: " << *n);
57+
slice_layer->setName((util::node_info(n) + "_slice").c_str());
58+
slice_layer->setMode(nvinfer1::SliceMode::kFILL);
59+
slice_layer->setInput(4, *value_itensor);
60+
61+
if (ctx->input_is_dynamic) {
62+
// build the size using inetwork layers
63+
auto shape_layer = ctx->net->addShape(*in);
64+
TORCHTRT_CHECK(shape_layer, "Unable to create shape layer from node: " << *n);
65+
shape_layer->setName((util::node_info(n) + "_shape").c_str());
66+
auto total_padding_itensor = tensor_to_const(ctx, torch::tensor(total_padding, torch::kInt32));
67+
68+
auto add_layer = ctx->net->addElementWise(
69+
*shape_layer->getOutput(0), *total_padding_itensor, nvinfer1::ElementWiseOperation::kSUM);
70+
TORCHTRT_CHECK(add_layer, "Unable to create add layer from node: " << *n);
71+
add_layer->setName((util::node_info(n) + "_add").c_str());
72+
slice_layer->setInput(2, *add_layer->getOutput(0));
13773
}
13874

139-
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], in);
75+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], slice_layer->getOutput(0));
14076
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
14177
return true;
14278
}});

0 commit comments

Comments
 (0)