Skip to content

Commit 5fea650

Browse files
committed
Add benchmark mode for dense classifier
1 parent 74e0da7 commit 5fea650

File tree

1 file changed

+31
-22
lines changed

1 file changed

+31
-22
lines changed

dense_classifier.py

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@
6363
"The test file for inference")
6464
flags.DEFINE_string("inference_result_file", "./inference_result.txt",
6565
"The result file from inference")
66+
flags.DEFINE_boolean("benchmark_mode", False,
67+
"Reduce extra computation in benchmark mode")
6668

6769

6870
def main():
@@ -418,28 +420,35 @@ def inference(inputs, is_train=True):
418420

419421
try:
420422
while not coord.should_stop():
421-
_, loss_value, step = sess.run([train_op, loss, global_step])
422-
423-
# Print state while training
424-
if step % FLAGS.steps_to_validate == 0:
425-
train_accuracy_value, train_auc_value, validate_accuracy_value, validate_auc_value, summary_value = sess.run(
426-
[
427-
train_accuracy, train_auc, validate_accuracy, validate_auc,
428-
summary_op
429-
])
430-
end_time = datetime.datetime.now()
431-
logging.info(
432-
"[{}] Step: {}, loss: {}, train_acc: {}, train_auc: {}, valid_acc: {}, valid_auc: {}".
433-
format(end_time - start_time, step, loss_value,
434-
train_accuracy_value, train_auc_value,
435-
validate_accuracy_value, validate_auc_value))
436-
writer.add_summary(summary_value, step)
437-
saver.save(sess, CHECKPOINT_FILE, global_step=step)
438-
start_time = end_time
423+
if FLAGS.benchmark_mode:
424+
sess.run(train_op)
425+
else:
426+
_, step = sess.run([train_op, global_step])
427+
428+
# Print state while training
429+
if step % FLAGS.steps_to_validate == 0:
430+
loss_value, train_accuracy_value, train_auc_value, validate_accuracy_value, validate_auc_value, summary_value = sess.run(
431+
[
432+
loss, train_accuracy, train_auc, validate_accuracy,
433+
validate_auc, summary_op
434+
])
435+
end_time = datetime.datetime.now()
436+
logging.info(
437+
"[{}] Step: {}, loss: {}, train_acc: {}, train_auc: {}, valid_acc: {}, valid_auc: {}".
438+
format(end_time - start_time, step, loss_value,
439+
train_accuracy_value, train_auc_value,
440+
validate_accuracy_value, validate_auc_value))
441+
writer.add_summary(summary_value, step)
442+
saver.save(sess, CHECKPOINT_FILE, global_step=step)
443+
start_time = end_time
439444
except tf.errors.OutOfRangeError:
440-
# Export the model after training
441-
export_model(sess, saver, model_signature, FLAGS.model_path,
442-
FLAGS.model_version)
445+
if FLAGS.benchmark_mode:
446+
print("Finish training for benchmark")
447+
exit(0)
448+
else:
449+
# Export the model after training
450+
export_model(sess, saver, model_signature, FLAGS.model_path,
451+
FLAGS.model_version)
443452
finally:
444453
coord.request_stop()
445454
coord.join(threads)
@@ -578,4 +587,4 @@ def export_model(sess, saver, signature, model_path, model_version):
578587

579588

580589
if __name__ == "__main__":
581-
main()
590+
main()

0 commit comments

Comments
 (0)