Skip to content

Commit 6324799

Browse files
author
Tete Xiao
committed
initial commit
1 parent 51fe6fd commit 6324799

27 files changed

+24457
-798
lines changed

data/train.odgt

Lines changed: 20210 additions & 0 deletions
Large diffs are not rendered by default.

data/validation.odgt

Lines changed: 2000 additions & 0 deletions
Large diffs are not rendered by default.

dataset.py

Lines changed: 198 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -1,119 +1,226 @@
11
import os
2-
import random
3-
import numpy as np
2+
import json
43
import torch
5-
import torch.utils.data as torchdata
4+
import lib.utils.data as torchdata
5+
import cv2
66
from torchvision import transforms
77
from scipy.misc import imread, imresize
8+
import numpy as np
89

10+
# Round x to the nearest multiple of p and x' >= x
11+
def round2nearest_multiple(x, p):
12+
return ((x - 1) // p + 1) * p
913

10-
class Dataset(torchdata.Dataset):
11-
def __init__(self, txt, opt, max_sample=-1, is_train=1):
12-
self.root_img = opt.root_img
13-
self.root_seg = opt.root_seg
14+
class TrainDataset(torchdata.Dataset):
15+
def __init__(self, odgt, opt, max_sample=-1, batch_per_gpu=1):
16+
self.root_dataset = opt.root_dataset
1417
self.imgSize = opt.imgSize
15-
self.segSize = opt.segSize
16-
self.is_train = is_train
18+
self.imgMaxSize = opt.imgMaxSize
19+
self.random_flip = opt.random_flip
20+
# max down sampling rate of network to avoid rounding during conv or pooling
21+
self.padding_constant = opt.padding_constant
22+
# down sampling rate of segm labe
23+
self.segm_downsampling_rate = opt.segm_downsampling_rate
24+
self.batch_per_gpu = batch_per_gpu
25+
26+
# classify images into two classes: 1. h > w and 2. h <= w
27+
self.batch_record_list = [[], []]
28+
29+
# override dataset length when trainig with batch_per_gpu > 1
30+
self.cur_idx = 0
1731

1832
# mean and std
1933
self.img_transform = transforms.Compose([
20-
transforms.Normalize(mean=[0.485, 0.456, 0.406],
21-
std=[0.229, 0.224, 0.225])])
34+
transforms.Normalize(mean=[102.9801, 115.9465, 122.7717], std=[1., 1., 1.])
35+
])
2236

23-
self.list_sample = [x.rstrip() for x in open(txt, 'r')]
37+
self.list_sample = [json.loads(x.rstrip()) for x in open(odgt, 'r')]
2438

25-
if self.is_train:
26-
random.shuffle(self.list_sample)
39+
np.random.shuffle(self.list_sample)
2740
if max_sample > 0:
2841
self.list_sample = self.list_sample[0:max_sample]
29-
num_sample = len(self.list_sample)
30-
assert num_sample > 0
31-
print('# samples: {}'.format(num_sample))
42+
self.num_sample = len(self.list_sample)
43+
assert self.num_sample > 0
44+
print('# samples: {}'.format(self.num_sample))
45+
46+
def _get_sub_batch(self):
47+
while True:
48+
# get a sample record
49+
this_sample = self.list_sample[self.cur_idx]
50+
if this_sample['height'] > this_sample['width']:
51+
self.batch_record_list[0].append(this_sample) # h > w, go to 1st class
52+
else:
53+
self.batch_record_list[1].append(this_sample) # h <= w, go to 2nd class
54+
55+
# update current sample pointer
56+
self.cur_idx += 1
57+
if self.cur_idx >= self.num_sample:
58+
self.cur_idx = 0
59+
np.random.shuffle(self.list_sample)
60+
61+
if len(self.batch_record_list[0]) == self.batch_per_gpu:
62+
batch_records = self.batch_record_list[0]
63+
self.batch_record_list[0] = []
64+
break
65+
elif len(self.batch_record_list[1]) == self.batch_per_gpu:
66+
batch_records = self.batch_record_list[1]
67+
self.batch_record_list[1] = []
68+
break
69+
return batch_records
3270

