@@ -370,15 +370,13 @@ def inference(sparse_ids, sparse_values, is_train=True):
370
370
371
371
try :
372
372
while not coord .should_stop ():
373
- _ , loss_value , step = sess .run ([train_op , loss , global_step ])
374
-
375
- # Print state while training
376
- if step % FLAGS .steps_to_validate == 0 :
373
+ if FLAGS .benchmark_mode :
374
+ sess .run (train_op )
375
+ else :
377
376
_ , step = sess .run ([train_op , global_step ])
378
377
379
- if FLAGS .benchmark_mode :
380
- logging .info ("The step: {}" .format (step ))
381
- else :
378
+ # Print state while training
379
+ if step % FLAGS .steps_to_validate == 0 :
382
380
loss_value , train_accuracy_value , train_auc_value , validate_accuracy_value , auc_value , summary_value = sess .run (
383
381
[
384
382
loss , train_accuracy , train_auc , validate_accuracy ,
@@ -396,7 +394,8 @@ def inference(sparse_ids, sparse_values, is_train=True):
396
394
start_time = end_time
397
395
except tf .errors .OutOfRangeError :
398
396
if FLAGS .benchmark_mode :
399
- print ("Finish training" )
397
+ print ("Finish training for benchmark" )
398
+ exit (0 )
400
399
else :
401
400
# Export the model after training
402
401
export_model (sess , saver , model_signature , FLAGS .model_path ,
0 commit comments