Skip to content

Commit 58eeea6

Browse files
committed
Add test file
1 parent edf3c3f commit 58eeea6

File tree

1 file changed

+44
-0
lines changed

1 file changed

+44
-0
lines changed

test.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from model import EDSR
2+
import scipy.misc
3+
import argparse
4+
import data
5+
import os
6+
parser = argparse.ArgumentParser()
7+
parser.add_argument("--dataset",default="data/General-100")
8+
parser.add_argument("--imgsize",default=100,type=int)
9+
parser.add_argument("--scale",default=2,type=int)
10+
parser.add_argument("--layers",default=32,type=int)
11+
parser.add_argument("--featuresize",default=256,type=int)
12+
parser.add_argument("--batchsize",default=10,type=int)
13+
parser.add_argument("--savedir",default="saved_models")
14+
parser.add_argument("--iterations",default=1000,type=int)
15+
parser.add_argument("--numimgs",default=5,type=int)
16+
parser.add_argument("--outdir",default="out")
17+
parser.add_argument("--image")
18+
args = parser.parse_args()
19+
if not os.path.exists(args.outdir):
20+
os.mkdir(args.outdir)
21+
data.load_dataset(args.dataset)
22+
down_size = args.imgsize/args.scale
23+
network = EDSR(down_size,args.layers,args.featuresize,scale=args.scale)
24+
network.resume(args.savedir)
25+
if args.image:
26+
y = data.crop_center(scipy.misc.imread(args.image),args.imgsize,args.imgsize)
27+
x = [scipy.misc.imresize(y,(down_size,down_size))]
28+
y = [y]
29+
else:
30+
x,y=data.get_batch(args.numimgs,args.imgsize,down_size)
31+
inputs = x
32+
outputs = network.predict(x)
33+
correct = y
34+
if args.image:
35+
scipy.misc.imsave(args.outdir+"/input"+args.image,inputs[0])
36+
scipy.misc.imsave(args.outdir+"/output"+args.image,outputs[0])
37+
scipy.misc.imsave(args.outdir+"/correct"+args.image,correct[0])
38+
else:
39+
for i in range(len(inputs)):
40+
scipy.misc.imsave(args.outdir+"/input"+str(i)+".png",inputs[i])
41+
for i in range(len(outputs)):
42+
scipy.misc.imsave(args.outdir+"/output"+str(i)+".png",outputs[i])
43+
for i in range(len(correct)):
44+
scipy.misc.imsave(args.outdir+"/correct"+str(i)+".png",correct[i])

0 commit comments

Comments
 (0)