Skip to content

Commit f67fca0

Browse files
authored
loader report total_batch instead of total_record (tusen-ai#279)
1 parent 8f9c6ab commit f67fca0

File tree

7 files changed

+28
-37
lines changed

7 files changed

+28
-37
lines changed

core/detection_input.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -590,7 +590,7 @@ class Loader(mx.io.DataIter):
590590

591591
def __init__(self, roidb, transform, data_name, label_name, batch_size=1,
592592
shuffle=False, num_worker=None, num_collector=None,
593-
worker_queue_depth=None, collector_queue_depth=None, kv=None, valid_count=-1):
593+
worker_queue_depth=None, collector_queue_depth=None, valid_count=-1):
594594
"""
595595
This Iter will provide roi data to Fast R-CNN network
596596
:param roidb: must be preprocessed
@@ -600,11 +600,6 @@ def __init__(self, roidb, transform, data_name, label_name, batch_size=1,
600600
"""
601601
super().__init__(batch_size=batch_size)
602602

603-
if kv:
604-
(self.rank, self.num_worker) = (kv.rank, kv.num_workers)
605-
else:
606-
(self.rank, self.num_worker) = (0, 1)
607-
608603
# data processing utilities
609604
if isinstance(transform, dict):
610605
self.transform = transform["sample"]
@@ -653,8 +648,8 @@ def index(self):
653648
return self.total_index[:self.valid_count]
654649

655650
@property
656-
def total_record(self):
657-
return len(self.index) // self.batch_size * self.batch_size
651+
def total_batch(self):
652+
return len(self.index) // self.batch_size
658653

659654
@property
660655
def provide_data(self):
@@ -830,8 +825,7 @@ def __init__(self, roidb, transform, data_name, label_name, batch_size=1,
830825
num_worker=num_worker,
831826
num_collector=num_collector,
832827
worker_queue_depth=worker_queue_depth,
833-
collector_queue_depth=collector_queue_depth,
834-
kv=kv)
828+
collector_queue_depth=collector_queue_depth)
835829
loaders.append(h_loader)
836830
if len(v_roidb_part) >= batch_size:
837831
v_loader = Loader(roidb=v_roidb_part,
@@ -844,18 +838,17 @@ def __init__(self, roidb, transform, data_name, label_name, batch_size=1,
844838
num_worker=num_worker,
845839
num_collector=num_collector,
846840
worker_queue_depth=worker_queue_depth,
847-
collector_queue_depth=collector_queue_depth,
848-
kv=kv)
841+
collector_queue_depth=collector_queue_depth)
849842
loaders.append(v_loader)
850843
assert len(loaders) > 0, "at least one loader should be constructed"
851844
self.__loader = SequentialLoader(loaders)
852845

853846
@property
854-
def total_record(self):
855-
return sum([it.total_record for it in self.__loader.iters])
847+
def total_batch(self):
848+
return sum([it.total_batch for it in self.__loader.iters])
856849

857850
def __len__(self):
858-
return self.total_record
851+
return self.total_batch
859852

860853
def __getattr__(self, attr):
861854
# delegate unknown keys to underlying iterators

detection_infer_speed.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def parse_args():
3838
data_batch = mx.io.DataBatch(data=[data, im_info, im_id, rec_id])
3939

4040
'''
41-
there are some conflicts between `mergebn` and `attach_quantized_node` in graph_optimize.py
41+
there are some conflicts between `mergebn` and `attach_quantized_node` in graph_optimize.py
4242
when mergebn ahead of attach_quantized_node
4343
such as `Symbol.ComposeKeyword`
4444
'''
@@ -52,7 +52,7 @@ def parse_args():
5252
# raise NotImplementedError
5353
_, out_shape, _ = sym.get_internals().infer_shape(**worker_data_shape)
5454
out_shape_dictoinary = dict(zip(sym.get_internals().list_outputs(), out_shape))
55-
sym = attach_quantize_node(sym, out_shape_dictoinary, pQuant.WeightQuantizeParam,
55+
sym = attach_quantize_node(sym, out_shape_dictoinary, pQuant.WeightQuantizeParam,
5656
pQuant.ActQuantizeParam, pQuant.quantized_op)
5757
sym.save(pTest.model.prefix + "_infer_speed.json")
5858

detection_test.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,9 @@ def parse_args():
8484
num_worker=4,
8585
num_collector=2,
8686
worker_queue_depth=2,
87-
collector_queue_depth=2,
88-
kv=None)
87+
collector_queue_depth=2)
8988

90-
print("total number of images: {}".format(loader.total_record))
89+
print("total number of images: {}".format(loader.total_batch))
9190

9291
data_names = [k[0] for k in loader.provide_data]
9392

@@ -163,7 +162,7 @@ def data_enqueue(loader, data_queue):
163162
enqueue_worker.daemon = True
164163
enqueue_worker.start()
165164

