Skip to content

Commit a225769

Browse files
committed
feat: Add support for TorchTensorRTModule in Dynamo
- Rename `TRTModuleNext` to `TorchTensorRTModule` across the repository, and move the source directory to `dynamo` - Update imports across the repository - Refactor `convert_module` code to support conversion to a `TorchTensorRTModule` - Add logging information about which runtime is being used in Dynamo compile - Add tests for `TorchTensorRTModule` functionality in Dynamo
1 parent bd9c29a commit a225769

17 files changed

+247
-42
lines changed

examples/fx/fx2trt_example_next.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88
import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer
99
from torch_tensorrt.fx import InputTensorSpec, TRTInterpreter
1010
from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter, TRTSplitterSetting
11-
from torch_tensorrt import TRTModuleNext as TRTModule, Device
11+
from torch_tensorrt.dynamo._TorchTensorRTModule import (
12+
TorchTensorRTModule as TRTModule,
13+
Device,
14+
)
1215

1316
# The purpose of this example is to demonstrate the overall flow of lowering a PyTorch
1417
# model to TensorRT via FX with existing FX based tooling. The general lowering flow

py/torch_tensorrt/__init__.py

-1
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,6 @@ def _find_lib(name, paths):
9191
from torch_tensorrt import logging
9292
from torch_tensorrt._Input import Input
9393
from torch_tensorrt._Device import Device
94-
from torch_tensorrt._TRTModuleNext import TRTModuleNext
9594

9695
from torch_tensorrt import fx
9796

py/torch_tensorrt/_TRTModuleNext.py renamed to py/torch_tensorrt/dynamo/_TorchTensorRTModule.py

+10-11
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import logging
2-
from operator import truediv
3-
from typing import Any, List, Sequence, Tuple
2+
from typing import Any, List, Tuple
43

54
import torch
65
from torch_tensorrt import _C
@@ -9,8 +8,8 @@
98
logger = logging.getLogger(__name__)
109

1110

