Skip to content

Commit 6987dfb

Browse files
Update dataset.py
1 parent c12e050 commit 6987dfb

File tree

1 file changed

+34
-0
lines changed

1 file changed

+34
-0
lines changed

data/dataset.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,3 +277,37 @@ def get_samples(cls, roots, ids_dataset=None):
277277

278278
return train_samples, val_samples, test_samples
279279

280+
281+
class COCO_TestOnline(Dataset):
282+
def __init__(self, feat_path, ann_file, max_detections=49):
283+
"""
284+
feat_path: COCO官方划分的训练集和验证集的特征路径
285+
ann_file: 训练集或验证集的标注信息,用于获取image_id,进而检索出对应特征
286+
"""
287+
super(COCO_TestOnline, self).__init__()
288+
289+
# 读取图像信息
290+
with open(ann_file, 'r') as f:
291+
self.images_info = json.load(f)['images']
292+
293+
# 读取特征文件
294+
self.f = h5py.File(feat_path, 'r')
295+
296+
# 记录特征数目
297+
self.max_detections = max_detections
298+
299+
def __len__(self):
300+
return len(self.images_info)
301+
302+
def __getitem__(self, idx):
303+
image_id = self.images_info[idx]['id']
304+
precomp_data = self.f['%d_grids' % image_id][()]
305+
306+
delta = self.max_detections - precomp_data.shape[0]
307+
if delta > 0:
308+
precomp_data = np.concatenate([precomp_data, np.zeros((delta, precomp_data.shape[1]))], axis=0)
309+
elif delta < 0:
310+
precomp_data = precomp_data[:self.max_detections]
311+
312+
return int(image_id), precomp_data
313+

0 commit comments

Comments
 (0)