63
63
"The test file for inference" )
64
64
flags .DEFINE_string ("inference_result_file" , "./inference_result.txt" ,
65
65
"The result file from inference" )
66
+ flags .DEFINE_boolean ("benchmark_mode" , False ,
67
+ "Reduce extra computation in benchmark mode" )
66
68
67
69
68
70
def main ():
@@ -418,28 +420,35 @@ def inference(inputs, is_train=True):
418
420
419
421
try :
420
422
while not coord .should_stop ():
421
- _ , loss_value , step = sess .run ([train_op , loss , global_step ])
422
-
423
- # Print state while training
424
- if step % FLAGS .steps_to_validate == 0 :
425
- train_accuracy_value , train_auc_value , validate_accuracy_value , validate_auc_value , summary_value = sess .run (
426
- [
427
- train_accuracy , train_auc , validate_accuracy , validate_auc ,
428
- summary_op
429
- ])
430
- end_time = datetime .datetime .now ()
431
- logging .info (
432
- "[{}] Step: {}, loss: {}, train_acc: {}, train_auc: {}, valid_acc: {}, valid_auc: {}" .
433
- format (end_time - start_time , step , loss_value ,
434
- train_accuracy_value , train_auc_value ,
435
- validate_accuracy_value , validate_auc_value ))
436
- writer .add_summary (summary_value , step )
437
- saver .save (sess , CHECKPOINT_FILE , global_step = step )
438
- start_time = end_time
423
+ if FLAGS .benchmark_mode :
424
+ sess .run (train_op )
425
+ else :
426
+ _ , step = sess .run ([train_op , global_step ])
427
+
428
+ # Print state while training
429
+ if step % FLAGS .steps_to_validate == 0 :
430
+ loss_value , train_accuracy_value , train_auc_value , validate_accuracy_value , validate_auc_value , summary_value = sess .run (
431
+ [
432
+ loss , train_accuracy , train_auc , validate_accuracy ,
433
+ validate_auc , summary_op
434
+ ])
435
+ end_time = datetime .datetime .now ()
436
+ logging .info (
437
+ "[{}] Step: {}, loss: {}, train_acc: {}, train_auc: {}, valid_acc: {}, valid_auc: {}" .
438
+ format (end_time - start_time , step , loss_value ,
439
+ train_accuracy_value , train_auc_value ,
440
+ validate_accuracy_value , validate_auc_value ))
441
+ writer .add_summary (summary_value , step )
442
+ saver .save (sess , CHECKPOINT_FILE , global_step = step )
443
+ start_time = end_time
439
444
except tf .errors .OutOfRangeError :
440
- # Export the model after training
441
- export_model (sess , saver , model_signature , FLAGS .model_path ,
442
- FLAGS .model_version )
445
+ if FLAGS .benchmark_mode :
446
+ print ("Finish training for benchmark" )
447
+ exit (0 )
448
+ else :
449
+ # Export the model after training
450
+ export_model (sess , saver , model_signature , FLAGS .model_path ,
451
+ FLAGS .model_version )
443
452
finally :
444
453
coord .request_stop ()
445
454
coord .join (threads )
@@ -578,4 +587,4 @@ def export_model(sess, saver, signature, model_path, model_version):
578
587
579
588
580
589
if __name__ == "__main__" :
581
- main ()
590
+ main ()
0 commit comments