|
13 | 13 | slim = tf.contrib.slim
|
14 | 14 |
|
15 | 15 | #Open and read the yaml file:
|
16 |
| -stream = open(os.path.join(os.getcwd(), "yaml","config","config_multiclass.yaml")) |
| 16 | +cwd = os.getcwd() |
| 17 | +stream = open(os.path.join(cwd, "yaml","config","config_multiclass.yaml")) |
17 | 18 | data = load(stream)
|
18 | 19 | stream.close()
|
19 | 20 | #==================================#
|
|
26 | 27 | checkpoint_dir= data["checkpoint_dir"]
|
27 | 28 | checkpoint_pattern = data["checkpoint_pattern"]
|
28 | 29 | checkpoint_file = os.path.join(checkpoint_dir, checkpoint_pattern)
|
29 |
| -train_dir = os.path.join(os.getcwd(), "train_"+model_name) |
| 30 | +train_dir = os.path.join(cwd, "train_"+ model_name) |
30 | 31 | #Define the checkpoint state to determine initialization: from pre-trained weigths or recovery
|
31 | 32 | ckpt_state = tf.train.get_checkpoint_state(train_dir)
|
32 | 33 | #TODO: Place image_size on yaml/cnn/model_name.yaml
|
|
57 | 58 | num_batches_per_epoch = int(num_samples / batch_size)
|
58 | 59 | #num_batches = num_steps for one epcoh
|
59 | 60 | decay_steps = int(num_epochs_before_decay * num_batches_per_epoch)
|
60 |
| -#==================================# |
| 61 | +#==================================#. |
61 | 62 | #==================================#
|
62 | 63 |
|
63 | 64 | #==================================#
|
64 | 65 | #=======Network Informations=======#
|
65 | 66 | #==================================#
|
66 |
| -network_file = open(os.path.join(os.getcwd(), "yaml", "cnn", model_name+".yaml")) |
| 67 | +network_file = open(os.path.join(cwd, "yaml", "cnn", model_name+".yaml")) |
67 | 68 | network_config = load(network_file)
|
68 | 69 | network_file.close()
|
69 | 70 | variables_to_exclude = network_config.pop("variables_to_exclude")
|
| 71 | +print(variables_to_exclude) |
70 | 72 | argscope_config = network_config.pop("argscope")
|
| 73 | +print(argscope_config) |
71 | 74 | if "prediction_fn" in network_config.keys():
|
72 | 75 | network_config["prediction_fn"] = getattr(tf.contrib.layers, network_config["prediction_fn"])
|
73 | 76 | if "activation_fn" in network_config.keys():
|
|
79 | 82 |
|
80 | 83 | #Create log_dir:argscope_config
|
81 | 84 | if not os.path.exists(train_dir):
|
82 |
| - os.mkdir(os.path.join(os.getcwd(),train_dir)) |
| 85 | + os.mkdir(train_dir) |
83 | 86 | #===================================================================== Training ===========================================================================#
|
84 | 87 | #Adding the graph:
|
85 | 88 | #Set the verbosity to INFO level
|
|
0 commit comments