Skip to content

Commit 29dfa61

Browse files
committed
Change program architecture
- Regroup data_utils, utils_csv in one folder - adding gen_utils with functions load_batch and get_data - move train_fruit and eval_fruit to main direction
1 parent 63de0ad commit 29dfa61

9 files changed

+145
-244
lines changed

DenseNet/__init__.py

Whitespace-only changes.

__init__.py

Whitespace-only changes.

train_fruit/eval.py renamed to eval_fruit.py

+21-116
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import tensorflow as tf
22

33
from tensorflow.python.platform import tf_logging as logging
4+
45
import research.slim.nets.mobilenet_v1 as mobilenet_v1
5-
import research.slim.datasets.imagenet as imagenet
6-
from research.slim.preprocessing import inception_preprocessing
6+
7+
from utils.gen_utils import load_batch, get_dataset
78

89
import os
910
import time
@@ -13,23 +14,25 @@
1314

1415
flags = tf.app.flags
1516
flags.DEFINE_string('dataset_dir',None,'String: Your dataset directory')
16-
flags.DEFINE_string('eval_dir', None, 'String: Your train directory')
17-
17+
flags.DEFINE_string('ckpt_number','0','String: ckpt file number')
1818
FLAGS = flags.FLAGS
1919
#=======Dataset Informations=======#
2020
dataset_dir = FLAGS.dataset_dir
21+
ckpt_num = FLAGS.ckpt_number
22+
23+
24+
main_dir = "./train_fruit"
25+
log_dir= main_dir + "/log_eval"
2126

22-
log_dir="./log_eval"
2327

24-
log_eval = './train_fruit/log_eval'
2528
#Emplacement du checkpoint file
26-
checkpoint_dir = "./model.ckpt-36480"
29+
checkpoint_dir = main_dir +"/training/model.ckpt-"+ ckpt_num
2730
image_size = 224
2831
#Nombre de classes à prédire
2932
num_class = 65
3033

31-
file_pattern = "fruit360_%s_*.tfrecord"
32-
file_pattern_for_counting = "fruit360"
34+
file_pattern = "fruit_%s_*.tfrecord"
35+
file_pattern_for_counting = "fruit"
3336
#Création d'un dictionnaire pour reférer à chaque label
3437
labels_to_name = {0:'Apple Braeburn',
3538
1:'Apple Golden 1',
@@ -99,109 +102,13 @@
99102
}
100103
#Create a dictionary that will help people understand your dataset better. This is required by the Dataset class later.
101104

102-
items_to_descriptions = {
103-
'image': 'A 3-channel RGB coloured flower image that is either... ....',
104-
'label': 'A label that is as such -- fruits'
105-
}
106-
107105
#=======Training Informations======#
108106
#Nombre d'époques pour l'entraîen
109107
num_epochs = 1
110108

111109
#State your batch size
112110
batch_size = 16
113111

