Skip to content

Commit b93ce6e

Browse files
authored
fix crossentropy size in chatbot to align mask
1 parent a46b643 commit b93ce6e

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

beginner_source/chatbot_tutorial.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -890,7 +890,7 @@ def forward(self, input_step, last_hidden, encoder_outputs):
890890

891891
def maskNLLLoss(inp, target, mask):
892892
nTotal = mask.sum()
893-
crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)))
893+
crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)).squeeze(1))
894894
loss = crossEntropy.masked_select(mask).mean()
895895
loss = loss.to(device)
896896
return loss, nTotal.item()

0 commit comments

Comments
 (0)