Skip to content

Commit 94f3552

Browse files
authored
Add files via upload
1 parent bdcda4d commit 94f3552

14 files changed

+2279
-0
lines changed

anchor-cluster.py

+115
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
#coding=utf-8
2+
import xml.etree.ElementTree as ET
3+
import numpy as np
4+
import glob
5+
6+
def iou(box, clusters):
7+
"""
8+
计算一个ground truth边界盒和k个先验框(Anchor)的交并比(IOU)值。
9+
参数box: 元组或者数据,代表ground truth的长宽。
10+
参数clusters: 形如(k,2)的numpy数组,其中k是聚类Anchor框的个数
11+
返回:ground truth和每个Anchor框的交并比。
12+
"""
13+
x = np.minimum(clusters[:, 0], box[0])
14+
y = np.minimum(clusters[:, 1], box[1])
15+
if np.count_nonzero(x == 0) > 0 or np.count_nonzero(y == 0) > 0:
16+
raise ValueError("Box has no area")
17+
intersection = x * y
18+
box_area = box[0] * box[1]
19+
cluster_area = clusters[:, 0] * clusters[:, 1]
20+
iou_ = intersection / (box_area + cluster_area - intersection)
21+
return iou_
22+
23+
24+
def avg_iou(boxes, clusters):
25+
"""
26+
计算一个ground truth和k个Anchor的交并比的均值。
27+
"""
28+
return np.mean([np.max(iou(boxes[i], clusters)) for i in range(boxes.shape[0])])
29+
30+
def kmeans(boxes, k, dist=np.median):
31+
"""
32+
利用IOU值进行K-means聚类
33+
参数boxes: 形状为(r, 2)的ground truth框,其中r是ground truth的个数
34+
参数k: Anchor的个数
35+
参数dist: 距离函数
36+
返回值:形状为(k, 2)的k个Anchor框
37+
"""
38+
# 即是上面提到的r
39+
rows = boxes.shape[0]
40+
# 距离数组,计算每个ground truth和k个Anchor的距离
41+
distances = np.empty((rows, k))
42+
# 上一次每个ground truth"距离"最近的Anchor索引
43+
last_clusters = np.zeros((rows,))
44+
# 设置随机数种子
45+
np.random.seed()
46+
47+
# 初始化聚类中心,k个簇,从r个ground truth随机选k个
48+
clusters = boxes[np.random.choice(rows, k, replace=False)]
49+
# 开始聚类
50+
while True:
51+
# 计算每个ground truth和k个Anchor的距离,用1-IOU(box,anchor)来计算
52+
for row in range(rows):
53+
distances[row] = 1 - iou(boxes[row], clusters)
54+
# 对每个ground truth,选取距离最小的那个Anchor,并存下索引
55+
nearest_clusters = np.argmin(distances, axis=1)
56+
# 如果当前每个ground truth"距离"最近的Anchor索引和上一次一样,聚类结束
57+
if (last_clusters == nearest_clusters).all():
58+
break
59+
# 更新簇中心为簇里面所有的ground truth框的均值
60+
for cluster in range(k):
61+
clusters[cluster] = dist(boxes[nearest_clusters == cluster], axis=0)
62+
# 更新每个ground truth"距离"最近的Anchor索引
63+
last_clusters = nearest_clusters
64+
65+
return clusters
66+
67+
# 加载自己的数据集,只需要所有labelimg标注出来的xml文件即可
68+
def load_dataset(path):
69+
dataset = []
70+
for xml_file in glob.glob("{}/*xml".format(path)):
71+
tree = ET.parse(xml_file)
72+
# 图片高度
73+
height = int(tree.findtext("./size/height"))
74+
# 图片宽度
75+
width = int(tree.findtext("./size/width"))
76+
77+
for obj in tree.iter("object"):
78+
# 偏移量
79+
xmin = int(obj.findtext("bndbox/xmin")) / width
80+
ymin = int(obj.findtext("bndbox/ymin")) / height
81+
xmax = int(obj.findtext("bndbox/xmax")) / width
82+
ymax = int(obj.findtext("bndbox/ymax")) / height
83+
xmin = np.float64(xmin)
84+
ymin = np.float64(ymin)
85+
xmax = np.float64(xmax)
86+
ymax = np.float64(ymax)
87+
if xmax == xmin or ymax == ymin:
88+
print(xml_file)
89+
# 将Anchor的长宽放入dateset,运行kmeans获得Anchor
90+
dataset.append([xmax - xmin, ymax - ymin])
91+
return np.array(dataset)
92+
93+
if __name__ == '__main__':
94+
import argparse
95+
import os
96+
parser = argparse.ArgumentParser()
97+
parser.add_argument('--voc-root', help="VOC格式数据集路径", type=str)
98+
parser.add_argument('--clusters', help="anchor数量", type=int, default=9)
99+
parser.add_argument('--input-size', help="输入网络大小", type=str, default=416)
100+
101+
args = parser.parse_args()
102+
103+
ANNOTATIONS_PATH = os.path.join(args.voc_root,'Annotations') # xml文件所在文件夹
104+
CLUSTERS = args.clusters #聚类数量,anchor数量
105+
INPUTDIM = args.input_size #输入网络大小
106+
107+
data = load_dataset(ANNOTATIONS_PATH)
108+
out = kmeans(data, k=CLUSTERS)
109+
print('Boxes:')
110+
print(np.array(out)*INPUTDIM)
111+
print("Accuracy: {:.2f}%".format(avg_iou(data, out) * 100))
112+
final_anchors = np.around(out[:, 0] / out[:, 1], decimals=2).tolist()
113+
print("Before Sort Ratios:\n {}".format(final_anchors))
114+
print("After Sort Ratios:\n {}".format(sorted(final_anchors)))
115+

