@@ -89,6 +89,7 @@ def create_model(data_format):
89
89
90
90
def define_mnist_flags ():
91
91
flags_core .define_base ()
92
+ flags_core .define_performance (num_parallel_calls = False )
92
93
flags_core .define_image ()
93
94
flags .adopt_module_key_flags (flags_core )
94
95
flags_core .set_defaults (data_dir = '/tmp/mnist_data' ,
@@ -119,10 +120,6 @@ def model_fn(features, labels, mode, params):
119
120
if mode == tf .estimator .ModeKeys .TRAIN :
120
121
optimizer = tf .train .AdamOptimizer (learning_rate = LEARNING_RATE )
121
122
122
- # If we are running multi-GPU, we need to wrap the optimizer.
123
- if params .get ('multi_gpu' ):
124
- optimizer = tf .contrib .estimator .TowerOptimizer (optimizer )
125
-
126
123
logits = model (image , training = True )
127
124
loss = tf .losses .sparse_softmax_cross_entropy (labels = labels , logits = logits )
128
125
accuracy = tf .metrics .accuracy (
@@ -162,21 +159,16 @@ def run_mnist(flags_obj):
162
159
model_helpers .apply_clean (flags_obj )
163
160
model_function = model_fn
164
161
165
- # Get number of GPUs as defined by the --num_gpus flags and the number of
166
- # GPUs available on the machine.
167
- num_gpus = flags_core . get_num_gpus ( flags_obj )
168
- multi_gpu = num_gpus > 1
162
+ session_config = tf . ConfigProto (
163
+ inter_op_parallelism_threads = flags_obj . inter_op_parallelism_threads ,
164
+ intra_op_parallelism_threads = flags_obj . intra_op_parallelism_threads ,
165
+ allow_soft_placement = True )
169
166
170
- if multi_gpu :
171
- # Validate that the batch size can be split into devices.
172
- distribution_utils .per_device_batch_size (flags_obj .batch_size , num_gpus )
167
+ distribution_strategy = distribution_utils .get_distribution_strategy (
168
+ flags_core .get_num_gpus (flags_obj ), flags_obj .all_reduce_alg )
173
169
174
- # There are two steps required if using multi-GPU: (1) wrap the model_fn,
175
- # and (2) wrap the optimizer. The first happens here, and (2) happens
176
- # in the model_fn itself when the optimizer is defined.
177
- model_function = tf .contrib .estimator .replicate_model_fn (
178
- model_fn , loss_reduction = tf .losses .Reduction .MEAN ,
179
- devices = ["/device:GPU:%d" % d for d in range (num_gpus )])
170
+ run_config = tf .estimator .RunConfig (
171
+ train_distribute = distribution_strategy , session_config = session_config )
180
172
181
173
data_format = flags_obj .data_format
182
174
if data_format is None :
@@ -185,9 +177,9 @@ def run_mnist(flags_obj):
185
177
mnist_classifier = tf .estimator .Estimator (
186
178
model_fn = model_function ,
187
179
model_dir = flags_obj .model_dir ,
180
+ config = run_config ,
188
181
params = {
189
182
'data_format' : data_format ,
190
- 'multi_gpu' : multi_gpu
191
183
})
192
184
193
185
# Set up training and evaluation input functions.
0 commit comments