Skip to content

Commit 1a75dfd

Browse files
committed
added transfer instructions
1 parent 1e751c6 commit 1a75dfd

File tree

5 files changed

+9
-3
lines changed

5 files changed

+9
-3
lines changed

1. O'Reilly Training.ipynb

+9-3
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
"outputs": [],
6464
"source": [
6565
"model_path = './models/tensorflow'\n",
66+
"model_path_transfer = './models/tf_final'\n",
6667
"feature_path = './data/feats.npy'\n",
6768
"annotation_path = './data/results_20130124.token'"
6869
]
@@ -281,7 +282,7 @@
281282
"momentum = 0.9\n",
282283
"n_epochs = 150\n",
283284
"\n",
284-
"def train(learning_rate=0.001, continue_training=False):\n",
285+
"def train(learning_rate=0.001, continue_training=False, transfer=True):\n",
285286
" \n",
286287
" tf.reset_default_graph()\n",
287288
"\n",
@@ -309,7 +310,10 @@
309310
" tf.global_variables_initializer().run()\n",
310311
"\n",
311312
" if continue_training:\n",
312-
" saver.restore(sess,tf.train.latest_checkpoint(model_path))\n",
313+
" if not transfer:\n",
314+
" saver.restore(sess,tf.train.latest_checkpoint(model_path))\n",
315+
" else:\n",
316+
" saver.restore(sess,tf.train.latest_checkpoint(model_path_transfer))\n",
313317
" losses=[]\n",
314318
" for epoch in range(n_epochs):\n",
315319
" for start, end in zip( range(0, len(index), batch_size), range(batch_size, len(index), batch_size)):\n",
@@ -349,7 +353,9 @@
349353
"outputs": [],
350354
"source": [
351355
"try:\n",
352-
" train(.001,True)\n",
356+
" #train(.001,False,False) #train from scratch\n",
357+
" train(.001,True,True) #continue training from pretrained weights @epoch500\n",
358+
" #train(.001,True,False) #train from previously saved weights \n",
353359
"except KeyboardInterrupt:\n",
354360
" print('Exiting Training')"
355361
]
File renamed without changes.
File renamed without changes.
File renamed without changes.

0 commit comments

Comments
 (0)