check_voc.py

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from pathlib import Path
2+
import os
3+
import argparse
4+
5+
def check_files(ann_root, img_root):
6+
if os.path.exists(ann_root):
7+
ann = Path(ann_root)
8+
else:
9+
raise Exception("标注文件路径错误")
10+
if os.path.exists(img_root):
11+
img = Path(img_root)
12+
else:
13+
raise Exception("图像文件路径错误")
14+
ann_files = []
15+
img_files = []
16+
for an, im in zip(ann.iterdir(),img.iterdir()):
17+
ann_files.append(an.stem)
18+
img_files.append(im.stem)
19+
20+
if set(ann_files)==set(img_files):
21+
print('标注文件和图像文件匹配')
22+
else:
23+
print('标注文件和图像文件不匹配')
24+
25+
if __name__ == "__main__":
26+
27+
parser = argparse.ArgumentParser()
28+
parser.add_argument('--voc-root', type=str, required=True,
29+
help='VOC格式数据集根目录,该目录下必须包含JPEGImages和Annotations这两个文件夹')
30+
opt = parser.parse_args()
31+
32+
IMG_DIR = os.path.join(opt.voc_root, "JPEGImages")
33+
XML_DIR = os.path.join(opt.voc_root, "Annotations")
34+
35+
check_files(XML_DIR, IMG_DIR)

gen_tfrecord.py

