Skip to content

Commit 486c083

Browse files
committed
fix sensevoice loss_rich
1 parent e775c38 commit 486c083

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -691,10 +691,11 @@ def forward(
691691
encoder_out[:, :4, :], text[:, :4]
692692
)
693693

694-
loss = loss_ctc
694+
loss = loss_ctc + loss_rich
695695
# Collect total loss stats
696-
stats["loss"] = torch.clone(loss.detach()) if loss_ctc is not None else None
696+
stats["loss_ctc"] = torch.clone(loss_ctc.detach()) if loss_ctc is not None else None
697697
stats["loss_rich"] = torch.clone(loss_rich.detach()) if loss_rich is not None else None
698+
stats["loss"] = torch.clone(loss.detach()) if loss is not None else None
698699
stats["acc_rich"] = acc_rich
699700

700701
# force_gatherable: to-device and to-tensor if scalar for DataParallel

0 commit comments

Comments
 (0)