@@ -18,9 +18,9 @@ def main():
1818 shuffle = True , num_workers = opt .workers )
1919
2020 # ------------Set labels------------
21- real_I_label_D = torch .ones ((opt .batch_size ,), requires_grad = True )
22- fake_I_label_D = torch .zeros ((opt .batch_size ,), requires_grad = True )
23- fake_I_label_G = torch .ones ((opt .batch_size ,), requires_grad = True )
21+ real_I_label_D = torch .ones ((opt .batch_size ,)). long ( )
22+ fake_I_label_D = torch .zeros ((opt .batch_size ,)). long ( )
23+ fake_I_label_G = torch .ones ((opt .batch_size ,)). long ( )
2424
2525 # ---Fixed z input to eval/visualize---
2626 z_sample = torch .from_numpy (np .random .uniform (- 1 , 1 , size = (opt .batch_size , opt .input_nz , 1 , 1 )))
@@ -62,8 +62,8 @@ def main():
6262 fake_I_logits = net_D (fake_I .detach ())
6363 # ****** detach阻断梯度继续反向计算,计算多余梯度,反正优化器里也不会优化之前的参数******
6464 real_I_logits = net_D (real_I )
65- loss_D = 0.5 * criterion_D (fake_I_logits , fake_I_label_D . detach (). long () ) + \
66- 0.5 * criterion_D (real_I_logits , real_I_label_D . detach (). long () )
65+ loss_D = 0.5 * criterion_D (fake_I_logits , fake_I_label_D ) + \
66+ 0.5 * criterion_D (real_I_logits , real_I_label_D )
6767 # Backward
6868 optim_D .zero_grad ()
6969 loss_D .backward ()
@@ -73,7 +73,7 @@ def main():
7373 # Forward
7474 # fake_I = net_G(z_G) # 此时net_G没有变,z_G和上一个fake_I都是一样的
7575 fake_I_logits = net_D (fake_I )
76- loss_G = criterion_G (fake_I_logits , fake_I_label_G . detach (). long () )
76+ loss_G = criterion_G (fake_I_logits , fake_I_label_G )
7777 # Backward
7878 optim_G .zero_grad ()
7979 loss_G .backward ()
@@ -85,7 +85,7 @@ def main():
8585 fake_I_logits = net_D (fake_I )
8686 # ****** 此处fake_I没有加detach(),是因为更新G网络,梯度要从D到G一直反向计算到初始位置,******
8787 # ****** 然后在优化器里面只更新G的参数。计算D的梯度只是为了把loss反传到G ******
88- loss_G = criterion_G (fake_I_logits , fake_I_label_G . detach (). long () )
88+ loss_G = criterion_G (fake_I_logits , fake_I_label_G )
8989 # Backward
9090 optim_G .zero_grad ()
9191 loss_G .backward ()
@@ -101,7 +101,7 @@ def main():
101101 save_imgs (fake_I , os .path .join (opt .sample_dir ,
102102 'train_epoch_%d_batch_%d.png' % (epoch , batch )))
103103 fake_I_logits = net_D (fake_I )
104- loss_G = criterion_G (fake_I_logits , fake_I_label_G . detach (). long () )
104+ loss_G = criterion_G (fake_I_logits , fake_I_label_G )
105105 print ('Eval loss %.4f' % loss_G .item ())
106106
107107 # Save model every epoch
0 commit comments