Skip to content

Commit 2e13232

Browse files
committed
before debugging
1 parent 64c3a9d commit 2e13232

File tree

5 files changed

+253
-91
lines changed

5 files changed

+253
-91
lines changed

src/lib/tester.py

Lines changed: 109 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
import os
22
import abc
3+
import sys
4+
import tqdm
5+
import pickle
36
import numpy as np
47
import scipy.misc
58
from lib import util
9+
from lib.external.dataset import factory
10+
from . import tester_util
611

712

813
class TesterABC(abc.ABC):
@@ -12,7 +17,7 @@ def __init__(self, global_args, tester_args):
1217
self.result_root = tester_args['result_root']
1318

1419
@ abc.abstractmethod
15-
def test(self, framework, data_loader, result_dir):
20+
def run(self, framework, data_loader, result_dir):
1621
pass
1722

1823

@@ -22,14 +27,12 @@ def __init__(self, global_args, tester_args):
2227
self.n_samples = tester_args['n_samples']
2328
self.max_boxes = tester_args['max_boxes']
2429
self.conf_thresh = tester_args['conf_thresh']
25-
self.coord_size = (global_args['coord_h'], global_args['coord_w'])
2630

27-
def test(self, framework, data_loader, result_dir):
31+
def run(self, framework, data_loader, result_dir):
2832
assert data_loader.batch_size == 1
2933
pre_proc = data_loader.dataset.pre_proc
3034
class_map = data_loader.dataset.get_number2name_map()
31-
save_dir = os.path.join(self.result_root, result_dir)
32-
util.make_dir(save_dir)
35+
util.make_dir(result_dir)
3336

3437
for i, data_dict in enumerate(data_loader):
3538
if i >= self.n_samples:
@@ -44,28 +47,109 @@ def test(self, framework, data_loader, result_dir):
4447
img_s = data_dict['img'][0]
4548
gt_boxes_s = data_dict['boxes'][0]
4649
gt_labels_s = data_dict['labels'][0]
47-
img_size = data_dict['img_size'][0]
48-
input_size = img_s.shape[:2]
4950

5051
sort_idx = 0
51-
gt_img_path = os.path.join(save_dir, '%03d_%d_%s.png' % (i, sort_idx, 'gt'))
52-
gt_img_s = testutil.draw_boxes(img_s, gt_boxes_s, None, gt_labels_s,
53-
class_map, self.conf_thresh, self.max_boxes)
52+
gt_img_path = os.path.join(result_dir, '%03d_%d_%s.png' % (i, sort_idx, 'gt'))
53+
gt_img_s = tester_util.draw_boxes(
54+
img_s, gt_boxes_s, None, gt_labels_s,
55+
class_map, self.conf_thresh, self.max_boxes)
5456
scipy.misc.imsave(gt_img_path, gt_img_s)
55-
del gt_img_s
57+
sort_idx += 1
5658

