Skip to content

Commit 74e0da7

Browse files
committed
Avoid printing when in benchmark mode
1 parent 0142c61 commit 74e0da7

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ Following are the supported features.
4040
- [x] Optimizers
4141
- [x] Learning rate decay
4242
- [x] Batch normalization
43+
- [x] Benchmark mode
4344
- [x] [Distributed training](./distributed/)
4445

4546
## Usage

sparse_classifier.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -370,15 +370,13 @@ def inference(sparse_ids, sparse_values, is_train=True):
370370

371371
try:
372372
while not coord.should_stop():
373-
_, loss_value, step = sess.run([train_op, loss, global_step])
374-
375-
# Print state while training
376-
if step % FLAGS.steps_to_validate == 0:
373+
if FLAGS.benchmark_mode:
374+
sess.run(train_op)
375+
else:
377376
_, step = sess.run([train_op, global_step])
378377

379-
if FLAGS.benchmark_mode:
380-
logging.info("The step: {}".format(step))
381-
else:
378+
# Print state while training
379+
if step % FLAGS.steps_to_validate == 0:
382380
loss_value, train_accuracy_value, train_auc_value, validate_accuracy_value, auc_value, summary_value = sess.run(
383381
[
384382
loss, train_accuracy, train_auc, validate_accuracy,
@@ -396,7 +394,8 @@ def inference(sparse_ids, sparse_values, is_train=True):
396394
start_time = end_time
397395
except tf.errors.OutOfRangeError:
398396
if FLAGS.benchmark_mode:
399-
print("Finish training")
397+
print("Finish training for benchmark")
398+
exit(0)
400399
else:
401400
# Export the model after training
402401
export_model(sess, saver, model_signature, FLAGS.model_path,

0 commit comments

Comments
 (0)