|
1 | 1 | import os |
2 | | -import random |
3 | | -import numpy as np |
| 2 | +import json |
4 | 3 | import torch |
5 | | -import torch.utils.data as torchdata |
| 4 | +import lib.utils.data as torchdata |
| 5 | +import cv2 |
6 | 6 | from torchvision import transforms |
7 | 7 | from scipy.misc import imread, imresize |
| 8 | +import numpy as np |
8 | 9 |
|
| 10 | +# Round x to the nearest multiple of p and x' >= x |
| 11 | +def round2nearest_multiple(x, p): |
| 12 | + return ((x - 1) // p + 1) * p |
9 | 13 |
|
10 | | -class Dataset(torchdata.Dataset): |
11 | | - def __init__(self, txt, opt, max_sample=-1, is_train=1): |
12 | | - self.root_img = opt.root_img |
13 | | - self.root_seg = opt.root_seg |
| 14 | +class TrainDataset(torchdata.Dataset): |
| 15 | + def __init__(self, odgt, opt, max_sample=-1, batch_per_gpu=1): |
| 16 | + self.root_dataset = opt.root_dataset |
14 | 17 | self.imgSize = opt.imgSize |
15 | | - self.segSize = opt.segSize |
16 | | - self.is_train = is_train |
| 18 | + self.imgMaxSize = opt.imgMaxSize |
| 19 | + self.random_flip = opt.random_flip |
| 20 | + # max down sampling rate of network to avoid rounding during conv or pooling |
| 21 | + self.padding_constant = opt.padding_constant |
| 22 | + # down sampling rate of segm labe |
| 23 | + self.segm_downsampling_rate = opt.segm_downsampling_rate |
| 24 | + self.batch_per_gpu = batch_per_gpu |
| 25 | + |
| 26 | + # classify images into two classes: 1. h > w and 2. h <= w |
| 27 | + self.batch_record_list = [[], []] |
| 28 | + |
| 29 | + # override dataset length when trainig with batch_per_gpu > 1 |
| 30 | + self.cur_idx = 0 |
17 | 31 |
|
18 | 32 | # mean and std |
19 | 33 | self.img_transform = transforms.Compose([ |
20 | | - transforms.Normalize(mean=[0.485, 0.456, 0.406], |
21 | | - std=[0.229, 0.224, 0.225])]) |
| 34 | + transforms.Normalize(mean=[102.9801, 115.9465, 122.7717], std=[1., 1., 1.]) |
| 35 | + ]) |
22 | 36 |
|
23 | | - self.list_sample = [x.rstrip() for x in open(txt, 'r')] |
| 37 | + self.list_sample = [json.loads(x.rstrip()) for x in open(odgt, 'r')] |
24 | 38 |
|
25 | | - if self.is_train: |
26 | | - random.shuffle(self.list_sample) |
| 39 | + np.random.shuffle(self.list_sample) |
27 | 40 | if max_sample > 0: |
28 | 41 | self.list_sample = self.list_sample[0:max_sample] |
29 | | - num_sample = len(self.list_sample) |
30 | | - assert num_sample > 0 |
31 | | - print('# samples: {}'.format(num_sample)) |
| 42 | + self.num_sample = len(self.list_sample) |
| 43 | + assert self.num_sample > 0 |
| 44 | + print('# samples: {}'.format(self.num_sample)) |
| 45 | + |
| 46 | + def _get_sub_batch(self): |
| 47 | + while True: |
| 48 | + # get a sample record |
| 49 | + this_sample = self.list_sample[self.cur_idx] |
| 50 | + if this_sample['height'] > this_sample['width']: |
| 51 | + self.batch_record_list[0].append(this_sample) # h > w, go to 1st class |
| 52 | + else: |
| 53 | + self.batch_record_list[1].append(this_sample) # h <= w, go to 2nd class |
| 54 | + |
| 55 | + # update current sample pointer |
| 56 | + self.cur_idx += 1 |
| 57 | + if self.cur_idx >= self.num_sample: |
| 58 | + self.cur_idx = 0 |
| 59 | + np.random.shuffle(self.list_sample) |
| 60 | + |
| 61 | + if len(self.batch_record_list[0]) == self.batch_per_gpu: |
| 62 | + batch_records = self.batch_record_list[0] |
| 63 | + self.batch_record_list[0] = [] |
| 64 | + break |
| 65 | + elif len(self.batch_record_list[1]) == self.batch_per_gpu: |
| 66 | + batch_records = self.batch_record_list[1] |
| 67 | + self.batch_record_list[1] = [] |
| 68 | + break |
| 69 | + return batch_records |
32 | 70 |
|
33 | | - def _scale_and_crop(self, img, seg, cropSize, is_train): |
34 | | - h, w = img.shape[0], img.shape[1] |
35 | | - |
36 | | - if is_train: |
37 | | - # random scale |
38 | | - scale = random.random() + 0.5 # 0.5-1.5 |
39 | | - scale = max(scale, 1. * cropSize / (min(h, w) - 1)) |
| 71 | + def __getitem__(self, index): |
| 72 | + # get sub-batch candidates |
| 73 | + batch_records = self._get_sub_batch() |
| 74 | + |
| 75 | + # resize all images' short edges to the chosen size |
| 76 | + if isinstance(self.imgSize, list): |
| 77 | + this_short_size = np.random.choice(self.imgSize) |
40 | 78 | else: |
41 | | - # scale to crop size |
42 | | - scale = 1. * cropSize / (min(h, w) - 1) |
| 79 | + this_short_size = self.imgSize |
| 80 | + |
| 81 | + # calculate the BATCH's height and width |
| 82 | + # since we concat more than one samples, the batch's h and w shall be larger than EACH sample |
| 83 | + batch_resized_size = np.zeros((self.batch_per_gpu, 2), np.int32) |
| 84 | + for i in range(self.batch_per_gpu): |
| 85 | + img_height, img_width = batch_records[i]['height'], batch_records[i]['width'] |
| 86 | + this_scale = min(this_short_size / min(img_height, img_width), \ |
| 87 | + self.imgMaxSize / max(img_height, img_width)) |
| 88 | + img_resized_height, img_resized_width = img_height * this_scale, img_width * this_scale |
| 89 | + batch_resized_size[i, :] = img_resized_height, img_resized_width |
| 90 | + batch_resized_height = np.max(batch_resized_size[:, 0]) |
| 91 | + batch_resized_width = np.max(batch_resized_size[:, 1]) |
| 92 | + |
| 93 | + # Here we must pad both input image and segmentation map to size h' and w' so that p | h' and p | w' |
| 94 | + batch_resized_height = int(round2nearest_multiple(batch_resized_height, self.padding_constant)) |
| 95 | + batch_resized_width = int(round2nearest_multiple(batch_resized_width, self.padding_constant)) |
| 96 | + |
| 97 | + assert self.padding_constant >= self.segm_downsampling_rate,\ |
| 98 | + 'padding constant must be equal or large than segm downsamping rate' |
| 99 | + batch_images = torch.zeros(self.batch_per_gpu, 3, batch_resized_height, batch_resized_width) |
| 100 | + batch_segms = torch.zeros(self.batch_per_gpu, batch_resized_height // self.segm_downsampling_rate, \ |
| 101 | + batch_resized_width // self.segm_downsampling_rate).long() |
| 102 | + |
| 103 | + for i in range(self.batch_per_gpu): |
| 104 | + this_record = batch_records[i] |
| 105 | + |
| 106 | + # load image and label |
| 107 | + image_path = os.path.join(self.root_dataset, this_record['fpath_img']) |
| 108 | + segm_path = os.path.join(self.root_dataset, this_record['fpath_segm']) |
| 109 | + img = imread(image_path, mode='RGB') |
| 110 | + segm = imread(segm_path) |
43 | 111 |
|
44 | | - img_scale = imresize(img, scale, interp='bilinear') |
45 | | - seg_scale = imresize(seg, scale, interp='nearest') |
| 112 | + assert(img.ndim == 3) |
| 113 | + assert(segm.ndim == 2) |
| 114 | + assert(img.shape[0] == segm.shape[0]) |
| 115 | + assert(img.shape[1] == segm.shape[1]) |
| 116 | + |
| 117 | + if self.random_flip == True: |
| 118 | + random_flip = np.random.choice([0, 1]) |
| 119 | + if random_flip == 1: |
| 120 | + img = cv2.flip(img, 1) |
| 121 | + segm = cv2.flip(segm, 1) |
| 122 | + |
| 123 | + # note that each sample within a mini batch has different scale param |
| 124 | + img = imresize(img, (batch_resized_size[i, 0], batch_resized_size[i, 1]), interp='bilinear') |
| 125 | + segm = imresize(segm, (batch_resized_size[i, 0], batch_resized_size[i, 1]), interp='nearest') |
| 126 | + |
| 127 | + # to avoid seg label misalignment |
| 128 | + segm_rounded_height = round2nearest_multiple(segm.shape[0], self.segm_downsampling_rate) |
| 129 | + segm_rounded_width = round2nearest_multiple(segm.shape[1], self.segm_downsampling_rate) |
| 130 | + segm_rounded = np.zeros((segm_rounded_height, segm_rounded_width), dtype='uint8') |
| 131 | + segm_rounded[:segm.shape[0], :segm.shape[1]] = segm |
| 132 | + |
| 133 | + segm = imresize(segm_rounded, (segm_rounded.shape[0] // self.segm_downsampling_rate, \ |
| 134 | + segm_rounded.shape[1] // self.segm_downsampling_rate), \ |
| 135 | + interp='nearest') |
| 136 | + # image to float |
| 137 | + img = img.astype(np.float32)[:, :, ::-1] # RGB to BGR!!! |
| 138 | + img = img.transpose((2, 0, 1)) |
| 139 | + img = self.img_transform(torch.from_numpy(img.copy())) |
| 140 | + |
| 141 | + batch_images[i][:, :img.shape[1], :img.shape[2]] = img |
| 142 | + batch_segms[i][:segm.shape[0], :segm.shape[1]] = torch.from_numpy(segm.astype(np.int)).long() |
| 143 | + |
| 144 | + batch_segms = batch_segms - 1 # label from -1 to 149 |
| 145 | + output = dict() |
| 146 | + output['img_data'] = batch_images |
| 147 | + output['seg_label'] = batch_segms |
| 148 | + return output |
46 | 149 |
|
47 | | - h_s, w_s = img_scale.shape[0], img_scale.shape[1] |
48 | | - if is_train: |
49 | | - # random crop |
50 | | - x1 = random.randint(0, w_s - cropSize) |
51 | | - y1 = random.randint(0, h_s - cropSize) |
52 | | - else: |
53 | | - # center crop |
54 | | - x1 = (w_s - cropSize) // 2 |
55 | | - y1 = (h_s - cropSize) // 2 |
| 150 | + def __len__(self): |
| 151 | + return int(1e6) # It's a fake length due to the trick that every loader maintains its own list |
| 152 | + #return self.num_sampleclass |
56 | 153 |
|
57 | | - img_crop = img_scale[y1: y1 + cropSize, x1: x1 + cropSize, :] |
58 | | - seg_crop = seg_scale[y1: y1 + cropSize, x1: x1 + cropSize] |
59 | | - return img_crop, seg_crop |
60 | 154 |
|
61 | | - def _flip(self, img, seg): |
62 | | - img_flip = img[:, ::-1, :] |
63 | | - seg_flip = seg[:, ::-1] |
64 | | - return img_flip, seg_flip |
| 155 | +class ValDataset(torchdata.Dataset): |
| 156 | + def __init__(self, odgt, opt, max_sample=-1): |
| 157 | + self.root_dataset = opt.root_dataset |
| 158 | + self.imgSize = opt.imgSize |
| 159 | + self.imgMaxSize = opt.imgMaxSize |
| 160 | + # max down sampling rate of network to avoid rounding during conv or pooling |
| 161 | + self.padding_constant = opt.padding_constant |
| 162 | + # down sampling rate of segm labe |
| 163 | + self.segm_downsampling_rate = opt.segm_downsampling_rate |
65 | 164 |
|
66 | | - def __getitem__(self, index): |
67 | | - img_basename = self.list_sample[index] |
68 | | - path_img = os.path.join(self.root_img, img_basename) |
69 | | - path_seg = os.path.join(self.root_seg, |
70 | | - img_basename.replace('.jpg', '.png')) |
| 165 | + # mean and std |
| 166 | + self.img_transform = transforms.Compose([ |
| 167 | + transforms.Normalize(mean=[102.9801, 115.9465, 122.7717], std=[1., 1., 1.]) |
| 168 | + ]) |
71 | 169 |
|
72 | | - assert os.path.exists(path_img), '[{}] does not exist'.format(path_img) |
73 | | - assert os.path.exists(path_seg), '[{}] does not exist'.format(path_seg) |
| 170 | + self.list_sample = [json.loads(x.rstrip()) for x in open(odgt, 'r')] |
74 | 171 |
|
75 | | - # load image and label |
76 | | - try: |
77 | | - img = imread(path_img, mode='RGB') |
78 | | - seg = imread(path_seg) |
79 | | - assert(img.ndim == 3) |
80 | | - assert(seg.ndim == 2) |
81 | | - assert(img.shape[0] == seg.shape[0]) |
82 | | - assert(img.shape[1] == seg.shape[1]) |
| 172 | + if max_sample > 0: |
| 173 | + self.list_sample = self.list_sample[0:max_sample] |
| 174 | + self.num_sample = len(self.list_sample) |
| 175 | + assert self.num_sample > 0 |
| 176 | + print('# samples: {}'.format(self.num_sample)) |
83 | 177 |
|
84 | | - # random scale, crop, flip |
85 | | - if self.imgSize > 0: |
86 | | - img, seg = self._scale_and_crop(img, seg, |
87 | | - self.imgSize, self.is_train) |
88 | | - if random.choice([-1, 1]) > 0: |
89 | | - img, seg = self._flip(img, seg) |
90 | 178 |
|
| 179 | + def __getitem__(self, index): |
| 180 | + this_record = self.list_sample[index] |
| 181 | + # load image and label |
| 182 | + image_path = os.path.join(self.root_dataset, this_record['fpath_img']) |
| 183 | + segm_path = os.path.join(self.root_dataset, this_record['fpath_segm']) |
| 184 | + img = imread(image_path, mode='RGB') |
| 185 | + img = img[:, :, ::-1] # BGR to RGB!!! |
| 186 | + segm = imread(segm_path) |
| 187 | + |
| 188 | + ori_height, ori_width, _ = img.shape |
| 189 | + |
| 190 | + img_resized_list = [] |
| 191 | + for this_short_size in self.imgSize: |
| 192 | + # calculate target height and width |
| 193 | + scale = min(this_short_size / float(min(ori_height, ori_width)), |
| 194 | + self.imgMaxSize / float(max(ori_height, ori_width))) |
| 195 | + target_height, target_width = int(ori_height * scale), int(ori_width * scale) |
| 196 | + |
| 197 | + # to avoid rounding in network |
| 198 | + target_height = round2nearest_multiple(target_height, self.padding_constant) |
| 199 | + target_width = round2nearest_multiple(target_width, self.padding_constant) |
| 200 | + |
| 201 | + # resize |
| 202 | + img_resized = cv2.resize(img.copy(), (target_width, target_height)) |
| 203 | + |
91 | 204 | # image to float |
92 | | - img = img.astype(np.float32) / 255. |
93 | | - img = img.transpose((2, 0, 1)) |
94 | | - |
95 | | - if self.segSize > 0: |
96 | | - seg = imresize(seg, (self.segSize, self.segSize), |
97 | | - interp='nearest') |
98 | | - |
99 | | - # label to int from -1 to 149 |
100 | | - seg = seg.astype(np.int) - 1 |
101 | | - |
102 | | - # to torch tensor |
103 | | - image = torch.from_numpy(img) |
104 | | - segmentation = torch.from_numpy(seg) |
105 | | - except Exception as e: |
106 | | - print('Failed loading image/segmentation [{}]: {}' |
107 | | - .format(path_img, e)) |
108 | | - # dummy data |
109 | | - image = torch.zeros(3, self.imgSize, self.imgSize) |
110 | | - segmentation = -1 * torch.ones(self.segSize, self.segSize).long() |
111 | | - return image, segmentation, img_basename |
112 | | - |
113 | | - # substracted by mean and divided by std |
114 | | - image = self.img_transform(image) |
115 | | - |
116 | | - return image, segmentation, img_basename |
| 205 | + img_resized = img_resized.astype(np.float32) |
| 206 | + img_resized = img_resized.transpose((2, 0, 1)) |
| 207 | + img_resized = self.img_transform(torch.from_numpy(img_resized)) |
| 208 | + |
| 209 | + img_resized = torch.unsqueeze(img_resized, 0) |
| 210 | + img_resized_list.append(img_resized) |
| 211 | + |
| 212 | + segm = torch.from_numpy(segm.astype(np.int)).long() |
| 213 | + |
| 214 | + batch_segms = torch.unsqueeze(segm, 0) |
| 215 | + |
| 216 | + batch_segms = batch_segms - 1 # label from -1 to 149 |
| 217 | + output = dict() |
| 218 | + output['img_ori'] = img.copy() |
| 219 | + output['img_data'] = [x.contiguous() for x in img_resized_list] |
| 220 | + output['seg_label'] = batch_segms.contiguous() |
| 221 | + output['info'] = this_record['fpath_img'] |
| 222 | + return output |
117 | 223 |
|
118 | 224 | def __len__(self): |
119 | | - return len(self.list_sample) |
| 225 | + return self.num_sample |
| 226 | + |
0 commit comments