@@ -194,60 +194,6 @@ bool add_expand_dynamic(
194
194
return true ;
195
195
}
196
196
197
- bool add_repeat (ConversionCtx* ctx, const torch::jit::Node* n, args& args, const std::string& layer) {
198
- auto in = args[0 ].ITensorOrFreeze (ctx);
199
- auto input_dims = in->getDimensions ();
200
- auto repeats = args[1 ].unwrapToIntList ().vec ();
201
- int repeats_rank = repeats.size ();
202
- TORCHTRT_CHECK (
203
- repeats_rank >= input_dims.nbDims ,
204
- " Number of repeat dimensions cannot be smaller than number of input dimensions" );
205
-
206
- auto num_expand_dims = repeats_rank - input_dims.nbDims ;
207
-
208
- if (ctx->input_is_dynamic ) {
209
- int input_rank = input_dims.nbDims ;
210
- int output_rank = repeats_rank;
211
- auto new_input_shape_tensor = concat (output_rank, input_rank, ctx, in);
212
-
213
- auto shuffle = ctx->net ->addShuffle (*in);
214
- shuffle->setInput (1 , *new_input_shape_tensor);
215
- in = shuffle->getOutput (0 );
216
- } else {
217
- if (num_expand_dims > 0 ) {
218
- nvinfer1::Dims reshape_dims;
219
- reshape_dims.nbDims = repeats.size ();
220
- for (int i = 0 ; i < num_expand_dims; i++) {
221
- reshape_dims.d [i] = 1 ;
222
- }
223
- for (int i = 0 ; i < input_dims.nbDims ; i++) {
224
- reshape_dims.d [num_expand_dims + i] = input_dims.d [i];
225
- }
226
- // Add a reshape layer to expand dims
227
- auto reshape_layer = ctx->net ->addShuffle (*in);
228
- reshape_layer->setReshapeDimensions (reshape_dims);
229
- in = reshape_layer->getOutput (0 );
230
- LOG_DEBUG (" Input reshaped to : " << in->getDimensions () << " from " << input_dims);
231
- }
232
- LOG_DEBUG (" Repeats: " << repeats);
233
- }
234
-
235
- // Concat across all repeat axes.
236
- for (int i = repeats.size () - 1 ; i >= 0 ; --i) {
237
- std::vector<nvinfer1::ITensor*> tensors_vec;
238
- for (int j = 0 ; j < repeats[i]; j++) {
239
- tensors_vec.push_back (in);
240
- }
241
- auto concat_layer = ctx->net ->addConcatenation (tensors_vec.data (), tensors_vec.size ());
242
- concat_layer->setAxis (i);
243
- in = concat_layer->getOutput (0 );
244
- }
245
-
246
- auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], in);
247
- LOG_DEBUG (layer << " layer output tensor shape: " << out->getDimensions ());
248
- return true ;
249
- }
250
-
251
197
auto expand_registrations TORCHTRT_UNUSED =
252
198
RegisterNodeConversionPatterns ()
253
199
.pattern(
@@ -284,7 +230,59 @@ auto expand_registrations TORCHTRT_UNUSED =
284
230
.pattern(
285
231
{" aten::repeat(Tensor self, int[] repeats) -> (Tensor)" ,
286
232
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
287
- return add_repeat (ctx, n, args, " Repeat" );
233
+ auto in = args[0 ].ITensorOrFreeze (ctx);
234
+ auto input_dims = in->getDimensions ();
235
+ auto repeats = args[1 ].unwrapToIntList ().vec ();
236
+ int repeats_rank = repeats.size ();
237
+ TORCHTRT_CHECK (
238
+ repeats_rank >= input_dims.nbDims ,
239
+ " Number of repeat dimensions cannot be smaller than number of input dimensions" );
240
+ auto num_expand_dims = repeats_rank - input_dims.nbDims ;
241
+
242
+ if (ctx->input_is_dynamic ) {
243
+ int input_rank = input_dims.nbDims ;
244
+ int output_rank = repeats_rank;
245
+ auto new_input_shape_tensor = concat (output_rank, input_rank, ctx, in);
246
+
247
+ // Add a reshape layer to expand dims
248
+ auto shuffle = ctx->net ->addShuffle (*in);
249
+ shuffle->setInput (1 , *new_input_shape_tensor);
250
+ in = shuffle->getOutput (0 );
251
+ } else {
252
+ if (num_expand_dims > 0 ) {
253
+ nvinfer1::Dims reshape_dims;
254
+ reshape_dims.nbDims = repeats.size ();
255
+ for (int i = 0 ; i < num_expand_dims; i++) {
256
+ reshape_dims.d [i] = 1 ;
257
+ }
258
+ for (int i = 0 ; i < input_dims.nbDims ; i++) {
259
+ reshape_dims.d [num_expand_dims + i] = input_dims.d [i];
260
+ }
261
+ // Add a reshape layer to expand dims
262
+ auto reshape_layer = ctx->net ->addShuffle (*in);
263
+ reshape_layer->setReshapeDimensions (reshape_dims);
264
+ in = reshape_layer->getOutput (0 );
265
+ LOG_DEBUG (" Input reshaped to : " << in->getDimensions () << " from " << input_dims);
266
+ }
267
+ LOG_DEBUG (" Repeats: " << repeats);
268
+ }
269
+
270
+ // Concat across all repeat axes.
271
+ // TODO: Implementation might not be performant. Explore other strategies to improve performance.
272
+ for (int i = repeats.size () - 1 ; i >= 0 ; --i) {
273
+ std::vector<nvinfer1::ITensor*> tensors_vec;
274
+ for (int j = 0 ; j < repeats[i]; j++) {
275
+ tensors_vec.push_back (in);
276
+ }
277
+ auto concat_layer = ctx->net ->addConcatenation (tensors_vec.data (), tensors_vec.size ());
278
+ concat_layer->setAxis (i);
279
+ in = concat_layer->getOutput (0 );
280
+ }
281
+
282
+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], in);
283
+
284
+ LOG_DEBUG (" Repeat layer output tensor shape: " << out->getDimensions ());
285
+ return true ;
288
286
}})
289
287
.pattern(
290
288
{" aten::repeat_interleave.self_int(Tensor self, int repeats, int? dim=None, *, int? output_size=None) -> (Tensor)" ,
@@ -397,11 +395,6 @@ auto expand_registrations TORCHTRT_UNUSED =
397
395
398
396
return true ;
399
397
}})
400
- .pattern(
401
- {" aten::tile(Tensor self, int[] dims) -> (Tensor)" ,
402
- [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
403
- return add_repeat (ctx, n, args, " Tile" );
404
- }})
405
398
.pattern(
406
399
{" aten::meshgrid(Tensor[] tensors) -> (Tensor[])" ,
407
400
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
@@ -491,4 +484,4 @@ auto expand_registrations TORCHTRT_UNUSED =
491
484
} // namespace converters
492
485
} // namespace conversion
493
486
} // namespace core
494
- } // namespace torch_tensorrt
487
+ } // namespace torch_tensorrt
0 commit comments