59+
# draw_boxes
60+
pred_img_path = os.path.join(result_dir, '%03d_%d_%s.png' % (i, sort_idx, 'pred'))
61+
pred_img_s = tester_util.draw_boxes(
62+
img_s, pred_boxes_s, pred_confs_s, pred_labels_s,
63+
class_map, self.conf_thresh, self.max_boxes)
64+
scipy.misc.imsave(pred_img_path, pred_img_s)
65+
sort_idx += 1
66+
67+
68+
class QuantTester(TesterABC):
69+
def __init__(self, global_args, tester_args):
70+
super(QuantTester, self).__init__(global_args, tester_args)
71+
self.n_classes = global_args['n_classes']
72+
self.imdb_name = tester_args['dataset']
73+
self.iou_thresh = tester_args['iou_thresh']
74+
assert self.imdb_name in ('voc_2007_test', 'coco_2017_val', 'coco_2017_test-dev')
75+
76+
def test(self, framework, data_loader, result_dir_name):
77+
assert data_loader.batch_size == 1
78+
num_samples = data_loader.dataset.__len__()
79+
all_boxes = [[[] for _ in range(num_samples)] for _ in range(self.n_classes)]
80+
81+
# import cv2
82+
s = 0
83+
times = list()
84+
data_loader_pbar = tqdm(data_loader)
85+
for idx, data_dict in enumerate(data_loader_pbar):
86+
output_dict, result_dict = framework.infer_forward(data_dict, flip=self.flip_test)
87+
times.append(result_dict['time'])
88+
mean_time = np.mean(times[10:]) if len(times) > 10 else np.mean(times)
89+
data_loader_pbar.set_description('infer time: %.4f sec' % mean_time)
90+
91+
# total predict boxes shape : (batch, # pred box, 4)
92+
# total predict boxes confidence shape : (batch, # pred box, 1)
93+
# total predict boxes label shape : (batch, # pred box, 1)
94+
img_size_s = data_dict['img_size'].float()[0]
95+
input_size = data_dict['img'].shape[2:]
96+
97+
boxes_s = result_dict['boxes_l'][0]
98+
confs_s = result_dict['confs_l'][0]
99+
labels_s = result_dict['labels_l'][0]
100+
101+
boxes_s[:, [0, 2]] *= (img_size_s[1] / input_size[1])
102+
boxes_s[:, [1, 3]] *= (img_size_s[0] / input_size[0])
103+
boxes_s, confs_s, labels_s = \
104+
util.sort_boxes_s(boxes_s, confs_s, labels_s)
105+
106+
boxes_s = util.cvt_torch2numpy(boxes_s)
107+
confs_s = util.cvt_torch2numpy(confs_s)
108+
labels_s = util.cvt_torch2numpy(labels_s)
109+
110+
if len(confs_s.shape) == 1:
111+
confs_s = np.expand_dims(confs_s, axis=1)
112+
for i, (cls_box, cls_conf, cls_label) in \
113+
enumerate(zip(boxes_s, confs_s, labels_s)):
114+
115+
cls_box_with_conf = np.concatenate((cls_box, cls_conf), axis=0)
116+
cls_box_with_conf = np.expand_dims(cls_box_with_conf, axis=0)
117+
all_boxes[int(cls_label)][s].append(cls_box_with_conf)
118+
119+
for c in range(self.n_classes):
120+
all_boxes[c][s] = np.concatenate(all_boxes[c][s], axis=0) \
121+
if 0 < len(all_boxes[c][s]) else np.concatenate([[]], axis=0)
57122
data_dict.clear()
58-
output_dict.clear()
59-
result_dict.clear()
60-
return save_dir
61-
62-
63-
def save_img(save_dir, img_dict, key, size, range_ratio, idx, sort_idx, tag):
64-
if key in img_dict.keys():
65-
img = util.cvt_torch2numpy(img_dict[key])[0] * range_ratio
66-
if len(img.shape) == 3:
67-
img = np.squeeze(np.transpose(img, (1, 2, 0)))
68-
img = scipy.misc.imresize(img, size, interp='bilinear')
69-
img_path = os.path.join(save_dir, '%03d_%d_%s.png' % (idx, sort_idx, tag))
70-
scipy.misc.imsave(img_path, img)
71-
return sort_idx + 1
123+
124+
# create result directories
125+
if not os.path.exists(self.result_dir):
126+
os.mkdir(self.result_dir)
127+
result_dir_path = os.path.join(self.result_dir, result_dir_name)
128+
if not os.path.exists(result_dir_path):
129+
os.mkdir(result_dir_path)
130+
131+
dataset_root = data_loader.dataset.get_dataset_root()
132+
if isinstance(dataset_root, list):
133+
dataset_root = dataset_root[0]
134+
imdb = factory.get_imdb(self.imdb_name, dataset_root)
135+
136+
if 'coco' in self.imdb_name:
137+
sys_stdout = sys.stdout
138+
result_file_path = open(os.path.join(result_dir_path, 'ap_ar.txt'), 'w')
139+
sys.stdout = result_file_path
140+
imdb.evaluate_detections(all_boxes, result_dir_path)
141+
sys.stdout = sys_stdout
142+
result_file_path.close()
143+
144+
else:
145+
det_file_path = os.path.join(result_dir_path, 'detection_results.pkl')
146+
with open(det_file_path, 'wb') as det_file:
147+
pickle.dump(all_boxes, det_file, pickle.HIGHEST_PROTOCOL)
148+
result_msg = imdb.evaluate_detections(all_boxes, result_dir_path)
149+
result_file_path = os.path.join(result_dir_path, 'mean_ap.txt')
150+
with open(result_file_path, 'w') as file:
151+
file.write(result_msg)
152+
os.remove(det_file_path)
153+
154+
all_boxes.clear()
155+

