Skip to content

Commit aed84cd

Browse files
committed
Sync DistillLoss.
1 parent 683fe89 commit aed84cd

File tree

1 file changed

+27
-0
lines changed

1 file changed

+27
-0
lines changed

loss/distill.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
5+
6+
class DistillLoss(nn.Module):
7+
def __init__(self, alpha, temperature, k=None):
8+
super(DistillLoss, self).__init__()
9+
self.alpha = alpha
10+
self.start_alpha = alpha
11+
self.temperature = temperature
12+
self.kl_loss = nn.KLDivLoss(reduction="batchmean")
13+
self.TT = self.temperature * self.temperature
14+
self.ce = nn.CrossEntropyLoss()
15+
self.K = k
16+
17+
def forward(self,
18+
student_out: torch.Tensor,
19+
teacher_out: torch.Tensor,
20+
label: torch.Tensor):
21+
if self.K is not None:
22+
_, index = teacher_out.topk(student_out.shape[1] - self.K, dim=1, largest=False)
23+
teacher_out[index] = 0.0 # TODO maybe uniform random value
24+
l0 = self.kl_loss(F.log_softmax(student_out / self.temperature, dim=1),
25+
F.softmax(teacher_out / self.temperature, dim=1))
26+
l1 = self.ce(student_out, label)
27+
return l0 * self.alpha * self.TT + l1 * (1.0 - self.alpha)

0 commit comments

Comments
 (0)