Skip to content

Commit d885ed3

Browse files
committed
deprecate Variable
1 parent 9f89b85 commit d885ed3

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

main.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ def main():
1818
shuffle=True, num_workers=opt.workers)
1919

2020
# ------------Set labels------------
21-
real_I_label_D = torch.ones((opt.batch_size,), requires_grad=True)
22-
fake_I_label_D = torch.zeros((opt.batch_size,), requires_grad=True)
23-
fake_I_label_G = torch.ones((opt.batch_size,), requires_grad=True)
21+
real_I_label_D = torch.ones((opt.batch_size,)).long()
22+
fake_I_label_D = torch.zeros((opt.batch_size,)).long()
23+
fake_I_label_G = torch.ones((opt.batch_size,)).long()
2424

2525
# ---Fixed z input to eval/visualize---
2626
z_sample = torch.from_numpy(np.random.uniform(-1, 1, size=(opt.batch_size , opt.input_nz, 1, 1)))
@@ -62,8 +62,8 @@ def main():
6262
fake_I_logits = net_D(fake_I.detach())
6363
# ****** detach阻断梯度继续反向计算,计算多余梯度,反正优化器里也不会优化之前的参数******
6464
real_I_logits = net_D(real_I)
65-
loss_D = 0.5 * criterion_D(fake_I_logits, fake_I_label_D.detach().long()) + \
66-
0.5 * criterion_D(real_I_logits, real_I_label_D.detach().long())
65+
loss_D = 0.5 * criterion_D(fake_I_logits, fake_I_label_D) + \
66+
0.5 * criterion_D(real_I_logits, real_I_label_D)
6767
# Backward
6868
optim_D.zero_grad()
6969
loss_D.backward()
@@ -73,7 +73,7 @@ def main():
7373
# Forward
7474
# fake_I = net_G(z_G) # 此时net_G没有变,z_G和上一个fake_I都是一样的
7575
fake_I_logits = net_D(fake_I)
76-
loss_G = criterion_G(fake_I_logits, fake_I_label_G.detach().long())
76+
loss_G = criterion_G(fake_I_logits, fake_I_label_G)
7777
# Backward
7878
optim_G.zero_grad()
7979
loss_G.backward()
@@ -85,7 +85,7 @@ def main():
8585
fake_I_logits = net_D(fake_I)
8686
# ****** 此处fake_I没有加detach(),是因为更新G网络,梯度要从D到G一直反向计算到初始位置,******
8787
# ****** 然后在优化器里面只更新G的参数。计算D的梯度只是为了把loss反传到G ******
88-
loss_G = criterion_G(fake_I_logits, fake_I_label_G.detach().long())
88+
loss_G = criterion_G(fake_I_logits, fake_I_label_G)
8989
# Backward
9090
optim_G.zero_grad()
9191
loss_G.backward()
@@ -101,7 +101,7 @@ def main():
101101
save_imgs(fake_I, os.path.join(opt.sample_dir,
102102
'train_epoch_%d_batch_%d.png' % (epoch, batch)))
103103
fake_I_logits = net_D(fake_I)
104-
loss_G = criterion_G(fake_I_logits, fake_I_label_G.detach().long())
104+
loss_G = criterion_G(fake_I_logits, fake_I_label_G)
105105
print('Eval loss %.4f' % loss_G.item())
106106

107107
# Save model every epoch

0 commit comments

Comments
 (0)