Skip to content

Commit c465c14

Browse files
author
zhangming8
committed
default seed=None; pre-allocate gpu memory
1 parent 5d6faac commit c465c14

File tree

5 files changed

+33
-7
lines changed

5 files changed

+33
-7
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ mAP was reevaluated on COCO val2017 and test2017, and some results are slightly
5757

5858
d. Some tips:
5959
Ⅰ You can also change params in 'train.sh'(these params will replace opt.xxx in config.py) and use 'nohup sh train.sh &' to train
60-
Ⅱ If you want to close mulit-size training, change opt.random_size = None or (20, 21) in 'config.py' or set random_size=None in 'train.sh'
60+
Ⅱ If you want to close mulit-size training, change opt.random_size = None in 'config.py' or set random_size=None in 'train.sh'
6161
Ⅲ Mulit-gpu train: change opt.gpus = "3,5,6,7"
6262
Ⅳ Visualized log by tensorboard: tensorboard --logdir exp/your_exp_id/logs_2021-08-xx-xx-xx and visit http://localhost:6006
6363
Your can also use the following shell scripts:

config.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def update_nano_tiny(cfg, inp_params):
2929
# opt.dataset_path = r"D:\work\public_dataset\coco2017" # Windows system
3030
opt.backbone = "CSPDarknet-s" # CSPDarknet-nano, CSPDarknet-tiny, CSPDarknet-s, CSPDarknet-m, l, x
3131
opt.input_size = (640, 640)
32-
opt.random_size = (14, 26) # None; multi-size train: from 448 to 800, random sample an int value and *32 as input size
32+
opt.random_size = (14, 26) # None; multi-size train: from 448(14*32) to 832(26*32), set None to disable it
3333
opt.test_size = (640, 640) # evaluate size
3434
opt.gpus = "0" # "-1" "0" "3,4,5" "0,1,2,3,4,5,6,7" # -1 for cpu
3535
opt.batch_size = 24
@@ -78,7 +78,7 @@ def update_nano_tiny(cfg, inp_params):
7878
opt.shear = 2.0
7979
opt.perspective = 0.0
8080
opt.enable_mixup = True
81-
opt.seed = 0
81+
opt.seed = None # 0
8282
opt.data_num_workers = 4
8383

8484
opt.momentum = 0.9
@@ -94,6 +94,7 @@ def update_nano_tiny(cfg, inp_params):
9494
opt.use_amp = False # True
9595
opt.cuda_benchmark = True
9696
opt.nms_thresh = 0.65
97+
opt.occupy_mem = False # pre-allocate gpu memory for training to avoid memory Fragmentation.
9798

9899
opt.rgb_means = [0.485, 0.456, 0.406]
99100
opt.std = [0.229, 0.224, 0.225]
@@ -132,6 +133,7 @@ def update_nano_tiny(cfg, inp_params):
132133
opt.cuda_benchmark = False
133134
if opt.reid_dim > 0:
134135
assert opt.tracking_id_nums is not None
135-
136+
if opt.random_size is None:
137+
opt.test_size = opt.input_size
136138
os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpus_str
137139
print("\n{} final config: {}\n{}".format("-" * 20, "-" * 20, opt))