33-
def _scale_and_crop(self, img, seg, cropSize, is_train):
34-
h, w = img.shape[0], img.shape[1]
35-
36-
if is_train:
37-
# random scale
38-
scale = random.random() + 0.5 # 0.5-1.5
39-
scale = max(scale, 1. * cropSize / (min(h, w) - 1))
71+
def __getitem__(self, index):
72+
# get sub-batch candidates
73+
batch_records = self._get_sub_batch()
74+
75+
# resize all images' short edges to the chosen size
76+
if isinstance(self.imgSize, list):
77+
this_short_size = np.random.choice(self.imgSize)
4078
else:
41-
# scale to crop size
42-
scale = 1. * cropSize / (min(h, w) - 1)
79+
this_short_size = self.imgSize
80+
81+
# calculate the BATCH's height and width
82+
# since we concat more than one samples, the batch's h and w shall be larger than EACH sample
83+
batch_resized_size = np.zeros((self.batch_per_gpu, 2), np.int32)
84+
for i in range(self.batch_per_gpu):
85+
img_height, img_width = batch_records[i]['height'], batch_records[i]['width']
86+
this_scale = min(this_short_size / min(img_height, img_width), \
87+
self.imgMaxSize / max(img_height, img_width))
88+
img_resized_height, img_resized_width = img_height * this_scale, img_width * this_scale
89+
batch_resized_size[i, :] = img_resized_height, img_resized_width
90+
batch_resized_height = np.max(batch_resized_size[:, 0])
91+
batch_resized_width = np.max(batch_resized_size[:, 1])
92+
93+
# Here we must pad both input image and segmentation map to size h' and w' so that p | h' and p | w'
94+
batch_resized_height = int(round2nearest_multiple(batch_resized_height, self.padding_constant))
95+
batch_resized_width = int(round2nearest_multiple(batch_resized_width, self.padding_constant))
96+
97+
assert self.padding_constant >= self.segm_downsampling_rate,\
98+
'padding constant must be equal or large than segm downsamping rate'
99+
batch_images = torch.zeros(self.batch_per_gpu, 3, batch_resized_height, batch_resized_width)
100+
batch_segms = torch.zeros(self.batch_per_gpu, batch_resized_height // self.segm_downsampling_rate, \
101+
batch_resized_width // self.segm_downsampling_rate).long()
102+
103+
for i in range(self.batch_per_gpu):
104+
this_record = batch_records[i]
105+
106+
# load image and label
107+
image_path = os.path.join(self.root_dataset, this_record['fpath_img'])
108+
segm_path = os.path.join(self.root_dataset, this_record['fpath_segm'])
109+
img = imread(image_path, mode='RGB')
110+
segm = imread(segm_path)
43111

44-
img_scale = imresize(img, scale, interp='bilinear')
45-
seg_scale = imresize(seg, scale, interp='nearest')
112+
assert(img.ndim == 3)
113+
assert(segm.ndim == 2)
114+
assert(img.shape[0] == segm.shape[0])
115+
assert(img.shape[1] == segm.shape[1])
116+
117+
if self.random_flip == True:
118+
random_flip = np.random.choice([0, 1])
119+
if random_flip == 1:
120+
img = cv2.flip(img, 1)
121+
segm = cv2.flip(segm, 1)
122+
123+
# note that each sample within a mini batch has different scale param
124+
img = imresize(img, (batch_resized_size[i, 0], batch_resized_size[i, 1]), interp='bilinear')
125+
segm = imresize(segm, (batch_resized_size[i, 0], batch_resized_size[i, 1]), interp='nearest')
126+
127+
# to avoid seg label misalignment
128+
segm_rounded_height = round2nearest_multiple(segm.shape[0], self.segm_downsampling_rate)
129+
segm_rounded_width = round2nearest_multiple(segm.shape[1], self.segm_downsampling_rate)
130+
segm_rounded = np.zeros((segm_rounded_height, segm_rounded_width), dtype='uint8')
131+
segm_rounded[:segm.shape[0], :segm.shape[1]] = segm
132+
133+
segm = imresize(segm_rounded, (segm_rounded.shape[0] // self.segm_downsampling_rate, \
134+
segm_rounded.shape[1] // self.segm_downsampling_rate), \
135+
interp='nearest')
136+
# image to float
137+
img = img.astype(np.float32)[:, :, ::-1] # RGB to BGR!!!
138+
img = img.transpose((2, 0, 1))
139+
img = self.img_transform(torch.from_numpy(img.copy()))
140+
141+
batch_images[i][:, :img.shape[1], :img.shape[2]] = img
142+
batch_segms[i][:segm.shape[0], :segm.shape[1]] = torch.from_numpy(segm.astype(np.int)).long()
143+
144+
batch_segms = batch_segms - 1 # label from -1 to 149
145+
output = dict()
146+
output['img_data'] = batch_images
147+
output['seg_label'] = batch_segms
148+
return output
46149

47-
h_s, w_s = img_scale.shape[0], img_scale.shape[1]
48-
if is_train:
49-
# random crop
50-
x1 = random.randint(0, w_s - cropSize)
51-
y1 = random.randint(0, h_s - cropSize)
52-
else:
53-
# center crop
54-
x1 = (w_s - cropSize) // 2
55-
y1 = (h_s - cropSize) // 2
150+
def __len__(self):
151+
return int(1e6) # It's a fake length due to the trick that every loader maintains its own list
152+
#return self.num_sampleclass
56153

57-
img_crop = img_scale[y1: y1 + cropSize, x1: x1 + cropSize, :]
58-
seg_crop = seg_scale[y1: y1 + cropSize, x1: x1 + cropSize]
59-
return img_crop, seg_crop
60154

61-
def _flip(self, img, seg):
62-
img_flip = img[:, ::-1, :]
63-
seg_flip = seg[:, ::-1]
64-
return img_flip, seg_flip
155+
class ValDataset(torchdata.Dataset):
156+
def __init__(self, odgt, opt, max_sample=-1):
157+
self.root_dataset = opt.root_dataset
158+
self.imgSize = opt.imgSize
159+
self.imgMaxSize = opt.imgMaxSize
160+
# max down sampling rate of network to avoid rounding during conv or pooling
161+
self.padding_constant = opt.padding_constant
162+
# down sampling rate of segm labe
163+
self.segm_downsampling_rate = opt.segm_downsampling_rate
65164

66-
def __getitem__(self, index):
67-
img_basename = self.list_sample[index]
68-
path_img = os.path.join(self.root_img, img_basename)
69-
path_seg = os.path.join(self.root_seg,
70-
img_basename.replace('.jpg', '.png'))
165+
# mean and std
166+
self.img_transform = transforms.Compose([
167+
transforms.Normalize(mean=[102.9801, 115.9465, 122.7717], std=[1., 1., 1.])
168+
])
71169

72-
assert os.path.exists(path_img), '[{}] does not exist'.format(path_img)
73-
assert os.path.exists(path_seg), '[{}] does not exist'.format(path_seg)
170+
self.list_sample = [json.loads(x.rstrip()) for x in open(odgt, 'r')]
74171

75-
# load image and label
76-
try:
77-
img = imread(path_img, mode='RGB')
78-
seg = imread(path_seg)
79-
assert(img.ndim == 3)
80-
assert(seg.ndim == 2)
81-
assert(img.shape[0] == seg.shape[0])
82-
assert(img.shape[1] == seg.shape[1])
172+
if max_sample > 0:
173+
self.list_sample = self.list_sample[0:max_sample]
174+
self.num_sample = len(self.list_sample)
175+
assert self.num_sample > 0
176+
print('# samples: {}'.format(self.num_sample))
83177

84-
# random scale, crop, flip
85-
if self.imgSize > 0:
86-
img, seg = self._scale_and_crop(img, seg,
87-
self.imgSize, self.is_train)
88-
if random.choice([-1, 1]) > 0:
89-
img, seg = self._flip(img, seg)
90178

179+
def __getitem__(self, index):
180+
this_record = self.list_sample[index]
181+
# load image and label
182+
image_path = os.path.join(self.root_dataset, this_record['fpath_img'])
183+
segm_path = os.path.join(self.root_dataset, this_record['fpath_segm'])
184+
img = imread(image_path, mode='RGB')
185+
img = img[:, :, ::-1] # BGR to RGB!!!
186+
segm = imread(segm_path)
187+
188+
ori_height, ori_width, _ = img.shape
189+
190+
img_resized_list = []
191+
for this_short_size in self.imgSize:
192+
# calculate target height and width
193+
scale = min(this_short_size / float(min(ori_height, ori_width)),
194+
self.imgMaxSize / float(max(ori_height, ori_width)))
195+
target_height, target_width = int(ori_height * scale), int(ori_width * scale)
196+
197+
# to avoid rounding in network
198+
target_height = round2nearest_multiple(target_height, self.padding_constant)
199+
target_width = round2nearest_multiple(target_width, self.padding_constant)
200+
201+
# resize
202+
img_resized = cv2.resize(img.copy(), (target_width, target_height))
203+
91204
# image to float
92-
img = img.astype(np.float32) / 255.
93-
img = img.transpose((2, 0, 1))
94-
95-
if self.segSize > 0:
96-
seg = imresize(seg, (self.segSize, self.segSize),
97-
interp='nearest')
98-
99-
# label to int from -1 to 149
100-
seg = seg.astype(np.int) - 1
101-
102-
# to torch tensor
103-
image = torch.from_numpy(img)
104-
segmentation = torch.from_numpy(seg)
105-
except Exception as e:
106-
print('Failed loading image/segmentation [{}]: {}'
107-
.format(path_img, e))
108-
# dummy data
109-
image = torch.zeros(3, self.imgSize, self.imgSize)
110-
segmentation = -1 * torch.ones(self.segSize, self.segSize).long()
111-
return image, segmentation, img_basename
112-
113-
# substracted by mean and divided by std
114-
image = self.img_transform(image)
115-
116-
return image, segmentation, img_basename
205+
img_resized = img_resized.astype(np.float32)
206+
img_resized = img_resized.transpose((2, 0, 1))
207+
img_resized = self.img_transform(torch.from_numpy(img_resized))
208+
209+
img_resized = torch.unsqueeze(img_resized, 0)
210+
img_resized_list.append(img_resized)
211+
212+
segm = torch.from_numpy(segm.astype(np.int)).long()
213+
214+
batch_segms = torch.unsqueeze(segm, 0)
215+
216+
batch_segms = batch_segms - 1 # label from -1 to 149
217+
output = dict()
218+
output['img_ori'] = img.copy()
219+
output['img_data'] = [x.contiguous() for x in img_resized_list]
220+
output['seg_label'] = batch_segms.contiguous()
221+
output['info'] = this_record['fpath_img']
222+
return output
117223

118224
def __len__(self):
119-
return len(self.list_sample)
225+
return self.num_sample
226+

demo_test.sh

Lines changed: 0 additions & 15 deletions
This file was deleted.

0 commit comments

Comments
 (0)