Skip to content

Commit e7bb8c2

Browse files
authored
feat: Add functionality to performance tooling (#1451)
- Add functionality for timing compilation in addition to inference - Add bash scripting code for concatenating all model result outputs
1 parent cd1bda3 commit e7bb8c2

File tree

2 files changed

+19
-3
lines changed

2 files changed

+19
-3
lines changed

tools/perf/benchmark.sh

+8
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,11 @@ do
6262
--truncate \
6363
--report "bert_base_perf_bs${bs}.txt"
6464
done
65+
66+
# Collect and concatenate all results
67+
echo "Concatenating all results"
68+
(echo "Output of All Model Runs"; echo) >> all_outputs.txt;
69+
70+
for i in $(ls *_bs*.txt);
71+
do (echo $i; cat $i; echo; echo) >> all_outputs.txt;
72+
done

tools/perf/perf_run.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from __future__ import absolute_import
33
from __future__ import division
44

5+
import time
56
import timeit
67
import numpy as np
78
import torch.backends.cudnn as cudnn
@@ -103,7 +104,10 @@ def run_torch_tensorrt(
103104
if precision == "int8":
104105
compile_settings.update({"calib": params.get("calibration_cache")})
105106

107+
start_compile = time.time_ns()
106108
model = torchtrt.compile(model, **compile_settings)
109+
end_compile = time.time_ns()
110+
compile_time_ms = (end_compile - start_compile) / 1e6
107111

108112
iters = params.get("iterations", 20)
109113
# Warm up
@@ -123,7 +127,7 @@ def run_torch_tensorrt(
123127
meas_time = end_time - start_time
124128
timings.append(meas_time)
125129

126-
recordStats("Torch-TensorRT", timings, precision, batch_size)
130+
recordStats("Torch-TensorRT", timings, precision, batch_size, compile_time_ms)
127131

128132

129133
# Runs inference using FX2TRT backend
@@ -136,13 +140,16 @@ def run_fx2trt(model, input_tensors, params, precision, batch_size):
136140
model.half()
137141
input_tensors = [tensor.half() for tensor in input_tensors]
138142
# Run lowering eager mode benchmark
143+
start_compile = time.time_ns()
139144
model = compile(
140145
model,
141146
input_tensors,
142147
max_batch_size=batch_size,
143148
lower_precision=precision,
144149
verbose_log=False,
145150
)
151+
end_compile = time.time_ns()
152+
compile_time_ms = (end_compile - start_compile) / 1e6
146153

147154
iters = params.get("iterations", 20)
148155
# Warm up
@@ -162,7 +169,7 @@ def run_fx2trt(model, input_tensors, params, precision, batch_size):
162169
meas_time = end_time - start_time
163170
timings.append(meas_time)
164171

165-
recordStats("FX-TensorRT", timings, precision, batch_size)
172+
recordStats("FX-TensorRT", timings, precision, batch_size, compile_time_ms)
166173

167174

168175
def torch_dtype_from_trt(dtype):
@@ -331,7 +338,7 @@ def run(
331338

332339

333340
# Generate report
334-
def recordStats(backend, timings, precision, batch_size=1):
341+
def recordStats(backend, timings, precision, batch_size=1, compile_time_ms=None):
335342
times = np.array(timings)
336343
steps = len(times)
337344
speeds = batch_size / times
@@ -350,6 +357,7 @@ def recordStats(backend, timings, precision, batch_size=1):
350357
"Mean(FPS)": speed_mean,
351358
"Median-Latency(ms)": time_med * 1000,
352359
"Mean-Latency(ms)": time_mean * 1000,
360+
"Compile Time(ms)": compile_time_ms,
353361
}
354362
results.append(stats)
355363

0 commit comments

Comments
 (0)