Skip to content

Commit 8b09e71

Browse files
authored
fix/feat: Move convolution core to impl + add feature (FX converter refactor) (#1972)
1 parent 2844630 commit 8b09e71

File tree

8 files changed

+327
-362
lines changed

8 files changed

+327
-362
lines changed

py/torch_tensorrt/dynamo/test/test_dynamo_backend.py

+5
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def test_resnet18(ir):
3030
"device": torchtrt.Device("cuda:0"),
3131
"enabled_precisions": {torch.float},
3232
"ir": ir,
33+
"pass_through_build_failures": True,
3334
}
3435

3536
trt_mod = torchtrt.compile(model, **compile_spec)
@@ -60,6 +61,7 @@ def test_mobilenet_v2(ir):
6061
"device": torchtrt.Device("cuda:0"),
6162
"enabled_precisions": {torch.float},
6263
"ir": ir,
64+
"pass_through_build_failures": True,
6365
}
6466

6567
trt_mod = torchtrt.compile(model, **compile_spec)
@@ -90,6 +92,7 @@ def test_efficientnet_b0(ir):
9092
"device": torchtrt.Device("cuda:0"),
9193
"enabled_precisions": {torch.float},
9294
"ir": ir,
95+
"pass_through_build_failures": True,
9396
}
9497

9598
trt_mod = torchtrt.compile(model, **compile_spec)
@@ -129,6 +132,7 @@ def test_bert_base_uncased(ir):
129132
"enabled_precisions": {torch.float},
130133
"truncate_long_and_double": True,
131134
"ir": ir,
135+
"pass_through_build_failures": True,
132136
}
133137
trt_mod = torchtrt.compile(model, **compile_spec)
134138

@@ -163,6 +167,7 @@ def test_resnet18_half(ir):
163167
"device": torchtrt.Device("cuda:0"),
164168
"enabled_precisions": {torch.half},
165169
"ir": ir,
170+
"pass_through_build_failures": True,
166171
}
167172

168173
trt_mod = torchtrt.compile(model, **compile_spec)

py/torch_tensorrt/fx/converters/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@
55
from .adaptive_avgpool import * # noqa: F401 F403
66
from .add import * # noqa: F401 F403
77
from .batchnorm import * # noqa: F401 F403
8-
from .convolution import * # noqa: F401 F403
98
from .linear import * # noqa: F401 F403
109
from .maxpool import * # noqa: F401 F403
1110
from .mul import * # noqa: F401 F403
1211
from .transformation import * # noqa: F401 F403
1312
from .quantization import * # noqa: F401 F403
1413
from .acc_ops_converters import * # noqa: F401 F403
1514
from .aten_ops_converters import * # noqa: F401 F403
15+
from .nn_ops_converters import * # noqa: F401 F403
1616

1717
TRT_LOGGER = trt.Logger()
1818
trt.init_libnvinfer_plugins(TRT_LOGGER, "")

py/torch_tensorrt/fx/converters/acc_ops_converters.py

+29-138
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
trt_transposed_matmul,
2727
)
2828
from torch_tensorrt.fx.tracer.acc_tracer.acc_ops import contiguous
29-
from torch_tensorrt.fx.converters.impl import activation
29+
from torch_tensorrt.fx.converters.impl import activation, convolution
3030

3131
_LOGGER: logging.Logger = logging.getLogger(__name__)
3232

