Skip to content

Commit a7fbb3d

Browse files
committed
Update to opset 11
1 parent d03b68b commit a7fbb3d

File tree

12 files changed

+59
-49
lines changed

12 files changed

+59
-49
lines changed

INFERENCE.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,12 @@ odtk export model.pth model_fp32.plan --full-precision --size 800 1280
6262

6363
In order to use INT8 precision with TensorRT, you need to provide calibration images (images that are representative of what will be seen at runtime) that will be used to rescale the network.
6464
```bash
65-
odtk export model.pth model_int8.plan --int8 --calibration-images /data/val/ --calibration-batches 10 --calibration-table model_calibration_table
65+
odtk export model.pth model_int8.plan --int8 --calibration-images /data/val/ --calibration-batches 2 --calibration-table model_calibration_table
6666
```
6767

68-
This will randomly select 20 images from `/data/val/` to calibrate the network for INT8 precision. The results from calibration will be saved to `model_calibration_table` that can be used to create subsequent INT8 engines for this model without needed to recalibrate.
68+
This will randomly select 16 images from `/data/val/` to calibrate the network for INT8 precision. The results from calibration will be saved to `model_calibration_table` that can be used to create subsequent INT8 engines for this model without needed to recalibrate.
69+
70+
**NOTE:** Number of images in `/data/val/` must be greater than or equal to the kOPT(middle) optimization profile from `--dynamic-batch-opts`. Here, the default kOPT is 8.
6971

