|
14 | 14 | from mmseg.core.evaluation import get_classes, get_palette |
15 | 15 | from mmseg.datasets import (DATASETS, ADE20KDataset, CityscapesDataset, |
16 | 16 | ConcatDataset, CustomDataset, LoveDADataset, |
17 | | - PascalVOCDataset, RepeatDataset, build_dataset) |
| 17 | + MultiImageMixDataset, PascalVOCDataset, |
| 18 | + RepeatDataset, build_dataset) |
18 | 19 |
|
19 | 20 |
|
20 | 21 | def test_classes(): |
@@ -95,6 +96,66 @@ def test_dataset_wrapper(): |
95 | 96 | assert repeat_dataset[27] == 7 |
96 | 97 | assert len(repeat_dataset) == 10 * len(dataset_a) |
97 | 98 |
|
| 99 | + img_scale = (60, 60) |
| 100 | + pipeline = [ |
| 101 | + # dict(type='Mosaic', img_scale=img_scale, pad_val=255), |
| 102 | + # need to merge mosaic |
| 103 | + dict(type='RandomFlip', prob=0.5), |
| 104 | + dict(type='Resize', img_scale=img_scale, keep_ratio=False), |
| 105 | + ] |
| 106 | + |
| 107 | + CustomDataset.load_annotations = MagicMock() |
| 108 | + results = [] |
| 109 | + for _ in range(2): |
| 110 | + height = np.random.randint(10, 30) |
| 111 | + weight = np.random.randint(10, 30) |
| 112 | + img = np.ones((height, weight, 3)) |
| 113 | + gt_semantic_seg = np.random.randint(5, size=(height, weight)) |
| 114 | + results.append(dict(gt_semantic_seg=gt_semantic_seg, img=img)) |
| 115 | + |
| 116 | + classes = ['0', '1', '2', '3', '4'] |
| 117 | + palette = [(0, 0, 0), (1, 1, 1), (2, 2, 2), (3, 3, 3), (4, 4, 4)] |
| 118 | + CustomDataset.__getitem__ = MagicMock(side_effect=lambda idx: results[idx]) |
| 119 | + dataset_a = CustomDataset( |
| 120 | + img_dir=MagicMock(), |
| 121 | + pipeline=[], |
| 122 | + test_mode=True, |
| 123 | + classes=classes, |
| 124 | + palette=palette) |
| 125 | + len_a = 2 |
| 126 | + cat_ids_list_a = [ |
| 127 | + np.random.randint(0, 80, num).tolist() |
| 128 | + for num in np.random.randint(1, 20, len_a) |
| 129 | + ] |
| 130 | + dataset_a.data_infos = MagicMock() |
| 131 | + dataset_a.data_infos.__len__.return_value = len_a |
| 132 | + dataset_a.get_cat_ids = MagicMock( |
| 133 | + side_effect=lambda idx: cat_ids_list_a[idx]) |
| 134 | + |
| 135 | + multi_image_mix_dataset = MultiImageMixDataset(dataset_a, pipeline) |
| 136 | + assert len(multi_image_mix_dataset) == len(dataset_a) |
| 137 | + |
| 138 | + for idx in range(len_a): |
| 139 | + results_ = multi_image_mix_dataset[idx] |
| 140 | + |
| 141 | + # test skip_type_keys |
| 142 | + multi_image_mix_dataset = MultiImageMixDataset( |
| 143 | + dataset_a, pipeline, skip_type_keys=('RandomFlip')) |
| 144 | + for idx in range(len_a): |
| 145 | + results_ = multi_image_mix_dataset[idx] |
| 146 | + assert results_['img'].shape == (img_scale[0], img_scale[1], 3) |
| 147 | + |
| 148 | + skip_type_keys = ('RandomFlip', 'Resize') |
| 149 | + multi_image_mix_dataset.update_skip_type_keys(skip_type_keys) |
| 150 | + for idx in range(len_a): |
| 151 | + results_ = multi_image_mix_dataset[idx] |
| 152 | + assert results_['img'].shape[:2] != img_scale |
| 153 | + |
| 154 | + # test pipeline |
| 155 | + with pytest.raises(TypeError): |
| 156 | + pipeline = [['Resize']] |
| 157 | + multi_image_mix_dataset = MultiImageMixDataset(dataset_a, pipeline) |
| 158 | + |
98 | 159 |
|
99 | 160 | def test_custom_dataset(): |
100 | 161 | img_norm_cfg = dict( |
|
0 commit comments