Skip to content

Commit 1398800

Browse files
fix scores mask
1 parent 2f7b80e commit 1398800

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

inference/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -585,8 +585,8 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
585585
else:
586586
group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1)
587587
indices = group_scores.topk(self.topk_groups, dim=-1)[1]
588-
mask = torch.zeros_like(scores[..., 0]).scatter_(1, indices, True)
589-
scores = (scores * mask.unsqueeze(-1)).flatten(1)
588+
mask = scores.new_ones(x.size(0), self.n_groups, dtype=bool).scatter_(1, indices, False)
589+
scores = scores.masked_fill_(mask.unsqueeze(-1), float("-inf")).flatten(1)
590590
indices = torch.topk(scores, self.topk, dim=-1)[1]
591591
weights = original_scores.gather(1, indices)
592592
if self.score_func == "sigmoid":

0 commit comments

Comments
 (0)