Skip to content

Commit 7f16bda

Browse files
authored
chore: cherry-pick FP8 (#2892)
1 parent 29272fa commit 7f16bda

23 files changed

+491
-44
lines changed

.github/workflows/build-test-linux.yml

+9-1
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ jobs:
6565
package-name: torch_tensorrt
6666
pre-script: packaging/pre_build_script.sh
6767
post-script: packaging/post_build_script.sh
68+
smoke-test-script: packaging/smoke_test_script.sh
6869
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@main
6970
with:
7071
job-name: tests-py-torchscript-fe
@@ -99,6 +100,7 @@ jobs:
99100
package-name: torch_tensorrt
100101
pre-script: packaging/pre_build_script.sh
101102
post-script: packaging/post_build_script.sh
103+
smoke-test-script: packaging/smoke_test_script.sh
102104
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@main
103105
with:
104106
job-name: tests-py-dynamo-converters
@@ -126,6 +128,7 @@ jobs:
126128
package-name: torch_tensorrt
127129
pre-script: packaging/pre_build_script.sh
128130
post-script: packaging/post_build_script.sh
131+
smoke-test-script: packaging/smoke_test_script.sh
129132
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@main
130133
with:
131134
job-name: tests-py-dynamo-fe
@@ -154,6 +157,7 @@ jobs:
154157
package-name: torch_tensorrt
155158
pre-script: packaging/pre_build_script.sh
156159
post-script: packaging/post_build_script.sh
160+
smoke-test-script: packaging/smoke_test_script.sh
157161
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@main
158162
with:
159163
job-name: tests-py-dynamo-serde
@@ -181,6 +185,7 @@ jobs:
181185
package-name: torch_tensorrt
182186
pre-script: packaging/pre_build_script.sh
183187
post-script: packaging/post_build_script.sh
188+
smoke-test-script: packaging/smoke_test_script.sh
184189
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@main
185190
with:
186191
job-name: tests-py-torch-compile-be
@@ -210,6 +215,7 @@ jobs:
210215
package-name: torch_tensorrt
211216
pre-script: packaging/pre_build_script.sh
212217
post-script: packaging/post_build_script.sh
218+
smoke-test-script: packaging/smoke_test_script.sh
213219
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@main
214220
with:
215221
job-name: tests-py-dynamo-core
@@ -238,7 +244,9 @@ jobs:
238244
- repository: pytorch/tensorrt
239245
package-name: torch_tensorrt
240246
pre-script: packaging/pre_build_script.sh
241-
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@main
247+
post-script: packaging/post_build_script.sh
248+
smoke-test-script: packaging/smoke_test_script.sh
249+
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@release/2.3
242250
with:
243251
job-name: tests-py-core
244252
repository: "pytorch/tensorrt"

dev_dep_versions.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
__version__: "2.4.0.dev0"
22
__cuda_version__: "12.1"
3-
__tensorrt_version__: "10.0.1"
3+
__tensorrt_version__: "10.0.1"

docsrc/index.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ Tutorials
114114
tutorials/_rendered_examples/dynamo/custom_kernel_plugins
115115
tutorials/_rendered_examples/distributed_inference/data_parallel_gpt2
116116
tutorials/_rendered_examples/distributed_inference/data_parallel_stable_diffusion
117-
117+
tutorials/_rendered_examples/dynamo/vgg16_fp8_ptq
118118

119119
Python API Documenation
120120
------------------------

examples/dynamo/README.rst

+1
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,4 @@ a number of ways you can leverage this backend to accelerate inference.
1111
* :ref:`torch_compile_advanced_usage`: Advanced usage including making a custom backend to use directly with the ``torch.compile`` API
1212
* :ref:`torch_compile_stable_diffusion`: Compiling a Stable Diffusion model using ``torch.compile``
1313
* :ref:`custom_kernel_plugins`: Creating a plugin to use a custom kernel inside TensorRT engines
14+
* :ref:`vgg16_fp8_ptq`: Compiling a VGG16 model with FP8 and PTQ using ``torch.compile``

examples/dynamo/vgg16_fp8_ptq.py

+251
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
"""
2+
.. _vgg16_fp8_ptq:
3+
4+
Torch Compile VGG16 with FP8 and PTQ
5+
======================================================
6+
7+
This script is intended as a sample of the Torch-TensorRT workflow with `torch.compile` on a VGG16 model with FP8 and PTQ.
8+
"""
9+
10+
# %%
11+
# Imports and Model Definition
12+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
13+
14+
import argparse
15+
16+
import modelopt.torch.quantization as mtq
17+
import torch
18+
import torch.nn as nn
19+
import torch.nn.functional as F
20+
import torch_tensorrt as torchtrt
21+
import torchvision.datasets as datasets
22+
import torchvision.transforms as transforms
23+
from modelopt.torch.quantization.utils import export_torch_mode
24+
25+
26+
class VGG(nn.Module):
27+
def __init__(self, layer_spec, num_classes=1000, init_weights=False):
28+
super(VGG, self).__init__()
29+
30+
layers = []
31+
in_channels = 3
32+
for l in layer_spec:
33+
if l == "pool":
34+
layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
35+
else:
36+
layers += [
37+
nn.Conv2d(in_channels, l, kernel_size=3, padding=1),
38+
nn.BatchNorm2d(l),
39+
nn.ReLU(),
40+
]
41+
in_channels = l
42+
43+
self.features = nn.Sequential(*layers)
44+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
45+
self.classifier = nn.Sequential(
46+
nn.Linear(512 * 1 * 1, 4096),
47+
nn.ReLU(),
48+
nn.Dropout(),
49+
nn.Linear(4096, 4096),
50+
nn.ReLU(),
51+
nn.Dropout(),
52+
nn.Linear(4096, num_classes),
53+
)
54+
if init_weights:
55+
self._initialize_weights()
56+
57+
def _initialize_weights(self):
58+
for m in self.modules():
59+
if isinstance(m, nn.Conv2d):
60+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
61+
if m.bias is not None:
62+
nn.init.constant_(m.bias, 0)
63+
elif isinstance(m, nn.BatchNorm2d):
64+
nn.init.constant_(m.weight, 1)
65+
nn.init.constant_(m.bias, 0)
66+
elif isinstance(m, nn.Linear):
67+
nn.init.normal_(m.weight, 0, 0.01)
68+
nn.init.constant_(m.bias, 0)
69+
70+
def forward(self, x):
71+
x = self.features(x)
72+
x = self.avgpool(x)
73+
x = torch.flatten(x, 1)
74+
x = self.classifier(x)
75+
return x
76+
77+
78+
def vgg16(num_classes=1000, init_weights=False):
79+
vgg16_cfg = [
80+
64,
81+
64,
82+
"pool",
83+
128,
84+
128,
85+
"pool",
86+
256,
87+
256,
88+
256,
89+
"pool",
90+
512,
91+
512,
92+
512,
93+
"pool",
94+
512,
95+
512,
96+
512,
97+
"pool",
98+
]
99+
return VGG(vgg16_cfg, num_classes, init_weights)
100+
101+
102+
PARSER = argparse.ArgumentParser(
103+
description="Load pre-trained VGG model and then tune with FP8 and PTQ"
104+
)
105+
PARSER.add_argument(
106+
"--ckpt", type=str, required=True, help="Path to the pre-trained checkpoint"
107+
)
108+
PARSER.add_argument(
109+
"--batch-size",
110+
default=128,
111+
type=int,
112+
help="Batch size for tuning the model with PTQ and FP8",
113+
)
114+
115+
args = PARSER.parse_args()
116+
117+
model = vgg16(num_classes=10, init_weights=False)
118+
model = model.cuda()
119+
120+
# %%
121+
# Load the pre-trained model weights
122+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
123+
124+
ckpt = torch.load(args.ckpt)
125+
weights = ckpt["model_state_dict"]
126+
127+
if torch.cuda.device_count() > 1:
128+
from collections import OrderedDict
129+
130+
new_state_dict = OrderedDict()
131+
for k, v in weights.items():
132+
name = k[7:] # remove `module.`
133+
new_state_dict[name] = v
134+
weights = new_state_dict
135+
136+
model.load_state_dict(weights)
137+
# Don't forget to set the model to evaluation mode!
138+
model.eval()
139+
140+
# %%
141+
# Load training dataset and define loss function for PTQ
142+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
143+
144+
training_dataset = datasets.CIFAR10(
145+
root="./data",
146+
train=True,
147+
download=True,
148+
transform=transforms.Compose(
149+
[
150+
transforms.RandomCrop(32, padding=4),
151+
transforms.RandomHorizontalFlip(),
152+
transforms.ToTensor(),
153+
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
154+
]
155+
),
156+
)
157+
training_dataloader = torch.utils.data.DataLoader(
158+
training_dataset, batch_size=args.batch_size, shuffle=True, num_workers=2
159+
)
160+
161+
data = iter(training_dataloader)
162+
images, _ = next(data)
163+
164+
crit = nn.CrossEntropyLoss()
165+
166+
# %%
167+
# Define Calibration Loop for quantization
168+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
169+
170+
171+
def calibrate_loop(model):
172+
# calibrate over the training dataset
173+
total = 0
174+
correct = 0
175+
loss = 0.0
176+
for data, labels in training_dataloader:
177+
data, labels = data.cuda(), labels.cuda(non_blocking=True)
178+
out = model(data)
179+
loss += crit(out, labels)
180+
preds = torch.max(out, 1)[1]
181+
total += labels.size(0)
182+
correct += (preds == labels).sum().item()
183+
184+
print("PTQ Loss: {:.5f} Acc: {:.2f}%".format(loss / total, 100 * correct / total))
185+
186+
187+
# %%
188+
# Tune the pre-trained model with FP8 and PTQ
189+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
190+
191+
quant_cfg = mtq.FP8_DEFAULT_CFG
192+
# PTQ with in-place replacement to quantized modules
193+
mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
194+
# model has FP8 qdq nodes at this point
195+
196+
# %%
197+
# Inference
198+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
199+
200+
# Load the testing dataset
201+
testing_dataset = datasets.CIFAR10(
202+
root="./data",
203+
train=False,
204+
download=True,
205+
transform=transforms.Compose(
206+
[
207+
transforms.ToTensor(),
208+
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
209+
]
210+
),
211+
)
212+
213+
testing_dataloader = torch.utils.data.DataLoader(
214+
testing_dataset, batch_size=args.batch_size, shuffle=False, num_workers=2
215+
)
216+
217+
with torch.no_grad():
218+
with export_torch_mode():
219+
# Compile the model with Torch-TensorRT Dynamo backend
220+
input_tensor = images.cuda()
221+
exp_program = torch.export.export(model, (input_tensor,))
222+
trt_model = torchtrt.dynamo.compile(
223+
exp_program,
224+
inputs=[input_tensor],
225+
enabled_precisions={torch.float8_e4m3fn},
226+
min_block_size=1,
227+
debug=False,
228+
)
229+
230+
# Inference compiled Torch-TensorRT model over the testing dataset
231+
total = 0
232+
correct = 0
233+
loss = 0.0
234+
class_probs = []
235+
class_preds = []
236+
model.eval()
237+
for data, labels in testing_dataloader:
238+
data, labels = data.cuda(), labels.cuda(non_blocking=True)
239+
out = model(data)
240+
loss += crit(out, labels)
241+
preds = torch.max(out, 1)[1]
242+
class_probs.append([F.softmax(i, dim=0) for i in out])
243+
class_preds.append(preds)
244+
total += labels.size(0)
245+
correct += (preds == labels).sum().item()
246+
247+
test_probs = torch.cat([torch.stack(batch) for batch in class_probs])
248+
test_preds = torch.cat(class_preds)
249+
test_loss = loss / total
250+
test_acc = correct / total
251+
print("Test Loss: {:.5f} Test Acc: {:.2f}%".format(test_loss, 100 * test_acc))

examples/int8/training/vgg16/requirements.txt

+2
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,5 @@ nvidia-pyindex
44
--extra-index-url https://pypi.nvidia.com
55
pytorch-quantization
66
tqdm
7+
nvidia-modelopt
8+
--extra-index-url https://pypi.nvidia.com

packaging/pre_build_script.sh

-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
# Install dependencies
44
python3 -m pip install pyyaml
55
yum install -y ninja-build gettext
6-
TRT_VERSION=$(python3 -c "import versions; versions.tensorrt_version()")
76
wget https://github.com/bazelbuild/bazelisk/releases/download/v1.17.0/bazelisk-linux-amd64 \
87
&& mv bazelisk-linux-amd64 /usr/bin/bazel \
98
&& chmod +x /usr/bin/bazel

0 commit comments

Comments
 (0)