@@ -96,86 +96,20 @@ def acc_ops_conv1d(
9696
kwargs: Dict[str, Argument],
9797
name: str,
9898
) -> Union[TRTTensor, Sequence[TRTTensor]]:
99-
input_val = kwargs["input"]
100-
if not isinstance(input_val, TRTTensor):
101-
raise RuntimeError(
102-
f"Conv received input {input_val} that is not part "
103-
"of the TensorRT region!"
104-
)
105-
106-
# Process 1d input with unsqueeze -> conv2d -> squeeze to calculated conv1d
107-
unsqueeze_layer = network.add_shuffle(input=input_val)
108-
unsqueeze_layer.reshape_dims = tuple([*input_val.shape, 1])
109-
set_layer_name(unsqueeze_layer, target, name + "_unsqueeze")
110-
input_val = unsqueeze_layer.get_output(0)
111-
112-
if has_dynamic_shape(input_val.shape):
113-
assert input_val.shape[1] != -1, "Channel dim can't be dynamic for convolution."
114-
115-
# for now we'll assume bias is constant Tensor or None,
116-
# and bias being ITensor is not supported in TensorRT api
117-
# right now
118-
if kwargs["bias"] is not None and not isinstance(kwargs["bias"], torch.Tensor):
119-
raise RuntimeError(
120-
f"linear {name} has bias of type {type(kwargs['bias'])}, Expect Optional[Tensor]"
121-
)
122-
bias = to_numpy(kwargs["bias"]) # type: ignore[arg-type]
123-
if bias is not None:
124-
bias = bias[None]
125-
weight = kwargs["weight"]
126-
127-
if network.has_explicit_precision or isinstance(weight, TRTTensor):
128-
weight = get_trt_tensor(network, weight, f"{name}_weight")
129-
# Expand 1d weight with unsqueeze for calculation
130-
unsqueeze_weight_layer = network.add_shuffle(input=weight)
131-
unsqueeze_weight_layer.reshape_dims = tuple([*weight.shape, 1])
132-
set_layer_name(unsqueeze_layer, target, name + "_unsqueeze_weight")
133-
weight = unsqueeze_weight_layer.get_output(0)
134-
weight_shape = tuple(kwargs["weight"].shape) # type: ignore[union-attr]
135-
# will need to use uninitialized weight and set it later to support
136-
# ITensor weights
137-
dummy_weight = trt.Weights()
138-
layer = network.add_convolution_nd(
139-
input=input_val,
140-
num_output_maps=weight.shape[0],
141-
kernel_shape=weight.shape[2:],
142-
kernel=dummy_weight,
143-
bias=bias,
144-
)
145-
146-
layer.set_input(1, weight)
147-
else:
148-
if not isinstance(kwargs["weight"], torch.Tensor):
149-
raise RuntimeError(
150-
f"linear {name} has weight of type {type(kwargs['weight'])}, Expect Optional[Tensor]"
151-
)
152-
weight = to_numpy(weight)
153-
weight = np.expand_dims(weight, -1)
154-
layer = network.add_convolution_nd(
155-
input=input_val,
156-
num_output_maps=weight.shape[0],
157-
kernel_shape=weight.shape[2:],
158-
kernel=weight,
159-
bias=bias,
160-
)
161-
# expand params to 2d for computation
162-
padding = list(kwargs["padding"])
163-
padding.append(0)
164-
stride = extend_attr_to_tuple(kwargs["stride"], 2)
165-
dilation = extend_attr_to_tuple(kwargs["dilation"], 2)
166-
167-
set_layer_name(layer, target, name)
168-
layer.stride_nd = stride
169-
layer.padding_nd = padding
170-
layer.dilation_nd = dilation
171-
if kwargs["groups"] is not None:
172-
layer.num_groups = kwargs["groups"]
173-
174-
result = layer.get_output(0)
175-
squeeze_layer = network.add_shuffle(input=result)
176-
squeeze_layer.reshape_dims = tuple(result.shape[:-1])
177-
set_layer_name(squeeze_layer, target, name + "_squeeze")
178-
return squeeze_layer.get_output(0)
99+
return convolution.convNd(
100+
network,
101+
target,
102+
source_ir=SourceIR.ACC,
103+
name=name,
104+
is_conv1d=True,
105+
input_val=kwargs["input"],
106+
weight=kwargs["weight"],
107+
bias=kwargs["bias"],
108+
stride=kwargs["stride"],
109+
padding=kwargs["padding"],
110+
dilation=kwargs["dilation"],
111+
groups=kwargs["groups"],
112+
)
179113

180114

