Skip to content

Commit 4c6f61e

Browse files
committed
Update
1 parent 94f3552 commit 4c6f61e

File tree

3 files changed

+150
-98
lines changed

3 files changed

+150
-98
lines changed

gen_tfrecord.py

Lines changed: 126 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,13 @@
1+
'''
2+
tf2.x 版本转换PASCAL VOC至tfrecord格式
3+
1. 使用labelimg等标注工具制作pascal voc格式数据集,注意:图像存储在JPEGImages文件夹,xml标注文件存储在Annotations文件夹
4+
2. 将xml格式转换成csv格式,本脚本使用xml_to_csv函数已经在内部实现
5+
3. 将csv转成TFrecord格式,注意tf1.x版本和tf2.x版本接口是不一样的
6+
7+
参考链接:https://www.pythonf.cn/read/109620
8+
9+
注意事项:对于自定义数据集,需要指定labels列表
10+
'''
111
from __future__ import division
212
from __future__ import print_function
313
from __future__ import absolute_import
@@ -8,93 +18,141 @@
818
import tensorflow as tf
919

1020
from PIL import Image
11-
from object_detection.utils import dataset_util
21+
# from object_detection.utils import dataset_util
1222
from collections import namedtuple, OrderedDict
13-
import tqdm
23+
from tqdm import tqdm
1424
import argparse
15-
25+
import glob
26+
import xml.etree.ElementTree as ET
27+
from pathlib import Path
1628
# flags = tf.app.flags
1729
# flags.DEFINE_string('csv_input', '', 'Path to the CSV input')
1830
# flags.DEFINE_string('output_path', '', 'Path to output TFRecord')
1931
# FLAGS = flags.FLAGS
2032
# TO-DO replace this with label map
21-
labels = ['cow', 'tvmonitor', 'car', 'aeroplane', 'sheep',
22-
'motorbike', 'train', 'chair', 'person', 'sofa',
23-
'pottedplant', 'diningtable', 'horse', 'bottle',
24-
'boat', 'bus', 'bird', 'bicycle', 'cat', 'dog']
33+
# labels = ['cow', 'tvmonitor', 'car', 'aeroplane', 'sheep',
34+
# 'motorbike', 'train', 'chair', 'person', 'sofa',
35+
# 'pottedplant', 'diningtable', 'horse', 'bottle',
36+
# 'boat', 'bus', 'bird', 'bicycle', 'cat', 'dog']
2537

26-
def class_text_to_int(row_label, labels):
38+
# 根据自定义数据集修改该列表
39+
labels = ['raccoon']
40+
41+
def class_text_to_int(row_label):
2742
return labels.index(row_label)+1
2843

2944
def split(df, group):
3045
data = namedtuple('data', ['filename', 'object'])
3146
gb = df.groupby(group)
3247
return [data(filename, gb.get_group(x)) for filename, x in zip(gb.groups.keys(), gb.groups)]
3348

