@@ -16,127 +16,63 @@ auto constant_pad_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns
16
16
{" aten::constant_pad_nd(Tensor self, int[] pad, Scalar value=0) -> (Tensor)" ,
17
17
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
18
18
auto in = args[0 ].ITensor ();
19
- auto inDims = in->getDimensions ();
20
- int64_t inRank = inDims .nbDims ;
19
+ auto in_dims = in->getDimensions ();
20
+ int64_t in_rank = in_dims .nbDims ;
21
21
auto padding = args[1 ].unwrapToIntList ().vec ();
22
- int64_t padSize = padding.size ();
22
+ int64_t pad_size = padding.size ();
23
23
auto value = args[2 ].unwrapToScalar ().to <float >();
24
24
at::Tensor value_tensor = torch::tensor (value, util::TRTDataTypeToScalarType (in->getType ()));
25
- auto valueTensor = tensor_to_const (ctx, value_tensor);
26
- TORCHTRT_CHECK (padSize % 2 == 0 , " Length of pad must be even but instead it equals " << padSize);
27
-
28
- int64_t l_pad = padSize / 2 ;
29
- TORCHTRT_CHECK (
30
- inRank >= (int64_t )l_pad,
31
- " Length of pad should be no more than twice the number of "
32
- " dimensions of the input. Pad length is "
33
- << padSize << " while the input has " << inRank << " dimensions." );
34
-
35
- // TODO negative padding. When the pad is negative, we need to crop the image.
36
-
37
- std::vector<nvinfer1::ITensor*> tensors_vec;
38
- // input: (N, C, D_in, H_in, W_in).
39
- // padding: (padding_left, padding_right, padding_top, padding_bottom, padding_front, padding_back)
40
- // When axis is inRank - 1, making W_out = W_in + padding_left + padding_right.
41
- // When axis is inRank - 2, making H_out = H_in + padding_top + padding_bottom.
42
- // When axis is inRank - 3, making D_out = D_in + padding_front + padding_back.
43
- for (int64_t i = 0 ; i < l_pad; i++) {
44
- int64_t axis = inRank - (i + 1 ); // axis = {inRank - 1, inRank - 2, inRank - 3}
45
- int64_t padding_index = i * 2 ;
46
-
47
- if (padding[padding_index] > 0 ) { // left/top/front padding value
48
- tensors_vec.clear ();
49
- if (ctx->input_is_dynamic ) {
50
- at::Tensor left_indices = torch::tensor ({0 }, torch::kInt32 );
51
- auto indicesTensor = tensor_to_const (ctx, left_indices);
52
- auto left_gather_layer = ctx->net ->addGather (*in, *indicesTensor, axis);
53
- auto left_gather_out = left_gather_layer->getOutput (0 );
54
-
55
- // fill the left_gather_out with value
56
- auto fill_layer = ctx->net ->addFill (nvinfer1::Dims{1 , {1 }}, nvinfer1::FillOperation::kLINSPACE );
57
- auto shape_gather_out = ctx->net ->addShape (*left_gather_out)->getOutput (0 );
58
- fill_layer->setInput (0 , *shape_gather_out);
59
- fill_layer->setInput (1 , *valueTensor);
60
- at::Tensor delta_tensor = torch::zeros (inRank, util::TRTDataTypeToScalarType (in->getType ()));
61
- auto deltaTensor = tensor_to_const (ctx, delta_tensor);
62
- fill_layer->setInput (2 , *deltaTensor);
63
- auto padTensor = fill_layer->getOutput (0 );
64
-
65
- for (int i = 0 ; i < padding[padding_index]; i++) {
66
- tensors_vec.push_back (padTensor);
67
- }
68
- } else {
69
- inDims.d [axis] = padding[padding_index];
70
- auto fill_layer = ctx->net ->addFill (inDims, nvinfer1::FillOperation::kLINSPACE );
71
- fill_layer->setInput (1 , *valueTensor);
72
- at::Tensor delta_tensor = torch::zeros (inRank, util::TRTDataTypeToScalarType (in->getType ()));
73
- auto deltaTensor = tensor_to_const (ctx, delta_tensor);
74
- fill_layer->setInput (2 , *deltaTensor);
75
- auto padTensor = fill_layer->getOutput (0 );
76
-
77
- tensors_vec.push_back (padTensor);
78
- }
25
+ auto value_itensor = tensor_to_const (ctx, value_tensor);
26
+ TORCHTRT_CHECK (pad_size % 2 == 0 , " Length of pad must be even but instead it equals " << pad_size);
27
+
28
+ std::vector<int64_t > start (in_rank, 0 );
29
+ std::vector<int64_t > total_padding (in_rank, 0 );
30
+ std::vector<int64_t > stride (in_rank, 1 );
31
+
32
+ // Padding is stored (left, right) starting from the last dim and working backwards
33
+ for (size_t i = 0UL ; i < padding.size (); i += 2 ) {
34
+ auto left = padding[i];
35
+ TORCHTRT_CHECK (left >= 0 , " Unsupported negative pad at index " << i);
36
+ auto right = padding[i + 1 ];
37
+ TORCHTRT_CHECK (right >= 0 , " Unsupported negative pad at index " << i + 1 );
38
+ auto idx = in_rank - ((i / 2 ) + 1 );
39
+ start[idx] = -left;
40
+ total_padding[idx] = left + right;
41
+ }
79
42
80
- tensors_vec. push_back (in);
81
- auto concat_layer = ctx->net -> addConcatenation (tensors_vec. data (), tensors_vec. size ());
82
- concat_layer-> setAxis (axis) ;
83
- in = concat_layer-> getOutput ( 0 );
84
- inDims = in-> getDimensions () ;
43
+ auto size = stride; // placeholder for the dynamic case
44
+ if (! ctx->input_is_dynamic ) {
45
+ size = total_padding ;
46
+ for ( size_t i = 0UL ; i < total_padding. size (); ++i) {
47
+ size[i] += in_dims. d [i] ;
85
48
}
49
+ }
86
50
87
- if (padding[padding_index + 1 ] > 0 ) { // right/bottom/back padding value
88
- tensors_vec.clear ();
89
- tensors_vec.push_back (in);
90
-
91
- nvinfer1::ITensor* indicesTensor = NULL ;
92
- if (inDims.d [axis] == -1 ) {
93
- auto shapeTensor = ctx->net ->addShape (*in)->getOutput (0 );
94
- at::Tensor dimValue = torch::tensor ({axis}, torch::kInt32 );
95
- auto dimTensor = tensor_to_const (ctx, dimValue);
96
- indicesTensor = ctx->net ->addGather (*shapeTensor, *dimTensor, 0 )->getOutput (0 );
97
- auto oneTensor = tensor_to_const (ctx, torch::tensor ({1 }, torch::kInt32 ));
98
- indicesTensor = ctx->net ->addElementWise (*indicesTensor, *oneTensor, nvinfer1::ElementWiseOperation::kSUB )
99
- ->getOutput (0 );
100
- } else {
101
- auto indices = torch::tensor ({inDims.d [axis] - 1 }, torch::kInt32 );
102
- indicesTensor = tensor_to_const (ctx, indices);
103
- }
104
- auto right_gather_layer = ctx->net ->addGather (*in, *indicesTensor, axis);
105
- auto right_gather_out = right_gather_layer->getOutput (0 );
106
-
107
- if (ctx->input_is_dynamic ) {
108
- // fill the right_gather_out with value
109
- auto fill_layer = ctx->net ->addFill (nvinfer1::Dims{1 , {1 }}, nvinfer1::FillOperation::kLINSPACE );
110
- auto shape_gather_out = ctx->net ->addShape (*right_gather_out)->getOutput (0 );
111
- fill_layer->setInput (0 , *shape_gather_out);
112
- fill_layer->setInput (1 , *valueTensor);
113
- at::Tensor delta_tensor = torch::zeros (inRank, util::TRTDataTypeToScalarType (in->getType ()));
114
- auto deltaTensor = tensor_to_const (ctx, delta_tensor);
115
- fill_layer->setInput (2 , *deltaTensor);
116
- auto padTensor = fill_layer->getOutput (0 );
117
-
118
- for (int i = 0 ; i < padding[padding_index + 1 ]; i++) {
119
- tensors_vec.push_back (padTensor);
120
- }
121
- } else {
122
- inDims.d [axis] = padding[padding_index + 1 ];
123
- auto fill_layer = ctx->net ->addFill (inDims, nvinfer1::FillOperation::kLINSPACE );
124
- fill_layer->setInput (1 , *valueTensor);
125
- at::Tensor delta_tensor = torch::zeros (inRank, util::TRTDataTypeToScalarType (in->getType ()));
126
- auto deltaTensor = tensor_to_const (ctx, delta_tensor);
127
- fill_layer->setInput (2 , *deltaTensor);
128
- auto padTensor = fill_layer->getOutput (0 );
129
-
130
- tensors_vec.push_back (padTensor);
131
- }
132
- auto concat_layer = ctx->net ->addConcatenation (tensors_vec.data (), tensors_vec.size ());
133
- concat_layer->setAxis (axis);
134
- in = concat_layer->getOutput (0 );
135
- inDims = in->getDimensions ();
136
- }
51
+ auto slice_layer = ctx->net ->addSlice (
52
+ *in,
53
+ util::toDims (c10::IntArrayRef (start)),
54
+ util::toDims (c10::IntArrayRef (size)),
55
+ util::toDims (c10::IntArrayRef (stride)));
56
+ TORCHTRT_CHECK (slice_layer, " Unable to create slice layer from node: " << *n);
57
+ slice_layer->setName ((util::node_info (n) + " _slice" ).c_str ());
58
+ slice_layer->setMode (nvinfer1::SliceMode::kFILL );
59
+ slice_layer->setInput (4 , *value_itensor);
60
+
61
+ if (ctx->input_is_dynamic ) {
62
+ // build the size using inetwork layers
63
+ auto shape_layer = ctx->net ->addShape (*in);
64
+ TORCHTRT_CHECK (shape_layer, " Unable to create shape layer from node: " << *n);
65
+ shape_layer->setName ((util::node_info (n) + " _shape" ).c_str ());
66
+ auto total_padding_itensor = tensor_to_const (ctx, torch::tensor (total_padding, torch::kInt32 ));
67
+
68
+ auto add_layer = ctx->net ->addElementWise (
69
+ *shape_layer->getOutput (0 ), *total_padding_itensor, nvinfer1::ElementWiseOperation::kSUM );
70
+ TORCHTRT_CHECK (add_layer, " Unable to create add layer from node: " << *n);
71
+ add_layer->setName ((util::node_info (n) + " _add" ).c_str ());
72
+ slice_layer->setInput (2 , *add_layer->getOutput (0 ));
137
73
}
138
74
139
- auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], in );
75
+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], slice_layer-> getOutput ( 0 ) );
140
76
LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
141
77
return true ;
142
78
}});
0 commit comments