88import torch .nn as nn
99from torch .autograd import Variable
1010from scipy .io import loadmat
11- from scipy .misc import imsave
12- from scipy .ndimage import zoom
1311# Our libs
1412from dataset import ValDataset
1513from models import ModelBuilder , SegmentationModule
@@ -84,7 +82,9 @@ def evaluate(segmentation_module, loader, args):
8482
8583 # visualization
8684 if args .visualize :
87- visualize_result ((batch_data ['img_ori' ], seg_label , batch_data ['info' ]), preds , args )
85+ visualize_result (
86+ (batch_data ['img_ori' ], seg_label , batch_data ['info' ]),
87+ preds , args )
8888
8989 iou = intersection_meter .sum / (union_meter .sum + 1e-10 )
9090 for i , _iou in enumerate (iou ):
@@ -136,7 +136,7 @@ def main(args):
136136 # Model related arguments
137137 parser .add_argument ('--id' , required = True ,
138138 help = "a name for identifying the model to load" )
139- parser .add_argument ('--suffix' , default = '_best .pth' ,
139+ parser .add_argument ('--suffix' , default = '_epoch_13 .pth' ,
140140 help = "which snapshot to load" )
141141 parser .add_argument ('--arch_encoder' , default = 'resnet50_dilated8' ,
142142 help = "architecture of net_encoder" )
@@ -170,7 +170,7 @@ def main(args):
170170 # Misc arguments
171171 parser .add_argument ('--ckpt' , default = './ckpt' ,
172172 help = 'folder to output checkpoints' )
173- parser .add_argument ('--visualize' , default = 0 ,
173+ parser .add_argument ('--visualize' , action = 'store_true' ,
174174 help = 'output visualization?' )
175175 parser .add_argument ('--result' , default = './result' ,
176176 help = 'folder to output visualization results' )
@@ -180,17 +180,15 @@ def main(args):
180180 args = parser .parse_args ()
181181 print (args )
182182
183- #torch.cuda.set_device(args.gpu_id)
184-
185- # scales for evaluation
186- # args.scales = (1, )
187- # args.scales = (0.5, 0.75, 1, 1.25, 1.5)
183+ # torch.cuda.set_device(args.gpu_id)
188184
189185 # absolute paths of model weights
190186 args .weights_encoder = os .path .join (args .ckpt , args .id ,
191187 'encoder' + args .suffix )
192188 args .weights_decoder = os .path .join (args .ckpt , args .id ,
193189 'decoder' + args .suffix )
190+ assert os .path .exists (args .weights_encoder ) and \
191+ os .path .exists (args .weights_encoder ), 'checkpoint does not exitst!'
194192
195193 args .result = os .path .join (args .result , args .id )
196194 if not os .path .isdir (args .result ):
0 commit comments