166-
for _ in range(loader.total_record):
165+
for _ in range(loader.total_batch):
167166
r = result_queue.get()
168167

169168
rid, id, info, cls, box = r

detection_train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def train_net(config):
186186
eval_metrics = mx.metric.CompositeEvalMetric(metric_list)
187187

188188
# callback
189-
batch_end_callback = [callback.Speedometer(train_data.batch_size, frequent=pGen.log_frequency)]
189+
batch_end_callback = [callback.Speedometer(train_data.batch_size, len(train_data) * (end_epoch - begin_epoch), frequent=pGen.log_frequency)]
190190
batch_end_callback += pModel.batch_end_callbacks or []
191191
epoch_end_callback = callback.do_checkpoint(model_prefix)
192192
sym.save(model_prefix + ".json")
@@ -196,7 +196,7 @@ def train_net(config):
196196
base_lr = pOpt.optimizer.lr * kv.num_workers
197197
lr_factor = pOpt.schedule.lr_factor or 0.1
198198

199-
iter_per_epoch = len(train_data) // input_batch_size
199+
iter_per_epoch = len(train_data)
200200
total_iter = iter_per_epoch * (end_epoch - begin_epoch)
201201
lr_iter = [total_iter + it if it < 0 else it for it in lr_iter]
202202
lr_iter = [it // kv.num_workers for it in lr_iter]

mask_test.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,9 @@ def parse_args():
8686
num_worker=4,
8787
num_collector=2,
8888
worker_queue_depth=2,
89-
collector_queue_depth=2,
90-
kv=None)
89+
collector_queue_depth=2)
9190

92-
print("total number of images: {}".format(loader.total_record))
91+
print("total number of images: {}".format(loader.total_batch))
9392

9493
data_names = [k[0] for k in loader.provide_data]
9594

@@ -158,7 +157,7 @@ def data_enqueue(loader, data_queue):
158157
enqueue_worker.daemon = True
159158
enqueue_worker.start()
160159

161-
for index in range(loader.total_record):
160+
for _ in range(loader.total_batch):
162161
r = result_queue.get()
163162

164163
rid, id, info, post_cls_score, post_box, post_cls, mask, mask_score = r

rpn_test.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -83,22 +83,21 @@ def parse_args():
8383
num_worker=4,
8484
num_collector=2,
8585
worker_queue_depth=2,
86-
collector_queue_depth=2,
87-
kv=None)
86+
collector_queue_depth=2)
8887

89-
print("total number of images: {}".format(loader.total_record))
88+
print("total number of images: {}".format(loader.total_batch))
9089

9190
data_names = [k[0] for k in loader.provide_data]
9291

9392
if index_split == 0:
9493
arg_params, aux_params = load_checkpoint(pTest.model.prefix, pTest.model.epoch)
9594
if pModel.process_weight is not None:
9695
pModel.process_weight(sym, arg_params, aux_params)
97-
98-
# merge batch normalization
96+
97+
# merge batch normalization
9998
from utils.graph_optimize import merge_bn
10099
sym, arg_params, aux_params = merge_bn(sym, arg_params, aux_params)
101-
100+
102101
for i in pKv.gpus:
103102
ctx = mx.gpu(i)
104103
mod = DetModule(sym, data_names=data_names, context=ctx)
@@ -131,7 +130,7 @@ def data_enqueue(loader, data_queue):
131130
enqueue_worker.daemon = True
132131
enqueue_worker.start()
133132

134-
for _ in range(loader.total_record):
133+
for _ in range(loader.total_batch):
135134
r = result_queue.get()
136135

137136
rid, id, info, box, score = r

utils/callback.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44

55

66
class Speedometer(object):
7-
def __init__(self, batch_size, frequent=50):
7+
def __init__(self, batch_size, total_iter, frequent=50):
88
self.batch_size = batch_size
9+
self.total_iter = total_iter
910
self.frequent = frequent
1011
self.init = False
1112
self.tic = 0
@@ -23,8 +24,8 @@ def __call__(self, param):
2324
speed = self.frequent * self.batch_size / (time.time() - self.tic)
2425
if param.eval_metric is not None:
2526
name, value = param.eval_metric.get()
26-
s = "Epoch[%d] Batch [%d]\tIter: %d\tLr: %.5f\tSpeed: %.2f samples/sec\tTrain-" % \
27-
(param.epoch, count, param.iter, param.lr, speed)
27+
s = "Epoch[%d] Batch [%d]\tIter: %d/%d\tLr: %.5f\tSpeed: %.2f samples/sec\tTrain-" % \
28+
(param.epoch, count, param.iter, self.total_iter, param.lr, speed)
2829
for n, v in zip(name, value):
2930
s += "%s=%f,\t" % (n, v)
3031
logging.info(s)

0 commit comments

Comments
 (0)