Skip to content

Commit 2aa2cfb

Browse files
committed
pi thresholding
1 parent cf78f05 commit 2aa2cfb

File tree

7 files changed

+40
-210
lines changed

7 files changed

+40
-210
lines changed

run_mmod_coco.sh renamed to run_mmod.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
#!/usr/bin/env bash
22

33
python3 ./src/run.py \
4-
--bash_file="./run_mmod_coco.sh" \
5-
--result_dir="./result/`(date "+%Y%m%d%H%M%S")`-coco-320x320-mmod_res50fpn-preset" \
4+
--bash_file="./run_mmod.sh" \
5+
--result_dir="./result/`(date "+%Y%m%d%H%M%S")`-320x320-res50fpn" \
66
\
77
--global_args="{
88
'n_classes': 81, 'batch_size': 32,
99
'img_h': 320, 'img_w': 320,
1010
'coord_h': 10, 'coord_w': 10,
11-
'devices': [0, 1], 'main_device': 0,
11+
'devices': [0], 'main_device': 0,
1212
}" \
1313
--network_args="{
1414
'pretrained': True, 'backbone': 'res50fpn', 'fmap_ch': 256,

run_mmod_voc.sh

Lines changed: 0 additions & 67 deletions
This file was deleted.

src/lib/dataset.py

Lines changed: 0 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,15 @@
1-
import os
21
import abc
32
import traceback
43
import numpy as np
54
import scipy.misc
65
from shutil import copyfile
7-
from xml.etree import ElementTree
86
from torch.utils.data.dataset import Dataset
97
from .pre_proc import get_pre_proc_dict
108
from lib.external.dataset.roidb import combined_roidb
119

1210

1311
def get_dataset_dict():
1412
return {
15-
'voc': VOCDataset,
1613
'coco': COCODataset,
1714
}
1815

@@ -39,91 +36,6 @@ def get_dataset_roots(self):
3936
pass
4037

4138

