Skip to content

Commit bfa070e

Browse files
committed
Compute weighted average of individual dev set losses
1 parent 911a1ce commit bfa070e

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

DeepSpeech.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,7 @@ def __call__(self, progress, data, **kwargs):
506506

507507
pbar.finish()
508508
mean_loss = total_loss / step_count if step_count > 0 else 0.0
509-
return mean_loss
509+
return mean_loss, step_count
510510

511511
log_info('STARTING Optimization')
512512
train_start_time = datetime.utcnow()
@@ -516,19 +516,21 @@ def __call__(self, progress, data, **kwargs):
516516
for epoch in range(FLAGS.epochs):
517517
# Training
518518
log_progress('Training epoch %d...' % epoch)
519-
train_loss = run_set('train', epoch, train_init_op)
519+
train_loss, _ = run_set('train', epoch, train_init_op)
520520
log_progress('Finished training epoch %d - loss: %f' % (epoch, train_loss))
521521
checkpoint_saver.save(session, checkpoint_path, global_step=global_step)
522522

523523
if FLAGS.dev_files:
524524
# Validation
525525
dev_loss = 0.0
526+
total_steps = 0
526527
for csv, init_op in zip(dev_csvs, dev_init_ops):
527528
log_progress('Validating epoch %d on %s...' % (epoch, csv))
528-
set_loss = run_set('dev', epoch, init_op, dataset=csv)
529-
dev_loss += set_loss
529+
set_loss, steps = run_set('dev', epoch, init_op, dataset=csv)
530+
dev_loss += set_loss * steps
531+
total_steps += steps
530532
log_progress('Finished validating epoch %d on %s - loss: %f' % (epoch, csv, set_loss))
531-
dev_loss = dev_loss / len(dev_csvs)
533+
dev_loss = dev_loss / total_steps
532534

533535
dev_losses.append(dev_loss)
534536

0 commit comments

Comments
 (0)