train.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from models.yolox import get_model
2020
from models.post_process import yolox_post_process
2121
from utils.lr_scheduler import LRScheduler
22-
from utils.util import AverageMeter, write_log, configure_module
22+
from utils.util import AverageMeter, write_log, configure_module, occupy_mem
2323
from utils.model_utils import EMA, save_model, load_model, ensure_same, clip_grads
2424
from utils.data_parallel import set_device, _DataParallel
2525
from utils.logger import Logger
@@ -98,8 +98,7 @@ def run_epoch(model_with_loss, optimizer, scaler, ema, phase, epoch, data_iter,
9898
avg_loss_stats[l] = AverageMeter()
9999
avg_loss_stats[l].update(loss_stats[l], inps.size(0))
100100
Bar.suffix = Bar.suffix + '|{} {:.4f} '.format(l, avg_loss_stats[l].avg)
101-
Bar.suffix = Bar.suffix + '|Data {dt.val:.3f}s({dt.avg:.3f}s) |Net {bt.avg:.3f}s'.format(dt=data_time,
102-
bt=batch_time)
101+
Bar.suffix = Bar.suffix + '|Data {dt.val:.3f}s |Batch {bt.val:.3f}s'.format(dt=data_time, bt=batch_time)
103102
if opt.print_iter > 0 and iter_id % opt.print_iter == 0:
104103
print('{}| {}'.format(opt.exp_id, Bar.suffix))
105104
logger.write('{}| {}\n'.format(opt.exp_id, Bar.suffix))
@@ -178,6 +177,7 @@ def train(model, scaler, train_loader, val_loader, optimizer, lr_scheduler, star
178177
if loss_dict_val['loss'] <= best:
179178
best = loss_dict_val['loss']
180179
save_model(os.path.join(opt.save_dir, 'model_best.pth'), epoch, model)
180+
del loss_dict_val, preds
181181

182182
save_model(os.path.join(opt.save_dir, 'model_{}.pth'.format(epoch)), epoch,
183183
model) if epoch % opt.save_epoch == 0 else ""
@@ -227,6 +227,8 @@ def main():
227227

228228
# DP
229229
opt.device = torch.device('cuda' if opt.gpus[0] >= 0 else 'cpu')
230+
if opt.occupy_mem and opt.device.type != 'cpu':
231+
occupy_mem(opt.device)
230232
model, optimizer = set_device(model, optimizer, opt)
231233
train(model, scaler, train_loader, val_loader, optimizer, lr_scheduler, start_epoch, opt.accumulate, no_aug)
232234

train.sh

+3
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,6 @@ python train.py gpus='0' backbone="CSPDarknet-s" num_epochs=300 exp_id="coco_CSP
1919

2020
# resume 'model_last.pth', include weight, optimizer, scaler and epoch
2121
#python train.py gpus='0' backbone="CSPDarknet-s" num_epochs=300 exp_id="coco_CSPDarknet-s_640x640" use_amp=True val_intervals=2 data_num_workers=6 metric="ap" batch_size=48 load_model="exp/coco_CSPDarknet-s_640x640/model_last.pth" resume=True
22+
23+
# GPU memory changes with the input size when multi-size training, which can be avoided by pre allocating memory
24+
#python train.py gpus='0' backbone="CSPDarknet-tiny" num_epochs=300 exp_id="coco_CSPDarknet-tiny_416x416" use_amp=True val_intervals=2 data_num_workers=6 metric="ap" batch_size=128 occupy_mem=True

utils/util.py

+19
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,25 @@ def sync_time(inputs):
6060
return time.time()
6161

6262

63+
def get_total_and_free_memory_in_mb(cuda_device):
64+
devices_info_str = os.popen("nvidia-smi --query-gpu=memory.total,memory.used --format=csv,nounits,noheader")
65+
devices_info = devices_info_str.read().strip().split("\n")
66+
total, used = devices_info[int(cuda_device)].split(",")
67+
return int(total), int(used)
68+
69+
70+
def occupy_mem(cuda_device, mem_ratio=0.9):
71+
"""
72+
pre-allocate gpu memory for training to avoid memory Fragmentation.
73+
"""
74+
total, used = get_total_and_free_memory_in_mb(0)
75+
max_mem = int(total * mem_ratio)
76+
block_mem = max_mem - used
77+
x = torch.FloatTensor(256, 1024, block_mem).to(cuda_device)
78+
del x
79+
time.sleep(5)
80+
81+
6382
def gpu_mem_usage():
6483
"""
6584
Compute the GPU memory usage for the current device (MB).

0 commit comments

Comments
 (0)