Skip to content

Commit 62ec8ed

Browse files
committed
adding distribute strategy
1 parent d82004d commit 62ec8ed

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

estimator.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,11 @@ def model_fn(features, mode):
169169
def main():
170170
#Define max steps:
171171
max_step = num_epochs*num_batches_per_epoch
172+
strategy = tf.distribute.MirroredStrategy()
173+
172174
#Define configuration non-distributed work:
173-
run_config = tf.estimator.RunConfig(save_checkpoints_steps=num_batches_per_epoch,keep_checkpoint_max=num_epochs)
175+
run_config = tf.estimator.RunConfig(save_checkpoints_steps=num_batches_per_epoch,keep_checkpoint_max=num_epochs,
176+
train_distribute=strategy, eval_distribute=strategy)
174177
train_spec = tf.estimator.TrainSpec(input_fn=lambda:input_fn(tf.estimator.ModeKeys.TRAIN,
175178
dataset_dir, model_name, file_pattern,
176179
file_pattern_for_counting, names_to_labels,

0 commit comments

Comments
 (0)