7072
Build an INT8 engine for a previously calibrated model:
7173
```bash

csrc/engine.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ Engine::~Engine() {
9393
}
9494

9595
Engine::Engine(const char *onnx_model, size_t onnx_size, const vector<int>& dynamic_batch_opts,
96-
size_t batch, string precision, float score_thresh, int top_n, const vector<vector<float>>& anchors,
96+
string precision, float score_thresh, int top_n, const vector<vector<float>>& anchors,
9797
bool rotated, float nms_thresh, int detections_per_im, const vector<string>& calibration_images,
9898
string model_name, string calibration_table, bool verbose, size_t workspace_size) {
9999

@@ -134,9 +134,9 @@ Engine::Engine(const char *onnx_model, size_t onnx_size, const vector<int>& dyna
134134

135135
std::unique_ptr<Int8EntropyCalibrator> calib;
136136
if (int8) {
137-
// Calibration is performed using kOPT values of the profile.
138-
// Calibration input data size must match this profile.
139137
builderConfig->setFlag(BuilderFlag::kINT8);
138+
// Calibration is performed using kOPT values of the profile.
139+
// Calibration batch size must match this profile.
140140
builderConfig->setCalibrationProfile(profile);
141141
ImageStream stream(dynamic_batch_opts[1], inputDims, calibration_images);
142142
calib = std::unique_ptr<Int8EntropyCalibrator>(new Int8EntropyCalibrator(stream, model_name, calibration_table));
@@ -201,6 +201,8 @@ Engine::Engine(const char *onnx_model, size_t onnx_size, const vector<int>& dyna
201201
network->destroy();
202202
builderConfig->destroy();
203203
builder->destroy();
204+
205+
_prepare();
204206
}
205207

206208
void Engine::save(const string &path) {
@@ -236,4 +238,4 @@ int Engine::getStride() {
236238
return 1;
237239
}
238240

239-
}
241+
}

csrc/engine.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929

3030
#include <cuda_runtime.h>
3131

32-
3332
using namespace std;
3433
using namespace nvinfer1;
3534

@@ -43,9 +42,9 @@ class Engine {
4342

4443
// Create engine from serialized onnx model
4544

46-
Engine(const char *onnx_model, size_t onnx_size, const vector<int>& dynamic_batch_opts, size_t batch,
47-
string precision, float score_thresh, int top_n, const vector<vector<float>>& anchors, bool rotated,
48-
float nms_thresh, int detections_per_im, const vector<string>& calibration_images,
45+
Engine(const char *onnx_model, size_t onnx_size, const vector<int>& dynamic_batch_opts,
46+
string precision, float score_thresh, int top_n, const vector<vector<float>>& anchors,
47+
bool rotated, float nms_thresh, int detections_per_im, const vector<string>& calibration_images,
4948
string model_name, string calibration_table, bool verbose, size_t workspace_size=(1ULL << 30));
5049

5150
~Engine();

csrc/extensions.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,8 +183,8 @@ vector<at::Tensor> infer(retinanet::Engine &engine, at::Tensor data, bool rotate
183183

184184
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
185185
pybind11::class_<retinanet::Engine>(m, "Engine")
186-
.def(pybind11::init<const char *, size_t, const vector<int>&, size_t, string, float,
187-
int, const vector<vector<float>>&, bool, float, int, const vector<string>&, string, string, bool>())
186+
.def(pybind11::init<const char *, size_t, const vector<int>&, string, float, int,
187+
const vector<vector<float>>&, bool, float, int, const vector<string>&, string, string, bool>())
188188
.def("save", &retinanet::Engine::save)
189189
.def("infer", &retinanet::Engine::infer)
190190
.def_property_readonly("stride", &retinanet::Engine::getStride)

extras/cppapi/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ msbuild retinanet_infer.sln
3232

3333
If you don't have an ONNX core model, generate one from your RetinaNet model:
3434
```bash
35-
retinanet export model.pth model.onnx
35+
odtk export model.pth model.onnx
3636
```
3737

3838
Load the ONNX core model and export it to a RetinaNet TensorRT engine (using FP16 precision):

extras/cppapi/export.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ int main(int argc, char *argv[]) {
5555

5656
// Define default RetinaNet parameters to use for TRT export
5757
const vector<int> dynamic_batch_opts{1, 8, 16};
58-
int batch = 1;
58+
int calibration_batches = 2; // must be >= 1
5959
float score_thresh = 0.05f;
6060
int top_n = 1000;
6161
size_t workspace_size =(1ULL << 30);
@@ -86,7 +86,7 @@ int main(int argc, char *argv[]) {
8686
}
8787

8888
// For INT8 calibration, after setting COCO_PATH on line 10:
89-
// const vector<string> calibration_files = glob(dynamic_batch_opts[1]);
89+
// const vector<string> calibration_files = glob(calibration_batches*dynamic_batch_opts[1]);
9090
const vector<string> calibration_files;
9191
string model_name = "";
9292
string calibration_table = argc == 4 ? string(argv[3]) : "";
@@ -97,7 +97,7 @@ int main(int argc, char *argv[]) {
9797
precision = "INT8";
9898

9999
cout << "Building engine..." << endl;
100-
auto engine = retinanet::Engine(buffer, size, dynamic_batch_opts, batch, precision, score_thresh, top_n,
100+
auto engine = retinanet::Engine(buffer, size, dynamic_batch_opts, precision, score_thresh, top_n,
101101
anchors, ROTATED, nms_thresh, detections_per_im, calibration_files, model_name, calibration_table, verbose, workspace_size);
102102
engine.save(string(argv[2]));
103103

extras/cppapi/infervideo.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ int main(int argc, char *argv[]) {
9595
cout << "Inferring on frame: " << count <<"/" << nframes << endl;
9696
count++;
9797
vector<void *> buffers = { data_d, scores_d, boxes_d, classes_d };
98-
engine.infer(buffers);
98+
engine.infer(buffers, 1);
9999

100100
cudaMemcpy(scores.get(), scores_d, sizeof(float) * num_det, cudaMemcpyDeviceToHost);
101101
cudaMemcpy(boxes.get(), boxes_d, sizeof(float) * num_det * 4, cudaMemcpyDeviceToHost);

retinanet/dali.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def __init__(self, batch_size, num_threads, path, training, annotations, world,
3232

3333
self.decode_train = ops.ImageDecoderSlice(device="mixed", output_type=types.RGB)
3434
self.decode_infer = ops.ImageDecoder(device="mixed", output_type=types.RGB)
35-
self.bbox_crop = ops.RandomBBoxCrop(device='cpu', ltrb=True, scaling=[0.3, 1.0],
35+
self.bbox_crop = ops.RandomBBoxCrop(device='cpu', bbox_layout="xyXY", scaling=[0.3, 1.0],
3636
thresholds=[0.1, 0.3, 0.5, 0.7, 0.9])
3737

3838
self.bbox_flip = ops.BbFlip(device='cpu', ltrb=True)
@@ -122,7 +122,7 @@ def __init__(self, path, resize, max_size, batch_size, stride, world, annotation
122122
augment_brightness=augment_brightness,
123123
augment_contrast=augment_contrast, augment_hue=augment_hue,
124124
augment_saturation=augment_saturation)
125-
125+
126126
self.pipe.build()
127127

128128
def __repr__(self):
@@ -149,7 +149,7 @@ def __iter__(self):
149149
id = int(dali_ids.at(batch)[0])
150150

151151
# Convert dali tensor to pytorch
152-
dali_tensor = dali_data.at(batch)
152+
dali_tensor = dali_data[batch]
153153
tensor_shape = dali_tensor.shape()
154154

155155
datum = torch.zeros(dali_tensor.shape(), dtype=torch.float, device=torch.device('cuda'))
@@ -158,7 +158,7 @@ def __iter__(self):
158158

159159
# Calculate image resize ratio to rescale boxes
160160
prior_size = dali_attrs.as_cpu().at(batch)
161-
resized_size = dali_resize_img.at(batch).shape()
161+
resized_size = dali_resize_img[batch].shape()
162162
ratio = max(resized_size) / max(prior_size)
163163

164164
if self.training:
@@ -192,12 +192,10 @@ def __iter__(self):
192192

193193
if self.training:
194194
pyt_targets = pyt_targets.cuda(non_blocking=True)
195-
196195
yield data, pyt_targets
197196

198197
else:
199198
ids = torch.Tensor(ids).int().cuda(non_blocking=True)
200199
ratios = torch.Tensor(ratios).cuda(non_blocking=True)
201-
202200
yield data, ids, ratios
203201

retinanet/infer.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,6 @@ def infer(model, path, detections_file, resize, max_size, batch_size, mixed_prec
3333

3434
# Prepare dataset
3535
if verbose: print('Preparing dataset...')
36-
data_iterator = (DaliDataIterator if use_dali else DataIterator)(
37-
path, resize, max_size, batch_size, stride,
38-
world, annotations, training=False)
39-
4036
if rotated_bbox:
4137
if use_dali: raise NotImplementedError("This repo does not currently support DALI for rotated bbox.")
4238
data_iterator = RotatedDataIterator(path, resize, max_size, batch_size, stride,
@@ -45,7 +41,6 @@ def infer(model, path, detections_file, resize, max_size, batch_size, mixed_prec
4541
data_iterator = (DaliDataIterator if use_dali else DataIterator)(
4642
path, resize, max_size, batch_size, stride,
4743
world, annotations, training=False)
48-
4944
if verbose: print(data_iterator)
5045

5146
# Prepare model

retinanet/main.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -98,12 +98,10 @@ def parse(args):
9898
parser_export.add_argument('--size', metavar='height width', type=int, nargs='+',
9999
help='input size (square) or sizes (h w) to use when generating TensorRT engine',
100100
default=[1280])
101-
parser_export.add_argument('--batch', metavar='size', type=int, help='max batch size to use for TensorRT engine',
102-
default=2)
103101
parser_export.add_argument('--full-precision', help='export in full instead of half precision', action='store_true')
104102
parser_export.add_argument('--int8', help='calibrate model and export in int8 precision', action='store_true')
105103
parser_export.add_argument('--calibration-batches', metavar='size', type=int,
106-
help='number of batches to use for int8 calibration', default=4)
104+
help='number of batches to use for int8 calibration', default=2)
107105
parser_export.add_argument('--calibration-images', metavar='path', type=str,
108106
help='path to calibration images to use for int8 calibration', default="")
109107
parser_export.add_argument('--calibration-table', metavar='path', type=str,
@@ -163,7 +161,7 @@ def worker(rank, args, world, model, state):
163161
torch.cuda.set_device(rank)
164162
torch.distributed.init_process_group(backend='nccl', init_method='env://')
165163

166-
if args.batch % world != 0:
164+
if (args.command != 'export') and (args.batch % world != 0):
167165
raise RuntimeError('Batch size should be a multiple of the number of GPUs')
168166

169167
if model and model.angles is not None:
@@ -204,11 +202,16 @@ def worker(rank, args, world, model, state):
204202
for ex in file_extensions:
205203
calibration_files += glob.glob("{}/*{}".format(args.calibration_images, ex), recursive=True)
206204
# Only need enough images for specified num of calibration batches
207-
if len(calibration_files) >= args.calibration_batches * args.batch:
208-
calibration_files = calibration_files[:(args.calibration_batches * args.batch)]
205+
if len(calibration_files) >= args.calibration_batches * args.dynamic_batch_opts[1]:
206+
calibration_files = calibration_files[:(args.calibration_batches * args.dynamic_batch_opts[1])]
209207
else:
210-
print('Only found enough images for {} batches. Continuing anyway...'.format(
211-
len(calibration_files) // args.batch))
208+
# Number of images for calibration must be greater than or equal to the kOPT optimization profile
209+
if len(calibration_files) >= args.dynamic_batch_opts[1]:
210+
print('Only found enough images for {} batches. Continuing anyway...'.format(
211+
len(calibration_files) // args.dynamic_batch_opts[1]))
212+
else:
213+
raise RuntimeError('Not enough images found for calibration. ({} < {})'
214+
.format(len(calibration_files), args.dynamic_batch_opts[1]))
212215

213216
random.shuffle(calibration_files)
214217

@@ -218,7 +221,7 @@ def worker(rank, args, world, model, state):
218221
elif not args.full_precision:
219222
precision = "FP16"
220223

221-
exported = model.export(input_size, args.dynamic_batch_opts, args.batch, precision, calibration_files,
224+
exported = model.export(input_size, args.dynamic_batch_opts, precision, calibration_files,
222225
args.calibration_table, args.verbose, onnx_only=onnx_only)
223226
if onnx_only:
224227
with open(args.export, 'wb') as out:

retinanet/model.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,17 @@
1515
class Model(nn.Module):
1616
'RetinaNet - https://arxiv.org/abs/1708.02002'
1717

18-
def __init__(self, backbones='ResNet50FPN', classes=80,
19-
ratios=[1.0, 2.0, 0.5], scales=[4 * 2 ** (i / 3) for i in range(3)],
20-
angles=None, rotated_bbox=False, anchor_ious=[0.4, 0.5], config={}):
18+
def __init__(
19+
self,
20+
backbones='ResNet50FPN',
21+
classes=80,
22+
ratios=[1.0, 2.0, 0.5],
23+
scales=[4 * 2 ** (i / 3) for i in range(3)],
24+
angles=None,
25+
rotated_bbox=False,
26+
anchor_ious=[0.4, 0.5],
27+
config={}
28+
):
2129
super().__init__()
2230

2331
if not isinstance(backbones, list):
@@ -242,15 +250,16 @@ def load(cls, filename, rotated_bbox=False):
242250

243251
return model, state
244252

245-
def export(self, size, dynamic_batch_opts, batch, precision, calibration_files, calibration_table, verbose, onnx_only=False):
253+
def export(self, size, dynamic_batch_opts, precision, calibration_files, calibration_table, verbose, onnx_only=False):
246254

247-
import torch.onnx.symbolic_opset10 as onnx_symbolic
255+
import torch.onnx.symbolic_opset11 as onnx_symbolic
248256
def upsample_nearest2d(g, input, output_size, *args):
249-
# Currently, TRT 5.1/6.0/7.0 ONNX Parser does not support all ONNX ops
257+
# Currently, TRT 7.1 ONNX Parser does not support all ONNX ops
250258
# needed to support dynamic upsampling ONNX forumlation
251259
# Here we hardcode scale=2 as a temporary workaround
252260
scales = g.op("Constant", value_t=torch.tensor([1., 1., 2., 2.]))
253-
return g.op("Resize", input, scales, mode_s="nearest")
261+
empty_tensor = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32))
262+
return g.op("Resize", input, empty_tensor, scales, mode_s="nearest", nearest_mode_s="floor")
254263

255264
onnx_symbolic.upsample_nearest2d = upsample_nearest2d
256265

@@ -265,7 +274,7 @@ def upsample_nearest2d(g, input, output_size, *args):
265274
dynamic_axes = {input_names[0]: {0:'batch'}}
266275
for _, name in enumerate(output_names):
267276
dynamic_axes[name] = dynamic_axes[input_names[0]]
268-
extra_args = {'opset_version': 10, 'verbose': verbose,
277+
extra_args = {'opset_version': 11, 'verbose': verbose,
269278
'input_names': input_names, 'output_names': output_names,
270279
'dynamic_axes': dynamic_axes}
271280
torch.onnx.export(self.cuda(), zero_input, onnx_bytes, **extra_args)
@@ -284,6 +293,6 @@ def upsample_nearest2d(g, input, output_size, *args):
284293
anchors = [generate_anchors_rotated(stride, self.ratios, self.scales,
285294
self.angles)[0].view(-1).tolist() for stride in self.strides]
286295

287-
return Engine(onnx_bytes.getvalue(), len(onnx_bytes.getvalue()), dynamic_batch_opts, batch,
288-
precision, self.threshold, self.top_n, anchors, self.rotated_bbox, self.nms,
289-
self.detections, calibration_files, model_name, calibration_table, verbose)
296+
return Engine(onnx_bytes.getvalue(), len(onnx_bytes.getvalue()), dynamic_batch_opts, precision,
297+
self.threshold, self.top_n, anchors, self.rotated_bbox, self.nms, self.detections,
298+
calibration_files, model_name, calibration_table, verbose)

retinanet/train.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,12 @@ def train(model, state, path, annotations, val_path, val_annotations, resize, ma
3131
# Setup optimizer and schedule
3232
optimizer = SGD(model.parameters(), lr=lr, weight_decay=regularization_l2, momentum=0.9)
3333

34+
loss_scale = "dynamic" if use_dali else "128.0"
35+
3436
model, optimizer = amp.initialize(model, optimizer,
3537
opt_level='O2' if mixed_precision else 'O0',
3638
keep_batchnorm_fp32=True,
37-
loss_scale=128.0,
39+
loss_scale=loss_scale,
3840
verbosity=is_master)
3941

4042
if world > 1:

0 commit comments

Comments
 (0)