Skip to content

Commit 4362d3f

Browse files
committed
deprecate Variable
1 parent b88c05b commit 4362d3f

File tree

1 file changed

+11
-20
lines changed

1 file changed

+11
-20
lines changed

main.py

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# -*- coding: utf-8 -*
22
import torch
33
from torch import optim, nn
4-
from torch.autograd import Variable
54
from torch.utils.data import DataLoader
65
from model import Generator, Discriminator
76
from option import BaseOptions
@@ -19,23 +18,18 @@ def main():
1918
shuffle=True, num_workers=opt.workers)
2019

2120
# ------------Set labels------------
22-
real_I_label_D = torch.from_numpy(np.ones((opt.batch_size,))).long()
23-
fake_I_label_D = torch.from_numpy(np.zeros((opt.batch_size,))).long()
24-
fake_I_label_G = torch.from_numpy(np.ones((opt.batch_size,))).long()
21+
real_I_label_D = torch.ones((opt.batch_size,),).long()
22+
fake_I_label_D = torch.zeros((opt.batch_size,),).long()
23+
fake_I_label_G = torch.ones((opt.batch_size,),).long()
2524

2625
# ---Fixed z input to eval/visualize---
27-
z_sample = torch.randn(opt.batch_size, opt.input_nz, 1, 1)
26+
z_sample = torch.randn((opt.batch_size, opt.input_nz, 1, 1))
2827

2928
if torch.cuda.is_available():
30-
z_sample = Variable(z_sample).cuda()
31-
real_I_label_D = Variable(real_I_label_D).cuda()
32-
fake_I_label_D = Variable(fake_I_label_D).cuda()
33-
fake_I_label_G = Variable(fake_I_label_G).cuda()
34-
else:
35-
z_sample = Variable(z_sample)
36-
real_I_label_D = Variable(real_I_label_D)
37-
fake_I_label_D = Variable(fake_I_label_D)
38-
fake_I_label_G = Variable(fake_I_label_G)
29+
z_sample = z_sample.cuda()
30+
real_I_label_D = real_I_label_D.cuda()
31+
fake_I_label_D = fake_I_label_D.cuda()
32+
fake_I_label_G = fake_I_label_G.cuda()
3933

4034
# --------Define class object-------
4135
net_G = Generator(opt)
@@ -54,13 +48,10 @@ def main():
5448
for epoch in range(opt.max_epoch):
5549
for batch, real_I in enumerate(dataloader):
5650
# Prepare input data
57-
z_G = torch.randn(opt.batch_size, opt.input_nz, 1, 1)
51+
z_G = torch.randn((opt.batch_size, opt.input_nz, 1, 1))
5852
if torch.cuda.is_available():
59-
real_I = Variable(real_I).cuda()
60-
z_G = Variable(z_G).cuda()
61-
else:
62-
real_I = Variable(real_I)
63-
z_G = Variable(z_G)
53+
real_I = real_I.cuda()
54+
z_G = z_G.cuda()
6455

6556
# ------Train Discriminator------
6657
# Forward

0 commit comments

Comments
 (0)