Skip to content

Commit c969e99

Browse files
committed
add spark changes
1 parent e50e643 commit c969e99

File tree

6 files changed

+18
-9
lines changed

6 files changed

+18
-9
lines changed

estimator.py

-3
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,7 @@
6868
network_config = load(network_file)
6969
network_file.close()
7070
variables_to_exclude = network_config.pop("variables_to_exclude")
71-
print(variables_to_exclude)
7271
argscope_config = network_config.pop("argscope")
73-
print(argscope_config)
7472
if "prediction_fn" in network_config.keys():
7573
network_config["prediction_fn"] = getattr(tf.contrib.layers, network_config["prediction_fn"])
7674
if "activation_fn" in network_config.keys():
@@ -103,7 +101,6 @@ def input_fn(mode, dataset_dir, model_name, file_pattern, file_pattern_for_count
103101

104102
def model_fn(features, mode):
105103
train_mode = mode==tf.estimator.ModeKeys.TRAIN
106-
tf.summary.image("images",features['image/encoded'])
107104
tf.summary.histogram("final_image_hist", features['image/encoded'])
108105
#Create the model structure using network_fn :
109106
network = nets_factory.networks_map[model_name]

estimator_multilabel.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from tensorflow.python.platform import tf_logging as logging
44
import research.slim.nets.nets_factory as nets_factory
5-
from utils.gen_tfrec import get_dataset_multiclass, load_batch_estimator
5+
from utils.gen_tfrec import get_dataset_multilabel, load_batch_estimator
66

77
import os
88
import sys
@@ -65,7 +65,7 @@
6565
def input_fn(mode, dataset_dir,file_pattern, file_pattern_for_counting, labels_to_name, batch_size, image_size):
6666
train_mode = mode==tf.estimator.ModeKeys.TRAIN
6767
with tf.name_scope("dataset"):
68-
dataset = get_dataset_multiclass("train" if train_mode else "eval",
68+
dataset = get_dataset_multilabel("train" if train_mode else "eval",
6969
dataset_dir, file_pattern=file_pattern,
7070
file_pattern_for_counting=file_pattern_for_counting,
7171
labels_to_name=labels_to_name)

test.py

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import tensorflow as tf
2+
3+
x = tf.placeholder(tf.float32, name="input")
4+
# y =a*x + b
5+
a = tf.constant(2., dtype=tf.float32)
6+
b = tf.constant(0.5,dtype=tf.float32 )
7+
y = a*x+b
8+
9+
with tf.Session() as sess:
10+
c = sess.run([y], feed_dict={'input:0' : 0.2})
11+
print(c)

utils/gen_tfrec.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def parse_fn(example):
3030
dataset = dataset.map(parse_fn, num_parallel_calls=8)
3131
return dataset
3232

33-
def get_dataset_multiclass(phase_name, dataset_dir, file_pattern, file_pattern_for_counting, labels_to_name):
33+
def get_dataset_multilabel(phase_name, dataset_dir, file_pattern, file_pattern_for_counting, labels_to_name):
3434
"""Creates dataset based on phased_name(train or evaluation), datatset_dir. """
3535

3636
#On vérifie si phase_name est 'train' ou 'validation'

utils/images/visu_spark.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@ def load_images(filenames_pattern, train_size=1.):
2121
"""
2222
df = spark.read.load(filenames_pattern, format="image")
2323
a = df.withColumn("image.data", F.decode(df.image.data,'UTF-8'))\
24-
.drop("image.data").withColumnRenamed("image.data", "image.data")
25-
a["image.data"].show(1)
24+
.drop("image.data")\
25+
.withColumnRenamed("image.data", "image.data")
26+
a.collect(1)
2627
return df
2728

2829
def per_pixel_mean(dataframe):

yaml/config/config_multiclass.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ dataset_dir : "D:/MURA-v1.1/multiclass"
33
#Portion of GPU to attribute for training
44
gpu_p : 1.
55

6-
#Model name to call automatically:
6+
#Model name to call automatically:******
77
model_name : "mobilenet_v2"
88
variables_to_exclude : []
99
#Checkpoint directory (For transfer learning)

0 commit comments

Comments
 (0)