Skip to content

Commit 71c559e

Browse files
committed
replace cv2 with PIL
1 parent 42b7567 commit 71c559e

File tree

4 files changed

+84
-82
lines changed

4 files changed

+84
-82
lines changed

dataset.py

Lines changed: 75 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,22 @@
11
import os
22
import json
33
import torch
4-
import cv2
54
from torchvision import transforms
65
import numpy as np
7-
import PIL
6+
from PIL import Image
87

98

109
def imresize(im, size, interp='bilinear'):
1110
if interp == 'nearest':
12-
resample = PIL.Image.NEAREST
11+
resample = Image.NEAREST
1312
elif interp == 'bilinear':
14-
resample = PIL.Image.BILINEAR
13+
resample = Image.BILINEAR
1514
elif interp == 'bicubic':
16-
resample = PIL.Image.BICUBIC
15+
resample = Image.BICUBIC
1716
else:
1817
raise Exception('resample method undefined!')
1918

20-
return np.array(
21-
PIL.Image.fromarray(im).resize((size[1], size[0]), resample)
22-
)
19+
return im.resize(size, resample)
2320

2421

2522
class BaseDataset(torch.utils.data.Dataset):
@@ -35,7 +32,7 @@ def __init__(self, odgt, opt, **kwargs):
3532

3633
# mean and std
3734
self.normalize = transforms.Normalize(
38-
mean=[102.9801, 115.9465, 122.7717],
35+
mean=[122.7717/255., 115.9465/255., 102.9801/255.],
3936
std=[1., 1., 1.])
4037

4138
def parse_input_list(self, odgt, max_sample=-1, start_idx=-1, end_idx=-1):
@@ -54,12 +51,17 @@ def parse_input_list(self, odgt, max_sample=-1, start_idx=-1, end_idx=-1):
5451
print('# samples: {}'.format(self.num_sample))
5552

5653
def img_transform(self, img):
57-
# image to float
58-
img = img.astype(np.float32)
54+
# 0-255 to 0-1
55+
img = np.float32(np.array(img)) / 255.
5956
img = img.transpose((2, 0, 1))
6057
img = self.normalize(torch.from_numpy(img.copy()))
6158
return img
6259