181115
@tensorrt_converter(acc_ops.conv3d)
@@ -187,63 +121,20 @@ def acc_ops_convnd(
187121
kwargs: Dict[str, Argument],
188122
name: str,
189123
) -> Union[TRTTensor, Sequence[TRTTensor]]:
190-
input_val = kwargs["input"]
191-
192-
if not isinstance(input_val, TRTTensor):
193-
raise RuntimeError(
194-
f"Conv received input {input_val} that is not part "
195-
"of the TensorRT region!"
196-
)
197-
198-
if has_dynamic_shape(input_val.shape):
199-
assert input_val.shape[1] != -1, "Channel dim can't be dynamic for convolution."
200-
201-
# for now we'll assume bias is constant Tensor or None,
202-
# and bias being ITensor is not supported in TensorRT api
203-
# right now
204-
if kwargs["bias"] is not None and not isinstance(kwargs["bias"], torch.Tensor):
205-
raise RuntimeError(
206-
f"linear {name} has bias of type {type(kwargs['bias'])}, Expect Optional[Tensor]"
207-
)
208-
bias = to_numpy(kwargs["bias"]) # type: ignore[arg-type]
209-
210-
if network.has_explicit_precision or isinstance(kwargs["weight"], TRTTensor):
211-
weight = get_trt_tensor(network, kwargs["weight"], f"{name}_weight")
212-
weight_shape = tuple(kwargs["weight"].shape) # type: ignore[union-attr]
213-
# will need to use uninitialized weight and set it later to support
214-
# ITensor weights
215-
dummy_weight = trt.Weights()
216-
layer = network.add_convolution_nd(
217-
input=input_val,
218-
num_output_maps=weight.shape[0],
219-
kernel_shape=weight.shape[2:],
220-
kernel=dummy_weight,
221-
bias=bias,
222-
)
223-
224-
layer.set_input(1, weight)
225-
else:
226-
if not isinstance(kwargs["weight"], torch.Tensor):
227-
raise RuntimeError(
228-
f"linear {name} has weight of type {type(kwargs['weight'])}, Expect Optional[Tensor]"
229-
)
230-
weight = to_numpy(kwargs["weight"])
231-
layer = network.add_convolution_nd(
232-
input=input_val,
233-
num_output_maps=weight.shape[0],
234-
kernel_shape=weight.shape[2:],
235-
kernel=weight,
236-
bias=bias,
237-
)
238-
239-
set_layer_name(layer, target, name)
240-
layer.stride_nd = kwargs["stride"]
241-
layer.padding_nd = kwargs["padding"]
242-
layer.dilation_nd = kwargs["dilation"]
243-
if kwargs["groups"] is not None:
244-
layer.num_groups = kwargs["groups"]
245-
246-
return layer.get_output(0)
124+
return convolution.convNd(
125+
network,
126+
target,
127+
source_ir=SourceIR.ACC,
128+
name=name,
129+
is_conv1d=False,
130+
input_val=kwargs["input"],
131+
weight=kwargs["weight"],
132+
bias=kwargs["bias"],
133+
stride=kwargs["stride"],
134+
padding=kwargs["padding"],
135+
dilation=kwargs["dilation"],
136+
groups=kwargs["groups"],
137+
)
247138

248139

249140
@tensorrt_converter(acc_ops.conv_transpose2d)

py/torch_tensorrt/fx/converters/aten_ops_converters.py

+28-5
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from .converter_utils import * # noqa: F403
2222
import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils
23-
from torch_tensorrt.fx.converters.impl import activation
23+
from torch_tensorrt.fx.converters.impl import activation, convolution
2424

2525
_LOGGER: logging.Logger = logging.getLogger(__name__)
2626

@@ -127,13 +127,36 @@ def aten_ops_convolution(
127127
# we do not handle output_padding.
128128
if args[7] not in ([0], [0, 0], [0, 0, 0]):
129129
raise RuntimeError(f"Target {target} has non-0 output_padding")
130+
130131
if len(kwargs_new["stride"]) == 1:
131-
return acc_ops_converters.acc_ops_conv1d(
132-
network, target, None, kwargs_new, name
132+
return convolution.convNd(
133+
network,
134+
target,
135+
source_ir=SourceIR.ATEN,
136+
name=name,
137+
is_conv1d=True,
138+
input_val=kwargs_new["input"],
139+
weight=kwargs_new["weight"],
140+
bias=kwargs_new["bias"],
141+
stride=kwargs_new["stride"],
142+
padding=kwargs_new["padding"],
143+
dilation=kwargs_new["dilation"],
144+
groups=kwargs_new["groups"],
133145
)
134146
else:
135-
return acc_ops_converters.acc_ops_convnd(
136-
network, target, None, kwargs_new, name
147+
return convolution.convNd(
148+
network,
149+
target,
150+
source_ir=SourceIR.ATEN,
151+
name=name,
152+
is_conv1d=False,
153+
input_val=kwargs_new["input"],
154+
weight=kwargs_new["weight"],
155+
bias=kwargs_new["bias"],
156+
stride=kwargs_new["stride"],
157+
padding=kwargs_new["padding"],
158+
dilation=kwargs_new["dilation"],
159+
groups=kwargs_new["groups"],
137160
)
138161

139162

py/torch_tensorrt/fx/converters/converter_utils.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -99,14 +99,17 @@ def get_positive_dim(dim: int, dim_size: int) -> int:
9999

100100

101101
def set_layer_name(
102-
layer: TRTLayer, target: Target, name: str, source_ir: Optional[SourceIR] = None
102+
layer: TRTLayer,
103+
target: Union[Target, torch.nn.Module, str],
104+
name: str,
105+
source_ir: Optional[SourceIR] = None,
103106
) -> None:
104107
"""
105108
Set the TensorRT layer name to "[TensorRT Layer Type]_[Original Op Name]_[FX Node Name with Suffix]"
106109
107110
Args:
108111
layer (TRTLayer): A TensorRT layer of which we want to set the name.
109-
target (Target): A fx node.target. For call_function node, it's the function that
112+
target (Target): A fx node.target or submodule. For call_function node, it's the function that
110113
the node represents.
111114
name (str): Consists of fx node.name with optional suffix.
112115
source_ir: (Optional[SourceIR]): The IR producing the op.

0 commit comments

Comments
 (0)