@@ -277,3 +277,37 @@ def get_samples(cls, roots, ids_dataset=None):
277277
278278 return train_samples , val_samples , test_samples
279279
280+
281+ class COCO_TestOnline (Dataset ):
282+ def __init__ (self , feat_path , ann_file , max_detections = 49 ):
283+ """
284+ feat_path: COCO官方划分的训练集和验证集的特征路径
285+ ann_file: 训练集或验证集的标注信息,用于获取image_id,进而检索出对应特征
286+ """
287+ super (COCO_TestOnline , self ).__init__ ()
288+
289+ # 读取图像信息
290+ with open (ann_file , 'r' ) as f :
291+ self .images_info = json .load (f )['images' ]
292+
293+ # 读取特征文件
294+ self .f = h5py .File (feat_path , 'r' )
295+
296+ # 记录特征数目
297+ self .max_detections = max_detections
298+
299+ def __len__ (self ):
300+ return len (self .images_info )
301+
302+ def __getitem__ (self , idx ):
303+ image_id = self .images_info [idx ]['id' ]
304+ precomp_data = self .f ['%d_grids' % image_id ][()]
305+
306+ delta = self .max_detections - precomp_data .shape [0 ]
307+ if delta > 0 :
308+ precomp_data = np .concatenate ([precomp_data , np .zeros ((delta , precomp_data .shape [1 ]))], axis = 0 )
309+ elif delta < 0 :
310+ precomp_data = precomp_data [:self .max_detections ]
311+
312+ return int (image_id ), precomp_data
313+
0 commit comments