File tree Expand file tree Collapse file tree 1 file changed +4
-4
lines changed Expand file tree Collapse file tree 1 file changed +4
-4
lines changed Original file line number Diff line number Diff line change @@ -94,7 +94,7 @@ def forward_backward(self, batch):
9494 domain = torch .split (domain , self .split_batch , 0 )
9595 domain = [d [0 ].item () for d in domain ]
9696
97- loss = 0
97+ loss_x = 0
9898 loss_cr = 0
9999 acc = 0
100100
@@ -106,7 +106,7 @@ def forward_backward(self, batch):
106106
107107 # Learning expert
108108 pred_i = self .E (i , feat_i )
109- loss += (- label_i * torch .log (pred_i + 1e-5 )).sum (1 ).mean ()
109+ loss_x += (- label_i * torch .log (pred_i + 1e-5 )).sum (1 ).mean ()
110110 expert_label_i = pred_i .detach ()
111111 acc += compute_accuracy (pred_i .detach (),
112112 label_i .max (1 )[1 ])[0 ].item ()
@@ -121,12 +121,12 @@ def forward_backward(self, batch):
121121 cr_pred = cr_pred .mean (1 )
122122 loss_cr += ((cr_pred - expert_label_i )** 2 ).sum (1 ).mean ()
123123
124- loss /= self .n_domain
124+ loss_x /= self .n_domain
125125 loss_cr /= self .n_domain
126126 acc /= self .n_domain
127127
128128 loss = 0
129- loss += loss
129+ loss += loss_x
130130 loss += loss_cr
131131 self .model_backward_and_update (loss )
132132
You can’t perform that action at this time.
0 commit comments