File tree Expand file tree Collapse file tree 1 file changed +27
-0
lines changed Expand file tree Collapse file tree 1 file changed +27
-0
lines changed Original file line number Diff line number Diff line change
1
+ import torch
2
+ import torch .nn as nn
3
+ import torch .nn .functional as F
4
+
5
+
6
+ class DistillLoss (nn .Module ):
7
+ def __init__ (self , alpha , temperature , k = None ):
8
+ super (DistillLoss , self ).__init__ ()
9
+ self .alpha = alpha
10
+ self .start_alpha = alpha
11
+ self .temperature = temperature
12
+ self .kl_loss = nn .KLDivLoss (reduction = "batchmean" )
13
+ self .TT = self .temperature * self .temperature
14
+ self .ce = nn .CrossEntropyLoss ()
15
+ self .K = k
16
+
17
+ def forward (self ,
18
+ student_out : torch .Tensor ,
19
+ teacher_out : torch .Tensor ,
20
+ label : torch .Tensor ):
21
+ if self .K is not None :
22
+ _ , index = teacher_out .topk (student_out .shape [1 ] - self .K , dim = 1 , largest = False )
23
+ teacher_out [index ] = 0.0 # TODO maybe uniform random value
24
+ l0 = self .kl_loss (F .log_softmax (student_out / self .temperature , dim = 1 ),
25
+ F .softmax (teacher_out / self .temperature , dim = 1 ))
26
+ l1 = self .ce (student_out , label )
27
+ return l0 * self .alpha * self .TT + l1 * (1.0 - self .alpha )
You can’t perform that action at this time.
0 commit comments