Skip to content

Commit e162fc0

Browse files
author
zhangming8
committed
train Customer data; will support mult-object tracking soon
1 parent 86a4491 commit e162fc0

18 files changed

+514
-299
lines changed

.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11

2+
#!demo/*.jpg
3+
24
*.so
35
*.o
46
.DS_Store

README.md

+36-13
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,54 @@
1-
# a Pytorch easy re-implement of "YOLOX: Exceeding YOLO Series in 2021"
1+
## a Pytorch easy re-implement of "YOLOX: Exceeding YOLO Series in 2021"
22

3+
## Notes
4+
1. this is a Pytorch easy re-implement of "YOLOX: Exceeding YOLO Series in 2021"
5+
2. the repo is still under development
6+
3. we needn't install apex, Pytorch(version >= 1.7.0) has supported it
37

4-
# environment
5-
pytorch>=1.7.0, python>=3.6
8+
## Environment
9+
pytorch>=1.7.0, python>=3.6, Ubuntu/Windows, see more in 'requirements.txt'
10+
11+
## Dataset
12+
put COCO dataset in following folders:
613

7-
# dataset
8-
put your COCO dataset in fllowing folder:
9-
1014
/path/to/dataset/annotations/instances_train2017.json
1115
/path/to/dataset/annotations/instances_val2017.json
1216
/path/to/dataset/images/train2017/*.jpg
1317
/path/to/dataset/images/val2017/*.jpg
1418

15-
modify 'config.py'
16-
opt.dataset_path = "/path/to/dataset"
19+
change opt.dataset_path = "/path/to/dataset" in 'config.py'
1720

18-
# train
21+
## Train
1922
sh train.sh
2023

21-
# evaluate
24+
## Evaluate
2225
sh evaluate.sh
2326

24-
# predict/inference/demo
27+
## Predict/Inference/Demo
2528
sh predict.sh
26-
2729

28-
# reference
30+
## Train Customer Dataset(VOC format)
31+
32+
1. put your annotations(.xml) and images(.jpg) into:
33+
/path/to/voc_data/images/train2017/*.jpg # train images
34+
/path/to/voc_data/images/train2017/*.xml # train xml annotations
35+
/path/to/voc_data/images/val2017/*.jpg # val images
36+
/path/to/voc_data/images/val2017/*.xml # val xml annotations
37+
38+
2. change opt.label_name = ['your', 'dataset', 'label'] in 'config.py'
39+
change opt.dataset_path = '/path/to/voc_data' in 'config.py'
40+
41+
3. python tools/voc_to_coco.py
42+
Converted COCO format annotation will be saved into:
43+
/path/to/voc_data/annotations/instances_train2017.json
44+
/path/to/voc_data/annotations/instances_val2017.json
45+
46+
4. (Optional) you can visualize the converted annotations by:
47+
python tools/show_coco_anns.py
48+
49+
5. run train.sh, evaluate.sh, predict.sh (are the same as COCO)
50+
51+
## Reference
2952
https://github.com/Megvii-BaseDetection/YOLOX
3053
https://github.com/PaddlePaddle/PaddleDetection
3154
https://github.com/open-mmlab/mmdetection

config.py

+17-15
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,25 @@ def update_nano_tiny(cfg):
3434
opt.random_size = (14, 26) # None
3535
opt.accumulate = 1 # real batch size = accumulate * batch_size
3636

37+
# coco 80 classes
38+
opt.label_name = [
39+
'person', 'bicycle', 'car', 'motorcycle', 'airplane',
40+
'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
41+
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
42+
'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack',
43+
'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis',
44+
'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
45+
'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass',
46+
'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich',
47+
'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake',
48+
'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv',
49+
'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
50+
'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
51+
'scissors', 'teddy bear', 'hair drier', 'toothbrush']
52+
# opt.label_name = ['person']
3753
# TODO: support MOT(multi-object tracking) like FairMot/JDE when reid_dim > 0
3854
opt.reid_dim = 0 # 128
39-
opt.id_num = None # tracking id number in train dataset
55+
opt.tracking_id_nums = None # tracking id number in train dataset
4056

4157
opt.warmup_lr = 0
4258
opt.basic_lr_per_img = 0.01 / 64.0
@@ -72,20 +88,6 @@ def update_nano_tiny(cfg):
7288
opt.cuda_benchmark = True
7389
opt.nms_thresh = 0.65
7490

75-
opt.label_name = [
76-
'person', 'bicycle', 'car', 'motorcycle', 'airplane',
77-
'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
78-
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
79-
'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack',
80-
'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis',
81-
'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
82-
'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass',
83-
'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich',
84-
'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake',
85-
'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv',
86-
'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
87-
'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
88-
'scissors', 'teddy bear', 'hair drier', 'toothbrush']
8991
opt.rgb_means = [0.485, 0.456, 0.406]
9092
opt.std = [0.229, 0.224, 0.225]
9193

data/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# -*- coding:utf-8 -*-
33
# Copyright (c) Megvii, Inc. and its affiliates.
44

5-
from .data_augment import TrainTransform, ValTransform
5+
from .data_augment import TrainTransform
66
from .data_prefetcher import DataPrefetcher
77
from .dataloading import DataLoader, get_yolox_datadir
88
from .datasets import *

data/coco_dataset.py

+24-9
Original file line numberDiff line numberDiff line change
@@ -9,28 +9,30 @@
99
import sys
1010

1111
sys.path.append(".")
12-
from data import (COCODataset, TrainTransform, YoloBatchSampler, DataLoader, InfiniteSampler, MosaicDetection,
13-
ValTransform)
12+
from data import (COCODataset, TrainTransform, YoloBatchSampler, DataLoader, InfiniteSampler, MosaicDetection)
1413

1514

1615
def get_dataloader(opt, no_aug=False):
1716
# train
17+
do_tracking = opt.reid_dim > 0
1818
train_dataset = COCODataset(data_dir=opt.data_dir,
1919
json_file=opt.train_ann,
2020
img_size=opt.input_size,
21-
preproc=TrainTransform(rgb_means=opt.rgb_means, std=opt.std, max_labels=50),
21+
tracking=do_tracking,
22+
preproc=TrainTransform(rgb_means=opt.rgb_means, std=opt.std, tracking=do_tracking),
2223
)
2324
train_dataset = MosaicDetection(
2425
train_dataset,
2526
mosaic=not no_aug,
2627
img_size=opt.input_size,
27-
preproc=TrainTransform(rgb_means=opt.rgb_means, std=opt.std, max_labels=120),
28+
preproc=TrainTransform(rgb_means=opt.rgb_means, std=opt.std, max_labels=120, tracking=do_tracking),
2829
degrees=opt.degrees,
2930
translate=opt.translate,
3031
scale=opt.scale,
3132
shear=opt.shear,
3233
perspective=opt.perspective,
3334
enable_mixup=opt.enable_mixup,
35+
tracking=do_tracking,
3436
)
3537
train_sampler = InfiniteSampler(len(train_dataset), seed=opt.seed)
3638
batch_sampler = YoloBatchSampler(
@@ -49,7 +51,9 @@ def get_dataloader(opt, no_aug=False):
4951
json_file=opt.val_ann,
5052
name="val2017",
5153
img_size=opt.test_size,
52-
preproc=ValTransform(rgb_means=opt.rgb_means, std=opt.std))
54+
tracking=do_tracking,
55+
preproc=TrainTransform(rgb_means=opt.rgb_means, std=opt.std, max_labels=120, tracking=do_tracking,
56+
augment=False))
5357
val_sampler = torch.utils.data.SequentialSampler(val_dataset)
5458
val_kwargs = {"num_workers": opt.data_num_workers, "pin_memory": True, "sampler": val_sampler,
5559
"batch_size": opt.batch_size}
@@ -68,24 +72,33 @@ def vis_inputs(inputs, targets, opt):
6872
img = (((inp.transpose((1, 2, 0)) * opt.std) + opt.rgb_means) * 255).astype(np.uint8)
6973
img = img[:, :, ::-1]
7074
img = np.ascontiguousarray(img)
75+
gt_n = 0
7176
for t in target:
7277
if t.sum() > 0:
73-
cls, c_x, c_y, w, h = [int(i) for i in t]
78+
if len(t) == 5:
79+
cls, c_x, c_y, w, h = [int(i) for i in t]
80+
tracking_id = None
81+
elif len(t) == 6:
82+
cls, c_x, c_y, w, h, tracking_id = [int(i) for i in t]
83+
else:
84+
raise ValueError("target shape != 5 or 6")
7485
bbox = [c_x - w // 2, c_y - h // 2, c_x + w // 2, c_y + h // 2]
7586
label = opt.label_name[cls]
7687
# print(label, bbox)
7788
color = label_color[cls]
7889
# show box
7990
cv2.rectangle(img, (bbox[0], bbox[1]), (bbox[2], bbox[3]), color, 2)
8091
# show label and conf
81-
txt = '{}'.format(label)
92+
txt = '{}-{}'.format(label, tracking_id) if tracking_id is not None else '{}'.format(label)
8293
font = cv2.FONT_HERSHEY_SIMPLEX
8394
txt_size = cv2.getTextSize(txt, font, 0.5, 2)[0]
8495
cv2.rectangle(img, (bbox[0], bbox[1] - txt_size[1] - 2), (bbox[0] + txt_size[0], bbox[1] - 2), color,
8596
-1)
8697
cv2.putText(img, txt, (bbox[0], bbox[1] - 2), font, 0.5, (255, 255, 255), thickness=1,
8798
lineType=cv2.LINE_AA)
99+
gt_n += 1
88100

101+
print("img {}/{} gt number: {}".format(b_i, len(inputs), gt_n))
89102
cv2.namedWindow("input", 0)
90103
cv2.imshow("input", img)
91104
key = cv2.waitKey(0)
@@ -99,7 +112,7 @@ def run_epoch(data_iter, loader, total_iter, e, phase, opt):
99112
batch = next(data_iter)
100113
inps, targets, img_info, ind = batch
101114
print("------------ epoch {} batch {}/{} ---------------".format(e, batch_i, total_iter))
102-
print(inps.shape, targets.shape)
115+
print("batch img shape {}, target shape {}".format(inps.shape, targets.shape))
103116
vis_inputs(inps, targets, opt)
104117
if batch_i == 0:
105118
print(ind)
@@ -119,8 +132,10 @@ def main():
119132
from config import opt
120133

121134
opt.input_size = (640, 640)
135+
opt.test_size = (640, 640)
122136
opt.batch_size = 2
123137
opt.data_num_workers = 0
138+
opt.reid_dim = 0 # 128
124139
print(opt)
125140
train_loader, val_loader = get_dataloader(opt, no_aug=False)
126141

@@ -132,8 +147,8 @@ def main():
132147
total_iter = len(loader)
133148
data_iter = iter(loader)
134149
for e in range(100):
135-
# train_loader.dataset.enable_mosaic = False
136150
# train_loader.dataset.enable_mixup = False
151+
# train_loader.dataset.enable_mosaic = False
137152
# train_loader.close_mosaic()
138153
# print(train_loader.batch_sampler.mosaic)
139154
run_epoch(data_iter, loader, total_iter, e, phase, opt)

0 commit comments

Comments
 (0)