42-
class VOCDataset(DatasetABC):
43-
def __init__(self, global_args, dataset_args):
44-
super(VOCDataset, self).__init__(global_args, dataset_args)
45-
46-
img_pathes = list()
47-
anno_pathes = list()
48-
for root_dir, set_type in zip(self.roots, self.types):
49-
set_path = os.path.join(root_dir, 'ImageSets', 'Main', '%s.txt' % set_type)
50-
img_path_form = os.path.join(root_dir, 'JPEGImages', '%s.jpg')
51-
anno_path_form = os.path.join(root_dir, 'Annotations', '%s.xml')
52-
53-
with open(set_path) as file:
54-
for img_name in file.readlines():
55-
img_name = img_name.strip('\n')
56-
img_pathes.append(img_path_form % img_name)
57-
anno_pathes.append(anno_path_form % img_name)
58-
59-
self.img_pathes = np.array(img_pathes).astype(np.string_)
60-
self.anno_pathes = np.array(anno_pathes).astype(np.string_)
61-
62-
self.name2number_map = {
63-
'background': 0,
64-
'aeroplane': 1, 'bicycle': 2, 'bird': 3, 'boat': 4,
65-
'bottle': 5, 'bus': 6, 'car': 7, 'cat': 8, 'chair': 9,
66-
'cow': 10, 'diningtable': 11, 'dog': 12, 'horse': 13,
67-
'motorbike': 14, 'person': 15, 'pottedplant': 16,
68-
'sheep': 17, 'sofa': 18, 'train': 19, 'tvmonitor': 20}
69-
self.number2name_map = {
70-
0: 'background',
71-
1: 'aeroplane', 2: 'bicycle', 3: 'bird', 4: 'boat',
72-
5: 'bottle', 6: 'bus', 7: 'car', 8: 'cat', 9: 'chair',
73-
10: 'cow', 11: 'diningtable', 12: 'dog', 13: 'horse',
74-
14: 'motorbike', 15: 'person', 16: 'pottedplant',
75-
17: 'sheep', 18: 'sofa', 19: 'train', 20: 'tvmonitor'}
76-
77-
def __len__(self):
78-
return len(self.img_pathes)
79-
80-
def __getitem__(self, data_idx):
81-
img = scipy.misc.imread(self.img_pathes[data_idx])
82-
anno = ElementTree.parse(self.anno_pathes[data_idx]).getroot()
83-
boxes, labels = self.__parse_anno__(anno)
84-
85-
sample_dict = {'img': img, 'boxes': boxes, 'labels': labels}
86-
try:
87-
sample_dict = self.pre_proc.process(sample_dict)
88-
except Exception:
89-
print(traceback.print_exc())
90-
print('- %s\n' % self.img_pathes[data_idx])
91-
copyfile( self.img_pathes[data_idx].replace('coco2017', 'coco2017-2'), self.img_pathes[data_idx])
92-
sample_dict = self.__getitem__(data_idx)
93-
return sample_dict
94-
95-
def __getitem_tmp__(self, data_idx):
96-
img = scipy.misc.imread(self.img_pathes[data_idx])
97-
anno = ElementTree.parse(self.anno_pathes[data_idx]).getroot()
98-
boxes, labels = self.__parse_anno__(anno)
99-
100-
sample_dict = {'img': img, 'boxes': boxes, 'labels': labels}
101-
sample_dict = self.pre_proc.process(sample_dict)
102-
return sample_dict
103-
104-
def __parse_anno__(self, anno):
105-
boxes = list()
106-
labels = list()
107-
for obj in anno.findall('object'):
108-
bndbox = obj.find('bndbox')
109-
boxes.append([
110-
float(bndbox.find('xmin').text), float(bndbox.find('ymin').text),
111-
float(bndbox.find('xmax').text), float(bndbox.find('ymax').text)])
112-
labels.append(np.array(self.name2number_map[obj.find('name').text]))
113-
boxes = np.array(boxes)
114-
labels = np.array(labels)
115-
return boxes, labels
116-
117-
def get_name2number_map(self):
118-
return self.name2number_map
119-
120-
def get_number2name_map(self):
121-
return self.number2name_map
122-
123-
def get_dataset_roots(self):
124-
return self.roots
125-
126-
12739
class COCODataset(DatasetABC):
12840
def __init__(self, global_args, dataset_args):
12941
super(COCODataset, self).__init__(global_args, dataset_args)

src/lib/external/sync_batchnorm/batchnorm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,8 @@ def convert_model(module):
370370
if isinstance(module, torch.nn.DataParallel):
371371
mod = module.module
372372
mod = convert_model(mod)
373-
mod = DataParallelWithCallback(mod)
373+
mod = DataParallelWithCallback(mod, device_ids=module.device_ids)
374+
# mod = DataParallelWithCallback(mod)
374375
return mod
375376

376377
mod = module

src/lib/external/sync_batchnorm/batchnorm_reimpl.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,4 +71,3 @@ def forward(self, input_):
7171
self.weight.unsqueeze(1) + self.bias.unsqueeze(1))
7272

7373
return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous()
74-

src/lib/network_util.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ def init_modules_xavier(module_list):
1010
isinstance(m, nn.ConvTranspose2d) or \
1111
isinstance(m, nn.BatchNorm2d):
1212
if m.weight is not None:
13-
# print('init')
1413
nn.init.xavier_uniform_(m.weight)
1514
if m.bias is not None:
1615
m.bias.data.zero_()
@@ -66,20 +65,18 @@ def create_box_coord_map(output_size, output_ch, coord_range):
6665
# gauss_ch: 4 --> ((0, 1, 2, 3), ...)
6766
ch_map = np.array(list(range(output_ch)))
6867

69-
# coord_w: 100 --> unit_intv_w: 20 = 100 / (4 + 1)
68+
# coord_w: 10 --> unit_intv_w: 2 = 10 / (4 + 1)
7069
unit_intv_w = coord_range[1] / (output_ch + 1.0)
7170
unit_intv_h = coord_range[0] / (output_ch + 1.0)
7271

