1
+ '''
2
+ tf2.x 版本转换PASCAL VOC至tfrecord格式
3
+ 1. 使用labelimg等标注工具制作pascal voc格式数据集,注意:图像存储在JPEGImages文件夹,xml标注文件存储在Annotations文件夹
4
+ 2. 将xml格式转换成csv格式,本脚本使用xml_to_csv函数已经在内部实现
5
+ 3. 将csv转成TFrecord格式,注意tf1.x版本和tf2.x版本接口是不一样的
6
+
7
+ 参考链接:https://www.pythonf.cn/read/109620
8
+
9
+ 注意事项:对于自定义数据集,需要指定labels列表
10
+ '''
1
11
from __future__ import division
2
12
from __future__ import print_function
3
13
from __future__ import absolute_import
8
18
import tensorflow as tf
9
19
10
20
from PIL import Image
11
- from object_detection .utils import dataset_util
21
+ # from object_detection.utils import dataset_util
12
22
from collections import namedtuple , OrderedDict
13
- import tqdm
23
+ from tqdm import tqdm
14
24
import argparse
15
-
25
+ import glob
26
+ import xml .etree .ElementTree as ET
27
+ from pathlib import Path
16
28
# flags = tf.app.flags
17
29
# flags.DEFINE_string('csv_input', '', 'Path to the CSV input')
18
30
# flags.DEFINE_string('output_path', '', 'Path to output TFRecord')
19
31
# FLAGS = flags.FLAGS
20
32
# TO-DO replace this with label map
21
- labels = ['cow' , 'tvmonitor' , 'car' , 'aeroplane' , 'sheep' ,
22
- 'motorbike' , 'train' , 'chair' , 'person' , 'sofa' ,
23
- 'pottedplant' , 'diningtable' , 'horse' , 'bottle' ,
24
- 'boat' , 'bus' , 'bird' , 'bicycle' , 'cat' , 'dog' ]
33
+ # labels = ['cow', 'tvmonitor', 'car', 'aeroplane', 'sheep',
34
+ # 'motorbike', 'train', 'chair', 'person', 'sofa',
35
+ # 'pottedplant', 'diningtable', 'horse', 'bottle',
36
+ # 'boat', 'bus', 'bird', 'bicycle', 'cat', 'dog']
25
37
26
- def class_text_to_int (row_label , labels ):
38
+ # 根据自定义数据集修改该列表
39
+ labels = ['raccoon' ]
40
+
41
+ def class_text_to_int (row_label ):
27
42
return labels .index (row_label )+ 1
28
43
29
44
def split (df , group ):
30
45
data = namedtuple ('data' , ['filename' , 'object' ])
31
46
gb = df .groupby (group )
32
47
return [data (filename , gb .get_group (x )) for filename , x in zip (gb .groups .keys (), gb .groups )]
33
48
49
+ def xml_to_csv (xml_anno , data_type ):
50
+ '''
51
+ xml_anno: pascal voc标准文件路径
52
+ data_type:['trainvaltest','train','val','trainval','test']
53
+ '''
54
+ xml_list = []
55
+ # xml_files = []
56
+ txt_file = str (Path (xml_anno ).parent / 'ImageSets/Main' / f'{ data_type } .txt' )
57
+ xml_files = [os .path .join (xml_anno , k .strip ()+ '.xml' ) for k in open (txt_file ,'r' ).readlines ()]
58
+ # for xml_file in glob.glob(xml_anno + '/*.xml'):
59
+ for xml_file in xml_files :
60
+ tree = ET .parse (xml_file )
61
+ root = tree .getroot ()
62
+ for member in root .findall ('object' ):
63
+ value = (root .find ('filename' ).text ,
64
+ int (root .find ('size' )[0 ].text ),
65
+ int (root .find ('size' )[1 ].text ),
66
+ member [0 ].text ,
67
+ int (member [4 ][0 ].text ),
68
+ int (member [4 ][1 ].text ),
69
+ int (member [4 ][2 ].text ),
70
+ int (member [4 ][3 ].text )
71
+ )
72
+ xml_list .append (value )
73
+ column_name = ['filename' , 'width' , 'height' , 'class' , 'xmin' , 'ymin' , 'xmax' , 'ymax' ]
74
+ xml_df = pd .DataFrame (xml_list , columns = column_name )
75
+ return xml_df
76
+
77
+ def create_tf_example (group , path ):
78
+ with tf .io .gfile .GFile (os .path .join (path , '{}' .format (group .filename )), 'rb' ) as fid :
79
+ encoded_jpg = fid .read ()
80
+ encoded_jpg_io = io .BytesIO (encoded_jpg )
81
+ image = Image .open (encoded_jpg_io )
82
+ width , height = image .size
83
+
84
+ filename = group .filename .encode ('utf8' )
85
+ image_format = opt .format .encode ('utf8' )
86
+ xmins = []
87
+ xmaxs = []
88
+ ymins = []
89
+ ymaxs = []
90
+ classes_text = []
91
+ classes = []
92
+
93
+ for index , row in group .object .iterrows ():
94
+ xmins .append (row ['xmin' ] / width )
95
+ xmaxs .append (row ['xmax' ] / width )
96
+ ymins .append (row ['ymin' ] / height )
97
+ ymaxs .append (row ['ymax' ] / height )
98
+ classes_text .append (row ['class' ].encode ('utf8' ))
99
+ classes .append (class_text_to_int (row ['class' ]))
100
+
101
+ tf_example = tf .train .Example (features = tf .train .Features (feature = {
102
+ 'image/height' : tf .train .Feature (int64_list = tf .train .Int64List (value = [height ])),
103
+ 'image/width' : tf .train .Feature (int64_list = tf .train .Int64List (value = [width ])),
104
+ 'image/filename' :tf .train .Feature (bytes_list = tf .train .BytesList (value = [filename ])),
105
+ 'image/source_id' : tf .train .Feature (bytes_list = tf .train .BytesList (value = [filename ])),
106
+ 'image/encoded' : tf .train .Feature (bytes_list = tf .train .BytesList (value = [encoded_jpg ])),
107
+ 'image/format' : tf .train .Feature (bytes_list = tf .train .BytesList (value = [image_format ])),
108
+ 'image/object/bbox/xmin' : tf .train .Feature (float_list = tf .train .FloatList (value = xmins )),
109
+ 'image/object/bbox/xmax' : tf .train .Feature (float_list = tf .train .FloatList (value = xmaxs )),
110
+ 'image/object/bbox/ymin' : tf .train .Feature (float_list = tf .train .FloatList (value = ymins )),
111
+ 'image/object/bbox/ymax' :tf .train .Feature (float_list = tf .train .FloatList (value = ymaxs )),
112
+ 'image/object/class/text' : tf .train .Feature (bytes_list = tf .train .BytesList (value = classes_text )),
113
+ 'image/object/class/label' : tf .train .Feature (int64_list = tf .train .Int64List (value = classes )),
114
+ }))
115
+ return tf_example
34
116
35
- def create_tf_example (group , path ):
36
- with tf .io .gfile .GFile (os .path .join (path , '{}' .format (group .filename )), 'rb' ) as fid :
37
- encoded_jpg = fid .read ()
38
- encoded_jpg_io = io .BytesIO (encoded_jpg )
39
- image = Image .open (encoded_jpg_io )
40
- width , height = image .size
41
-
42
- filename = group .filename .encode ('utf8' )
43
- image_format = b'jpg'
44
- xmins = []
45
- xmaxs = []
46
- ymins = []
47
- ymaxs = []
48
- classes_text = []
49
- classes = []
50
-
51
- for index , row in group .object .iterrows ():
52
- xmins .append (row ['xmin' ] / width )
53
- xmaxs .append (row ['xmax' ] / width )
54
- ymins .append (row ['ymin' ] / height )
55
- ymaxs .append (row ['ymax' ] / height )
56
- classes_text .append (row ['class' ].encode ('utf8' ))
57
- classes .append (class_text_to_int (row ['class' ], group .filename ))
58
-
59
- tf_example = tf .train .Example (features = tf .train .Features (feature = {
60
- 'image/height' : dataset_util .int64_feature (height ),
61
- 'image/width' : dataset_util .int64_feature (width ),
62
- 'image/filename' : dataset_util .bytes_feature (filename ),
63
- 'image/source_id' : dataset_util .bytes_feature (filename ),
64
- 'image/encoded' : dataset_util .bytes_feature (encoded_jpg ),
65
- 'image/format' : dataset_util .bytes_feature (image_format ),
66
- 'image/object/bbox/xmin' : dataset_util .float_list_feature (xmins ),
67
- 'image/object/bbox/xmax' : dataset_util .float_list_feature (xmaxs ),
68
- 'image/object/bbox/ymin' : dataset_util .float_list_feature (ymins ),
69
- 'image/object/bbox/ymax' : dataset_util .float_list_feature (ymaxs ),
70
- 'image/object/class/text' : dataset_util .bytes_list_feature (classes_text ),
71
- 'image/object/class/label' : dataset_util .int64_list_feature (classes ),
72
- }))
73
- return tf_example
74
-
75
-
76
- def main (csv_input , output_path ):
77
- writer = tf .io .TFRecordWriter (output_path )
78
- path = os .path .join (os .getcwd (), 'images' )
79
- examples = pd .read_csv (csv_input )
80
- grouped = split (examples , 'filename' )
81
- num = 0
82
- for group in grouped :
83
- num += 1
84
- tf_example = create_tf_example (group , path )
85
- writer .write (tf_example .SerializeToString ())
86
- if (num % 100 == 0 ): #每完成100个转换,打印一次
87
- print (num )
88
-
89
- writer .close ()
90
- output_path = os .path .join (os .getcwd (), output_path )
91
- print ('Successfully created the TFRecords: {}' .format (output_path ))
92
117
118
+ def main (voc_root , output_name ):
119
+ img_path = os .path .join (voc_root , 'JPEGImages' )
120
+ # examples = pd.read_csv(csv_input)
121
+ imgset_path = os .path .join (voc_root , 'ImageSets/Main' )
122
+ if not os .path .exists (imgset_path ):
123
+ raise Exception ('ImageSets/Main文件夹不存在,请通过脚本生成相应的文件!' )
124
+ txt_files = ['trainvaltest.txt' ,'train.txt' ,'val.txt' ,'trainval.txt' ,'test.txt' ]
125
+
126
+ valid_txt = []
127
+ for k in txt_files :
128
+ txt = os .path .join (imgset_path , k )
129
+ if os .path .exists (txt ):
130
+ valid_txt .append (k [:- 4 ])
131
+
132
+ if valid_txt :
133
+ print (valid_txt )
134
+ else :
135
+ raise Exception ('ImageSets/Main文件夹下不存在train.txt等文件,请检查数据集!' )
136
+
137
+ for data_type in valid_txt :
138
+ output_path = output_name + f'_{ data_type } .tfrecord'
139
+ output_path = os .path .join (voc_root , output_path )
140
+ writer = tf .io .TFRecordWriter (output_path )
141
+ examples = xml_to_csv (os .path .join (voc_root , 'Annotations' ), data_type )
142
+ grouped = split (examples , 'filename' )
143
+
144
+ for group in tqdm (grouped ):
145
+ tf_example = create_tf_example (group , img_path )
146
+ writer .write (tf_example .SerializeToString ())
147
+
148
+ writer .close ()
149
+ print ('Successfully created the TFRecords: {}' .format (output_path ))
93
150
94
151
if __name__ == '__main__' :
95
152
# tf.app.run()
96
153
parser = argparse .ArgumentParser ()
97
- parser .add_argument ("--csv_input" , type = str , required = True , help = "csv文件路径" )
98
- parser .add_argument ("--output_path" , type = str , default = "pascal_voc2007.tfrecord" , help = "tfrecord文件数据路径,默认保存在当前路径" )
154
+ parser .add_argument ("--voc-root" , type = str , required = True , help = "PASCAL VOC 数据集路径,包含JPEGImages和Annotations两个文件夹" )
155
+ parser .add_argument ("--output_name" , type = str , default = "voc2020" , help = "tfrecord文件名称,默认保存在VOC根路径" )
156
+ parser .add_argument ("--format" , type = str , default = "jpg" , help = "图像格式" )
99
157
opt = parser .parse_args ()
100
- main (opt .csv_input , opt .output_path )
158
+ main (opt .voc_root , opt .output_name )
0 commit comments