6
6
# Editor : VIM
7
7
# File name : train.py
8
8
# Author : YunYang1994
9
- # Created date: 2019-10-12 17:44:30
9
+ # Created date: 2019-10-14 19:12:36
10
10
# Description :
11
11
#
12
12
#================================================================
15
15
import cv2
16
16
import random
17
17
import tensorflow as tf
18
- from config import colormap , classes
19
18
import numpy as np
20
- from PIL import Image
19
+ from fcn8s import FCN8s
21
20
from scipy import misc
22
-
21
+ from config import colormap , classes , rgb_mean , rgb_std
23
22
24
23
25
24
def create_image_label_path_generator (images_filepath , labels_filepath ):
26
25
image_paths = open (images_filepath ).readlines ()
27
26
all_label_txts = os .listdir (labels_filepath )
28
- print (all_label_txts )
29
27
image_label_paths = []
30
28
for label_txt in all_label_txts :
31
29
label_name = label_txt [:- 4 ]
@@ -35,61 +33,88 @@ def create_image_label_path_generator(images_filepath, labels_filepath):
35
33
image_name = image_path .split ("/" )[- 1 ][:- 4 ]
36
34
if label_name == image_name :
37
35
image_label_paths .append ((image_path , label_path ))
38
- print (image_label_paths )
39
36
while True :
40
37
random .shuffle (image_label_paths )
41
38
for i in range (len (image_label_paths )):
42
39
yield image_label_paths [i ]
43
40
44
- image_label_path_generator = create_image_label_path_generator (
45
- "./data/train_image.txt" , "./data/train_labels"
46
- )
47
41
48
42
def process_image_label (image_path , label_path ):
49
43
# image = misc.imread(image_path)
50
44
image = cv2 .imread (image_path )
51
45
image = cv2 .resize (image , (224 , 224 ), interpolation = cv2 .INTER_NEAREST )
52
46
image = cv2 .cvtColor (image , cv2 .COLOR_BGR2RGB )
53
- image = np .array (image )
47
+ # data augmentation here
48
+ # pass
49
+ # image transformation here
50
+ image = (image / 255. - rgb_mean ) / rgb_std
54
51
55
52
label = open (label_path ).readlines ()
56
53
label = [np .array (line .rstrip ().split (" " )) for line in label ]
57
54
label = np .array (label , dtype = np .int )
58
55
label = cv2 .resize (label , (224 , 224 ), interpolation = cv2 .INTER_NEAREST )
56
+ label = label .astype (np .int )
59
57
60
58
return image , label
61
59
62
60
63
-
64
- def TrainGenerator (batch_size ):
61
+ def DataGenerator (train_image_txt , train_labels_dir , batch_size ):
65
62
"""
66
63
generate image and mask at the same time
67
64
"""
68
- pass
69
-
70
-
71
- for epoch in range (4 ):
72
- image_path , label_path = next (image_label_path_generator )
73
- # print(image_path, label_path)
74
- image , label = process_image_label (image_path , label_path )
75
-
65
+ image_label_path_generator = create_image_label_path_generator (
66
+ train_image_txt , train_labels_dir
67
+ )
68
+ while True :
69
+ images = np .zeros (shape = [batch_size , 224 , 224 , 3 ])
70
+ labels = np .zeros (shape = [batch_size , 224 , 224 ], dtype = np .float )
71
+ for i in range (batch_size ):
72
+ image_path , label_path = next (image_label_path_generator )
73
+ image , label = process_image_label (image_path , label_path )
74
+ images [i ], labels [i ] = image , label
75
+ yield images , labels
76
+
77
+
78
+ def visual_result (image , label ):
79
+ image = (image * rgb_std + rgb_mean ) * 255
80
+ image , label = image .astype (np .int ), label .astype (np .int )
76
81
H , W , C = image .shape
77
82
new_label = np .zeros (shape = [H , W , C ])
78
83
cls = []
79
84
for i in range (H ):
80
85
for j in range (W ):
81
- new_label [i , j ] = np .array (colormap [label [i ,j ]])
82
- cls .append (label [i , j ])
86
+ cls_idx = label [i , j ]
87
+ new_label [i , j ] = np .array (colormap [cls_idx ])
88
+ cls .append (cls_idx )
83
89
84
- show_image = 0.7 * new_label + 0.3 * image
85
- write_image = np .zeros (shape = [224 , 448 , 3 ])
86
- write_image [:, :224 , :] = image
87
- write_image [:, 224 :, :] = show_image
90
+ # show_image = 0.7*new_label + 0.3*image
91
+ show_image = np .zeros (shape = [224 , 448 , 3 ])
88
92
cls = set (cls )
89
- # for x in cls:
90
- # print(classes[x])
91
- # misc.imshow(show_image)
92
- misc .imshow (write_image )
93
- misc .imsave ("%d.jpg" % epoch , write_image )
93
+ for x in cls :
94
+ print (classes [x ])
95
+ show_image [:, :224 , :] = image
96
+ show_image [:, 224 :, :] = new_label
97
+ misc .imshow (show_image )
98
+
99
+ TrainSet = DataGenerator ("./data/train_image.txt" , "./data/train_labels" , 2 )
100
+ TestSet = DataGenerator ("./data/test_image.txt" , "./data/test_labels" , 2 )
101
+
102
+ model = FCN8s ()
103
+ callback = tf .keras .callbacks .ModelCheckpoint ("model.h5" , verbose = 1 , save_weights_only = True )
104
+ model .compile (optimizer = tf .keras .optimizers .Adam (lr = 1e-4 ),
105
+ callback = callback ,
106
+ loss = 'sparse_categorical_crossentropy' ,
107
+ metrics = ['accuracy' ])
108
+ model .fit_generator (TrainSet , steps_per_epoch = 6000 , epochs = 30 )
109
+ model .save_weights ("model.h5" )
110
+
111
+ # data = np.arange(224*224*3).reshape([1,224,224,3]).astype(np.float)
112
+ # model(data)
113
+ # model.load_weights("model.h5")
114
+
115
+ for x , y in TrainSet :
116
+ result = model (x )
117
+ pred_label = tf .argmax (result , axis = - 1 )
118
+ visual_result (x [0 ], pred_label [0 ].numpy ())
94
119
95
120
0 commit comments