Skip to content

Commit c09e739

Browse files
YunYang1994YunYang1994
YunYang1994
authored and
YunYang1994
committed
FCN train
1 parent fbb6979 commit c09e739

File tree

6 files changed

+75
-42
lines changed

6 files changed

+75
-42
lines changed

5-Image_Segmentation/FCN/README.md

+4-3
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ $ wget http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar
1212
```
1313
Extract all of these tars into one directory and rename them, which should have the following basic structure.
1414
```bashrc
15-
VOC # path: /home/yang/test/VOC/
15+
VOC # path: /home/yang/dataset/VOC
1616
├── test
1717
| └──VOCdevkit
1818
| └──VOC2007 (from VOCtest_06-Nov-2007.tar)
@@ -21,9 +21,10 @@ VOC # path: /home/yang/test/VOC/
2121
└──VOC2007 (from VOCtrainval_06-Nov-2007.tar)
2222
└──VOC2012 (from VOCtrainval_11-May-2012.tar)
2323
```
24-
Then you need to make some transformation.
24+
Finally you need to make some transformation and train it.
2525
```bashrc
26-
$ python parser_voc.py --data_path /home/yang/test/VOC
26+
$ python parser_voc.py --data_path /home/yang/dataset/VOC
27+
$ python train.py
2728
```
2829

2930
|![image](https://user-images.githubusercontent.com/30433053/66732790-d4d56680-ee8f-11e9-9120-07b0e8aa53d4.jpg)|![image](https://user-images.githubusercontent.com/30433053/66732791-d69f2a00-ee8f-11e9-9c5d-16cc84bc7e9e.jpg)|![image](https://user-images.githubusercontent.com/30433053/66732795-da32b100-ee8f-11e9-9d85-f0ddba7a3ab1.jpg)|

5-Image_Segmentation/FCN/config.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
#
1212
#================================================================
1313

14+
import numpy as np
15+
1416
classes = ['background','aeroplane','bicycle','bird','boat',
1517
'bottle','bus','car','cat','chair','cow','diningtable',
1618
'dog','horse','motorbike','person','potted plant',
@@ -22,4 +24,5 @@
2224
[64,128,128],[192,128,128],[0,64,0],[128,64,0],
2325
[0,192,0],[128,192,0],[0,64,128]]
2426

25-
image_size = [224, 224]
27+
rgb_mean = np.array([0.485, 0.456, 0.406])
28+
rgb_std = np.array([0.229, 0.224, 0.225])

5-Image_Segmentation/FCN/fcn8s.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,6 @@ def call(self, x, training=False):
8484
h = self.conv4_3(h)
8585
h = self.pool4(h)
8686
pool4 = h # 1/16
87-
print(pool4.shape)
8887

8988
h = self.conv5_1(h)
9089
h = self.conv5_2(h)
@@ -121,6 +120,5 @@ def call(self, x, training=False):
121120
h = self.upscore8(h)
122121
h = h[:, 31:31+x.shape[1], 31:31+x.shape[2], :] # channel last
123122

124-
return h
125-
123+
return tf.nn.softmax(h, axis=-1)
126124

5-Image_Segmentation/FCN/parser_voc.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@
1717
from scipy import misc
1818

1919
VOC_path = "/home/yang/dataset/VOC"
20-
train_labels = "./data/train_labels"
2120

2221
if not os.path.exists("./data"): os.mkdir("./data")
2322
if not os.path.exists("./data/train_labels"): os.mkdir("./data/train_labels")
23+
if not os.path.exists("./data/test_labels"): os.mkdir("./data/test_labels")
2424

25-
train_image_write = open(os.path.join(os.getcwd(), "data/train_image.txt"), "w")
2625

2726
for mode in ["train", "test"]:
27+
image_write = open(os.path.join(os.getcwd(), "data/%s_image.txt" %mode), "w")
2828
for year in [2007, 2012]:
2929
if mode == "test" and year == 2012: continue
3030
train_label_folder = os.path.join(VOC_path, "%s/VOCdevkit/VOC%d/SegmentationClass" %(mode, year))
@@ -35,10 +35,10 @@
3535
label_name = train_label_image[:-4]
3636
image_path = os.path.join(train_image_folder, label_name + ".jpg")
3737
if not os.path.exists(image_path): continue
38-
train_image_write.writelines(image_path+"\n")
38+
image_write.writelines(image_path+"\n")
3939
label_path = os.path.join(train_label_folder, train_label_image)
4040
label_image = np.array(misc.imread(label_path))
41-
write_label = open("./data/train_labels/"+label_name+".txt", 'w')
41+
write_label = open(("./data/%s_labels/" % mode)+label_name+".txt", 'w')
4242
print("=> processing %s" %label_path)
4343
H, W, C = label_image.shape
4444
for i in range(H):

5-Image_Segmentation/FCN/test.py

+6
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,10 @@
1111
#
1212
#================================================================
1313

14+
import numpy as np
15+
from fcn8s import FCN8s
1416

17+
data = np.arange(224*224*3).reshape([1, 224,224,3]).astype(np.float)
18+
19+
model = FCN8s()
20+
y = model(data)

5-Image_Segmentation/FCN/train.py

+56-31
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
# Editor : VIM
77
# File name : train.py
88
# Author : YunYang1994
9-
# Created date: 2019-10-12 17:44:30
9+
# Created date: 2019-10-14 19:12:36
1010
# Description :
1111
#
1212
#================================================================
@@ -15,17 +15,15 @@
1515
import cv2
1616
import random
1717
import tensorflow as tf
18-
from config import colormap, classes
1918
import numpy as np
20-
from PIL import Image
19+
from fcn8s import FCN8s
2120
from scipy import misc
22-
21+
from config import colormap, classes, rgb_mean, rgb_std
2322

2423

2524
def create_image_label_path_generator(images_filepath, labels_filepath):
2625
image_paths = open(images_filepath).readlines()
2726
all_label_txts = os.listdir(labels_filepath)
28-
print(all_label_txts)
2927
image_label_paths = []
3028
for label_txt in all_label_txts:
3129
label_name = label_txt[:-4]
@@ -35,61 +33,88 @@ def create_image_label_path_generator(images_filepath, labels_filepath):
3533
image_name = image_path.split("/")[-1][:-4]
3634
if label_name == image_name:
3735
image_label_paths.append((image_path, label_path))
38-
print(image_label_paths)
3936
while True:
4037
random.shuffle(image_label_paths)
4138
for i in range(len(image_label_paths)):
4239
yield image_label_paths[i]
4340

44-
image_label_path_generator = create_image_label_path_generator(
45-
"./data/train_image.txt", "./data/train_labels"
46-
)
4741

4842
def process_image_label(image_path, label_path):
4943
# image = misc.imread(image_path)
5044
image = cv2.imread(image_path)
5145
image = cv2.resize(image, (224, 224), interpolation=cv2.INTER_NEAREST)
5246
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
53-
image = np.array(image)
47+
# data augmentation here
48+
# pass
49+
# image transformation here
50+
image = (image / 255. - rgb_mean) / rgb_std
5451

5552
label = open(label_path).readlines()
5653
label = [np.array(line.rstrip().split(" ")) for line in label]
5754
label = np.array(label, dtype=np.int)
5855
label = cv2.resize(label, (224, 224), interpolation=cv2.INTER_NEAREST)
56+
label = label.astype(np.int)
5957

6058
return image, label
6159

6260

63-
64-
def TrainGenerator(batch_size):
61+
def DataGenerator(train_image_txt, train_labels_dir, batch_size):
6562
"""
6663
generate image and mask at the same time
6764
"""
68-
pass
69-
70-
71-
for epoch in range(4):
72-
image_path, label_path = next(image_label_path_generator)
73-
# print(image_path, label_path)
74-
image, label = process_image_label(image_path, label_path)
75-
65+
image_label_path_generator = create_image_label_path_generator(
66+
train_image_txt, train_labels_dir
67+
)
68+
while True:
69+
images = np.zeros(shape=[batch_size, 224, 224, 3])
70+
labels = np.zeros(shape=[batch_size, 224, 224], dtype=np.float)
71+
for i in range(batch_size):
72+
image_path, label_path = next(image_label_path_generator)
73+
image, label = process_image_label(image_path, label_path)
74+
images[i], labels[i] = image, label
75+
yield images, labels
76+
77+
78+
def visual_result(image, label):
79+
image = (image * rgb_std + rgb_mean) * 255
80+
image, label = image.astype(np.int), label.astype(np.int)
7681
H, W, C = image.shape
7782
new_label = np.zeros(shape=[H, W, C])
7883
cls = []
7984
for i in range(H):
8085
for j in range(W):
81-
new_label[i, j] = np.array(colormap[label[i,j]])
82-
cls.append(label[i, j])
86+
cls_idx = label[i, j]
87+
new_label[i, j] = np.array(colormap[cls_idx])
88+
cls.append(cls_idx)
8389

84-
show_image = 0.7*new_label + 0.3*image
85-
write_image = np.zeros(shape=[224, 448, 3])
86-
write_image[:, :224, :] = image
87-
write_image[:, 224:, :] = show_image
90+
# show_image = 0.7*new_label + 0.3*image
91+
show_image = np.zeros(shape=[224, 448, 3])
8892
cls = set(cls)
89-
# for x in cls:
90-
# print(classes[x])
91-
# misc.imshow(show_image)
92-
misc.imshow(write_image)
93-
misc.imsave("%d.jpg"%epoch, write_image)
93+
for x in cls:
94+
print(classes[x])
95+
show_image[:, :224, :] = image
96+
show_image[:, 224:, :] = new_label
97+
misc.imshow(show_image)
98+
99+
TrainSet = DataGenerator("./data/train_image.txt", "./data/train_labels", 2)
100+
TestSet = DataGenerator("./data/test_image.txt", "./data/test_labels", 2)
101+
102+
model = FCN8s()
103+
callback = tf.keras.callbacks.ModelCheckpoint("model.h5", verbose=1, save_weights_only=True)
104+
model.compile(optimizer=tf.keras.optimizers.Adam(lr=1e-4),
105+
callback=callback,
106+
loss='sparse_categorical_crossentropy',
107+
metrics=['accuracy'])
108+
model.fit_generator(TrainSet, steps_per_epoch=6000, epochs=30)
109+
model.save_weights("model.h5")
110+
111+
# data = np.arange(224*224*3).reshape([1,224,224,3]).astype(np.float)
112+
# model(data)
113+
# model.load_weights("model.h5")
114+
115+
for x, y in TrainSet:
116+
result = model(x)
117+
pred_label = tf.argmax(result, axis=-1)
118+
visual_result(x[0], pred_label[0].numpy())
94119

95120

0 commit comments

Comments
 (0)