|
| 1 | +from model import EDSR |
| 2 | +import scipy.misc |
| 3 | +import argparse |
| 4 | +import data |
| 5 | +import os |
| 6 | +parser = argparse.ArgumentParser() |
| 7 | +parser.add_argument("--dataset",default="data/General-100") |
| 8 | +parser.add_argument("--imgsize",default=100,type=int) |
| 9 | +parser.add_argument("--scale",default=2,type=int) |
| 10 | +parser.add_argument("--layers",default=32,type=int) |
| 11 | +parser.add_argument("--featuresize",default=256,type=int) |
| 12 | +parser.add_argument("--batchsize",default=10,type=int) |
| 13 | +parser.add_argument("--savedir",default="saved_models") |
| 14 | +parser.add_argument("--iterations",default=1000,type=int) |
| 15 | +parser.add_argument("--numimgs",default=5,type=int) |
| 16 | +parser.add_argument("--outdir",default="out") |
| 17 | +parser.add_argument("--image") |
| 18 | +args = parser.parse_args() |
| 19 | +if not os.path.exists(args.outdir): |
| 20 | + os.mkdir(args.outdir) |
| 21 | +data.load_dataset(args.dataset) |
| 22 | +down_size = args.imgsize/args.scale |
| 23 | +network = EDSR(down_size,args.layers,args.featuresize,scale=args.scale) |
| 24 | +network.resume(args.savedir) |
| 25 | +if args.image: |
| 26 | + y = data.crop_center(scipy.misc.imread(args.image),args.imgsize,args.imgsize) |
| 27 | + x = [scipy.misc.imresize(y,(down_size,down_size))] |
| 28 | + y = [y] |
| 29 | +else: |
| 30 | + x,y=data.get_batch(args.numimgs,args.imgsize,down_size) |
| 31 | +inputs = x |
| 32 | +outputs = network.predict(x) |
| 33 | +correct = y |
| 34 | +if args.image: |
| 35 | + scipy.misc.imsave(args.outdir+"/input"+args.image,inputs[0]) |
| 36 | + scipy.misc.imsave(args.outdir+"/output"+args.image,outputs[0]) |
| 37 | + scipy.misc.imsave(args.outdir+"/correct"+args.image,correct[0]) |
| 38 | +else: |
| 39 | + for i in range(len(inputs)): |
| 40 | + scipy.misc.imsave(args.outdir+"/input"+str(i)+".png",inputs[i]) |
| 41 | + for i in range(len(outputs)): |
| 42 | + scipy.misc.imsave(args.outdir+"/output"+str(i)+".png",outputs[i]) |
| 43 | + for i in range(len(correct)): |
| 44 | + scipy.misc.imsave(args.outdir+"/correct"+str(i)+".png",correct[i]) |
0 commit comments