+100
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
from __future__ import division
2+
from __future__ import print_function
3+
from __future__ import absolute_import
4+
5+
import os
6+
import io
7+
import pandas as pd
8+
import tensorflow as tf
9+
10+
from PIL import Image
11+
from object_detection.utils import dataset_util
12+
from collections import namedtuple, OrderedDict
13+
import tqdm
14+
import argparse
15+
16+
# flags = tf.app.flags
17+
# flags.DEFINE_string('csv_input', '', 'Path to the CSV input')
18+
# flags.DEFINE_string('output_path', '', 'Path to output TFRecord')
19+
# FLAGS = flags.FLAGS
20+
# TO-DO replace this with label map
21+
labels = ['cow', 'tvmonitor', 'car', 'aeroplane', 'sheep',
22+
'motorbike', 'train', 'chair', 'person', 'sofa',
23+
'pottedplant', 'diningtable', 'horse', 'bottle',
24+
'boat', 'bus', 'bird', 'bicycle', 'cat', 'dog']
25+
26+
def class_text_to_int(row_label, labels):
27+
return labels.index(row_label)+1
28+
29+
def split(df, group):
30+
data = namedtuple('data', ['filename', 'object'])
31+
gb = df.groupby(group)
32+
return [data(filename, gb.get_group(x)) for filename, x in zip(gb.groups.keys(), gb.groups)]
33+
34+
35+
def create_tf_example(group, path):
36+
with tf.io.gfile.GFile(os.path.join(path, '{}'.format(group.filename)), 'rb') as fid:
37+
encoded_jpg = fid.read()
38+
encoded_jpg_io = io.BytesIO(encoded_jpg)
39+
image = Image.open(encoded_jpg_io)
40+
width, height = image.size
41+
42+
filename = group.filename.encode('utf8')
43+
image_format = b'jpg'
44+
xmins = []
45+
xmaxs = []
46+
ymins = []
47+
ymaxs = []
48+
classes_text = []
49+
classes = []
50+
51+
for index, row in group.object.iterrows():
52+
xmins.append(row['xmin'] / width)
53+
xmaxs.append(row['xmax'] / width)
54+
ymins.append(row['ymin'] / height)
55+
ymaxs.append(row['ymax'] / height)
56+
classes_text.append(row['class'].encode('utf8'))
57+
classes.append(class_text_to_int(row['class'], group.filename))
58+
59+
tf_example = tf.train.Example(features=tf.train.Features(feature={
60+
'image/height': dataset_util.int64_feature(height),
61+
'image/width': dataset_util.int64_feature(width),
62+
'image/filename': dataset_util.bytes_feature(filename),
63+
'image/source_id': dataset_util.bytes_feature(filename),
64+
'image/encoded': dataset_util.bytes_feature(encoded_jpg),
65+
'image/format': dataset_util.bytes_feature(image_format),
66+
'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
67+
'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
68+
'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
69+
'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
70+
'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
71+
'image/object/class/label': dataset_util.int64_list_feature(classes),
72+
}))
73+
return tf_example
74+
75+
76+
def main(csv_input, output_path):
77+
writer = tf.io.TFRecordWriter(output_path)
78+
path = os.path.join(os.getcwd(), 'images')
79+
examples = pd.read_csv(csv_input)
80+
grouped = split(examples, 'filename')
81+
num=0
82+
for group in grouped:
83+
num+=1
84+
tf_example = create_tf_example(group, path)
85+
writer.write(tf_example.SerializeToString())
86+
if(num%100==0): #每完成100个转换,打印一次
87+
print(num)
88+
89+
writer.close()
90+
output_path = os.path.join(os.getcwd(), output_path)
91+
print('Successfully created the TFRecords: {}'.format(output_path))
92+
93+
94+
if __name__ == '__main__':
95+
# tf.app.run()
96+
parser = argparse.ArgumentParser()
97+
parser.add_argument("--csv_input", type=str, required=True, help="csv文件路径")
98+
parser.add_argument("--output_path", type=str, default="pascal_voc2007.tfrecord", help="tfrecord文件数据路径,默认保存在当前路径")
99+
opt = parser.parse_args()
100+
main(opt.csv_input, opt.output_path)

gen_yolo_train_test.py

