|
| 1 | +import os.path as osp |
| 2 | +import tempfile |
| 3 | + |
| 4 | +import mmcv |
| 5 | +import numpy as np |
| 6 | +from PIL import Image |
| 7 | + |
1 | 8 | from .builder import DATASETS |
2 | 9 | from .custom import CustomDataset |
3 | 10 |
|
@@ -82,3 +89,75 @@ def __init__(self, **kwargs): |
82 | 89 | seg_map_suffix='.png', |
83 | 90 | reduce_zero_label=True, |
84 | 91 | **kwargs) |
| 92 | + |
| 93 | + def results2img(self, results, imgfile_prefix, to_label_id): |
| 94 | + """Write the segmentation results to images. |
| 95 | +
|
| 96 | + Args: |
| 97 | + results (list[list | tuple | ndarray]): Testing results of the |
| 98 | + dataset. |
| 99 | + imgfile_prefix (str): The filename prefix of the png files. |
| 100 | + If the prefix is "somepath/xxx", |
| 101 | + the png files will be named "somepath/xxx.png". |
| 102 | + to_label_id (bool): whether convert output to label_id for |
| 103 | + submission |
| 104 | +
|
| 105 | + Returns: |
| 106 | + list[str: str]: result txt files which contains corresponding |
| 107 | + semantic segmentation images. |
| 108 | + """ |
| 109 | + mmcv.mkdir_or_exist(imgfile_prefix) |
| 110 | + result_files = [] |
| 111 | + prog_bar = mmcv.ProgressBar(len(self)) |
| 112 | + for idx in range(len(self)): |
| 113 | + result = results[idx] |
| 114 | + |
| 115 | + filename = self.img_infos[idx]['filename'] |
| 116 | + basename = osp.splitext(osp.basename(filename))[0] |
| 117 | + |
| 118 | + png_filename = osp.join(imgfile_prefix, f'{basename}.png') |
| 119 | + |
| 120 | + # The index range of official requirement is from 0 to 150. |
| 121 | + # But the index range of output is from 0 to 149. |
| 122 | + # That is because we set reduce_zero_label=True. |
| 123 | + result = result + 1 |
| 124 | + |
| 125 | + output = Image.fromarray(result.astype(np.uint8)) |
| 126 | + output.save(png_filename) |
| 127 | + result_files.append(png_filename) |
| 128 | + |
| 129 | + prog_bar.update() |
| 130 | + |
| 131 | + return result_files |
| 132 | + |
| 133 | + def format_results(self, results, imgfile_prefix=None, to_label_id=True): |
| 134 | + """Format the results into dir (standard format for ade20k evaluation). |
| 135 | +
|
| 136 | + Args: |
| 137 | + results (list): Testing results of the dataset. |
| 138 | + imgfile_prefix (str | None): The prefix of images files. It |
| 139 | + includes the file path and the prefix of filename, e.g., |
| 140 | + "a/b/prefix". If not specified, a temp file will be created. |
| 141 | + Default: None. |
| 142 | + to_label_id (bool): whether convert output to label_id for |
| 143 | + submission. Default: False |
| 144 | +
|
| 145 | + Returns: |
| 146 | + tuple: (result_files, tmp_dir), result_files is a list containing |
| 147 | + the image paths, tmp_dir is the temporal directory created |
| 148 | + for saving json/png files when img_prefix is not specified. |
| 149 | + """ |
| 150 | + |
| 151 | + assert isinstance(results, list), 'results must be a list' |
| 152 | + assert len(results) == len(self), ( |
| 153 | + 'The length of results is not equal to the dataset len: ' |
| 154 | + f'{len(results)} != {len(self)}') |
| 155 | + |
| 156 | + if imgfile_prefix is None: |
| 157 | + tmp_dir = tempfile.TemporaryDirectory() |
| 158 | + imgfile_prefix = tmp_dir.name |
| 159 | + else: |
| 160 | + tmp_dir = None |
| 161 | + |
| 162 | + result_files = self.results2img(results, imgfile_prefix, to_label_id) |
| 163 | + return result_files, tmp_dir |
0 commit comments