73-
# ((0, 1, 2, 3) + 1) * 20 == (20, 40, 60, 80)
72+
# ((0, 1, 2, 3) + 1) * 2 == (2, 4, 6, 8)
7473
w_map = (ch_map + 1) * unit_intv_w
7574
h_map = (ch_map + 1) * unit_intv_h
7675

77-
# ((20, 40, 60, 80) / 100)^2 == (0.04, 0.16, 0.36, 0.64)
78-
# (0.04, 0.16, 0.36, 0.64) * 100 == (4, 16, 36, 64)
79-
# w_map = ((w_map / coord_range[1]) ** 2) * coord_range[1]
80-
# h_map = ((h_map / coord_range[0]) ** 2) * coord_range[0]
81-
w_map = (w_map / coord_range[1]) * coord_range[1]
82-
h_map = (h_map / coord_range[0]) * coord_range[0]
76+
# ((2, 4, 6, 8) / 10)^2 == (0.04, 0.16, 0.36, 0.64)
77+
# (0.04, 0.16, 0.36, 0.64) * 10 == (0.4, 1.6, 3.6, 6.4)
78+
w_map = ((w_map / coord_range[1]) ** 2) * coord_range[1]
79+
h_map = ((h_map / coord_range[0]) ** 2) * coord_range[0]
8380

8481
w_map = w_map.reshape((output_ch, 1, 1))
8582
h_map = h_map.reshape((output_ch, 1, 1))
@@ -93,10 +90,7 @@ def create_box_coord_map(output_size, output_ch, coord_range):
9390

9491

9592
def create_limit_scale(batch_size, output_sizes, coord_range, limit_factor):
96-
# n_lv_mix_comps = [output_size[0] * output_size[1] for output_size in output_sizes]
97-
9893
lv_x_limit_scales, lv_y_limit_scales = list(), list()
99-
# for i, n_lv_mix_comp in enumerate(n_lv_mix_comps):
10094
for output_size in output_sizes:
10195
x_limit_scale = (coord_range[1] / output_size[1]) * limit_factor
10296
y_limit_scale = (coord_range[0] / output_size[0]) * limit_factor

src/lib/post_proc.py

Lines changed: 29 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -15,66 +15,57 @@ def __init__(self, global_args, post_proc_args):
1515
self.max_boxes = post_proc_args['max_boxes']
1616

1717
def __filter_cls_boxes_s__(self, boxes_s, confs_s, pi_s):
18-
cls_boxes_sl = list()
19-
cls_confs_sl = list()
20-
cls_labels_sl = list()
18+
boxes_sl = list()
19+
confs_sl = list()
20+
labels_sl = list()
2121

2222
norm_pi_s = pi_s / torch.max(pi_s)
23+
keep_idxes = torch.nonzero(norm_pi_s > self.pi_thresh).view(-1)
24+
boxes_s = boxes_s[:, keep_idxes]
25+
confs_s = confs_s[:, keep_idxes]
26+
2327
for c in range(self.n_classes - 1):
24-
cls_boxes_sc = boxes_s[c]
25-
cls_confs_sc = confs_s[c]
26-
# cls_pi_sc = norm_pi_s.clone()
28+
boxes_sc = boxes_s[c]
29+
confs_sc = confs_s[c]
2730

28-
if len(cls_boxes_sc) == 0:
31+
if len(boxes_sc) == 0:
2932
continue
3033

31-
# print(cls_boxes_sc.shape)
32-
keep_idxes = torch.nonzero(norm_pi_s > self.pi_thresh).view(-1)
33-
cls_boxes_sc = cls_boxes_sc[keep_idxes]
34-
cls_confs_sc = cls_confs_sc[keep_idxes]
35-
36-
# print(cls_boxes_sc.shape)
37-
keep_idxes = torch.nonzero(cls_confs_sc > self.conf_thresh).view(-1)
38-
cls_boxes_sc = cls_boxes_sc[keep_idxes]
39-
cls_confs_sc = cls_confs_sc[keep_idxes]
34+
keep_idxes = torch.nonzero(confs_sc > self.conf_thresh).view(-1)
35+
boxes_sc = boxes_sc[keep_idxes]
36+
confs_sc = confs_sc[keep_idxes]
4037
if keep_idxes.shape[0] == 0:
4138
continue
42-
# print(cls_boxes_sc.shape)
43-
# print('')
4439

