26
26
trt_transposed_matmul ,
27
27
)
28
28
from torch_tensorrt .fx .tracer .acc_tracer .acc_ops import contiguous
29
- from torch_tensorrt .fx .converters .impl import activation
29
+ from torch_tensorrt .fx .converters .impl import activation , convolution
30
30
31
31
_LOGGER : logging .Logger = logging .getLogger (__name__ )
32
32
@@ -96,86 +96,20 @@ def acc_ops_conv1d(
96
96
kwargs : Dict [str , Argument ],
97
97
name : str ,
98
98
) -> Union [TRTTensor , Sequence [TRTTensor ]]:
99
- input_val = kwargs ["input" ]
100
- if not isinstance (input_val , TRTTensor ):
101
- raise RuntimeError (
102
- f"Conv received input { input_val } that is not part "
103
- "of the TensorRT region!"
104
- )
105
-
106
- # Process 1d input with unsqueeze -> conv2d -> squeeze to calculated conv1d
107
- unsqueeze_layer = network .add_shuffle (input = input_val )
108
- unsqueeze_layer .reshape_dims = tuple ([* input_val .shape , 1 ])
109
- set_layer_name (unsqueeze_layer , target , name + "_unsqueeze" )
110
- input_val = unsqueeze_layer .get_output (0 )
111
-
112
- if has_dynamic_shape (input_val .shape ):
113
- assert input_val .shape [1 ] != - 1 , "Channel dim can't be dynamic for convolution."
114
-
115
- # for now we'll assume bias is constant Tensor or None,
116
- # and bias being ITensor is not supported in TensorRT api
117
- # right now
118
- if kwargs ["bias" ] is not None and not isinstance (kwargs ["bias" ], torch .Tensor ):
119
- raise RuntimeError (
120
- f"linear { name } has bias of type { type (kwargs ['bias' ])} , Expect Optional[Tensor]"
121
- )
122
- bias = to_numpy (kwargs ["bias" ]) # type: ignore[arg-type]
123
- if bias is not None :
124
- bias = bias [None ]
125
- weight = kwargs ["weight" ]
126
-
127
- if network .has_explicit_precision or isinstance (weight , TRTTensor ):
128
- weight = get_trt_tensor (network , weight , f"{ name } _weight" )
129
- # Expand 1d weight with unsqueeze for calculation
130
- unsqueeze_weight_layer = network .add_shuffle (input = weight )
131
- unsqueeze_weight_layer .reshape_dims = tuple ([* weight .shape , 1 ])
132
- set_layer_name (unsqueeze_layer , target , name + "_unsqueeze_weight" )
133
- weight = unsqueeze_weight_layer .get_output (0 )
134
- weight_shape = tuple (kwargs ["weight" ].shape ) # type: ignore[union-attr]
135
- # will need to use uninitialized weight and set it later to support
136
- # ITensor weights
137
- dummy_weight = trt .Weights ()
138
- layer = network .add_convolution_nd (
139
- input = input_val ,
140
- num_output_maps = weight .shape [0 ],
141
- kernel_shape = weight .shape [2 :],
142
- kernel = dummy_weight ,
143
- bias = bias ,
144
- )
145
-
146
- layer .set_input (1 , weight )
147
- else :
148
- if not isinstance (kwargs ["weight" ], torch .Tensor ):
149
- raise RuntimeError (
150
- f"linear { name } has weight of type { type (kwargs ['weight' ])} , Expect Optional[Tensor]"
151
- )
152
- weight = to_numpy (weight )
153
- weight = np .expand_dims (weight , - 1 )
154
- layer = network .add_convolution_nd (
155
- input = input_val ,
156
- num_output_maps = weight .shape [0 ],
157
- kernel_shape = weight .shape [2 :],
158
- kernel = weight ,
159
- bias = bias ,
160
- )
161
- # expand params to 2d for computation
162
- padding = list (kwargs ["padding" ])
163
- padding .append (0 )
164
- stride = extend_attr_to_tuple (kwargs ["stride" ], 2 )
165
- dilation = extend_attr_to_tuple (kwargs ["dilation" ], 2 )
166
-
167
- set_layer_name (layer , target , name )
168
- layer .stride_nd = stride
169
- layer .padding_nd = padding
170
- layer .dilation_nd = dilation
171
- if kwargs ["groups" ] is not None :
172
- layer .num_groups = kwargs ["groups" ]
173
-
174
- result = layer .get_output (0 )
175
- squeeze_layer = network .add_shuffle (input = result )
176
- squeeze_layer .reshape_dims = tuple (result .shape [:- 1 ])
177
- set_layer_name (squeeze_layer , target , name + "_squeeze" )
178
- return squeeze_layer .get_output (0 )
99
+ return convolution .convNd (
100
+ network ,
101
+ target ,
102
+ source_ir = SourceIR .ACC ,
103
+ name = name ,
104
+ is_conv1d = True ,
105
+ input_val = kwargs ["input" ],
106
+ weight = kwargs ["weight" ],
107
+ bias = kwargs ["bias" ],
108
+ stride = kwargs ["stride" ],
109
+ padding = kwargs ["padding" ],
110
+ dilation = kwargs ["dilation" ],
111
+ groups = kwargs ["groups" ],
112
+ )
179
113
180
114
181
115
@tensorrt_converter (acc_ops .conv3d )
@@ -187,63 +121,20 @@ def acc_ops_convnd(
187
121
kwargs : Dict [str , Argument ],
188
122
name : str ,
189
123
) -> Union [TRTTensor , Sequence [TRTTensor ]]:
190
- input_val = kwargs ["input" ]
191
-
192
- if not isinstance (input_val , TRTTensor ):
193
- raise RuntimeError (
194
- f"Conv received input { input_val } that is not part "
195
- "of the TensorRT region!"
196
- )
197
-
198
- if has_dynamic_shape (input_val .shape ):
199
- assert input_val .shape [1 ] != - 1 , "Channel dim can't be dynamic for convolution."
200
-
201
- # for now we'll assume bias is constant Tensor or None,
202
- # and bias being ITensor is not supported in TensorRT api
203
- # right now
204
- if kwargs ["bias" ] is not None and not isinstance (kwargs ["bias" ], torch .Tensor ):
205
- raise RuntimeError (
206
- f"linear { name } has bias of type { type (kwargs ['bias' ])} , Expect Optional[Tensor]"
207
- )
208
- bias = to_numpy (kwargs ["bias" ]) # type: ignore[arg-type]
209
-
210
- if network .has_explicit_precision or isinstance (kwargs ["weight" ], TRTTensor ):
211
- weight = get_trt_tensor (network , kwargs ["weight" ], f"{ name } _weight" )
212
- weight_shape = tuple (kwargs ["weight" ].shape ) # type: ignore[union-attr]
213
- # will need to use uninitialized weight and set it later to support
214
- # ITensor weights
215
- dummy_weight = trt .Weights ()
216
- layer = network .add_convolution_nd (
217
- input = input_val ,
218
- num_output_maps = weight .shape [0 ],
219
- kernel_shape = weight .shape [2 :],
220
- kernel = dummy_weight ,
221
- bias = bias ,
222
- )
223
-
224
- layer .set_input (1 , weight )
225
- else :
226
- if not isinstance (kwargs ["weight" ], torch .Tensor ):
227
- raise RuntimeError (
228
- f"linear { name } has weight of type { type (kwargs ['weight' ])} , Expect Optional[Tensor]"
229
- )
230
- weight = to_numpy (kwargs ["weight" ])
231
- layer = network .add_convolution_nd (
232
- input = input_val ,
233
- num_output_maps = weight .shape [0 ],
234
- kernel_shape = weight .shape [2 :],
235
- kernel = weight ,
236
- bias = bias ,
237
- )
238
-
239
- set_layer_name (layer , target , name )
240
- layer .stride_nd = kwargs ["stride" ]
241
- layer .padding_nd = kwargs ["padding" ]
242
- layer .dilation_nd = kwargs ["dilation" ]
243
- if kwargs ["groups" ] is not None :
244
- layer .num_groups = kwargs ["groups" ]
245
-
246
- return layer .get_output (0 )
124
+ return convolution .convNd (
125
+ network ,
126
+ target ,
127
+ source_ir = SourceIR .ACC ,
128
+ name = name ,
129
+ is_conv1d = False ,
130
+ input_val = kwargs ["input" ],
131
+ weight = kwargs ["weight" ],
132
+ bias = kwargs ["bias" ],
133
+ stride = kwargs ["stride" ],
134
+ padding = kwargs ["padding" ],
135
+ dilation = kwargs ["dilation" ],
136
+ groups = kwargs ["groups" ],
137
+ )
247
138
248
139
249
140
@tensorrt_converter (acc_ops .conv_transpose2d )
0 commit comments