+110
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
'''
2+
Pascal VOC格式数据集生成ImageSets/Main/train.txt,val.txt,trainval.ttx和test.txt
3+
'''
4+
from pathlib import Path
5+
import os
6+
import sys
7+
# from voc2coco import voc_root
8+
import xml.etree.ElementTree as ET
9+
import random
10+
import argparse
11+
from sklearn.model_selection import train_test_split
12+
from sklearn.utils import shuffle
13+
import shutil
14+
15+
def mkdir(path):
16+
# 去除首位空格
17+
path = path.strip()
18+
# 去除尾部 \ 符号
19+
path = path.rstrip("\\")
20+
# 判断路径是否存在
21+
# 存在 True
22+
# 不存在 False
23+
isExists = os.path.exists(path)
24+
# 判断结果
25+
if not isExists:
26+
# 如果不存在则创建目录
27+
# 创建目录操作函数
28+
os.makedirs(path)
29+
print(path + ' 创建成功')
30+
return True
31+
else:
32+
# 如果目录存在则不创建,并提示目录已存在
33+
print(path + ' 目录已存在')
34+
return False
35+
36+
def write_txt(txt_path, data):
37+
'''写入txt文件'''
38+
with open(txt_path,'w') as f:
39+
for d in data:
40+
f.write(str(d))
41+
f.write('\n')
42+
43+
if __name__ == '__main__':
44+
45+
parser = argparse.ArgumentParser()
46+
parser.add_argument('--yolo-root', type=str, required=True,
47+
help='YOLO格式数据集根目录,该目录下必须包含images和labels这两个文件夹')
48+
parser.add_argument('--from_voc',type=bool, default=False,
49+
help='从VOC数据集中的ImageSets/Main文件夹下提取')
50+
parser.add_argument('--voc-root',type=str,
51+
help='VOC数据集路径,需要包含ImageSets/Main文件夹')
52+
parser.add_argument('--test-ratio',type=float, default=0.2,
53+
help='验证集比例,默认为0.2')
54+
parser.add_argument('--ext', type=str, default='.png',
55+
help='YOLO图像数据后缀,注意带"." ' )
56+
opt = parser.parse_args()
57+
58+
yolo_root = opt.yolo_root
59+
print('YOLO格式数据集路径:', yolo_root)
60+
61+
ANNO = os.path.join(yolo_root, 'labels')
62+
JPEG = os.path.join(yolo_root, 'images')
63+
64+
if opt.from_voc:
65+
print('从VOC数据集中分割数据集')
66+
if not opt.voc_root:
67+
raise Exception('需要提供VOC格式路径')
68+
voc_root = opt.voc_root
69+
voc_sets = os.path.join(voc_root,'ImageSets/Main')
70+
if not os.path.exists(voc_sets):
71+
raise Exception('VOC数据集不存在ImageSets/Main路径')
72+
else:
73+
file_lists = list(Path(voc_sets).iterdir())
74+
for file in file_lists:
75+
img_ids = [x.strip() for x in open(file,'r').readlines()]
76+
img_full_path = [os.path.join(JPEG, img_id+opt.ext) for img_id in img_ids]
77+
file_to_write = os.path.join(yolo_root,file.name)
78+
write_txt(file_to_write, img_full_path)
79+
else:
80+
print('从YOLO数据集中按比例随机分割数据集')
81+
p = Path(JPEG)
82+
files = []
83+
for file in p.iterdir():
84+
# name,sufix = file.name.split('.')
85+
if file.name.split('.')[1]==opt.ext[1:]:
86+
files.append(str(file))
87+
# print(name, sufix)
88+
print('数据集长度:',len(files))
89+
files = shuffle(files)
90+
ratio = opt.test_ratio
91+
trainval, test = train_test_split(files, test_size=ratio)
92+
train, val = train_test_split(trainval,test_size=0.2)
93+
print('训练集数量: ',len(train))
94+
print('验证集数量: ',len(val))
95+
print('测试集数量: ',len(test))
96+
97+
98+
# 写入各个txt文件
99+
trainval_txt = os.path.join(yolo_root,'trainval.txt')
100+
write_txt(trainval_txt, trainval)
101+
102+
train_txt = os.path.join(yolo_root,'train.txt')
103+
write_txt(train_txt, train)
104+
105+
val_txt = os.path.join(yolo_root,'val.txt')
106+
write_txt(val_txt, val)
107+
108+
test_txt = os.path.join(yolo_root,'test.txt')
109+
write_txt(test_txt, test)
110+

0 commit comments

Comments
 (0)