Skip to content

Commit 969634b

Browse files
committed
bugfix memory leak
1 parent 449af34 commit 969634b

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -190,13 +190,13 @@ def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None):
190190
"inf"
191191
) # float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
192192
scores = scores.masked_fill(mask, min_value)
193-
self.attn = torch.softmax(scores, dim=-1).masked_fill(
193+
attn = torch.softmax(scores, dim=-1).masked_fill(
194194
mask, 0.0
195195
) # (batch, head, time1, time2)
196196
else:
197-
self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
197+
attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
198198

199-
p_attn = self.dropout(self.attn)
199+
p_attn = self.dropout(attn)
200200
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
201201
x = (
202202
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)

0 commit comments

Comments
 (0)