11# -*- coding: utf-8 -*
22import torch
33from torch import optim , nn
4- from torch .autograd import Variable
54from torch .utils .data import DataLoader
65from model import Generator , Discriminator
76from option import BaseOptions
@@ -19,23 +18,18 @@ def main():
1918 shuffle = True , num_workers = opt .workers )
2019
2120 # ------------Set labels------------
22- real_I_label_D = torch .from_numpy ( np . ones ((opt .batch_size ,)) ).long ()
23- fake_I_label_D = torch .from_numpy ( np . zeros ((opt .batch_size ,)) ).long ()
24- fake_I_label_G = torch .from_numpy ( np . ones ((opt .batch_size ,)) ).long ()
21+ real_I_label_D = torch .ones ((opt .batch_size ,), ).long ()
22+ fake_I_label_D = torch .zeros ((opt .batch_size ,), ).long ()
23+ fake_I_label_G = torch .ones ((opt .batch_size ,), ).long ()
2524
2625 # ---Fixed z input to eval/visualize---
27- z_sample = torch .randn (opt .batch_size , opt .input_nz , 1 , 1 )
26+ z_sample = torch .randn (( opt .batch_size , opt .input_nz , 1 , 1 ) )
2827
2928 if torch .cuda .is_available ():
30- z_sample = Variable (z_sample ).cuda ()
31- real_I_label_D = Variable (real_I_label_D ).cuda ()
32- fake_I_label_D = Variable (fake_I_label_D ).cuda ()
33- fake_I_label_G = Variable (fake_I_label_G ).cuda ()
34- else :
35- z_sample = Variable (z_sample )
36- real_I_label_D = Variable (real_I_label_D )
37- fake_I_label_D = Variable (fake_I_label_D )
38- fake_I_label_G = Variable (fake_I_label_G )
29+ z_sample = z_sample .cuda ()
30+ real_I_label_D = real_I_label_D .cuda ()
31+ fake_I_label_D = fake_I_label_D .cuda ()
32+ fake_I_label_G = fake_I_label_G .cuda ()
3933
4034 # --------Define class object-------
4135 net_G = Generator (opt )
@@ -54,13 +48,10 @@ def main():
5448 for epoch in range (opt .max_epoch ):
5549 for batch , real_I in enumerate (dataloader ):
5650 # Prepare input data
57- z_G = torch .randn (opt .batch_size , opt .input_nz , 1 , 1 )
51+ z_G = torch .randn (( opt .batch_size , opt .input_nz , 1 , 1 ) )
5852 if torch .cuda .is_available ():
59- real_I = Variable (real_I ).cuda ()
60- z_G = Variable (z_G ).cuda ()
61- else :
62- real_I = Variable (real_I )
63- z_G = Variable (z_G )
53+ real_I = real_I .cuda ()
54+ z_G = z_G .cuda ()
6455
6556 # ------Train Discriminator------
6657 # Forward
0 commit comments