60+
def segm_transform(self, segm):
61+
# to tensor, -1 to 149
62+
segm = torch.from_numpy(np.array(segm)).long() - 1
63+
return segm
64+
6365
# Round x to the nearest multiple of p and x' >= x
6466
def round2nearest_multiple(self, x, p):
6567
return ((x - 1) // p + 1) * p
@@ -69,7 +71,6 @@ class TrainDataset(BaseDataset):
6971
def __init__(self, root_dataset, odgt, opt, batch_per_gpu=1, **kwargs):
7072
super(TrainDataset, self).__init__(odgt, opt, **kwargs)
7173
self.root_dataset = root_dataset
72-
self.random_flip = opt.random_flip
7374
# down sampling rate of segm labe
7475
self.segm_downsampling_rate = opt.segm_downsampling_rate
7576
self.batch_per_gpu = batch_per_gpu
@@ -124,71 +125,74 @@ def __getitem__(self, index):
124125

125126
# calculate the BATCH's height and width
126127
# since we concat more than one samples, the batch's h and w shall be larger than EACH sample
127-
batch_resized_size = np.zeros((self.batch_per_gpu, 2), np.int32)
128+
batch_widths = np.zeros(self.batch_per_gpu, np.int32)
129+
batch_heights = np.zeros(self.batch_per_gpu, np.int32)
128130
for i in range(self.batch_per_gpu):
129131
img_height, img_width = batch_records[i]['height'], batch_records[i]['width']
130132
this_scale = min(
131133
this_short_size / min(img_height, img_width), \
132134
self.imgMaxSize / max(img_height, img_width))
133-
img_resized_height, img_resized_width = img_height * this_scale, img_width * this_scale
134-
batch_resized_size[i, :] = img_resized_height, img_resized_width
135-
batch_resized_height = np.max(batch_resized_size[:, 0])
136-
batch_resized_width = np.max(batch_resized_size[:, 1])
135+
batch_widths[i] = img_width * this_scale
136+
batch_heights[i] = img_height * this_scale
137137

138138
# Here we must pad both input image and segmentation map to size h' and w' so that p | h' and p | w'
139-
batch_resized_height = int(self.round2nearest_multiple(batch_resized_height, self.padding_constant))
140-
batch_resized_width = int(self.round2nearest_multiple(batch_resized_width, self.padding_constant))
141-
142-
assert self.padding_constant >= self.segm_downsampling_rate,\
143-
'padding constant must be equal or large than segm downsamping rate'
144-
batch_images = torch.zeros(self.batch_per_gpu, 3, batch_resized_height, batch_resized_width)
139+
batch_width = np.max(batch_widths)
140+
batch_height = np.max(batch_heights)
141+
batch_width = int(self.round2nearest_multiple(batch_width, self.padding_constant))
142+
batch_height = int(self.round2nearest_multiple(batch_height, self.padding_constant))
143+
144+
assert self.padding_constant >= self.segm_downsampling_rate, \
145+
'padding constant must be equal or large than segm downsamping rate'
146+
batch_images = torch.zeros(
147+
self.batch_per_gpu, 3, batch_height, batch_width)
145148
batch_segms = torch.zeros(
146-
self.batch_per_gpu, batch_resized_height // self.segm_downsampling_rate, \
147-
batch_resized_width // self.segm_downsampling_rate).long()
149+
self.batch_per_gpu,
150+
batch_height // self.segm_downsampling_rate,
151+
batch_width // self.segm_downsampling_rate).long()
148152

149153
for i in range(self.batch_per_gpu):
150154
this_record = batch_records[i]
151155

152156
# load image and label
153157
image_path = os.path.join(self.root_dataset, this_record['fpath_img'])
154158
segm_path = os.path.join(self.root_dataset, this_record['fpath_segm'])
155-
img = cv2.imread(image_path, cv2.IMREAD_COLOR)
156-
segm = cv2.imread(segm_path, cv2.IMREAD_GRAYSCALE)
157159

158-
assert(img.ndim == 3)
159-
assert(segm.ndim == 2)
160-
assert(img.shape[0] == segm.shape[0])
161-
assert(img.shape[1] == segm.shape[1])
160+
img = Image.open(image_path).convert('RGB')
161+
segm = Image.open(segm_path)
162+
assert(segm.mode == "L")
163+
assert(img.size[0] == segm.size[0])
164+
assert(img.size[1] == segm.size[1])
162165

163-
if self.random_flip is True:
164-
random_flip = np.random.choice([0, 1])
165-
if random_flip == 1:
166-
img = cv2.flip(img, 1)
167-
segm = cv2.flip(segm, 1)
166+
# random_flip
167+
if np.random.choice([0, 1]):
168+
img = img.transpose(Image.FLIP_LEFT_RIGHT)
169+
segm = segm.transpose(Image.FLIP_LEFT_RIGHT)
168170

169171
# note that each sample within a mini batch has different scale param
170-
img = imresize(img, (batch_resized_size[i, 0], batch_resized_size[i, 1]), interp='bilinear')
171-
segm = imresize(segm, (batch_resized_size[i, 0], batch_resized_size[i, 1]), interp='nearest')
172-
173-
# to avoid seg label misalignment
174-
segm_rounded_height = self.round2nearest_multiple(segm.shape[0], self.segm_downsampling_rate)
175-
segm_rounded_width = self.round2nearest_multiple(segm.shape[1], self.segm_downsampling_rate)
176-
segm_rounded = np.zeros((segm_rounded_height, segm_rounded_width), dtype='uint8')
177-
segm_rounded[:segm.shape[0], :segm.shape[1]] = segm
178-
172+
img = imresize(img, (batch_widths[i], batch_heights[i]), interp='bilinear')
173+
segm = imresize(segm, (batch_widths[i], batch_heights[i]), interp='nearest')
174+
175+
# further downsample seg label, need to avoid seg label misalignment
176+
segm_rounded_width = self.round2nearest_multiple(segm.size[0], self.segm_downsampling_rate)
177+
segm_rounded_height = self.round2nearest_multiple(segm.size[1], self.segm_downsampling_rate)
178+
segm_rounded = Image.new('L', (segm_rounded_width, segm_rounded_height), 0)
179+
segm_rounded.paste(segm, (0, 0))
179180
segm = imresize(
180181
segm_rounded,
181-
(segm_rounded.shape[0] // self.segm_downsampling_rate, \
182-
segm_rounded.shape[1] // self.segm_downsampling_rate), \
182+
(segm_rounded.size[0] // self.segm_downsampling_rate, \
183+
segm_rounded.size[1] // self.segm_downsampling_rate), \
183184
interp='nearest')
184185

185-
# image transform
186+
# image transform, to torch float tensor 3xHxW
186187
img = self.img_transform(img)
187188

189+
# segm transform, to torch long tensor HxW
190+
segm = self.segm_transform(segm)
191+
192+
# put into batch arrays
188193
batch_images[i][:, :img.shape[1], :img.shape[2]] = img
189-
batch_segms[i][:segm.shape[0], :segm.shape[1]] = torch.from_numpy(segm.astype(np.int)).long()
194+
batch_segms[i][:segm.shape[0], :segm.shape[1]] = segm
190195

191-
batch_segms = batch_segms - 1 # label from -1 to 149
192196
output = dict()
193197
output['img_data'] = batch_images
194198
output['seg_label'] = batch_segms
@@ -209,10 +213,13 @@ def __getitem__(self, index):
209213
# load image and label
210214
image_path = os.path.join(self.root_dataset, this_record['fpath_img'])
211215
segm_path = os.path.join(self.root_dataset, this_record['fpath_segm'])
212-
img = cv2.imread(image_path, cv2.IMREAD_COLOR)
213-
segm = cv2.imread(segm_path, cv2.IMREAD_GRAYSCALE)
216+
img = Image.open(image_path).convert('RGB')
217+
segm = Image.open(segm_path)
218+
assert(segm.mode == "L")
219+
assert(img.size[0] == segm.size[0])
220+
assert(img.size[1] == segm.size[1])
214221

215-
ori_height, ori_width, _ = img.shape
222+
ori_width, ori_height = img.size
216223

217224
img_resized_list = []
218225
for this_short_size in self.imgSizes:
@@ -222,24 +229,23 @@ def __getitem__(self, index):
222229
target_height, target_width = int(ori_height * scale), int(ori_width * scale)
223230

224231
# to avoid rounding in network
225-
target_height = self.round2nearest_multiple(target_height, self.padding_constant)
226232
target_width = self.round2nearest_multiple(target_width, self.padding_constant)
233+
target_height = self.round2nearest_multiple(target_height, self.padding_constant)
227234

228-
# resize
229-
img_resized = cv2.resize(img.copy(), (target_width, target_height))
235+
# resize images
236+
img_resized = imresize(img, (target_width, target_height), interp='bilinear')
230237

231-
# image transform
238+
# image transform, to torch float tensor 3xHxW
232239
img_resized = self.img_transform(img_resized)
233-
234240
img_resized = torch.unsqueeze(img_resized, 0)
235241
img_resized_list.append(img_resized)
236242

237-
segm = torch.from_numpy(segm.astype(np.int)).long()
243+
# segm transform, to torch long tensor HxW
244+
segm = self.segm_transform(segm)
238245
batch_segms = torch.unsqueeze(segm, 0)
239246

240-
batch_segms = batch_segms - 1 # label from -1 to 149
241247
output = dict()
242-
output['img_ori'] = img.copy()
248+
output['img_ori'] = np.array(img)
243249
output['img_data'] = [x.contiguous() for x in img_resized_list]
244250
output['seg_label'] = batch_segms.contiguous()
245251
output['info'] = this_record['fpath_img']
@@ -255,11 +261,11 @@ def __init__(self, odgt, opt, **kwargs):
255261

256262
def __getitem__(self, index):
257263
this_record = self.list_sample[index]
258-
# load image and label
264+
# load image
259265
image_path = this_record['fpath_img']
260-
img = cv2.imread(image_path, cv2.IMREAD_COLOR)
266+
img = Image.open(image_path).convert('RGB')
261267

262-
ori_height, ori_width, _ = img.shape
268+
ori_width, ori_height = img.size
263269

264270
img_resized_list = []
265271
for this_short_size in self.imgSizes:
@@ -269,19 +275,19 @@ def __getitem__(self, index):
269275
target_height, target_width = int(ori_height * scale), int(ori_width * scale)
270276

271277
# to avoid rounding in network
272-
target_height = self.round2nearest_multiple(target_height, self.padding_constant)
273278
target_width = self.round2nearest_multiple(target_width, self.padding_constant)
279+
target_height = self.round2nearest_multiple(target_height, self.padding_constant)
274280

275-
# resize
276-
img_resized = cv2.resize(img.copy(), (target_width, target_height))
281+
# resize images
282+
img_resized = imresize(img, (target_width, target_height), interp='bilinear')
277283

278-
# image transform
284+
# image transform, to torch float tensor 3xHxW
279285
img_resized = self.img_transform(img_resized)
280286
img_resized = torch.unsqueeze(img_resized, 0)
281287
img_resized_list.append(img_resized)
282288

283289
output = dict()
284-
output['img_ori'] = img.copy()
290+
output['img_ori'] = np.array(img)
285291
output['img_data'] = [x.contiguous() for x in img_resized_list]
286292
output['info'] = this_record['fpath_img']
287293
return output

eval.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from utils import AverageMeter, colorEncode, accuracy, intersectionAndUnion, setup_logger
1616
from lib.nn import user_scattered_collate, async_copy_to
1717
from lib.utils import as_numpy
18-
import cv2
18+
from PIL import Image
1919
from tqdm import tqdm
2020

2121
colors = loadmat('data/color150.mat')['colors']
@@ -35,10 +35,7 @@ def visualize_result(data, pred, dir_result):
3535
axis=1).astype(np.uint8)
3636

3737
img_name = info.split('/')[-1]
38-
cv2.imwrite(
39-
os.path.join(dir_result, img_name.replace('.jpg', '.png')),
40-
im_vis
41-
)
38+
Image.fromarray(im_vis).save(os.path.join(dir_result, img_name.replace('.jpg', '.png')))
4239

4340

4441
def evaluate(segmentation_module, loader, cfg, gpu):

eval_multipro.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from utils import AverageMeter, colorEncode, accuracy, intersectionAndUnion, parse_devices, setup_logger
1717
from lib.nn import user_scattered_collate, async_copy_to
1818
from lib.utils import as_numpy
19-
import cv2
19+
from PIL import Image
2020
from tqdm import tqdm
2121

2222
colors = loadmat('data/color150.mat')['colors']
@@ -36,10 +36,7 @@ def visualize_result(data, pred, dir_result):
3636
axis=1).astype(np.uint8)
3737

3838
img_name = info.split('/')[-1]
39-
cv2.imwrite(
40-
os.path.join(dir_result, img_name.replace('.jpg', '.png')),
41-
im_vis
42-
)
39+
Image.fromarray(im_vis).save(os.path.join(dir_result, img_name.replace('.jpg', '.png')))
4340

4441

4542
def evaluate(segmentation_module, loader, cfg, gpu_id, result_queue):
@@ -112,6 +109,8 @@ def worker(cfg, gpu_id, start_idx, end_idx, result_queue):
112109
weights=cfg.MODEL.weights_decoder,
113110
use_softmax=True)
114111

112+
net_encoder.features[0][0].weight.data = net_encoder.features[0][0].weight.data[:, (2,1,0), :, :] * 255.
113+
115114
crit = nn.NLLLoss(ignore_index=-1)
116115

117116
segmentation_module = SegmentationModule(net_encoder, net_decoder, crit)

test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from utils import colorEncode, find_recursive, setup_logger
1515
from lib.nn import user_scattered_collate, async_copy_to
1616
from lib.utils import as_numpy
17-
import cv2
17+
from PIL import Image
1818
from tqdm import tqdm
1919
from config import cfg
2020

@@ -48,8 +48,8 @@ def visualize_result(data, pred, cfg):
4848
im_vis = np.concatenate((img, pred_color), axis=1)
4949

5050
img_name = info.split('/')[-1]
51-
cv2.imwrite(os.path.join(cfg.TEST.result,
52-
img_name.replace('.jpg', '.png')), im_vis)
51+
Image.fromarray(im_vis).save(
52+
os.path.join(cfg.TEST.result, img_name.replace('.jpg', '.png')))
5353

5454

5555
def test(segmentation_module, loader, gpu):

0 commit comments

Comments
 (0)