49+
def xml_to_csv(xml_anno, data_type):
50+
'''
51+
xml_anno: pascal voc标准文件路径
52+
data_type:['trainvaltest','train','val','trainval','test']
53+
'''
54+
xml_list = []
55+
# xml_files = []
56+
txt_file = str(Path(xml_anno).parent/'ImageSets/Main'/f'{data_type}.txt')
57+
xml_files = [os.path.join(xml_anno, k.strip()+'.xml') for k in open(txt_file,'r').readlines()]
58+
# for xml_file in glob.glob(xml_anno + '/*.xml'):
59+
for xml_file in xml_files:
60+
tree = ET.parse(xml_file)
61+
root = tree.getroot()
62+
for member in root.findall('object'):
63+
value = (root.find('filename').text,
64+
int(root.find('size')[0].text),
65+
int(root.find('size')[1].text),
66+
member[0].text,
67+
int(member[4][0].text),
68+
int(member[4][1].text),
69+
int(member[4][2].text),
70+
int(member[4][3].text)
71+
)
72+
xml_list.append(value)
73+
column_name = ['filename', 'width', 'height', 'class', 'xmin', 'ymin', 'xmax', 'ymax']
74+
xml_df = pd.DataFrame(xml_list, columns=column_name)
75+
return xml_df
76+
77+
def create_tf_example(group, path):
78+
with tf.io.gfile.GFile(os.path.join(path, '{}'.format(group.filename)), 'rb') as fid:
79+
encoded_jpg = fid.read()
80+
encoded_jpg_io = io.BytesIO(encoded_jpg)
81+
image = Image.open(encoded_jpg_io)
82+
width, height = image.size
83+
84+
filename = group.filename.encode('utf8')
85+
image_format = opt.format.encode('utf8')
86+
xmins = []
87+
xmaxs = []
88+
ymins = []
89+
ymaxs = []
90+
classes_text = []
91+
classes = []
92+
93+
for index, row in group.object.iterrows():
94+
xmins.append(row['xmin'] / width)
95+
xmaxs.append(row['xmax'] / width)
96+
ymins.append(row['ymin'] / height)
97+
ymaxs.append(row['ymax'] / height)
98+
classes_text.append(row['class'].encode('utf8'))
99+
classes.append(class_text_to_int(row['class']))
100+
101+
tf_example = tf.train.Example(features=tf.train.Features(feature={
102+
'image/height': tf.train.Feature(int64_list=tf.train.Int64List(value=[height])),
103+
'image/width': tf.train.Feature(int64_list=tf.train.Int64List(value=[width])),
104+
'image/filename':tf.train.Feature(bytes_list=tf.train.BytesList(value=[filename])),
105+
'image/source_id': tf.train.Feature(bytes_list=tf.train.BytesList(value=[filename])),
106+
'image/encoded': tf.train.Feature(bytes_list=tf.train.BytesList(value=[encoded_jpg])),
107+
'image/format': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_format])),
108+
'image/object/bbox/xmin': tf.train.Feature(float_list=tf.train.FloatList(value=xmins)),
109+
'image/object/bbox/xmax': tf.train.Feature(float_list=tf.train.FloatList(value=xmaxs)),
110+
'image/object/bbox/ymin': tf.train.Feature(float_list=tf.train.FloatList(value=ymins)),
111+
'image/object/bbox/ymax':tf.train.Feature(float_list=tf.train.FloatList(value=ymaxs)),
112+
'image/object/class/text': tf.train.Feature(bytes_list=tf.train.BytesList(value=classes_text)),
113+
'image/object/class/label': tf.train.Feature(int64_list=tf.train.Int64List(value=classes)),
114+
}))
115+
return tf_example
34116

35-
def create_tf_example(group, path):
36-
with tf.io.gfile.GFile(os.path.join(path, '{}'.format(group.filename)), 'rb') as fid:
37-
encoded_jpg = fid.read()
38-
encoded_jpg_io = io.BytesIO(encoded_jpg)
39-
image = Image.open(encoded_jpg_io)
40-
width, height = image.size
41-
42-
filename = group.filename.encode('utf8')
43-
image_format = b'jpg'
44-
xmins = []
45-
xmaxs = []
46-
ymins = []
47-
ymaxs = []
48-
classes_text = []
49-
classes = []
50-
51-
for index, row in group.object.iterrows():
52-
xmins.append(row['xmin'] / width)
53-
xmaxs.append(row['xmax'] / width)
54-
ymins.append(row['ymin'] / height)
55-
ymaxs.append(row['ymax'] / height)
56-
classes_text.append(row['class'].encode('utf8'))
57-
classes.append(class_text_to_int(row['class'], group.filename))
58-
59-
tf_example = tf.train.Example(features=tf.train.Features(feature={
60-
'image/height': dataset_util.int64_feature(height),
61-
'image/width': dataset_util.int64_feature(width),
62-
'image/filename': dataset_util.bytes_feature(filename),
63-
'image/source_id': dataset_util.bytes_feature(filename),
64-
'image/encoded': dataset_util.bytes_feature(encoded_jpg),
65-
'image/format': dataset_util.bytes_feature(image_format),
66-
'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
67-
'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
68-
'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
69-
'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
70-
'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
71-
'image/object/class/label': dataset_util.int64_list_feature(classes),
72-
}))
73-
return tf_example
74-
75-
76-
def main(csv_input, output_path):
77-
writer = tf.io.TFRecordWriter(output_path)
78-
path = os.path.join(os.getcwd(), 'images')
79-
examples = pd.read_csv(csv_input)
80-
grouped = split(examples, 'filename')
81-
num=0
82-
for group in grouped:
83-
num+=1
84-
tf_example = create_tf_example(group, path)
85-
writer.write(tf_example.SerializeToString())
86-
if(num%100==0): #每完成100个转换,打印一次
87-
print(num)
88-
89-
writer.close()
90-
output_path = os.path.join(os.getcwd(), output_path)
91-
print('Successfully created the TFRecords: {}'.format(output_path))
92117