114-
115-
116-
117-
def get_dataset(phase_name, dataset_dir, file_pattern=file_pattern, file_pattern_for_counting=file_pattern_for_counting):
118-
"""Creates dataset based on phased_name(train or validation), datatset_dir. """
119-
120-
#On vérifie si phase_name est 'train' ou 'validation'
121-
if phase_name not in ['train', 'validation']:
122-
raise ValueError('The phase_name %s is not recognized. Please input either train or validation as the phase_name' % (phase_name))
123-
124-
file_pattern_path = os.path.join(dataset_dir, file_pattern%(phase_name))
125-
126-
#Compte le nombre total d'examples dans tous les fichiers
127-
num_samples = 0
128-
file_pattern_for_counting = file_pattern_for_counting + '_' + phase_name
129-
tfrecords_to_count = [os.path.join(dataset_dir, file) for file in os.listdir(dataset_dir) if file.startswith(file_pattern_for_counting)]
130-
for tfrecord_file in tfrecords_to_count:
131-
for record in tf.python_io.tf_record_iterator(tfrecord_file):
132-
num_samples += 1
133-
134-
#Création d'un "reader", de type TFrecord pour ce cas précis:
135-
reader = tf.TFRecordReader
136-
137-
#Create the keys_to_features dictionary for the decoder
138-
feature = {
139-
'image/encoded':tf.FixedLenFeature((), tf.string, default_value=''),
140-
'image/format': tf.FixedLenFeature((), tf.string, default_value='jpg'),
141-
'image/class/label':tf.FixedLenFeature((), tf.int64,default_value=tf.zeros([], dtype=tf.int64)),
142-
}
143-
144-
#Create the items_to_handlers dictionary for the decoder.
145-
items_to_handlers = {
146-
'image': slim.tfexample_decoder.Image(),
147-
'label': slim.tfexample_decoder.Tensor('image/class/label'),
148-
}
149-
150-
#Decoder, provided by slim
151-
decoder = slim.tfexample_decoder.TFExampleDecoder(feature, items_to_handlers)
152-
153-
labels_map= labels_to_name
154-
155-
#create the dataset:
156-
dataset = slim.dataset.Dataset(
157-
data_sources = file_pattern_path,
158-
decoder = decoder,
159-
reader = reader,
160-
num_readers = 4,
161-
num_samples = num_samples,
162-
num_classes = num_class,
163-
labels_to_name = labels_map,
164-
items_to_descriptions = items_to_descriptions)
165-
166-
return dataset
167-
168-
169-
def load_batch(dataset, batch_size, height=image_size, width=image_size,is_training=False):
170-
171-
""" Fucntion for loading a train batch
172-
OUTPUTS:
173-
- images(Tensor): a Tensor of the shape (batch_size, height, width, channels) that contain one batch of images
174-
- labels(Tensor): the batch's labels with the shape (batch_size,) (requires one_hot_encoding).
175-
"""
176-
177-
#First, create a provider given by slim:
178-
provider = slim.dataset_data_provider.DatasetDataProvider(
179-
dataset,
180-
common_queue_capacity = 24 + 3*batch_size,
181-
common_queue_min = 24
182-
)
183-
184-
raw_image, label = provider.get(['image','label'])
185-
186-
#Preprocessing using inception_preprocessing:
187-
image = inception_preprocessing.preprocess_image(raw_image, height, width, is_training)
188-
one_hot_labels = slim.one_hot_encoding(label, dataset.num_classes)
189-
#As for the raw images, we just do a simple reshape to batch it up
190-
raw_image = tf.expand_dims(raw_image, 0)
191-
raw_image = tf.image.resize_nearest_neighbor(raw_image, [height, width])
192-
raw_image = tf.squeeze(raw_image)
193-
194-
#Batch up the image by enqueing the tensors internally in a FIFO queue and dequeueing many elements with tf.train.batch.
195-
images, raw_images, one_hot_labels, labels = tf.train.batch(
196-
[image, raw_image, one_hot_labels, label],
197-
batch_size = batch_size,
198-
num_threads = 4,
199-
capacity = 4 * batch_size,
200-
allow_smaller_final_batch = True)
201-
202-
return images, raw_images, one_hot_labels, labels
203-
204-
205112
def run():
206113
#Create log_dir:
207114
if not os.path.exists(log_dir):
@@ -213,8 +120,8 @@ def run():
213120
global_step_cs = tf.train.get_or_create_global_step()
214121
# Adding the graph:
215122

216-
dataset = get_dataset("validation", dataset_dir, file_pattern=file_pattern)
217-
images,_, oh_labels, labels = load_batch(dataset, batch_size)
123+
dataset = get_dataset("validation", dataset_dir, file_pattern=file_pattern, file_pattern_for_counting=file_pattern_for_counting, labels_to_name=labels_to_name)
124+
images,_, oh_labels, labels = load_batch(dataset, batch_size, image_size, image_size, is_training=False)
218125

219126
#Calcul of batches/epoch, number of steps after decay learning rate
220127
num_batches_per_epoch = int(dataset.num_samples / batch_size)
@@ -225,25 +132,23 @@ def run():
225132
"""with slim.arg_scope(mobilenet_v1.mobilenet_v1_arg_scope(is_training=True)):"""
226133
logits, end_points = mobilenet_v1.mobilenet_v1_050(images, num_classes = dataset.num_classes, is_training = False)
227134
variables_to_restore = slim.get_variables_to_restore()
228-
229-
230-
231-
232-
#Defining accuracy and regulization ops:
135+
136+
#Defining accuracy and predictions:
233137

234138
predictions = tf.argmax(end_points['Predictions'], 1)
235139
labels = tf.squeeze(labels)
236-
accuracy = tf.reduce_mean(tf.cast(tf.equal(labels, predictions), tf.float32))
237140
probabilities = end_points['Predictions']
238141

239142
#Define the metrics to evaluate
240143
names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
241144
'Accuracy': slim.metrics.streaming_accuracy(predictions, labels),
242-
'Recall_5': slim.metrics.streaming_recall_at_k(logits, labels, 5),
243145
})
244-
146+
for name, value in names_to_values.items():
147+
summary_name = 'eval/%s' % name
148+
op = tf.summary.scalar(summary_name, value, collections=[])
149+
op = tf.Print(op, [value], summary_name)
150+
tf.add_to_collection(tf.GraphKeys.SUMMARIES, op)
245151
#Define and merge summaries:
246-
tf.summary.scalar('Accuracy', accuracy)
247152
tf.summary.histogram('Predictions', probabilities)
248153
summary_op = tf.summary.merge_all()
249154

0 commit comments

Comments
 (0)