We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent d82004d commit 62ec8edCopy full SHA for 62ec8ed
estimator.py
@@ -169,8 +169,11 @@ def model_fn(features, mode):
169
def main():
170
#Define max steps:
171
max_step = num_epochs*num_batches_per_epoch
172
+ strategy = tf.distribute.MirroredStrategy()
173
+
174
#Define configuration non-distributed work:
- run_config = tf.estimator.RunConfig(save_checkpoints_steps=num_batches_per_epoch,keep_checkpoint_max=num_epochs)
175
+ run_config = tf.estimator.RunConfig(save_checkpoints_steps=num_batches_per_epoch,keep_checkpoint_max=num_epochs,
176
+ train_distribute=strategy, eval_distribute=strategy)
177
train_spec = tf.estimator.TrainSpec(input_fn=lambda:input_fn(tf.estimator.ModeKeys.TRAIN,
178
dataset_dir, model_name, file_pattern,
179
file_pattern_for_counting, names_to_labels,
0 commit comments