118+
def main(voc_root, output_name):
119+
img_path = os.path.join(voc_root, 'JPEGImages')
120+
# examples = pd.read_csv(csv_input)
121+
imgset_path = os.path.join(voc_root, 'ImageSets/Main')
122+
if not os.path.exists(imgset_path):
123+
raise Exception('ImageSets/Main文件夹不存在,请通过脚本生成相应的文件!')
124+
txt_files = ['trainvaltest.txt','train.txt','val.txt','trainval.txt','test.txt']
125+
126+
valid_txt = []
127+
for k in txt_files:
128+
txt = os.path.join(imgset_path, k)
129+
if os.path.exists(txt):
130+
valid_txt.append(k[:-4])
131+
132+
if valid_txt:
133+
print(valid_txt)
134+
else:
135+
raise Exception('ImageSets/Main文件夹下不存在train.txt等文件,请检查数据集!')
136+
137+
for data_type in valid_txt:
138+
output_path = output_name + f'_{data_type}.tfrecord'
139+
output_path = os.path.join(voc_root, output_path)
140+
writer = tf.io.TFRecordWriter(output_path)
141+
examples = xml_to_csv(os.path.join(voc_root, 'Annotations'), data_type)
142+
grouped = split(examples, 'filename')
143+
144+
for group in tqdm(grouped):
145+
tf_example = create_tf_example(group, img_path)
146+
writer.write(tf_example.SerializeToString())
147+
148+
writer.close()
149+
print('Successfully created the TFRecords: {}'.format(output_path))
93150

94151
if __name__ == '__main__':
95152
# tf.app.run()
96153
parser = argparse.ArgumentParser()
97-
parser.add_argument("--csv_input", type=str, required=True, help="csv文件路径")
98-
parser.add_argument("--output_path", type=str, default="pascal_voc2007.tfrecord", help="tfrecord文件数据路径,默认保存在当前路径")
154+
parser.add_argument("--voc-root", type=str, required=True, help="PASCAL VOC 数据集路径,包含JPEGImages和Annotations两个文件夹")
155+
parser.add_argument("--output_name", type=str, default="voc2020", help="tfrecord文件名称,默认保存在VOC根路径")
156+
parser.add_argument("--format", type=str, default="jpg", help="图像格式")
99157
opt = parser.parse_args()
100-
main(opt.csv_input, opt.output_path)
158+
main(opt.voc_root, opt.output_name)

voc2coco.py

Lines changed: 21 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def get_image_info(ann_path, annotation_root, extract_num_from_imgid=True):
8686
if extract_num_from_imgid and isinstance(img_id, str):
8787
# 采用正则表达式,支持转换的文件命名:0001.png, cls_0021.png, cls0123.jpg, 00123abc.png等
8888
img_id = int(re.findall(r'\d+', img_id)[0])
89+
print(img_id)
8990

9091
size = annotation_root.find('size')
9192
width = int(size.findtext('width'))
@@ -261,25 +262,21 @@ def create_dir(ROOT:str):
261262
if not os.path.exists(ImgSets):
262263
os.mkdir(ImgSets)
263264
ImgSetsMain = os.path.join(ImgSets,'Main')
265+
# if os.path.exists(ImgSetsMain):
266+
# print('目录ImageSets/Main已经存在')
267+
# else:
264268
create_dir(ImgSetsMain)
265269

