Skip to content

Commit 0142c61

Browse files
committed
Add bechmark mode for sparse classifier
1 parent fdd603d commit 0142c61

File tree

2 files changed

+29
-18
lines changed

2 files changed

+29
-18
lines changed

dense_classifier.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -578,4 +578,4 @@ def export_model(sess, saver, signature, model_path, model_version):
578578

579579

580580
if __name__ == "__main__":
581-
main()
581+
main()

sparse_classifier.py

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@
6363
"The test file for inference")
6464
flags.DEFINE_string("inference_result_file", "./inference_result.txt",
6565
"The result file from inference")
66+
flags.DEFINE_boolean("benchmark_mode", False,
67+
"Reduce extra computation in benchmark mode")
6668

6769

6870
def main():
@@ -372,24 +374,33 @@ def inference(sparse_ids, sparse_values, is_train=True):
372374

373375
# Print state while training
374376
if step % FLAGS.steps_to_validate == 0:
375-
train_accuracy_value, train_auc_value, validate_accuracy_value, auc_value, summary_value = sess.run(
376-
[
377-
train_accuracy, train_auc, validate_accuracy, validate_auc,
378-
summary_op
379-
])
380-
end_time = datetime.datetime.now()
381-
logging.info(
382-
"[{}] Step: {}, loss: {}, train_acc: {}, train_auc: {}, valid_acc: {}, valid_auc: {}".
383-
format(end_time - start_time, step, loss_value,
384-
train_accuracy_value, train_auc_value,
385-
validate_accuracy_value, auc_value))
386-
writer.add_summary(summary_value, step)
387-
saver.save(sess, CHECKPOINT_FILE, global_step=step)
388-
start_time = end_time
377+
_, step = sess.run([train_op, global_step])
378+
379+
if FLAGS.benchmark_mode:
380+
logging.info("The step: {}".format(step))
381+
else:
382+
loss_value, train_accuracy_value, train_auc_value, validate_accuracy_value, auc_value, summary_value = sess.run(
383+
[
384+
loss, train_accuracy, train_auc, validate_accuracy,
385+
validate_auc, summary_op
386+
])
387+
end_time = datetime.datetime.now()
388+
389+
logging.info(
390+
"[{}] Step: {}, loss: {}, train_acc: {}, train_auc: {}, valid_acc: {}, valid_auc: {}".
391+
format(end_time - start_time, step, loss_value,
392+
train_accuracy_value, train_auc_value,
393+
validate_accuracy_value, auc_value))
394+
writer.add_summary(summary_value, step)
395+
saver.save(sess, CHECKPOINT_FILE, global_step=step)
396+
start_time = end_time
389397
except tf.errors.OutOfRangeError:
390-
# Export the model after training
391-
export_model(sess, saver, model_signature, FLAGS.model_path,
392-
FLAGS.model_version)
398+
if FLAGS.benchmark_mode:
399+
print("Finish training")
400+
else:
401+
# Export the model after training
402+
export_model(sess, saver, model_signature, FLAGS.model_path,
403+
FLAGS.model_version)
393404
finally:
394405
coord.request_stop()
395406
coord.join(threads)

0 commit comments

Comments
 (0)