@@ -506,7 +506,7 @@ def __call__(self, progress, data, **kwargs):
506506
507507 pbar .finish ()
508508 mean_loss = total_loss / step_count if step_count > 0 else 0.0
509- return mean_loss
509+ return mean_loss , step_count
510510
511511 log_info ('STARTING Optimization' )
512512 train_start_time = datetime .utcnow ()
@@ -516,19 +516,21 @@ def __call__(self, progress, data, **kwargs):
516516 for epoch in range (FLAGS .epochs ):
517517 # Training
518518 log_progress ('Training epoch %d...' % epoch )
519- train_loss = run_set ('train' , epoch , train_init_op )
519+ train_loss , _ = run_set ('train' , epoch , train_init_op )
520520 log_progress ('Finished training epoch %d - loss: %f' % (epoch , train_loss ))
521521 checkpoint_saver .save (session , checkpoint_path , global_step = global_step )
522522
523523 if FLAGS .dev_files :
524524 # Validation
525525 dev_loss = 0.0
526+ total_steps = 0
526527 for csv , init_op in zip (dev_csvs , dev_init_ops ):
527528 log_progress ('Validating epoch %d on %s...' % (epoch , csv ))
528- set_loss = run_set ('dev' , epoch , init_op , dataset = csv )
529- dev_loss += set_loss
529+ set_loss , steps = run_set ('dev' , epoch , init_op , dataset = csv )
530+ dev_loss += set_loss * steps
531+ total_steps += steps
530532 log_progress ('Finished validating epoch %d on %s - loss: %f' % (epoch , csv , set_loss ))
531- dev_loss = dev_loss / len ( dev_csvs )
533+ dev_loss = dev_loss / total_steps
532534
533535 dev_losses .append (dev_loss )
534536
0 commit comments