Skip to content

Commit cefb2f2

Browse files
frank-weiWei Wei
and
Wei Wei
authored
add an example of aten2trt, fix batch norm pass (#1685)
Co-authored-by: Wei Wei <[email protected]>
1 parent deda87b commit cefb2f2

File tree

5 files changed

+212
-2
lines changed

5 files changed

+212
-2
lines changed

.circleci/config.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ commands:
263263
parameters:
264264
torch-build:
265265
type: string
266-
default: "2.0.0.dev20230129+cu117"
266+
default: "2.0.0.dev20230219+cu117"
267267
torch-build-index:
268268
type: string
269269
default: "https://download.pytorch.org/whl/nightly/cu117"
@@ -1026,7 +1026,7 @@ parameters:
10261026
# Nightly platform config
10271027
torch-build:
10281028
type: string
1029-
default: "2.0.0.dev20230129+cu117"
1029+
default: "2.0.0.dev20230219+cu117"
10301030
torch-build-index:
10311031
type: string
10321032
default: "https://download.pytorch.org/whl/nightly/cu117"

examples/fx/lower_example.py

+1
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ def run_configuration_benchmark(
188188
input,
189189
max_batch_size=conf.batch_size,
190190
lower_precision=LowerPrecision.FP16 if conf.fp16 else LowerPrecision.FP32,
191+
explicit_batch_dimension=True,
191192
)
192193
time = benchmark_torch_function(conf.batch_iter, lambda: lowered_module(*input))
193194
else:

examples/fx/lower_example_aten.py

+196
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
import typing as t
2+
from copy import deepcopy
3+
from dataclasses import dataclass, field, replace
4+
5+
import torch
6+
import torchvision
7+
from torch_tensorrt.fx import compile
8+
from torch_tensorrt.fx.utils import LowerPrecision
9+
10+
11+
"""
12+
The purpose of this example is to demostrate the onverall flow of lowering a PyTorch model
13+
to TensorRT conveniently with lower.py.
14+
"""
15+
16+
17+
@dataclass
18+
class Configuration:
19+
"""
20+
Specify the configuration used for fx2trt lowering and benchmark.
21+
22+
To extend, add a new configuration field to this class, and modify the
23+
lowering or benchmark behavior in `run_configuration_benchmark()`
24+
correspondingly.
25+
26+
It automatically prints all its values thanks to being a dataclass.
27+
"""
28+
29+
# number of inferences to run
30+
batch_iter: int
31+
32+
# Input batch size
33+
batch_size: int
34+
35+
# Friendly name of the configuration
36+
name: str = ""
37+
38+
# Whether to apply TRT lowering to the model before benchmarking
39+
trt: bool = False
40+
41+
# Whether to apply engine holder to the lowered model
42+
jit: bool = False
43+
44+
# Whether to enable FP16 mode for TRT lowering
45+
fp16: bool = False
46+
47+
# Relative tolerance for accuracy check after lowering. -1 means do not
48+
# check accuracy.
49+
accuracy_rtol: float = -1 # disable
50+
51+
52+
@dataclass
53+
class Result:
54+
"""Holds and computes the benchmark results.
55+
56+
Holds raw essential benchmark result values like duration.
57+
Also computes results that can be derived from the raw essential values
58+
(QPS), in the form of auto properties.
59+
60+
"""
61+
62+
module: torch.nn.Module = field(repr=False)
63+
input: t.Any = field(repr=False)
64+
conf: Configuration
65+
time_sec: float
66+
accuracy_res: t.Optional[bool] = None
67+
68+
@property
69+
def time_per_iter_ms(self) -> float:
70+
return self.time_sec * 1.0e3
71+
72+
@property
73+
def qps(self) -> float:
74+
return self.conf.batch_size / self.time_sec
75+
76+
def format(self) -> str:
77+
return (
78+
f"== Benchmark Result for: {self.conf}\n"
79+
f"BS: {self.conf.batch_size}, "
80+
f"Time per iter: {self.time_per_iter_ms:.2f}ms, "
81+
f"QPS: {self.qps:.2f}, "
82+
f"Accuracy: {self.accuracy_res} (rtol={self.conf.accuracy_rtol})"
83+
)
84+
85+
86+
@torch.inference_mode()
87+
def benchmark(
88+
model,
89+
inputs,
90+
batch_iter: int,
91+
batch_size: int,
92+
) -> None:
93+
"""
94+
Run fx2trt lowering and benchmark the given model according to the
95+
specified benchmark configuration. Prints the benchmark result for each
96+
configuration at the end of the run.
97+
"""
98+
99+
model = model.cuda().eval()
100+
inputs = [x.cuda() for x in inputs]
101+
102+
# benchmark base configuration
103+
conf = Configuration(batch_iter=batch_iter, batch_size=batch_size)
104+
105+
configurations = [
106+
# Baseline
107+
replace(conf, name="CUDA Eager", trt=False),
108+
# FP16
109+
replace(
110+
conf,
111+
name="TRT FP16 Eager",
112+
trt=True,
113+
jit=False,
114+
fp16=True,
115+
accuracy_rtol=1e-2,
116+
),
117+
]
118+
119+
results = [
120+
run_configuration_benchmark(deepcopy(model), inputs, conf_)
121+
for conf_ in configurations
122+
]
123+
124+
for res in results:
125+
print(res.format())
126+
127+
128+
def benchmark_torch_function(iters: int, f, *args) -> float:
129+
"""Estimates the average time duration for a single inference call in second
130+
131+
If the input is batched, then the estimation is for the batches inference call.
132+
133+
Args:
134+
iters: number of inference iterations to run
135+
f: a function to perform a single inference call
136+
137+
Returns:
138+
estimated average time duration in second for a single inference call
139+
"""
140+
with torch.inference_mode():
141+
f(*args)
142+
torch.cuda.synchronize()
143+
start_event = torch.cuda.Event(enable_timing=True)
144+
end_event = torch.cuda.Event(enable_timing=True)
145+
print("== Start benchmark iterations")
146+
with torch.inference_mode():
147+
start_event.record()
148+
for _ in range(iters):
149+
f(*args)
150+
end_event.record()
151+
torch.cuda.synchronize()
152+
print("== End benchmark iterations")
153+
return (start_event.elapsed_time(end_event) * 1.0e-3) / iters
154+
155+
156+
def run_configuration_benchmark(
157+
module,
158+
input,
159+
conf: Configuration,
160+
) -> Result:
161+
"""
162+
Runs `module` through lowering logic and benchmark the module before and
163+
after lowering.
164+
"""
165+
print(f"=== Running benchmark for: {conf}", "green")
166+
time = -1.0
167+
168+
if conf.fp16:
169+
module = module.half()
170+
input = [i.half() for i in input]
171+
172+
if not conf.trt:
173+
# Run eager mode benchmark
174+
time = benchmark_torch_function(conf.batch_iter, lambda: module(*input))
175+
elif not conf.jit:
176+
# Run lowering eager mode benchmark
177+
lowered_module = compile(
178+
module,
179+
input,
180+
max_batch_size=conf.batch_size,
181+
lower_precision=LowerPrecision.FP16 if conf.fp16 else LowerPrecision.FP32,
182+
explicit_batch_dimension=True,
183+
is_aten=True,
184+
)
185+
time = benchmark_torch_function(conf.batch_iter, lambda: lowered_module(*input))
186+
else:
187+
print("Lowering with JIT is not available!", "red")
188+
189+
result = Result(module=module, input=input, conf=conf, time_sec=time)
190+
return result
191+
192+
193+
if __name__ == "__main__":
194+
test_model = torchvision.models.resnet18(pretrained=True)
195+
input = [torch.rand(128, 3, 224, 224)] # type: ignore[attr-defined]
196+
benchmark(test_model, input, 50, 128)

py/setup.py

+2
Original file line numberDiff line numberDiff line change
@@ -353,13 +353,15 @@ def run(self):
353353
"torch_tensorrt.fx.passes",
354354
"torch_tensorrt.fx.tools",
355355
"torch_tensorrt.fx.tracer.acc_tracer",
356+
"torch_tensorrt.fx.tracer.dispatch_tracer",
356357
]
357358
package_dir = {
358359
"torch_tensorrt.fx": "torch_tensorrt/fx",
359360
"torch_tensorrt.fx.converters": "torch_tensorrt/fx/converters",
360361
"torch_tensorrt.fx.passes": "torch_tensorrt/fx/passes",
361362
"torch_tensorrt.fx.tools": "torch_tensorrt/fx/tools",
362363
"torch_tensorrt.fx.tracer.acc_tracer": "torch_tensorrt/fx/tracer/acc_tracer",
364+
"torch_tensorrt.fx.tracer.dispatch_tracer": "torch_tensorrt/fx/tracer/dispatch_tracer",
363365
}
364366