src/lib/tester_util.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import numpy as np
2+
import cv2
3+
4+
colors = ((64, 64, 64), (31, 119, 180), (174, 199, 232), (255, 127, 14),
5+
(255, 187, 120), (44, 160, 44), (152, 223, 138), (214, 39, 40),
6+
(255, 152, 150), (148, 103, 189), (197, 176, 213), (140, 86, 75),
7+
(196, 156, 148), (227, 119, 194), (247, 182, 210), (127, 127, 127),
8+
(199, 199, 199), (188, 189, 34), (219, 219, 141), (23, 190, 207),
9+
(158, 218, 229), (180, 119, 31))
10+
11+
12+
def draw_boxes(img_s, boxes_s, confs_s=None, labels_s=None,
13+
class_map=None, conf_thresh=0.0, max_boxes=100):
14+
15+
box_img_s = img_s.copy()
16+
n_draw_boxes = 0
17+
n_wrong_boxes = 0
18+
n_thresh_boxes = 0
19+
for i, box in enumerate(boxes_s):
20+
try:
21+
l, t = int(round(box[0])), int(round(box[1]))
22+
r, b = int(round(box[2])), int(round(box[3]))
23+
except IndexError:
24+
print(boxes_s)
25+
print(i, box)
26+
print('IndexError')
27+
exit()
28+
29+
if confs_s is not None:
30+
if conf_thresh > confs_s[i]:
31+
n_thresh_boxes += 1
32+
continue
33+
if (r - l <= 0) or (b - t <= 0):
34+
n_wrong_boxes += 1
35+
continue
36+
if n_draw_boxes >= max_boxes:
37+
continue
38+
39+
conf_str = '-' if confs_s is None else '%0.3f' % confs_s[i]
40+
if labels_s is None:
41+
lab_str, color = '-', colors[i % len(colors)]
42+
else:
43+
lab_i = int(labels_s[i])
44+
lab_str = str(lab_i) if class_map is None else class_map[lab_i]
45+
color = colors[lab_i % len(colors)]
46+
47+
box_img_s = cv2.rectangle(box_img_s, (l, t), (r, b), color, 2)
48+
l = int(l - 1 if l > 1 else r - 60)
49+
t = int(t - 8 if t > 8 else b)
50+
r, b = int(l + 60), int(t + 8)
51+
box_img_s = cv2.rectangle(box_img_s, (l, t), (r, b), color, cv2.FILLED)
52+
box_img_s = cv2.putText(box_img_s, '%s %s' % (conf_str, lab_str), (l + 1, t + 7),
53+
cv2.FONT_HERSHEY_SIMPLEX, 0.25, (255, 255, 255),
54+
1, cv2.LINE_AA)
55+
n_draw_boxes += 1
56+
57+
info_text = 'n_draw_b: %d, n_thr_b: %d, n_wrong_b: %d' % \
58+
(n_draw_boxes, n_thresh_boxes, n_wrong_boxes)
59+
if confs_s is not None:
60+
info_text += ', sum_of_conf: %.3f' % (np.sum(confs_s))
61+
else:
62+
info_text += ', sum_of_conf: -'
63+
64+
box_img_s = cv2.rectangle(box_img_s, (0, 0), (350, 11), (0, 0, 0), cv2.FILLED)
65+
box_img_s = cv2.putText(box_img_s, info_text, (5, 10), cv2.FONT_HERSHEY_SIMPLEX,
66+
0.25, (255, 255, 255), 1, cv2.LINE_AA)
67+
return box_img_s

src/run.py

Lines changed: 40 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,104 +1,79 @@
11
import os
22
import time
33
import torch
4-
import traceback
4+
# import traceback
55
from tensorboardX import SummaryWriter
6-
import util as run_util
76
import option
7+
import util
88

99

1010
def main():
1111
torch.multiprocessing.set_sharing_strategy('file_system')
12+
print('[RUN] parse arguments')
1213
args, framework, optimizer, data_loader_dict, tester_dict = option.parse_options()
1314

14-
print('CREATE RESULT DIRECTORIES')
15-
result_dir_dict = run_util.create_result_dir(args.result_dir, ['src', 'log', 'snapshot'])
16-
run_util.copy_file(args.bash_file, args.result_dir)
17-
run_util.copy_dir('./src', result_dir_dict['src'])
15+
print('[RUN] create result directories')
16+
result_dir_dict = util.create_result_dir(args.result_dir, ['src', 'log', 'snapshot', 'test'])
17+
util.copy_file(args.bash_file, args.result_dir)
18+
util.copy_dir('./src', result_dir_dict['src'])
1819

19-
print('CREATE LOGGERS')
20+
print('[RUN] create loggers')
2021
train_log_dir = os.path.join(result_dir_dict['log'], 'train')
2122
train_logger = SummaryWriter(train_log_dir)
2223

23-
print('')
24-
print('START TRAINING')
2524
print('[OPTIMIZER] learning rate:', optimizer.param_groups[0]['lr'])
2625
n_batches = data_loader_dict['train'].__len__()
2726
global_step = args.train_info_dict['init_iter']
2827
max_grad = args.train_info_dict['max_grad']
29-
skip_flag = False
3028

29+
print('')
3130
while True:
3231
start_time = time.time()
3332
for train_data_dict in data_loader_dict['train']:
3433
batch_time = time.time() - start_time
3534