12-
class TRTModuleNext(torch.nn.Module):
13-
"""TRTModuleNext is a PyTorch module which encompasses an arbitrary TensorRT Engine.
11+
class TorchTensorRTModule(torch.nn.Module):
12+
"""TorchTensorRTModule is a PyTorch module which encompasses an arbitrary TensorRT Engine.
1413
1514
This module is backed by the Torch-TensorRT runtime and is fully compatibile with both
1615
FX / Python deployments (just ``import torch_tensorrt`` as part of the application) as
@@ -20,7 +19,7 @@ class TRTModuleNext(torch.nn.Module):
2019
The forward function is simpily forward(*args: torch.Tensor) -> Tuple[torch.Tensor] where
2120
the internal implementation is ``return Tuple(torch.ops.tensorrt.execute_engine(list(inputs), self.engine))``
2221
23-
> Note: TRTModuleNext only supports engines built with explict batch
22+
> Note: TorchTensorRTModule only supports engines built with explict batch
2423
2524
Attributes:
2625
name (str): Name of module (for easier debugging)
@@ -37,7 +36,7 @@ def __init__(
3736
output_binding_names: List[str] = [],
3837
target_device: Device = Device._current_device(),
3938
):
40-
"""__init__ method for torch_tensorrt.TRTModuleNext
39+
"""__init__ method for torch_tensorrt.TorchTensorRTModule
4140
4241
Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs
4342
a PyTorch ``torch.nn.Module`` around it.
@@ -71,9 +70,9 @@ def __init__(
7170
7271
"""
7372
logger.warning(
74-
"TRTModuleNext should be considered experimental stability, APIs are subject to change. Note: TRTModuleNext only supports engines built with explict batch"
73+
"TorchTensorRTModule should be considered experimental stability, APIs are subject to change. Note: TorchTensorRTModule only supports engines built with explict batch"
7574
)
76-
super(TRTModuleNext, self).__init__()
75+
super(TorchTensorRTModule, self).__init__()
7776

7877
if not isinstance(serialized_engine, bytearray):
7978
ValueError("Expected serialized engine as bytearray")
@@ -89,8 +88,8 @@ def __init__(
8988
self.name + "_engine" if self.name != "" else "tensorrt_engine",
9089
target_device._to_serialized_rt_device(),
9190
serialized_engine,
92-
TRTModuleNext._pack_binding_names(self.input_binding_names),
93-
TRTModuleNext._pack_binding_names(self.output_binding_names),
91+
TorchTensorRTModule._pack_binding_names(self.input_binding_names),
92+
TorchTensorRTModule._pack_binding_names(self.output_binding_names),
9493
]
9594
)
9695
else:
@@ -154,7 +153,7 @@ def is_non_tensor(i: Tuple[Any, bool]) -> bool:
154153

155154
non_tensors = [i[0] for i in filter(zip(inputs, types), is_non_tensor)]
156155
raise RuntimeError(
157-
f"TRTModuleNext expects a flattened list of tensors as input, found non tensors: {non_tensors}"
156+
f"TorchTensorRTModule expects a flattened list of tensors as input, found non tensors: {non_tensors}"
158157
)
159158

160159
outputs = torch.ops.tensorrt.execute_engine(list(inputs), self.engine)

py/torch_tensorrt/dynamo/backend/__init__.py

+11
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
WORKSPACE_SIZE,
1717
MIN_BLOCK_SIZE,
1818
PASS_THROUGH_BUILD_FAILURES,
19+
USE_EXPERIMENTAL_RT,
1920
)
2021

2122

@@ -45,6 +46,7 @@ def compile(
4546
torch_executed_ops=[],
4647
torch_executed_modules=[],
4748
pass_through_build_failures=PASS_THROUGH_BUILD_FAILURES,
49+
use_experimental_rt=USE_EXPERIMENTAL_RT,
4850
**kwargs,
4951
):
5052
if debug:
@@ -57,6 +59,11 @@ def compile(
5759
+ "torch_executed_ops, pass_through_build_failures}"
5860
)
5961

62+
if "use_experimental_fx_rt" in kwargs:
63+
use_experimental_rt = kwargs["use_experimental_fx_rt"]
64+
65+
logger.info(f"Using {'C++' if use_experimental_rt else 'Python'} TRT Runtime")
66+
6067
if not isinstance(inputs, collections.abc.Sequence):
6168
inputs = [inputs]
6269

@@ -91,6 +98,7 @@ def compile(
9198
min_block_size=min_block_size,
9299
torch_executed_ops=torch_executed_ops,
93100
pass_through_build_failures=pass_through_build_failures,
101+
use_experimental_rt=use_experimental_rt,
94102
**kwargs,
95103
)
96104

@@ -114,6 +122,7 @@ def create_backend(
114122
min_block_size: int = MIN_BLOCK_SIZE,
115123
torch_executed_ops: Sequence[str] = set(),
116124
pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES,
125+
use_experimental_rt: bool = USE_EXPERIMENTAL_RT,
117126
**kwargs,
118127
):
119128
"""Create torch.compile backend given specified arguments
@@ -125,6 +134,7 @@ def create_backend(
125134
min_block_size: Minimum number of operators per TRT-Engine Block
126135
torch_executed_ops: Sequence of operations to run in Torch, regardless of converter coverage
127136
pass_through_build_failures: Whether to fail on TRT engine build errors (True) or not (False)
137+
use_experimental_rt: Whether to use the new experimental TRTModuleNext for TRT engines
128138
Returns:
129139
Backend for torch.compile
130140
"""
@@ -136,4 +146,5 @@ def create_backend(
136146
min_block_size=min_block_size,
137147
torch_executed_ops=torch_executed_ops,
138148
pass_through_build_failures=pass_through_build_failures,
149+
use_experimental_rt=use_experimental_rt,
139150
)

py/torch_tensorrt/dynamo/backend/_defaults.py

+1
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@
66
WORKSPACE_SIZE = 0
77
MIN_BLOCK_SIZE = 5
88
PASS_THROUGH_BUILD_FAILURES = False
9+
USE_EXPERIMENTAL_RT = False

py/torch_tensorrt/dynamo/backend/_settings.py

+2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
WORKSPACE_SIZE,
99
MIN_BLOCK_SIZE,
1010
PASS_THROUGH_BUILD_FAILURES,
11+
USE_EXPERIMENTAL_RT,
1112
)
1213

1314

@@ -19,3 +20,4 @@ class CompilationSettings:
1920
min_block_size: int = MIN_BLOCK_SIZE
2021
torch_executed_ops: Sequence[str] = field(default_factory=set)
2122
pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES
23+
use_experimental_rt: bool = USE_EXPERIMENTAL_RT

py/torch_tensorrt/dynamo/backend/backends.py

+1
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ def _compile_module(
139139
submodule,
140140
submodule_inputs,
141141
settings=settings,
142+
name=name,
142143
)
143144

144145
trt_modules[name] = trt_mod

py/torch_tensorrt/dynamo/backend/conversion.py

+22-7
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Sequence, Union
22
import torch
3+
import io
34
from torch_tensorrt.fx.trt_module import TRTModule
4-
from torch_tensorrt import TRTModuleNext
55
from torch_tensorrt.dynamo.backend._settings import CompilationSettings
66
from torch_tensorrt.dynamo.fx_ts_compat.fx2trt import (
77
InputTensorSpec,
@@ -15,12 +15,14 @@ def convert_module(
1515
module: torch.fx.GraphModule,
1616
inputs: Sequence[torch.Tensor],
1717
settings: CompilationSettings = CompilationSettings(),
18-
) -> Union[TRTModuleNext, TRTModule]:
18+
name: str = "",
19+
):
1920
"""Convert an FX module to a TRT module
2021
Args:
2122
module: FX GraphModule to convert
2223
inputs: Sequence of Tensors representing inputs to the module
2324
settings: Compilation settings
25+
name: TRT engine name
2426
Returns:
2527
TRTModule or TRTModuleNext
2628
"""
@@ -50,8 +52,21 @@ def convert_module(
5052
),
5153
)
5254

53-
return TRTModule(
54-
engine=interpreter_result.engine,
55-
input_names=interpreter_result.input_names,
56-
output_names=interpreter_result.output_names,
57-
)
55+
if settings.use_experimental_rt:
56+
from torch_tensorrt.dynamo._TorchTensorRTModule import TorchTensorRTModule
57+
58+
with io.BytesIO() as engine_bytes:
59+
engine_bytes.write(interpreter_result.engine.serialize())
60+
engine_str = engine_bytes.getvalue()
61+
return TorchTensorRTModule(
62+
serialized_engine=engine_str,
63+
name=name,
64+
input_binding_names=interpreter_result.input_names,
65+
output_binding_names=interpreter_result.output_names,
66+
)
67+
else:
68+
return TRTModule(
69+
engine=interpreter_result.engine,
70+
input_names=interpreter_result.input_names,
71+
output_names=interpreter_result.output_names,
72+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
from torch_tensorrt.dynamo.backend.lowering import partition
2+
from torch.testing._internal.common_utils import run_tests, TestCase
3+
import torch
4+
from copy import deepcopy
5+
from torch_tensorrt.dynamo import compile
6+
from utils import lower_graph_testing
7+
from torch_tensorrt.dynamo.common_utils.test_utils import DECIMALS_OF_AGREEMENT
8+
9+
10+
class TestTRTModuleNextCompilation(TestCase):
11+
def test_trt_module_next_full_support(self):
12+
class FullySupportedMultiOp(torch.nn.Module):
13+
def forward(self, x, y):
14+
out = x - y
15+
out = out + x
16+
out = 2 * out
17+
out = out + y
18+
return torch.mean(out, dim=1)
19+
20+
fx_graph = torch.fx.symbolic_trace(FullySupportedMultiOp())
21+
partitioned_graph = partition(deepcopy(fx_graph), min_block_size=3)
22+
23+
self.assertEquals(
24+
len(list(partitioned_graph.named_children())),
25+
1,
26+
"All operators are supported, there should be one segment",
27+
)
28+
29+
inputs = [
30+
torch.randint(-5, 5, (16, 7), dtype=torch.float).cuda(),
31+
torch.randint(-5, 5, (16, 7), dtype=torch.float).cuda(),
32+
]
33+
34+
torch._dynamo.reset()
35+
36+
# Validate that the results between Torch and Torch-TRT are similar
37+
optimized_model = compile(
38+
fx_graph,
39+
inputs,
40+
min_block_size=1,
41+
pass_through_build_failures=True,
42+
torch_executed_ops={"torch.ops.aten.add.Tensor"},
43+
use_experimental_rt=True,
44+
debug=True,
45+
)
46+
optimized_model_results = optimized_model(*inputs).detach().cpu()
47+
torch_model_results = fx_graph(*inputs).detach().cpu()
48+
49+
max_diff = float(
50+
torch.max(torch.abs(optimized_model_results - torch_model_results))
51+
)
52+
self.assertAlmostEqual(
53+
max_diff,
54+
0,
55+
DECIMALS_OF_AGREEMENT,
56+
f"TRT outputs don't match with the original model.",
57+
)
58+
59+
def test_trt_module_next_partial_support(self):
60+
class PartiallySupportedMultiOp(torch.nn.Module):
61+
def forward(self, x, y):
62+
out = x - y
63+
out = out - 3 * x
64+
out = out + y
65+
out = out.to(torch.float)
66+
out = 2 * out
67+
return torch.mean(out, dim=-1)
68+
69+
fx_graph = torch.fx.symbolic_trace(PartiallySupportedMultiOp())
70+
unexpected_ops = {torch.ops.aten.add.Tensor}
71+
72+
inputs = [
73+
torch.randint(-40, 40, (16, 7, 5), dtype=torch.int).cuda(),
74+
torch.randint(1, 40, (16, 7, 5), dtype=torch.int).cuda(),
75+
]
76+
77+
(unexpected_ops_seen, _, partitioned_graphs,) = lower_graph_testing(
78+
fx_graph,
79+
inputs,
80+
unexpected_ops=unexpected_ops,
81+
min_block_size=1,
82+
torch_executed_ops={"torch.ops.aten.add.Tensor"},
83+
testing_partitioning=True,
84+
)
85+
86+
self.assertEquals(
87+
len(unexpected_ops_seen),
88+
0,
89+
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
90+
)
91+
self.assertEquals(
92+
len(partitioned_graphs),
93+
1,
94+
"Without control flow breaks, there should only be a single graph",
95+
)
96+
self.assertEquals(
97+
len(list(partitioned_graphs[0].named_children())),
98+
2,
99+
"Certain operators are set to run in Torch, expected 2 segments",
100+
)
101+
102+
torch._dynamo.reset()
103+
104+
# Validate that the results between Torch and Torch-TRT are similar
105+
optimized_model = compile(
106+
fx_graph,
107+
inputs,
108+
min_block_size=1,
109+
pass_through_build_failures=True,
110+
torch_executed_ops={"torch.ops.aten.add.Tensor"},
111+
use_experimental_rt=True,
112+
debug=True,
113+
)
114+
optimized_model_results = optimized_model(*inputs).detach().cpu()
115+
torch_model_results = fx_graph(*inputs).detach().cpu()
116+
117+
max_diff = float(
118+
torch.max(torch.abs(optimized_model_results - torch_model_results))
119+
)
120+
self.assertAlmostEqual(
121+
max_diff,
122+
0,
123+
DECIMALS_OF_AGREEMENT,
124+
f"TRT outputs don't match with the original model.",
125+
)
126+
127+
128+
class TestCompilationOptions(TestCase):
129+
def test_trt_specific_options(self):
130+
class SupportedMultiOp(torch.nn.Module):
131+
def forward(self, x, y):
132+
out = x - y
133+
out = out - 3 * x
134+
out = out + y
135+
out = out - y / 5
136+
out = 2 * out
137+
return torch.mean(out, dim=-1)
138+
139+
fx_graph = torch.fx.symbolic_trace(SupportedMultiOp())
140+
141+
inputs = [
142+
torch.randint(-40, 40, (16, 7, 5), dtype=torch.float).cuda(),
143+
torch.randint(1, 40, (16, 7, 5), dtype=torch.float).cuda(),
144+
]
145+
146+
# Validate that the results between Torch and Torch-TRT are similar
147+
optimized_model = compile(
148+
fx_graph,
149+
inputs,
150+
min_block_size=1,
151+
pass_through_build_failures=True,
152+
use_experimental_rt=True,
153+
optimization_level=4,
154+
version_compatible=True,
155+
max_aux_streams=5,
156+
debug=True,
157+
)
158+
optimized_model_results = optimized_model(*inputs).detach().cpu()
159+
torch_model_results = fx_graph(*inputs).detach().cpu()
160+
161+
max_diff = float(
162+
torch.max(torch.abs(optimized_model_results - torch_model_results))
163+
)
164+
self.assertAlmostEqual(
165+
max_diff,
166+
0,
167+
DECIMALS_OF_AGREEMENT,
168+
f"TRT outputs don't match with the original model.",
169+
)
170+
171+
172+
if __name__ == "__main__":
173+
run_tests()

0 commit comments

Comments
 (0)