365367
with open("README.md", "r", encoding="utf-8") as fh:

py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py

+11
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ def replace_aten_op_with_indices(module: torch.fx.GraphModule) -> torch.fx.Graph
165165
torch.ops.aten.max_pool3d_with_indices.default,
166166
torch.ops.aten.native_batch_norm.default,
167167
torch.ops.aten._native_batch_norm_legit.default,
168+
torch.ops.aten._native_batch_norm_legit_no_training.default,
168169
):
169170
modified = True
170171
if len(n.users) != 1:
@@ -185,6 +186,16 @@ def replace_aten_op_with_indices(module: torch.fx.GraphModule) -> torch.fx.Graph
185186
new_args = list(n.args)
186187
new_args.append(False)
187188
new_args = tuple(new_args)
189+
elif (
190+
n.target == torch.ops.aten._native_batch_norm_legit_no_training.default
191+
):
192+
new_op = torch.ops.aten.batch_norm
193+
new_args = list(n.args)
194+
new_args.append(False)
195+
# _native_batch_norm_legit_no_training doesn't take in a training arg (assumed to be false)
196+
# but batchnorm takes in a training arg at position 5.
197+
new_args.insert(5, False)
198+
new_args = tuple(new_args)
188199

189200
getitem_node = next(iter(n.users))
190201
with module.graph.inserting_after(getitem_node):

0 commit comments

Comments
 (0)