Skip to content

Commit 131d980

Browse files
author
KaiyangZhou
committed
fix typo
1 parent 5e83fdc commit 131d980

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

dassl/engine/dg/daeldg.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def forward_backward(self, batch):
9494
domain = torch.split(domain, self.split_batch, 0)
9595
domain = [d[0].item() for d in domain]
9696

97-
loss = 0
97+
loss_x = 0
9898
loss_cr = 0
9999
acc = 0
100100

@@ -106,7 +106,7 @@ def forward_backward(self, batch):
106106

107107
# Learning expert
108108
pred_i = self.E(i, feat_i)
109-
loss += (-label_i * torch.log(pred_i + 1e-5)).sum(1).mean()
109+
loss_x += (-label_i * torch.log(pred_i + 1e-5)).sum(1).mean()
110110
expert_label_i = pred_i.detach()
111111
acc += compute_accuracy(pred_i.detach(),
112112
label_i.max(1)[1])[0].item()
@@ -121,12 +121,12 @@ def forward_backward(self, batch):
121121
cr_pred = cr_pred.mean(1)
122122
loss_cr += ((cr_pred - expert_label_i)**2).sum(1).mean()
123123

124-
loss /= self.n_domain
124+
loss_x /= self.n_domain
125125
loss_cr /= self.n_domain
126126
acc /= self.n_domain
127127

128128
loss = 0
129-
loss += loss
129+
loss += loss_x
130130
loss += loss_cr
131131
self.model_backward_and_update(loss)
132132

0 commit comments

Comments
 (0)