Skip to content

Commit 215617d

Browse files
authored
Merge pull request CSAILVision#190 from CSAILVision/rgb
use RGB instead of BGR
2 parents 42b7567 + 5a22d74 commit 215617d

File tree

9 files changed

+95
-94
lines changed

9 files changed

+95
-94
lines changed

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@ Color encoding of semantic categories can be found here:
1818
https://docs.google.com/spreadsheets/d/1se8YEtb2detS7OuPE86fXGyD269pMycAWe2mtKUj2W8/edit?usp=sharing
1919

2020
## Updates
21-
- We use configuration files to store most options which were in argument parser. The definitions of options are detailed in ```config/defaults.py```.
2221
- HRNet model is now supported.
22+
- We use configuration files to store most options which were in argument parser. The definitions of options are detailed in ```config/defaults.py```.
23+
- We conform to Pytorch practice in data preprocessing (RGB [0, 1], substract mean, divide std).
2324

2425

2526
## Highlights
@@ -61,7 +62,7 @@ Decoder:
6162
- UPerNet (Pyramid Pooling + FPN head, see [UperNet](https://arxiv.org/abs/1807.10221) for details.)
6263

6364
## Performance:
64-
IMPORTANT: We use our self-trained base model on ImageNet. The model takes the input in BGR form (consistent with opencv) instead of RGB form as used by default implementation of PyTorch. The base model will be automatically downloaded when needed.
65+
IMPORTANT: The base ResNet in our repository is a customized (different from the one in torchvision). The base models will be automatically downloaded when needed.
6566

6667
<table><tbody>
6768
<th valign="bottom">Architecture</th>

config/ade20k-resnet101dilated-ppm_deepsup.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ MODEL:
1616

1717
TRAIN:
1818
batch_size_per_gpu: 2
19-
num_epoch: 20
19+
num_epoch: 25
2020
start_epoch: 0
2121
epoch_iters: 5000
2222
optim: "SGD"
@@ -33,10 +33,10 @@ TRAIN:
3333

3434
VAL:
3535
visualize: False
36-
checkpoint: "epoch_20.pth"
36+
checkpoint: "epoch_25.pth"
3737

3838
TEST:
39-
checkpoint: "epoch_20.pth"
39+
checkpoint: "epoch_25.pth"
4040
result: "./"
4141

4242
DIR: "ckpt/ade20k-resnet50dilated-ppm_deepsup"

config/ade20k-resnet50-upernet.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ MODEL:
1616

1717
TRAIN:
1818
batch_size_per_gpu: 2
19-
num_epoch: 40
19+
num_epoch: 30
2020
start_epoch: 0
2121
epoch_iters: 5000
2222
optim: "SGD"
@@ -33,10 +33,10 @@ TRAIN:
3333

3434
VAL:
3535
visualize: False
36-
checkpoint: "epoch_40.pth"
36+
checkpoint: "epoch_30.pth"
3737

3838
TEST:
39-
checkpoint: "epoch_40.pth"
39+
checkpoint: "epoch_30.pth"
4040
result: "./"
4141

4242
DIR: "ckpt/ade20k-resnet50-upernet"

dataset.py

Lines changed: 76 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,22 @@
11
import os
22
import json
33
import torch
4-
import cv2
54
from torchvision import transforms
65
import numpy as np
7-
import PIL
6+
from PIL import Image
87

98

109
def imresize(im, size, interp='bilinear'):
1110
if interp == 'nearest':
12-
resample = PIL.Image.NEAREST
11+
resample = Image.NEAREST
1312
elif interp == 'bilinear':
14-
resample = PIL.Image.BILINEAR
13+
resample = Image.BILINEAR
1514
elif interp == 'bicubic':
16-
resample = PIL.Image.BICUBIC
15+
resample = Image.BICUBIC
1716
else:
1817
raise Exception('resample method undefined!')
1918

20-
return np.array(
21-
PIL.Image.fromarray(im).resize((size[1], size[0]), resample)
22-
)
19+
return im.resize(size, resample)
2320

2421

2522
class BaseDataset(torch.utils.data.Dataset):
@@ -35,8 +32,8 @@ def __init__(self, odgt, opt, **kwargs):
3532

3633
# mean and std
3734
self.normalize = transforms.Normalize(
38-
mean=[102.9801, 115.9465, 122.7717],
39-
std=[1., 1., 1.])
35+
mean=[0.485, 0.456, 0.406],
36+
std=[0.229, 0.224, 0.225])
4037

4138
def parse_input_list(self, odgt, max_sample=-1, start_idx=-1, end_idx=-1):
4239
if isinstance(odgt, list):
@@ -54,12 +51,17 @@ def parse_input_list(self, odgt, max_sample=-1, start_idx=-1, end_idx=-1):
5451
print('# samples: {}'.format(self.num_sample))
5552

5653
def img_transform(self, img):
57-
# image to float
58-
img = img.astype(np.float32)
54+
# 0-255 to 0-1
55+
img = np.float32(np.array(img)) / 255.
5956
img = img.transpose((2, 0, 1))
6057
img = self.normalize(torch.from_numpy(img.copy()))
6158
return img
6259

60+
def segm_transform(self, segm):
61+
# to tensor, -1 to 149
62+
segm = torch.from_numpy(np.array(segm)).long() - 1
63+
return segm
64+
6365
# Round x to the nearest multiple of p and x' >= x
6466
def round2nearest_multiple(self, x, p):
6567
return ((x - 1) // p + 1) * p
@@ -69,7 +71,6 @@ class TrainDataset(BaseDataset):
6971
def __init__(self, root_dataset, odgt, opt, batch_per_gpu=1, **kwargs):
7072
super(TrainDataset, self).__init__(odgt, opt, **kwargs)
7173
self.root_dataset = root_dataset
72-
self.random_flip = opt.random_flip
7374
# down sampling rate of segm labe
7475
self.segm_downsampling_rate = opt.segm_downsampling_rate
7576
self.batch_per_gpu = batch_per_gpu
@@ -124,71 +125,74 @@ def __getitem__(self, index):
124125

125126
# calculate the BATCH's height and width
126127
# since we concat more than one samples, the batch's h and w shall be larger than EACH sample
127-
batch_resized_size = np.zeros((self.batch_per_gpu, 2), np.int32)
128+
batch_widths = np.zeros(self.batch_per_gpu, np.int32)
129+
batch_heights = np.zeros(self.batch_per_gpu, np.int32)
128130
for i in range(self.batch_per_gpu):
129131
img_height, img_width = batch_records[i]['height'], batch_records[i]['width']
130132
this_scale = min(
131133
this_short_size / min(img_height, img_width), \
132134
self.imgMaxSize / max(img_height, img_width))
133-
img_resized_height, img_resized_width = img_height * this_scale, img_width * this_scale
134-
batch_resized_size[i, :] = img_resized_height, img_resized_width
135-
batch_resized_height = np.max(batch_resized_size[:, 0])
136-
batch_resized_width = np.max(batch_resized_size[:, 1])
135+
batch_widths[i] = img_width * this_scale
136+
batch_heights[i] = img_height * this_scale
137137

138138
# Here we must pad both input image and segmentation map to size h' and w' so that p | h' and p | w'
139-
batch_resized_height = int(self.round2nearest_multiple(batch_resized_height, self.padding_constant))
140-
batch_resized_width = int(self.round2nearest_multiple(batch_resized_width, self.padding_constant))
141-
142-
assert self.padding_constant >= self.segm_downsampling_rate,\
143-
'padding constant must be equal or large than segm downsamping rate'
144-
batch_images = torch.zeros(self.batch_per_gpu, 3, batch_resized_height, batch_resized_width)
139+
batch_width = np.max(batch_widths)
140+
batch_height = np.max(batch_heights)
141+
batch_width = int(self.round2nearest_multiple(batch_width, self.padding_constant))
142+
batch_height = int(self.round2nearest_multiple(batch_height, self.padding_constant))
143+
144+
assert self.padding_constant >= self.segm_downsampling_rate, \
145+
'padding constant must be equal or large than segm downsamping rate'
146+
batch_images = torch.zeros(
147+
self.batch_per_gpu, 3, batch_height, batch_width)
145148
batch_segms = torch.zeros(
146-
self.batch_per_gpu, batch_resized_height // self.segm_downsampling_rate, \
147-
batch_resized_width // self.segm_downsampling_rate).long()
149+
self.batch_per_gpu,
150+
batch_height // self.segm_downsampling_rate,
151+
batch_width // self.segm_downsampling_rate).long()
148152

149153
for i in range(self.batch_per_gpu):
150154
this_record = batch_records[i]
151155

152156
# load image and label
153157
image_path = os.path.join(self.root_dataset, this_record['fpath_img'])
154158
segm_path = os.path.join(self.root_dataset, this_record['fpath_segm'])
155-
img = cv2.imread(image_path, cv2.IMREAD_COLOR)
156-
segm = cv2.imread(segm_path, cv2.IMREAD_GRAYSCALE)
157159

158-
assert(img.ndim == 3)
159-
assert(segm.ndim == 2)
160-
assert(img.shape[0] == segm.shape[0])
161-
assert(img.shape[1] == segm.shape[1])
160+
img = Image.open(image_path).convert('RGB')
161+
segm = Image.open(segm_path)
162+
assert(segm.mode == "L")
163+
assert(img.size[0] == segm.size[0])
164+
assert(img.size[1] == segm.size[1])
162165

163-
if self.random_flip is True:
164-
random_flip = np.random.choice([0, 1])
165-
if random_flip == 1:
166-
img = cv2.flip(img, 1)
167-
segm = cv2.flip(segm, 1)
166+
# random_flip
167+
if np.random.choice([0, 1]):
168+
img = img.transpose(Image.FLIP_LEFT_RIGHT)
169+
segm = segm.transpose(Image.FLIP_LEFT_RIGHT)
168170

169171
# note that each sample within a mini batch has different scale param
170-
img = imresize(img, (batch_resized_size[i, 0], batch_resized_size[i, 1]), interp='bilinear')
171-
segm = imresize(segm, (batch_resized_size[i, 0], batch_resized_size[i, 1]), interp='nearest')
172-
173-
# to avoid seg label misalignment
174-
segm_rounded_height = self.round2nearest_multiple(segm.shape[0], self.segm_downsampling_rate)
175-
segm_rounded_width = self.round2nearest_multiple(segm.shape[1], self.segm_downsampling_rate)
176-
segm_rounded = np.zeros((segm_rounded_height, segm_rounded_width), dtype='uint8')
177-
segm_rounded[:segm.shape[0], :segm.shape[1]] = segm
178-
172+
img = imresize(img, (batch_widths[i], batch_heights[i]), interp='bilinear')
173+
segm = imresize(segm, (batch_widths[i], batch_heights[i]), interp='nearest')
174+
175+
# further downsample seg label, need to avoid seg label misalignment
176+
segm_rounded_width = self.round2nearest_multiple(segm.size[0], self.segm_downsampling_rate)
177+
segm_rounded_height = self.round2nearest_multiple(segm.size[1], self.segm_downsampling_rate)
178+
segm_rounded = Image.new('L', (segm_rounded_width, segm_rounded_height), 0)
179+
segm_rounded.paste(segm, (0, 0))
179180
segm = imresize(
180181
segm_rounded,
181-
(segm_rounded.shape[0] // self.segm_downsampling_rate, \
182-
segm_rounded.shape[1] // self.segm_downsampling_rate), \
182+
(segm_rounded.size[0] // self.segm_downsampling_rate, \
183+
segm_rounded.size[1] // self.segm_downsampling_rate), \
183184
interp='nearest')
184185

185-
# image transform
186+
# image transform, to torch float tensor 3xHxW
186187
img = self.img_transform(img)
187188

189+
# segm transform, to torch long tensor HxW
190+
segm = self.segm_transform(segm)
191+
192+
# put into batch arrays
188193
batch_images[i][:, :img.shape[1], :img.shape[2]] = img
189-
batch_segms[i][:segm.shape[0], :segm.shape[1]] = torch.from_numpy(segm.astype(np.int)).long()
194+
batch_segms[i][:segm.shape[0], :segm.shape[1]] = segm
190195

191-
batch_segms = batch_segms - 1 # label from -1 to 149
192196
output = dict()
193197
output['img_data'] = batch_images
194198
output['seg_label'] = batch_segms
@@ -209,10 +213,13 @@ def __getitem__(self, index):
209213
# load image and label
210214
image_path = os.path.join(self.root_dataset, this_record['fpath_img'])
211215
segm_path = os.path.join(self.root_dataset, this_record['fpath_segm'])
212-
img = cv2.imread(image_path, cv2.IMREAD_COLOR)
213-
segm = cv2.imread(segm_path, cv2.IMREAD_GRAYSCALE)
216+
img = Image.open(image_path).convert('RGB')
217+
segm = Image.open(segm_path)
218+
assert(segm.mode == "L")
219+
assert(img.size[0] == segm.size[0])
220+
assert(img.size[1] == segm.size[1])
214221

215-
ori_height, ori_width, _ = img.shape
222+
ori_width, ori_height = img.size
216223

217224
img_resized_list = []
218225
for this_short_size in self.imgSizes:
@@ -222,24 +229,23 @@ def __getitem__(self, index):
222229
target_height, target_width = int(ori_height * scale), int(ori_width * scale)
223230

224231
# to avoid rounding in network
225-
target_height = self.round2nearest_multiple(target_height, self.padding_constant)
226232
target_width = self.round2nearest_multiple(target_width, self.padding_constant)
233+
target_height = self.round2nearest_multiple(target_height, self.padding_constant)
227234

228-
# resize
229-
img_resized = cv2.resize(img.copy(), (target_width, target_height))
235+
# resize images
236+
img_resized = imresize(img, (target_width, target_height), interp='bilinear')
230237

231-
# image transform
238+
# image transform, to torch float tensor 3xHxW
232239
img_resized = self.img_transform(img_resized)
233-
234240
img_resized = torch.unsqueeze(img_resized, 0)
235241
img_resized_list.append(img_resized)
236242

237-
segm = torch.from_numpy(segm.astype(np.int)).long()
243+
# segm transform, to torch long tensor HxW
244+
segm = self.segm_transform(segm)
238245
batch_segms = torch.unsqueeze(segm, 0)
239246

240-
batch_segms = batch_segms - 1 # label from -1 to 149
241247
output = dict()
242-
output['img_ori'] = img.copy()
248+
output['img_ori'] = np.array(img)
243249
output['img_data'] = [x.contiguous() for x in img_resized_list]
244250
output['seg_label'] = batch_segms.contiguous()
245251
output['info'] = this_record['fpath_img']
@@ -255,11 +261,11 @@ def __init__(self, odgt, opt, **kwargs):
255261

256262
def __getitem__(self, index):
257263
this_record = self.list_sample[index]
258-
# load image and label
264+
# load image
259265
image_path = this_record['fpath_img']
260-
img = cv2.imread(image_path, cv2.IMREAD_COLOR)
266+
img = Image.open(image_path).convert('RGB')
261267

262-
ori_height, ori_width, _ = img.shape
268+
ori_width, ori_height = img.size
263269

264270
img_resized_list = []
265271
for this_short_size in self.imgSizes:
@@ -269,19 +275,19 @@ def __getitem__(self, index):
269275
target_height, target_width = int(ori_height * scale), int(ori_width * scale)
270276

271277
# to avoid rounding in network
272-
target_height = self.round2nearest_multiple(target_height, self.padding_constant)
273278
target_width = self.round2nearest_multiple(target_width, self.padding_constant)
279+
target_height = self.round2nearest_multiple(target_height, self.padding_constant)
274280

275-
# resize
276-
img_resized = cv2.resize(img.copy(), (target_width, target_height))
281+
# resize images
282+
img_resized = imresize(img, (target_width, target_height), interp='bilinear')
277283

278-
# image transform
284+
# image transform, to torch float tensor 3xHxW
279285
img_resized = self.img_transform(img_resized)
280286
img_resized = torch.unsqueeze(img_resized, 0)
281287
img_resized_list.append(img_resized)
282288

283289
output = dict()
284-
output['img_ori'] = img.copy()
290+
output['img_ori'] = np.array(img)
285291
output['img_data'] = [x.contiguous() for x in img_resized_list]
286292
output['info'] = this_record['fpath_img']
287293
return output

demo_test.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
# Image and model names
44
TEST_IMG=ADE_val_00001519.jpg
5-
MODEL_PATH=baseline-resnet50dilated-ppm_deepsup
5+
MODEL_PATH=ade20k-resnet50dilated-ppm_deepsup
66
RESULT_PATH=./
77

88
ENCODER=$MODEL_PATH/encoder_epoch_20.pth
@@ -28,4 +28,4 @@ python3 -u test.py \
2828
--cfg config/ade20k-resnet50dilated-ppm_deepsup.yaml \
2929
DIR $MODEL_PATH \
3030
TEST.result ./ \
31-
TEST.suffix _epoch_20.pth
31+
TEST.checkpoint epoch_20.pth

eval.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from utils import AverageMeter, colorEncode, accuracy, intersectionAndUnion, setup_logger
1616
from lib.nn import user_scattered_collate, async_copy_to
1717
from lib.utils import as_numpy
18-
import cv2
18+
from PIL import Image
1919
from tqdm import tqdm
2020

2121
colors = loadmat('data/color150.mat')['colors']
@@ -35,10 +35,7 @@ def visualize_result(data, pred, dir_result):
3535
axis=1).astype(np.uint8)
3636

3737
img_name = info.split('/')[-1]
38-
cv2.imwrite(
39-
os.path.join(dir_result, img_name.replace('.jpg', '.png')),
40-
im_vis
41-
)
38+
Image.fromarray(im_vis).save(os.path.join(dir_result, img_name.replace('.jpg', '.png')))
4239

4340

4441
def evaluate(segmentation_module, loader, cfg, gpu):

eval_multipro.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from utils import AverageMeter, colorEncode, accuracy, intersectionAndUnion, parse_devices, setup_logger
1717
from lib.nn import user_scattered_collate, async_copy_to
1818
from lib.utils import as_numpy
19-
import cv2
19+
from PIL import Image
2020
from tqdm import tqdm
2121

2222
colors = loadmat('data/color150.mat')['colors']
@@ -36,10 +36,7 @@ def visualize_result(data, pred, dir_result):
3636
axis=1).astype(np.uint8)
3737

3838
img_name = info.split('/')[-1]
39-
cv2.imwrite(
40-
os.path.join(dir_result, img_name.replace('.jpg', '.png')),
41-
im_vis
42-
)
39+
Image.fromarray(im_vis).save(os.path.join(dir_result, img_name.replace('.jpg', '.png')))
4340

4441

4542
def evaluate(segmentation_module, loader, cfg, gpu_id, result_queue):

0 commit comments

Comments
 (0)