|
63 | 63 | "The test file for inference")
|
64 | 64 | flags.DEFINE_string("inference_result_file", "./inference_result.txt",
|
65 | 65 | "The result file from inference")
|
| 66 | +flags.DEFINE_boolean("benchmark_mode", False, |
| 67 | + "Reduce extra computation in benchmark mode") |
66 | 68 |
|
67 | 69 |
|
68 | 70 | def main():
|
@@ -372,24 +374,33 @@ def inference(sparse_ids, sparse_values, is_train=True):
|
372 | 374 |
|
373 | 375 | # Print state while training
|
374 | 376 | if step % FLAGS.steps_to_validate == 0:
|
375 |
| - train_accuracy_value, train_auc_value, validate_accuracy_value, auc_value, summary_value = sess.run( |
376 |
| - [ |
377 |
| - train_accuracy, train_auc, validate_accuracy, validate_auc, |
378 |
| - summary_op |
379 |
| - ]) |
380 |
| - end_time = datetime.datetime.now() |
381 |
| - logging.info( |
382 |
| - "[{}] Step: {}, loss: {}, train_acc: {}, train_auc: {}, valid_acc: {}, valid_auc: {}". |
383 |
| - format(end_time - start_time, step, loss_value, |
384 |
| - train_accuracy_value, train_auc_value, |
385 |
| - validate_accuracy_value, auc_value)) |
386 |
| - writer.add_summary(summary_value, step) |
387 |
| - saver.save(sess, CHECKPOINT_FILE, global_step=step) |
388 |
| - start_time = end_time |
| 377 | + _, step = sess.run([train_op, global_step]) |
| 378 | + |
| 379 | + if FLAGS.benchmark_mode: |
| 380 | + logging.info("The step: {}".format(step)) |
| 381 | + else: |
| 382 | + loss_value, train_accuracy_value, train_auc_value, validate_accuracy_value, auc_value, summary_value = sess.run( |
| 383 | + [ |
| 384 | + loss, train_accuracy, train_auc, validate_accuracy, |
| 385 | + validate_auc, summary_op |
| 386 | + ]) |
| 387 | + end_time = datetime.datetime.now() |
| 388 | + |
| 389 | + logging.info( |
| 390 | + "[{}] Step: {}, loss: {}, train_acc: {}, train_auc: {}, valid_acc: {}, valid_auc: {}". |
| 391 | + format(end_time - start_time, step, loss_value, |
| 392 | + train_accuracy_value, train_auc_value, |
| 393 | + validate_accuracy_value, auc_value)) |
| 394 | + writer.add_summary(summary_value, step) |
| 395 | + saver.save(sess, CHECKPOINT_FILE, global_step=step) |
| 396 | + start_time = end_time |
389 | 397 | except tf.errors.OutOfRangeError:
|
390 |
| - # Export the model after training |
391 |
| - export_model(sess, saver, model_signature, FLAGS.model_path, |
392 |
| - FLAGS.model_version) |
| 398 | + if FLAGS.benchmark_mode: |
| 399 | + print("Finish training") |
| 400 | + else: |
| 401 | + # Export the model after training |
| 402 | + export_model(sess, saver, model_signature, FLAGS.model_path, |
| 403 | + FLAGS.model_version) |
393 | 404 | finally:
|
394 | 405 | coord.request_stop()
|
395 | 406 | coord.join(threads)
|
|
0 commit comments