Skip to content

Commit 6a0dda1

Browse files
parkjaemanTaylor Robie
authored andcommitted
Fix bug on distributed training in mnist using MirroredStrategy API (tensorflow#5183)
* Fix bug on distributed training in mnist using MirroredStrategy API * Remove unnecessary codes and chagne distribution strategy source - Remove multi-gpu - Remove TowerOptimizer - Change from MirroredStrategy to distribution_utils.get_distribution_strategy
1 parent 0d105c3 commit 6a0dda1

File tree

1 file changed

+10
-18
lines changed

1 file changed

+10
-18
lines changed

official/mnist/mnist.py

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def create_model(data_format):
8989

9090
def define_mnist_flags():
9191
flags_core.define_base()
92+
flags_core.define_performance(num_parallel_calls=False)
9293
flags_core.define_image()
9394
flags.adopt_module_key_flags(flags_core)
9495
flags_core.set_defaults(data_dir='/tmp/mnist_data',
@@ -119,10 +120,6 @@ def model_fn(features, labels, mode, params):
119120
if mode == tf.estimator.ModeKeys.TRAIN:
120121
optimizer = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE)
121122

122-
# If we are running multi-GPU, we need to wrap the optimizer.
123-
if params.get('multi_gpu'):
124-
optimizer = tf.contrib.estimator.TowerOptimizer(optimizer)
125-
126123
logits = model(image, training=True)
127124
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
128125
accuracy = tf.metrics.accuracy(
@@ -162,21 +159,16 @@ def run_mnist(flags_obj):
162159
model_helpers.apply_clean(flags_obj)
163160
model_function = model_fn
164161

165-
# Get number of GPUs as defined by the --num_gpus flags and the number of
166-
# GPUs available on the machine.
167-
num_gpus = flags_core.get_num_gpus(flags_obj)
168-
multi_gpu = num_gpus > 1
162+
session_config = tf.ConfigProto(
163+
inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads,
164+
intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,
165+
allow_soft_placement=True)
169166

170-
if multi_gpu:
171-
# Validate that the batch size can be split into devices.
172-
distribution_utils.per_device_batch_size(flags_obj.batch_size, num_gpus)
167+
distribution_strategy = distribution_utils.get_distribution_strategy(
168+
flags_core.get_num_gpus(flags_obj), flags_obj.all_reduce_alg)
173169

174-
# There are two steps required if using multi-GPU: (1) wrap the model_fn,
175-
# and (2) wrap the optimizer. The first happens here, and (2) happens
176-
# in the model_fn itself when the optimizer is defined.
177-
model_function = tf.contrib.estimator.replicate_model_fn(
178-
model_fn, loss_reduction=tf.losses.Reduction.MEAN,
179-
devices=["/device:GPU:%d" % d for d in range(num_gpus)])
170+
run_config = tf.estimator.RunConfig(
171+
train_distribute=distribution_strategy, session_config=session_config)
180172

181173
data_format = flags_obj.data_format
182174
if data_format is None:
@@ -185,9 +177,9 @@ def run_mnist(flags_obj):
185177
mnist_classifier = tf.estimator.Estimator(
186178
model_fn=model_function,
187179
model_dir=flags_obj.model_dir,
180+
config=run_config,
188181
params={
189182
'data_format': data_format,
190-
'multi_gpu': multi_gpu
191183
})
192184

193185
# Set up training and evaluation input functions.

0 commit comments

Comments
 (0)