36-
try:
37-
if not skip_flag:
38-
if global_step in args.snapshot_iters:
39-
snapshot_dir = os.path.join(result_dir_dict['snapshot'], '%07d' % global_step)
40-
run_util.save_snapshot(framework.network, optimizer, snapshot_dir)
41-
42-
# if global_step in args.test_iters:
43-
# snapshot_dir = os.path.join(result_dir_dict['snapshot'], '%07d' % global_step)
44-
# test_networks(args, tester_dict, framework, test_set_loader, global_step)
45-
46-
if args.train_info_dict['max_iter'] <= global_step:
47-
break
48-
49-
# update_learning_rate(optimizer, args.lr_decay_schd_dict, global_step)
50-
# update_loss_weights({'loss_func': loss_func}, args.lw_schd_dict, global_step)
51-
52-
try:
53-
start_time = time.time()
54-
_, train_loss_dict = framework.train_forward(train_data_dict)
55-
run.update_networks(network, optimizer, train_loss_dict, max_grad)
56-
train_time = time.time() - start_time
57-
58-
# if valid_loss_dict is not None:
59-
if global_step % args.train_info_dict['print_intv'] == 0:
60-
iter_str = '[%d/%d] ' % (global_step, args.train_info_dict['max_iter'])
61-
info_str = 'n_batches:%d batch_time:%0.3f train_time:%0.3f' % \
62-
(n_batches, batch_time, train_time)
63-
train_str = util.cvt_dict2str(train_loss_dict)
64-
65-
print(iter_str + info_str)
66-
print('[train] ' + train_str + '\n')
35+
if global_step in args.snapshot_iters:
36+
snapshot_dir = os.path.join(result_dir_dict['snapshot'], '%07d' % global_step)
37+
util.save_snapshot(framework.network, optimizer, snapshot_dir)
6738

68-
for key, value in train_loss_dict.items():
69-
train_logger.add_scalar(key, value, global_step)
39+
if global_step in args.test_iters:
40+
test_dir = os.path.join(result_dir_dict['test'], '%07d' % global_step)
41+
util.run_testers(tester_dict, framework, data_loader_dict['test'], test_dir)
7042

71-
train_loss_dict.clear()
72-
del train_loss_dict
43+
if args.train_info_dict['max_iter'] <= global_step:
44+
break
7345

74-
skip_flag = False
75-
global_step += 1
46+
if global_step in args.lr_decay_schd_dict.keys():
47+
util.update_learning_rate(optimizer, args.lr_decay_schd_dict[global_step])
7648

77-
except Exception as e:
78-
print('[WARNING] %s' % (str(e)))
79-
print('[%d/%d] skip this mini-batch\n' % (global_step, args.train_info_dict['max_iter']))
49+
start_time = time.time()
50+
_, train_loss_dict = framework.train_forward(train_data_dict)
51+
util.update_network(framework.network, optimizer, train_loss_dict, max_grad)
52+
train_time = time.time() - start_time
8053

81-
if 'memory' in str(e):
82-
optimizer.zero_grad()
83-
optimizer.step()
54+
if global_step % args.train_info_dict['print_intv'] == 0:
55+
iter_str = '[TRAINING] %d/%d:' % (global_step, args.train_info_dict['max_iter'])
56+
info_str = 'n_batches: %d, batch_time: %0.3f, train_time: %0.3f' % \
57+
(n_batches, batch_time, train_time)
58+
train_str = util.cvt_dict2str(train_loss_dict)
59+
print(iter_str + '\n' + info_str + '\n' + train_str + '\n')
8460

85-
train_loss_dict.clear()
86-
del train_loss_dict
87-
torch.cuda.empty_cache()
61+
for key, value in train_loss_dict.items():
62+
train_logger.add_scalar(key, value, global_step)
8863

89-
skip_flag = True
64+
train_loss_dict.clear()
65+
train_data_dict.clear()
66+
del train_loss_dict, train_data_dict
9067

91-
train_data_dict.clear()
92-
start_time = time.time()
68+
global_step += 1
69+
start_time = time.time()
9370

94-
except Exception as e:
95-
traceback.print_tb(e.__traceback__)
96-
print('[ERROR] %s' % (str(e)))
97-
snapshot_dir = os.path.join(result_dir_dict['snapshot'], '%07d' % global_step)
98-
run_util.save_snapshot(framework.network, optimizer, snapshot_dir)
99-
exit()
71+
if args.train_info_dict['max_iter'] <= global_step:
72+
break
10073
train_logger.close()
10174

10275

10376
if __name__ == '__main__':
10477
main()
78+
79+

0 commit comments

Comments
 (0)