|
117 | 117 | 'momentum', 0.9, |
118 | 118 | 'The momentum for the MomentumOptimizer and RMSPropOptimizer.') |
119 | 119 |
|
| 120 | +tf.app.flags.DEFINE_float('rmsprop_momentum', 0.9, 'Momentum.') |
| 121 | + |
120 | 122 | tf.app.flags.DEFINE_float('rmsprop_decay', 0.9, 'Decay term for RMSProp.') |
121 | 123 |
|
122 | 124 | ####################### |
@@ -301,14 +303,15 @@ def _configure_optimizer(learning_rate): |
301 | 303 | optimizer = tf.train.RMSPropOptimizer( |
302 | 304 | learning_rate, |
303 | 305 | decay=FLAGS.rmsprop_decay, |
304 | | - momentum=FLAGS.momentum, |
| 306 | + momentum=FLAGS.rmsprop_momentum, |
305 | 307 | epsilon=FLAGS.opt_epsilon) |
306 | 308 | elif FLAGS.optimizer == 'sgd': |
307 | 309 | optimizer = tf.train.GradientDescentOptimizer(learning_rate) |
308 | 310 | else: |
309 | 311 | raise ValueError('Optimizer [%s] was not recognized', FLAGS.optimizer) |
310 | 312 | return optimizer |
311 | 313 |
|
| 314 | + |
312 | 315 | def _get_init_fn(): |
313 | 316 | """Returns a function run by the chief worker to warm-start the training. |
314 | 317 |
|
@@ -450,20 +453,19 @@ def main(_): |
450 | 453 | #################### |
451 | 454 | def clone_fn(batch_queue): |
452 | 455 | """Allows data parallelism by creating multiple clones of network_fn.""" |
453 | | - with tf.device(deploy_config.inputs_device()): |
454 | | - images, labels = batch_queue.dequeue() |
| 456 | + images, labels = batch_queue.dequeue() |
455 | 457 | logits, end_points = network_fn(images) |
456 | 458 |
|
457 | 459 | ############################# |
458 | 460 | # Specify the loss function # |
459 | 461 | ############################# |
460 | 462 | if 'AuxLogits' in end_points: |
461 | | - tf.losses.softmax_cross_entropy( |
462 | | - logits=end_points['AuxLogits'], onehot_labels=labels, |
463 | | - label_smoothing=FLAGS.label_smoothing, weights=0.4, scope='aux_loss') |
464 | | - tf.losses.softmax_cross_entropy( |
465 | | - logits=logits, onehot_labels=labels, |
466 | | - label_smoothing=FLAGS.label_smoothing, weights=1.0) |
| 463 | + slim.losses.softmax_cross_entropy( |
| 464 | + end_points['AuxLogits'], labels, |
| 465 | + label_smoothing=FLAGS.label_smoothing, weights=0.4, |
| 466 | + scope='aux_loss') |
| 467 | + slim.losses.softmax_cross_entropy( |
| 468 | + logits, labels, label_smoothing=FLAGS.label_smoothing, weights=1.0) |
467 | 469 | return end_points |
468 | 470 |
|
469 | 471 | # Gather initial summaries. |
@@ -515,10 +517,9 @@ def clone_fn(batch_queue): |
515 | 517 | optimizer = tf.train.SyncReplicasOptimizer( |
516 | 518 | opt=optimizer, |
517 | 519 | replicas_to_aggregate=FLAGS.replicas_to_aggregate, |
| 520 | + total_num_replicas=FLAGS.worker_replicas, |
518 | 521 | variable_averages=variable_averages, |
519 | | - variables_to_average=moving_average_variables, |
520 | | - replica_id=tf.constant(FLAGS.task, tf.int32, shape=()), |
521 | | - total_num_replicas=FLAGS.worker_replicas) |
| 522 | + variables_to_average=moving_average_variables) |
522 | 523 | elif FLAGS.moving_average_decay: |
523 | 524 | # Update ops executed locally by trainer. |
524 | 525 | update_ops.append(variable_averages.apply(moving_average_variables)) |
|
0 commit comments