Skip to content

Commit 1b0cca1

Browse files
committed
adding changes/code orga
1 parent c969e99 commit 1b0cca1

File tree

4 files changed

+15
-14
lines changed

4 files changed

+15
-14
lines changed

estimator.py

+5-8
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33

44
from tensorflow.python.platform import tf_logging as logging
5-
from tensorflow.python import debug as tf_debug
65

76
import research.slim.nets.nets_factory as nets_factory
87
from utils.gen_tfrec import load_batch, get_dataset, load_batch_dense, load_batch_estimator
@@ -14,9 +13,8 @@
1413

1514
#Open and read the yaml file:
1615
cwd = os.getcwd()
17-
stream = open(os.path.join(cwd, "yaml","config","config_multiclass.yaml"))
18-
data = load(stream)
19-
stream.close()
16+
with open(os.path.join(cwd, "yaml","config","config.yaml")) as stream:
17+
data = load(stream)
2018
#==================================#
2119
#=======Dataset Informations=======#
2220
#==================================#
@@ -64,9 +62,8 @@
6462
#==================================#
6563
#=======Network Informations=======#
6664
#==================================#
67-
network_file = open(os.path.join(cwd, "yaml", "cnn", model_name+".yaml"))
68-
network_config = load(network_file)
69-
network_file.close()
65+
with open(os.path.join(cwd, "yaml", "cnn", model_name+".yaml")) as network_file:
66+
network_config = load(network_file)
7067
variables_to_exclude = network_config.pop("variables_to_exclude")
7168
argscope_config = network_config.pop("argscope")
7269
if "prediction_fn" in network_config.keys():
@@ -173,7 +170,7 @@ def main():
173170
#Define max steps:
174171
max_step = num_epochs*num_batches_per_epoch
175172
#Define configuration non-distributed work:
176-
run_config = tf.estimator.RunConfig(model_dir=train_dir, save_checkpoints_steps=num_batches_per_epoch,keep_checkpoint_max=num_epochs)
173+
run_config = tf.estimator.RunConfig(save_checkpoints_steps=num_batches_per_epoch,keep_checkpoint_max=num_epochs)
177174
train_spec = tf.estimator.TrainSpec(input_fn=lambda:input_fn(tf.estimator.ModeKeys.TRAIN,
178175
dataset_dir, model_name, file_pattern,
179176
file_pattern_for_counting, names_to_labels,

utils/gen_tfrec.py

-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ def parse_fn(example):
5151
}
5252
parsed_example = tf.parse_single_example(example, feature)
5353
parsed_example['image/encoded'] = tf.image.decode_image(parsed_example['image/encoded'], channels = 3)
54-
parsed_example['image/encoded'] = tf.image.convert_image_dtype(parsed_example['image/encoded'], dtype = tf.float32)
5554
return parsed_example
5655
dataset = dataset.map(parse_fn)
5756
return dataset

utils/images/visu_spark.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,17 @@ def load_images(filenames_pattern, train_size=1.):
1919
train_size : float number representing the train size
2020
(use Dataframe.split([train_size, 1 - train_size]))
2121
"""
22+
struct_keys = ["origin", "height", "width", "nChannels", "mode", "data"]
23+
2224
df = spark.read.load(filenames_pattern, format="image")
23-
a = df.withColumn("image.data", F.decode(df.image.data,'UTF-8'))\
25+
new_cols = [df["image"].getField(alpha) for alpha in struct_keys]
26+
new_frame = df.select(*new_cols)
27+
a = new_frame.withColumn("image.data", F.decode(new_frame.image.data,'UTF-8'))\
2428
.drop("image.data")\
25-
.withColumnRenamed("image.data", "image.data")
26-
a.collect(1)
27-
return df
29+
.withColumnRenamed("image.data", "data")
30+
a.describe().show()
31+
a.printSchema()
32+
return a
2833

2934
def per_pixel_mean(dataframe):
3035
"""

yaml/config/config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#Dataset directory
2-
dataset_dir : "D:/MURA-v1.1"
2+
dataset_dir : "D:/MURA-v1.1/binary"
33
#Portion of GPU to attribute for training
44
gpu_p : 1.
55
#Model name, it serves to target the correct yaml file file in the "cnn" folder:

0 commit comments

Comments
 (0)