Skip to content

Commit e50e643

Browse files
committed
preparing to add jinja/automatic way to generate correct yamls
1 parent e5ed408 commit e50e643

File tree

5 files changed

+19
-13
lines changed

5 files changed

+19
-13
lines changed

estimator.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
slim = tf.contrib.slim
1414

1515
#Open and read the yaml file:
16-
stream = open(os.path.join(os.getcwd(), "yaml","config","config_multiclass.yaml"))
16+
cwd = os.getcwd()
17+
stream = open(os.path.join(cwd, "yaml","config","config_multiclass.yaml"))
1718
data = load(stream)
1819
stream.close()
1920
#==================================#
@@ -26,7 +27,7 @@
2627
checkpoint_dir= data["checkpoint_dir"]
2728
checkpoint_pattern = data["checkpoint_pattern"]
2829
checkpoint_file = os.path.join(checkpoint_dir, checkpoint_pattern)
29-
train_dir = os.path.join(os.getcwd(), "train_"+model_name)
30+
train_dir = os.path.join(cwd, "train_"+ model_name)
3031
#Define the checkpoint state to determine initialization: from pre-trained weigths or recovery
3132
ckpt_state = tf.train.get_checkpoint_state(train_dir)
3233
#TODO: Place image_size on yaml/cnn/model_name.yaml
@@ -57,17 +58,19 @@
5758
num_batches_per_epoch = int(num_samples / batch_size)
5859
#num_batches = num_steps for one epcoh
5960
decay_steps = int(num_epochs_before_decay * num_batches_per_epoch)
60-
#==================================#
61+
#==================================#.
6162
#==================================#
6263

6364
#==================================#
6465
#=======Network Informations=======#
6566
#==================================#
66-
network_file = open(os.path.join(os.getcwd(), "yaml", "cnn", model_name+".yaml"))
67+
network_file = open(os.path.join(cwd, "yaml", "cnn", model_name+".yaml"))
6768
network_config = load(network_file)
6869
network_file.close()
6970
variables_to_exclude = network_config.pop("variables_to_exclude")
71+
print(variables_to_exclude)
7072
argscope_config = network_config.pop("argscope")
73+
print(argscope_config)
7174
if "prediction_fn" in network_config.keys():
7275
network_config["prediction_fn"] = getattr(tf.contrib.layers, network_config["prediction_fn"])
7376
if "activation_fn" in network_config.keys():
@@ -79,7 +82,7 @@
7982

8083
#Create log_dir:argscope_config
8184
if not os.path.exists(train_dir):
82-
os.mkdir(os.path.join(os.getcwd(),train_dir))
85+
os.mkdir(train_dir)
8386
#===================================================================== Training ===========================================================================#
8487
#Adding the graph:
8588
#Set the verbosity to INFO level

utils/images/visu_spark.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pyspark
2+
import pyspark.sql.functions as F
23

34

45
spark = pyspark.sql.SparkSession \
@@ -19,8 +20,9 @@ def load_images(filenames_pattern, train_size=1.):
1920
(use Dataframe.split([train_size, 1 - train_size]))
2021
"""
2122
df = spark.read.load(filenames_pattern, format="image")
22-
df.select("image.origin", "image.width", "image.height").show()
23-
23+
a = df.withColumn("image.data", F.decode(df.image.data,'UTF-8'))\
24+
.drop("image.data").withColumnRenamed("image.data", "image.data")
25+
a["image.data"].show(1)
2426
return df
2527

2628
def per_pixel_mean(dataframe):

yaml/config/config.yaml

+3-4
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,16 @@
22
dataset_dir : "D:/MURA-v1.1"
33
#Portion of GPU to attribute for training
44
gpu_p : 1.
5-
#Model name to call automatically:
6-
model_name : "mobilenet_v2_140"
5+
#Model name, it serves to target the correct yaml file file in the "cnn" folder:
6+
model_name : "mobilenet_v2"
77
#Checkpoint directory (For transfer learning)
88
checkpoint_dir : "D:/mobilenet"
99
checkpoint_pattern : "mobilenet_v2_1.4_224.ckpt"
1010
#File pattern to recognize
1111
file_pattern : "mura_*.tfrecord"
1212
file_pattern_for_counting : "mura"
1313
#Num samples in the training dataset
14-
#Chest-X ray num_samples
15-
14+
#MURA ray num_samples
1615
num_samples : 36807
1716
#Mapping from class to id
1817
names_to_labels : {

yaml/config/config_multilabel.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
dataset_dir : "D:/chest"
33
#Portion of GPU to attribute for training
44
gpu_p : 1.
5-
#Model name to call automatically:
6-
model_name : "mobilenet_v2_140"
5+
#Model name, it serves to target the correct yaml file file in the "cnn" folder:
6+
model_name : "mobilenet_v2"
77
#Checkpoint directory (For transfer learning)
88
checkpoint_dir : "D:/mobilenet"
99
checkpoint_pattern : "mobilenet_v2_1.4_224.ckpt"

yaml/config/config_multitask.yaml

+2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
dataset_dir : "D:/chest"
33
#Portion of GPU to attribute for training
44
gpu_p : 1.
5+
#Model name, it serves to target the correct yaml file file in the "cnn" folder:
6+
model_name : "mobilenet_v2"
57
#Checkpoint directory (For transfer learning)
68
checkpoint_dir : "D:/mobilenet"
79
checkpoint_pattern : "mobilenet_v2_1.4_224.ckpt"

0 commit comments

Comments
 (0)