|
63 | 63 | "outputs": [],
|
64 | 64 | "source": [
|
65 | 65 | "model_path = './models/tensorflow'\n",
|
| 66 | + "model_path_transfer = './models/tf_final'\n", |
66 | 67 | "feature_path = './data/feats.npy'\n",
|
67 | 68 | "annotation_path = './data/results_20130124.token'"
|
68 | 69 | ]
|
|
281 | 282 | "momentum = 0.9\n",
|
282 | 283 | "n_epochs = 150\n",
|
283 | 284 | "\n",
|
284 |
| - "def train(learning_rate=0.001, continue_training=False):\n", |
| 285 | + "def train(learning_rate=0.001, continue_training=False, transfer=True):\n", |
285 | 286 | " \n",
|
286 | 287 | " tf.reset_default_graph()\n",
|
287 | 288 | "\n",
|
|
309 | 310 | " tf.global_variables_initializer().run()\n",
|
310 | 311 | "\n",
|
311 | 312 | " if continue_training:\n",
|
312 |
| - " saver.restore(sess,tf.train.latest_checkpoint(model_path))\n", |
| 313 | + " if not transfer:\n", |
| 314 | + " saver.restore(sess,tf.train.latest_checkpoint(model_path))\n", |
| 315 | + " else:\n", |
| 316 | + " saver.restore(sess,tf.train.latest_checkpoint(model_path_transfer))\n", |
313 | 317 | " losses=[]\n",
|
314 | 318 | " for epoch in range(n_epochs):\n",
|
315 | 319 | " for start, end in zip( range(0, len(index), batch_size), range(batch_size, len(index), batch_size)):\n",
|
|
349 | 353 | "outputs": [],
|
350 | 354 | "source": [
|
351 | 355 | "try:\n",
|
352 |
| - " train(.001,True)\n", |
| 356 | + " #train(.001,False,False) #train from scratch\n", |
| 357 | + " train(.001,True,True) #continue training from pretrained weights @epoch500\n", |
| 358 | + " #train(.001,True,False) #train from previously saved weights \n", |
353 | 359 | "except KeyboardInterrupt:\n",
|
354 | 360 | " print('Exiting Training')"
|
355 | 361 | ]
|
|
0 commit comments