11import os
22import json
33import torch
4- import cv2
54from torchvision import transforms
65import numpy as np
7- import PIL
6+ from PIL import Image
87
98
109def imresize (im , size , interp = 'bilinear' ):
1110 if interp == 'nearest' :
12- resample = PIL . Image .NEAREST
11+ resample = Image .NEAREST
1312 elif interp == 'bilinear' :
14- resample = PIL . Image .BILINEAR
13+ resample = Image .BILINEAR
1514 elif interp == 'bicubic' :
16- resample = PIL . Image .BICUBIC
15+ resample = Image .BICUBIC
1716 else :
1817 raise Exception ('resample method undefined!' )
1918
20- return np .array (
21- PIL .Image .fromarray (im ).resize ((size [1 ], size [0 ]), resample )
22- )
19+ return im .resize (size , resample )
2320
2421
2522class BaseDataset (torch .utils .data .Dataset ):
@@ -35,7 +32,7 @@ def __init__(self, odgt, opt, **kwargs):
3532
3633 # mean and std
3734 self .normalize = transforms .Normalize (
38- mean = [102.9801 , 115.9465 , 122.7717 ],
35+ mean = [122.7717 / 255. , 115.9465 / 255. , 102.9801 / 255. ],
3936 std = [1. , 1. , 1. ])
4037
4138 def parse_input_list (self , odgt , max_sample = - 1 , start_idx = - 1 , end_idx = - 1 ):
@@ -54,12 +51,17 @@ def parse_input_list(self, odgt, max_sample=-1, start_idx=-1, end_idx=-1):
5451 print ('# samples: {}' .format (self .num_sample ))
5552
5653 def img_transform (self , img ):
57- # image to float
58- img = img . astype (np .float32 )
54+ # 0-255 to 0-1
55+ img = np . float32 (np .array ( img )) / 255.
5956 img = img .transpose ((2 , 0 , 1 ))
6057 img = self .normalize (torch .from_numpy (img .copy ()))
6158 return img
6259
60+ def segm_transform (self , segm ):
61+ # to tensor, -1 to 149
62+ segm = torch .from_numpy (np .array (segm )).long () - 1
63+ return segm
64+
6365 # Round x to the nearest multiple of p and x' >= x
6466 def round2nearest_multiple (self , x , p ):
6567 return ((x - 1 ) // p + 1 ) * p
@@ -69,7 +71,6 @@ class TrainDataset(BaseDataset):
6971 def __init__ (self , root_dataset , odgt , opt , batch_per_gpu = 1 , ** kwargs ):
7072 super (TrainDataset , self ).__init__ (odgt , opt , ** kwargs )
7173 self .root_dataset = root_dataset
72- self .random_flip = opt .random_flip
7374 # down sampling rate of segm labe
7475 self .segm_downsampling_rate = opt .segm_downsampling_rate
7576 self .batch_per_gpu = batch_per_gpu
@@ -124,71 +125,74 @@ def __getitem__(self, index):
124125
125126 # calculate the BATCH's height and width
126127 # since we concat more than one samples, the batch's h and w shall be larger than EACH sample
127- batch_resized_size = np .zeros ((self .batch_per_gpu , 2 ), np .int32 )
128+ batch_widths = np .zeros (self .batch_per_gpu , np .int32 )
129+ batch_heights = np .zeros (self .batch_per_gpu , np .int32 )
128130 for i in range (self .batch_per_gpu ):
129131 img_height , img_width = batch_records [i ]['height' ], batch_records [i ]['width' ]
130132 this_scale = min (
131133 this_short_size / min (img_height , img_width ), \
132134 self .imgMaxSize / max (img_height , img_width ))
133- img_resized_height , img_resized_width = img_height * this_scale , img_width * this_scale
134- batch_resized_size [i , :] = img_resized_height , img_resized_width
135- batch_resized_height = np .max (batch_resized_size [:, 0 ])
136- batch_resized_width = np .max (batch_resized_size [:, 1 ])
135+ batch_widths [i ] = img_width * this_scale
136+ batch_heights [i ] = img_height * this_scale
137137
138138 # Here we must pad both input image and segmentation map to size h' and w' so that p | h' and p | w'
139- batch_resized_height = int (self .round2nearest_multiple (batch_resized_height , self .padding_constant ))
140- batch_resized_width = int (self .round2nearest_multiple (batch_resized_width , self .padding_constant ))
141-
142- assert self .padding_constant >= self .segm_downsampling_rate ,\
143- 'padding constant must be equal or large than segm downsamping rate'
144- batch_images = torch .zeros (self .batch_per_gpu , 3 , batch_resized_height , batch_resized_width )
139+ batch_width = np .max (batch_widths )
140+ batch_height = np .max (batch_heights )
141+ batch_width = int (self .round2nearest_multiple (batch_width , self .padding_constant ))
142+ batch_height = int (self .round2nearest_multiple (batch_height , self .padding_constant ))
143+
144+ assert self .padding_constant >= self .segm_downsampling_rate , \
145+ 'padding constant must be equal or large than segm downsamping rate'
146+ batch_images = torch .zeros (
147+ self .batch_per_gpu , 3 , batch_height , batch_width )
145148 batch_segms = torch .zeros (
146- self .batch_per_gpu , batch_resized_height // self .segm_downsampling_rate , \
147- batch_resized_width // self .segm_downsampling_rate ).long ()
149+ self .batch_per_gpu ,
150+ batch_height // self .segm_downsampling_rate ,
151+ batch_width // self .segm_downsampling_rate ).long ()
148152
149153 for i in range (self .batch_per_gpu ):
150154 this_record = batch_records [i ]
151155
152156 # load image and label
153157 image_path = os .path .join (self .root_dataset , this_record ['fpath_img' ])
154158 segm_path = os .path .join (self .root_dataset , this_record ['fpath_segm' ])
155- img = cv2 .imread (image_path , cv2 .IMREAD_COLOR )
156- segm = cv2 .imread (segm_path , cv2 .IMREAD_GRAYSCALE )
157159
158- assert (img .ndim == 3 )
159- assert (segm .ndim == 2 )
160- assert (img .shape [0 ] == segm .shape [0 ])
161- assert (img .shape [1 ] == segm .shape [1 ])
160+ img = Image .open (image_path ).convert ('RGB' )
161+ segm = Image .open (segm_path )
162+ assert (segm .mode == "L" )
163+ assert (img .size [0 ] == segm .size [0 ])
164+ assert (img .size [1 ] == segm .size [1 ])
162165
163- if self .random_flip is True :
164- random_flip = np .random .choice ([0 , 1 ])
165- if random_flip == 1 :
166- img = cv2 .flip (img , 1 )
167- segm = cv2 .flip (segm , 1 )
166+ # random_flip
167+ if np .random .choice ([0 , 1 ]):
168+ img = img .transpose (Image .FLIP_LEFT_RIGHT )
169+ segm = segm .transpose (Image .FLIP_LEFT_RIGHT )
168170
169171 # note that each sample within a mini batch has different scale param
170- img = imresize (img , (batch_resized_size [i , 0 ], batch_resized_size [i , 1 ]), interp = 'bilinear' )
171- segm = imresize (segm , (batch_resized_size [i , 0 ], batch_resized_size [i , 1 ]), interp = 'nearest' )
172-
173- # to avoid seg label misalignment
174- segm_rounded_height = self .round2nearest_multiple (segm .shape [0 ], self .segm_downsampling_rate )
175- segm_rounded_width = self .round2nearest_multiple (segm .shape [1 ], self .segm_downsampling_rate )
176- segm_rounded = np .zeros ((segm_rounded_height , segm_rounded_width ), dtype = 'uint8' )
177- segm_rounded [:segm .shape [0 ], :segm .shape [1 ]] = segm
178-
172+ img = imresize (img , (batch_widths [i ], batch_heights [i ]), interp = 'bilinear' )
173+ segm = imresize (segm , (batch_widths [i ], batch_heights [i ]), interp = 'nearest' )
174+
175+ # further downsample seg label, need to avoid seg label misalignment
176+ segm_rounded_width = self .round2nearest_multiple (segm .size [0 ], self .segm_downsampling_rate )
177+ segm_rounded_height = self .round2nearest_multiple (segm .size [1 ], self .segm_downsampling_rate )
178+ segm_rounded = Image .new ('L' , (segm_rounded_width , segm_rounded_height ), 0 )
179+ segm_rounded .paste (segm , (0 , 0 ))
179180 segm = imresize (
180181 segm_rounded ,
181- (segm_rounded .shape [0 ] // self .segm_downsampling_rate , \
182- segm_rounded .shape [1 ] // self .segm_downsampling_rate ), \
182+ (segm_rounded .size [0 ] // self .segm_downsampling_rate , \
183+ segm_rounded .size [1 ] // self .segm_downsampling_rate ), \
183184 interp = 'nearest' )
184185
185- # image transform
186+ # image transform, to torch float tensor 3xHxW
186187 img = self .img_transform (img )
187188
189+ # segm transform, to torch long tensor HxW
190+ segm = self .segm_transform (segm )
191+
192+ # put into batch arrays
188193 batch_images [i ][:, :img .shape [1 ], :img .shape [2 ]] = img
189- batch_segms [i ][:segm .shape [0 ], :segm .shape [1 ]] = torch . from_numpy ( segm . astype ( np . int )). long ()
194+ batch_segms [i ][:segm .shape [0 ], :segm .shape [1 ]] = segm
190195
191- batch_segms = batch_segms - 1 # label from -1 to 149
192196 output = dict ()
193197 output ['img_data' ] = batch_images
194198 output ['seg_label' ] = batch_segms
@@ -209,10 +213,13 @@ def __getitem__(self, index):
209213 # load image and label
210214 image_path = os .path .join (self .root_dataset , this_record ['fpath_img' ])
211215 segm_path = os .path .join (self .root_dataset , this_record ['fpath_segm' ])
212- img = cv2 .imread (image_path , cv2 .IMREAD_COLOR )
213- segm = cv2 .imread (segm_path , cv2 .IMREAD_GRAYSCALE )
216+ img = Image .open (image_path ).convert ('RGB' )
217+ segm = Image .open (segm_path )
218+ assert (segm .mode == "L" )
219+ assert (img .size [0 ] == segm .size [0 ])
220+ assert (img .size [1 ] == segm .size [1 ])
214221
215- ori_height , ori_width , _ = img .shape
222+ ori_width , ori_height = img .size
216223
217224 img_resized_list = []
218225 for this_short_size in self .imgSizes :
@@ -222,24 +229,23 @@ def __getitem__(self, index):
222229 target_height , target_width = int (ori_height * scale ), int (ori_width * scale )
223230
224231 # to avoid rounding in network
225- target_height = self .round2nearest_multiple (target_height , self .padding_constant )
226232 target_width = self .round2nearest_multiple (target_width , self .padding_constant )
233+ target_height = self .round2nearest_multiple (target_height , self .padding_constant )
227234
228- # resize
229- img_resized = cv2 . resize (img . copy () , (target_width , target_height ))
235+ # resize images
236+ img_resized = imresize (img , (target_width , target_height ), interp = 'bilinear' )
230237
231- # image transform
238+ # image transform, to torch float tensor 3xHxW
232239 img_resized = self .img_transform (img_resized )
233-
234240 img_resized = torch .unsqueeze (img_resized , 0 )
235241 img_resized_list .append (img_resized )
236242
237- segm = torch .from_numpy (segm .astype (np .int )).long ()
243+ # segm transform, to torch long tensor HxW
244+ segm = self .segm_transform (segm )
238245 batch_segms = torch .unsqueeze (segm , 0 )
239246
240- batch_segms = batch_segms - 1 # label from -1 to 149
241247 output = dict ()
242- output ['img_ori' ] = img . copy ( )
248+ output ['img_ori' ] = np . array ( img )
243249 output ['img_data' ] = [x .contiguous () for x in img_resized_list ]
244250 output ['seg_label' ] = batch_segms .contiguous ()
245251 output ['info' ] = this_record ['fpath_img' ]
@@ -255,11 +261,11 @@ def __init__(self, odgt, opt, **kwargs):
255261
256262 def __getitem__ (self , index ):
257263 this_record = self .list_sample [index ]
258- # load image and label
264+ # load image
259265 image_path = this_record ['fpath_img' ]
260- img = cv2 . imread (image_path , cv2 . IMREAD_COLOR )
266+ img = Image . open (image_path ). convert ( 'RGB' )
261267
262- ori_height , ori_width , _ = img .shape
268+ ori_width , ori_height = img .size
263269
264270 img_resized_list = []
265271 for this_short_size in self .imgSizes :
@@ -269,19 +275,19 @@ def __getitem__(self, index):
269275 target_height , target_width = int (ori_height * scale ), int (ori_width * scale )
270276
271277 # to avoid rounding in network
272- target_height = self .round2nearest_multiple (target_height , self .padding_constant )
273278 target_width = self .round2nearest_multiple (target_width , self .padding_constant )
279+ target_height = self .round2nearest_multiple (target_height , self .padding_constant )
274280
275- # resize
276- img_resized = cv2 . resize (img . copy () , (target_width , target_height ))
281+ # resize images
282+ img_resized = imresize (img , (target_width , target_height ), interp = 'bilinear' )
277283
278- # image transform
284+ # image transform, to torch float tensor 3xHxW
279285 img_resized = self .img_transform (img_resized )
280286 img_resized = torch .unsqueeze (img_resized , 0 )
281287 img_resized_list .append (img_resized )
282288
283289 output = dict ()
284- output ['img_ori' ] = img . copy ( )
290+ output ['img_ori' ] = np . array ( img )
285291 output ['img_data' ] = [x .contiguous () for x in img_resized_list ]
286292 output ['info' ] = this_record ['fpath_img' ]
287293 return output
0 commit comments