266270
COCOPROJ = os.path.join(voc_root, opt.coco_dir) # pascal voc转coco格式的存储路径
267271
create_dir(COCOPROJ)
268272

269-
COCOTRAIN = os.path.join(COCOPROJ,'train')
270-
create_dir(COCOTRAIN)
273+
txt_files = ['trainvaltest','train','val','trainval','test']
271274

272-
COCOVAL= os.path.join(COCOPROJ,'val')
273-
create_dir(COCOVAL)
274-
275-
COCOTRAINVAL = os.path.join(COCOPROJ,'trainval')
276-
create_dir(COCOTRAINVAL)
277-
278-
COCOTEST= os.path.join(COCOPROJ,'test')
279-
create_dir(COCOTEST)
280-
281-
COCOALL= os.path.join(COCOPROJ,'trainvaltest')
282-
create_dir(COCOALL)
275+
coco_dirs = []
276+
for dir_ in txt_files:
277+
DIR = os.path.join(COCOPROJ, dir_)
278+
coco_dirs.append(DIR)
279+
create_dir(DIR)
283280

284281
COCOANNO = os.path.join(COCOPROJ, 'annotations') # coco标注文件存放路径
285282
create_dir(COCOANNO)
@@ -298,29 +295,22 @@ def create_dir(ROOT:str):
298295
print('训练集数量: ',len(train))
299296
print('验证集数量: ',len(val))
300297
print('测试集数量: ',len(test))
298+
301299
def write_txt(txt_path, data):
302300
with open(txt_path,'w') as f:
303301
for d in data:
304302
f.write(str(d))
305303
f.write('\n')
304+
306305
# 写入各个txt文件
307-
trainvaltest_txt = os.path.join(ImgSetsMain,'trainvaltest.txt')
308-
write_txt(trainvaltest_txt, files)
306+
datas = [files, train, val, trainval, test]
309307

310-
trainval_txt = os.path.join(ImgSetsMain,'trainval.txt')
311-
write_txt(trainval_txt, trainval)
312-
313-
train_txt = os.path.join(ImgSetsMain,'train.txt')
314-
write_txt(train_txt, train)
315-
316-
val_txt = os.path.join(ImgSetsMain,'val.txt')
317-
write_txt(val_txt, val)
318-
319-
test_txt = os.path.join(ImgSetsMain,'test.txt')
320-
write_txt(test_txt, test)
308+
for txt, data in zip(txt_files, datas):
309+
txt_path = os.path.join(ImgSetsMain, txt+'.txt')
310+
write_txt(txt_path, data)
321311

322312
# 遍历xml文件,得到所有标签值,并且保存为labels.txt
323-
if opt.labels:
313+
if opt.labels==True:
324314
print('从自定义标签文件读取!')
325315
labels = opt.labels
326316
else:
@@ -334,9 +324,10 @@ def write_txt(txt_path, data):
334324
label2id = get_label2id(labels_path=labels)
335325
print('标签值及其对应的编码值:',label2id)
336326

337-
for name,imgs,PATH in tqdm(zip(['trainvaltest','train','val','trainval','test'],
338-
[files, train,val,trainval,test],
339-
[COCOALL, COCOTRAIN, COCOVAL, COCOTRAINVAL, COCOTEST])):
327+
for name,imgs,PATH in tqdm(zip(txt_files,
328+
datas,
329+
coco_dirs)):
330+
340331
annotation_paths = []
341332
for img in imgs:
342333
annotation_paths.append(os.path.join(ANNO, img+'.xml'))

voc_gen_trainval_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,9 @@ def write_txt(txt_path, data):
8686
f.write(str(d))
8787
f.write('\n')
8888
# 写入各个txt文件
89+
trainvaltest_txt = os.path.join(ImgSetsMain,'trainvaltest.txt')
90+
write_txt(trainvaltest_txt, files)
91+
8992
trainval_txt = os.path.join(ImgSetsMain,'trainval.txt')
9093
write_txt(trainval_txt, trainval)
9194

0 commit comments

Comments
 (0)