1
1
import tensorflow as tf
2
2
3
3
from tensorflow .python .platform import tf_logging as logging
4
+
4
5
import research .slim .nets .mobilenet_v1 as mobilenet_v1
5
- import research . slim . datasets . imagenet as imagenet
6
- from research . slim . preprocessing import inception_preprocessing
6
+
7
+ from utils . gen_utils import load_batch , get_dataset
7
8
8
9
import os
9
10
import time
13
14
14
15
flags = tf .app .flags
15
16
flags .DEFINE_string ('dataset_dir' ,None ,'String: Your dataset directory' )
16
- flags .DEFINE_string ('eval_dir' , None , 'String: Your train directory' )
17
-
17
+ flags .DEFINE_string ('ckpt_number' ,'0' ,'String: ckpt file number' )
18
18
FLAGS = flags .FLAGS
19
19
#=======Dataset Informations=======#
20
20
dataset_dir = FLAGS .dataset_dir
21
+ ckpt_num = FLAGS .ckpt_number
22
+
23
+
24
+ main_dir = "./train_fruit"
25
+ log_dir = main_dir + "/log_eval"
21
26
22
- log_dir = "./log_eval"
23
27
24
- log_eval = './train_fruit/log_eval'
25
28
#Emplacement du checkpoint file
26
- checkpoint_dir = "./ model.ckpt-36480"
29
+ checkpoint_dir = main_dir + "/training/ model.ckpt-" + ckpt_num
27
30
image_size = 224
28
31
#Nombre de classes à prédire
29
32
num_class = 65
30
33
31
- file_pattern = "fruit360_ %s_*.tfrecord"
32
- file_pattern_for_counting = "fruit360 "
34
+ file_pattern = "fruit_ %s_*.tfrecord"
35
+ file_pattern_for_counting = "fruit "
33
36
#Création d'un dictionnaire pour reférer à chaque label
34
37
labels_to_name = {0 :'Apple Braeburn' ,
35
38
1 :'Apple Golden 1' ,
99
102
}
100
103
#Create a dictionary that will help people understand your dataset better. This is required by the Dataset class later.
101
104
102
- items_to_descriptions = {
103
- 'image' : 'A 3-channel RGB coloured flower image that is either... ....' ,
104
- 'label' : 'A label that is as such -- fruits'
105
- }
106
-
107
105
#=======Training Informations======#
108
106
#Nombre d'époques pour l'entraîen
109
107
num_epochs = 1
110
108
111
109
#State your batch size
112
110
batch_size = 16
113
111
114
-
115
-
116
-
117
- def get_dataset (phase_name , dataset_dir , file_pattern = file_pattern , file_pattern_for_counting = file_pattern_for_counting ):
118
- """Creates dataset based on phased_name(train or validation), datatset_dir. """
119
-
120
- #On vérifie si phase_name est 'train' ou 'validation'
121
- if phase_name not in ['train' , 'validation' ]:
122
- raise ValueError ('The phase_name %s is not recognized. Please input either train or validation as the phase_name' % (phase_name ))
123
-
124
- file_pattern_path = os .path .join (dataset_dir , file_pattern % (phase_name ))
125
-
126
- #Compte le nombre total d'examples dans tous les fichiers
127
- num_samples = 0
128
- file_pattern_for_counting = file_pattern_for_counting + '_' + phase_name
129
- tfrecords_to_count = [os .path .join (dataset_dir , file ) for file in os .listdir (dataset_dir ) if file .startswith (file_pattern_for_counting )]
130
- for tfrecord_file in tfrecords_to_count :
131
- for record in tf .python_io .tf_record_iterator (tfrecord_file ):
132
- num_samples += 1
133
-
134
- #Création d'un "reader", de type TFrecord pour ce cas précis:
135
- reader = tf .TFRecordReader
136
-
137
- #Create the keys_to_features dictionary for the decoder
138
- feature = {
139
- 'image/encoded' :tf .FixedLenFeature ((), tf .string , default_value = '' ),
140
- 'image/format' : tf .FixedLenFeature ((), tf .string , default_value = 'jpg' ),
141
- 'image/class/label' :tf .FixedLenFeature ((), tf .int64 ,default_value = tf .zeros ([], dtype = tf .int64 )),
142
- }
143
-
144
- #Create the items_to_handlers dictionary for the decoder.
145
- items_to_handlers = {
146
- 'image' : slim .tfexample_decoder .Image (),
147
- 'label' : slim .tfexample_decoder .Tensor ('image/class/label' ),
148
- }
149
-
150
- #Decoder, provided by slim
151
- decoder = slim .tfexample_decoder .TFExampleDecoder (feature , items_to_handlers )
152
-
153
- labels_map = labels_to_name
154
-
155
- #create the dataset:
156
- dataset = slim .dataset .Dataset (
157
- data_sources = file_pattern_path ,
158
- decoder = decoder ,
159
- reader = reader ,
160
- num_readers = 4 ,
161
- num_samples = num_samples ,
162
- num_classes = num_class ,
163
- labels_to_name = labels_map ,
164
- items_to_descriptions = items_to_descriptions )
165
-
166
- return dataset
167
-
168
-
169
- def load_batch (dataset , batch_size , height = image_size , width = image_size ,is_training = False ):
170
-
171
- """ Fucntion for loading a train batch
172
- OUTPUTS:
173
- - images(Tensor): a Tensor of the shape (batch_size, height, width, channels) that contain one batch of images
174
- - labels(Tensor): the batch's labels with the shape (batch_size,) (requires one_hot_encoding).
175
- """
176
-
177
- #First, create a provider given by slim:
178
- provider = slim .dataset_data_provider .DatasetDataProvider (
179
- dataset ,
180
- common_queue_capacity = 24 + 3 * batch_size ,
181
- common_queue_min = 24
182
- )
183
-
184
- raw_image , label = provider .get (['image' ,'label' ])
185
-
186
- #Preprocessing using inception_preprocessing:
187
- image = inception_preprocessing .preprocess_image (raw_image , height , width , is_training )
188
- one_hot_labels = slim .one_hot_encoding (label , dataset .num_classes )
189
- #As for the raw images, we just do a simple reshape to batch it up
190
- raw_image = tf .expand_dims (raw_image , 0 )
191
- raw_image = tf .image .resize_nearest_neighbor (raw_image , [height , width ])
192
- raw_image = tf .squeeze (raw_image )
193
-
194
- #Batch up the image by enqueing the tensors internally in a FIFO queue and dequeueing many elements with tf.train.batch.
195
- images , raw_images , one_hot_labels , labels = tf .train .batch (
196
- [image , raw_image , one_hot_labels , label ],
197
- batch_size = batch_size ,
198
- num_threads = 4 ,
199
- capacity = 4 * batch_size ,
200
- allow_smaller_final_batch = True )
201
-
202
- return images , raw_images , one_hot_labels , labels
203
-
204
-
205
112
def run ():
206
113
#Create log_dir:
207
114
if not os .path .exists (log_dir ):
@@ -213,8 +120,8 @@ def run():
213
120
global_step_cs = tf .train .get_or_create_global_step ()
214
121
# Adding the graph:
215
122
216
- dataset = get_dataset ("validation" , dataset_dir , file_pattern = file_pattern )
217
- images ,_ , oh_labels , labels = load_batch (dataset , batch_size )
123
+ dataset = get_dataset ("validation" , dataset_dir , file_pattern = file_pattern , file_pattern_for_counting = file_pattern_for_counting , labels_to_name = labels_to_name )
124
+ images ,_ , oh_labels , labels = load_batch (dataset , batch_size , image_size , image_size , is_training = False )
218
125
219
126
#Calcul of batches/epoch, number of steps after decay learning rate
220
127
num_batches_per_epoch = int (dataset .num_samples / batch_size )
@@ -225,25 +132,23 @@ def run():
225
132
"""with slim.arg_scope(mobilenet_v1.mobilenet_v1_arg_scope(is_training=True)):"""
226
133
logits , end_points = mobilenet_v1 .mobilenet_v1_050 (images , num_classes = dataset .num_classes , is_training = False )
227
134
variables_to_restore = slim .get_variables_to_restore ()
228
-
229
-
230
-
231
-
232
- #Defining accuracy and regulization ops:
135
+
136
+ #Defining accuracy and predictions:
233
137
234
138
predictions = tf .argmax (end_points ['Predictions' ], 1 )
235
139
labels = tf .squeeze (labels )
236
- accuracy = tf .reduce_mean (tf .cast (tf .equal (labels , predictions ), tf .float32 ))
237
140
probabilities = end_points ['Predictions' ]
238
141
239
142
#Define the metrics to evaluate
240
143
names_to_values , names_to_updates = slim .metrics .aggregate_metric_map ({
241
144
'Accuracy' : slim .metrics .streaming_accuracy (predictions , labels ),
242
- 'Recall_5' : slim .metrics .streaming_recall_at_k (logits , labels , 5 ),
243
145
})
244
-
146
+ for name , value in names_to_values .items ():
147
+ summary_name = 'eval/%s' % name
148
+ op = tf .summary .scalar (summary_name , value , collections = [])
149
+ op = tf .Print (op , [value ], summary_name )
150
+ tf .add_to_collection (tf .GraphKeys .SUMMARIES , op )
245
151
#Define and merge summaries:
246
- tf .summary .scalar ('Accuracy' , accuracy )
247
152
tf .summary .histogram ('Predictions' , probabilities )
248
153
summary_op = tf .summary .merge_all ()
249
154
0 commit comments