4540
if self.nms_thresh <= 0.0:
46-
cls_boxes_sc, cls_confs_sc = lib_util.sort_boxes_s(cls_boxes_sc, cls_confs_sc)
47-
cls_boxes_sc, cls_confs_sc = cls_boxes_sc[:1], cls_confs_sc[:1]
41+
boxes_sc, confs_sc = lib_util.sort_boxes_s(boxes_sc, confs_sc)
42+
boxes_sc, confs_sc = boxes_sc[:1], confs_sc[:1]
4843
elif self.nms_thresh > 1.0:
4944
pass
5045
else:
51-
keep_idxes = nms(cls_boxes_sc, cls_confs_sc, self.nms_thresh)
46+
keep_idxes = nms(boxes_sc, confs_sc, self.nms_thresh)
5247
keep_idxes = keep_idxes.long().view(-1)
53-
cls_boxes_sc = cls_boxes_sc[keep_idxes]
54-
cls_confs_sc = cls_confs_sc[keep_idxes].unsqueeze(dim=1)
48+
boxes_sc = boxes_sc[keep_idxes]
49+
confs_sc = confs_sc[keep_idxes].unsqueeze(dim=1)
5550

56-
labels_css = torch.zeros(cls_confs_sc.shape).float().cuda()
51+
labels_css = torch.zeros(confs_sc.shape).float().cuda()
5752
labels_css += c
5853

59-
cls_boxes_sl.append(cls_boxes_sc)
60-
cls_confs_sl.append(cls_confs_sc)
61-
cls_labels_sl.append(labels_css)
62-
# exit()
54+
boxes_sl.append(boxes_sc)
55+
confs_sl.append(confs_sc)
56+
labels_sl.append(labels_css)
6357

64-
if len(cls_boxes_sl) > 0:
65-
boxes_s = torch.cat(cls_boxes_sl, dim=0)
66-
confs_s = torch.cat(cls_confs_sl, dim=0)
67-
labels_s = torch.cat(cls_labels_sl, dim=0)
58+
if len(boxes_sl) > 0:
59+
boxes_s = torch.cat(boxes_sl, dim=0)
60+
confs_s = torch.cat(confs_sl, dim=0)
61+
labels_s = torch.cat(labels_sl, dim=0)
6862
else:
6963
boxes_s = torch.zeros((1, 4)).float().cuda()
7064
confs_s = torch.zeros((1, 1)).float().cuda()
7165
labels_s = torch.zeros((1, 1)).float().cuda()
72-
73-
boxes_s, confs_s, labels_s = lib_util.sort_boxes_s(boxes_s, confs_s, labels_s)
7466
return boxes_s, confs_s, labels_s
7567

7668
def forward(self, mu, prob, pi):
77-
# print('mu', torch.min(mu), torch.max(mu))
7869
boxes = mu.transpose(1, 2).clone()
7970
boxes[:, :, [0, 2]] = boxes[:, :, [0, 2]] * (self.input_size[1] / self.coord_range[1])
8071
boxes[:, :, [1, 3]] = boxes[:, :, [1, 3]] * (self.input_size[0] / self.coord_range[0])
@@ -85,7 +76,7 @@ def forward(self, mu, prob, pi):
8576
boxes_l, confs_l, labels_l = list(), list(), list()
8677
for i, (boxes_s, confs_s) in enumerate(zip(boxes, confs)):
8778
boxes_s, confs_s, labels_s = self.__filter_cls_boxes_s__(boxes_s, confs_s, pi[i, 0])
88-
boxes_l.append(boxes_s[:self.max_boxes])
89-
confs_l.append(confs_s[:self.max_boxes])
90-
labels_l.append(labels_s[:self.max_boxes] + 1)
79+
boxes_l.append(boxes_s)
80+
confs_l.append(confs_s)
81+
labels_l.append(labels_s + 1)
9182
return boxes_l, confs_l, labels_l

0 commit comments

Comments
 (0)