Skip to content

Commit 32d861c

Browse files
committed
formatting, fix colorization rgb2bgr
1 parent e924bc0 commit 32d861c

File tree

3 files changed

+35
-34
lines changed

3 files changed

+35
-34
lines changed

eval.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88
import torch.nn as nn
99
from torch.autograd import Variable
1010
from scipy.io import loadmat
11-
from scipy.misc import imsave
12-
from scipy.ndimage import zoom
1311
# Our libs
1412
from dataset import ValDataset
1513
from models import ModelBuilder, SegmentationModule
@@ -84,7 +82,9 @@ def evaluate(segmentation_module, loader, args):
8482

8583
# visualization
8684
if args.visualize:
87-
visualize_result((batch_data['img_ori'], seg_label, batch_data['info']), preds, args)
85+
visualize_result(
86+
(batch_data['img_ori'], seg_label, batch_data['info']),
87+
preds, args)
8888

8989
iou = intersection_meter.sum / (union_meter.sum + 1e-10)
9090
for i, _iou in enumerate(iou):
@@ -136,7 +136,7 @@ def main(args):
136136
# Model related arguments
137137
parser.add_argument('--id', required=True,
138138
help="a name for identifying the model to load")
139-
parser.add_argument('--suffix', default='_best.pth',
139+
parser.add_argument('--suffix', default='_epoch_13.pth',
140140
help="which snapshot to load")
141141
parser.add_argument('--arch_encoder', default='resnet50_dilated8',
142142
help="architecture of net_encoder")
@@ -170,7 +170,7 @@ def main(args):
170170
# Misc arguments
171171
parser.add_argument('--ckpt', default='./ckpt',
172172
help='folder to output checkpoints')
173-
parser.add_argument('--visualize', default=0,
173+
parser.add_argument('--visualize', action='store_true',
174174
help='output visualization?')
175175
parser.add_argument('--result', default='./result',
176176
help='folder to output visualization results')
@@ -180,17 +180,15 @@ def main(args):
180180
args = parser.parse_args()
181181
print(args)
182182

183-
#torch.cuda.set_device(args.gpu_id)
184-
185-
# scales for evaluation
186-
# args.scales = (1, )
187-
# args.scales = (0.5, 0.75, 1, 1.25, 1.5)
183+
# torch.cuda.set_device(args.gpu_id)
188184

189185
# absolute paths of model weights
190186
args.weights_encoder = os.path.join(args.ckpt, args.id,
191187
'encoder' + args.suffix)
192188
args.weights_decoder = os.path.join(args.ckpt, args.id,
193189
'decoder' + args.suffix)
190+
assert os.path.exists(args.weights_encoder) and \
191+
os.path.exists(args.weights_encoder), 'checkpoint does not exitst!'
194192

195193
args.result = os.path.join(args.result, args.id)
196194
if not os.path.isdir(args.result):

train.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,16 @@
88
import numpy as np
99
import torch
1010
import torch.nn as nn
11-
from torch.autograd import Variable
12-
from scipy.io import loadmat
13-
from scipy.misc import imresize, imsave
1411
# Our libs
1512
from dataset import TrainDataset, ValDataset
1613
from models import ModelBuilder, SegmentationModule
1714
from utils import AverageMeter, colorEncode, accuracy
1815
from lib.nn import UserScatteredDataParallel, user_scattered_collate, patch_replication_callback
1916
import lib.utils.data as torchdata
2017

21-
import matplotlib
22-
matplotlib.use('Agg')
23-
import matplotlib.pyplot as plt
18+
# import matplotlib
19+
# matplotlib.use('Agg')
20+
# import matplotlib.pyplot as plt
2421

2522

2623
# train one epoch
@@ -270,15 +267,14 @@ def main(args):
270267
args.id += '-' + str(args.arch_decoder)
271268
args.id += '-ngpus' + str(args.num_gpus)
272269
args.id += '-batchSize' + str(args.batch_size)
273-
#args.id += '-imgSize' + str(args.imgSize)
274270
args.id += '-imgMaxSize' + str(args.imgMaxSize)
275-
args.id += '-padding_constant' + str(args.padding_constant)
276-
args.id += '-segm_downsampling_rate' + str(args.segm_downsampling_rate)
277-
args.id += '-lr_encoder' + str(args.lr_encoder)
278-
args.id += '-lr_decoder' + str(args.lr_decoder)
271+
args.id += '-paddingConst' + str(args.padding_constant)
272+
args.id += '-segmDownsampleRate' + str(args.segm_downsampling_rate)
273+
args.id += '-LR_encoder' + str(args.lr_encoder)
274+
args.id += '-LR_decoder' + str(args.lr_decoder)
279275
args.id += '-epoch' + str(args.num_epoch)
280276
args.id += '-decay' + str(args.weight_decay)
281-
args.id += '-fix_bn' + str(args.fix_bn)
277+
args.id += '-fixBN' + str(args.fix_bn)
282278
print('Model ID: {}'.format(args.id))
283279

284280
args.ckpt = os.path.join(args.ckpt, args.id)

utils.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import torch
1+
# import torch
22
import numpy as np
33

44

@@ -80,7 +80,7 @@ def unique(ar, return_index=False, return_inverse=False, return_counts=False):
8080
return ret
8181

8282

83-
def colorEncode(labelmap, colors):
83+
def colorEncode(labelmap, colors, mode='BGR'):
8484
labelmap = labelmap.astype('int')
8585
labelmap_rgb = np.zeros((labelmap.shape[0], labelmap.shape[1], 3),
8686
dtype=np.uint8)
@@ -90,7 +90,11 @@ def colorEncode(labelmap, colors):
9090
labelmap_rgb += (labelmap == label)[:, :, np.newaxis] * \
9191
np.tile(colors[label],
9292
(labelmap.shape[0], labelmap.shape[1], 1))
93-
return labelmap_rgb
93+
94+
if mode == 'BGR':
95+
return labelmap_rgb[:, :, ::-1]
96+
else:
97+
return labelmap_rgb
9498

9599

96100
def accuracy(preds, label):
@@ -100,22 +104,25 @@ def accuracy(preds, label):
100104
acc = float(acc_sum) / (valid_sum + 1e-10)
101105
return acc, valid_sum
102106

107+
103108
def intersectionAndUnion(imPred, imLab, numClass):
104109
imPred = np.asarray(imPred).copy()
105110
imLab = np.asarray(imLab).copy()
106111

107-
imPred += 1; imLab += 1
108-
# Remove classes from unlabeled pixels in gt image.
112+
imPred += 1
113+
imLab += 1
114+
# Remove classes from unlabeled pixels in gt image.
109115
# We should not penalize detections in unlabeled portions of the image.
110-
imPred = imPred * (imLab>0)
116+
imPred = imPred * (imLab > 0)
111117

112118
# Compute area intersection:
113-
intersection = imPred * (imPred==imLab)
114-
(area_intersection,_) = np.histogram(intersection, bins=numClass, range=(1, numClass))
119+
intersection = imPred * (imPred == imLab)
120+
(area_intersection, _) = np.histogram(
121+
intersection, bins=numClass, range=(1, numClass))
115122

116123
# Compute area union:
117-
(area_pred,_) = np.histogram(imPred, bins=numClass, range=(1, numClass))
118-
(area_lab,_) = np.histogram(imLab, bins=numClass, range=(1, numClass))
124+
(area_pred, _) = np.histogram(imPred, bins=numClass, range=(1, numClass))
125+
(area_lab, _) = np.histogram(imLab, bins=numClass, range=(1, numClass))
119126
area_union = area_pred + area_lab - area_intersection
120-
121-
return (area_intersection, area_union)
127+
128+
return (area_intersection, area_union)

0 commit comments

Comments
 (0)