55import matplotlib .pyplot as plt
66import numpy as np
77from matplotlib .ticker import MultipleLocator
8- from mmengine import Config , DictAction
9- from mmengine .utils import ProgressBar , load
8+ from mmengine .config import Config , DictAction
9+ from mmengine .registry import init_default_scope
10+ from mmengine .utils import mkdir_or_exist , progressbar
11+ from PIL import Image
1012
11- from mmseg .datasets import build_dataset
13+ from mmseg .registry import DATASETS
14+
15+ init_default_scope ('mmseg' )
1216
1317
1418def parse_args ():
1519 parser = argparse .ArgumentParser (
1620 description = 'Generate confusion matrix from segmentation results' )
1721 parser .add_argument ('config' , help = 'test config file path' )
1822 parser .add_argument (
19- 'prediction_path' , help = 'prediction path where test .pkl result' )
23+ 'prediction_path' , help = 'prediction path where test folder result' )
2024 parser .add_argument (
2125 'save_dir' , help = 'directory where confusion matrix will be saved' )
2226 parser .add_argument (
@@ -50,15 +54,23 @@ def calculate_confusion_matrix(dataset, results):
5054 dataset (Dataset): Test or val dataset.
5155 results (list[ndarray]): A list of segmentation results in each image.
5256 """
53- n = len (dataset .CLASSES )
57+ n = len (dataset .METAINFO [ 'classes' ] )
5458 confusion_matrix = np .zeros (shape = [n , n ])
5559 assert len (dataset ) == len (results )
56- prog_bar = ProgressBar (len (results ))
60+ ignore_index = dataset .ignore_index
61+ reduce_zero_label = dataset .reduce_zero_label
62+ prog_bar = progressbar .ProgressBar (len (results ))
5763 for idx , per_img_res in enumerate (results ):
5864 res_segm = per_img_res
59- gt_segm = dataset .get_gt_seg_map_by_idx (idx )
65+ gt_segm = dataset [idx ]['data_samples' ] \
66+ .gt_sem_seg .data .squeeze ().numpy ().astype (np .uint8 )
67+ gt_segm , res_segm = gt_segm .flatten (), res_segm .flatten ()
68+ if reduce_zero_label :
69+ gt_segm = gt_segm - 1
70+ to_ignore = gt_segm == ignore_index
71+
72+ gt_segm , res_segm = gt_segm [~ to_ignore ], res_segm [~ to_ignore ]
6073 inds = n * gt_segm + res_segm
61- inds = inds .flatten ()
6274 mat = np .bincount (inds , minlength = n ** 2 ).reshape (n , n )
6375 confusion_matrix += mat
6476 prog_bar .update ()
@@ -70,7 +82,7 @@ def plot_confusion_matrix(confusion_matrix,
7082 save_dir = None ,
7183 show = True ,
7284 title = 'Normalized Confusion Matrix' ,
73- color_theme = 'winter ' ):
85+ color_theme = 'OrRd ' ):
7486 """Draw confusion matrix with matplotlib.
7587
7688 Args:
@@ -89,14 +101,15 @@ def plot_confusion_matrix(confusion_matrix,
89101
90102 num_classes = len (labels )
91103 fig , ax = plt .subplots (
92- figsize = (2 * num_classes , 2 * num_classes * 0.8 ), dpi = 180 )
104+ figsize = (2 * num_classes , 2 * num_classes * 0.8 ), dpi = 300 )
93105 cmap = plt .get_cmap (color_theme )
94106 im = ax .imshow (confusion_matrix , cmap = cmap )
95- plt .colorbar (mappable = im , ax = ax )
107+ colorbar = plt .colorbar (mappable = im , ax = ax )
108+ colorbar .ax .tick_params (labelsize = 20 ) # 设置 colorbar 标签的字体大小
96109
97- title_font = {'weight' : 'bold' , 'size' : 12 }
110+ title_font = {'weight' : 'bold' , 'size' : 20 }
98111 ax .set_title (title , fontdict = title_font )
99- label_font = {'size' : 10 }
112+ label_font = {'size' : 40 }
100113 plt .ylabel ('Ground Truth Label' , fontdict = label_font )
101114 plt .xlabel ('Prediction Label' , fontdict = label_font )
102115
@@ -116,8 +129,8 @@ def plot_confusion_matrix(confusion_matrix,
116129 # draw label
117130 ax .set_xticks (np .arange (num_classes ))
118131 ax .set_yticks (np .arange (num_classes ))
119- ax .set_xticklabels (labels )
120- ax .set_yticklabels (labels )
132+ ax .set_xticklabels (labels , fontsize = 20 )
133+ ax .set_yticklabels (labels , fontsize = 20 )
121134
122135 ax .tick_params (
123136 axis = 'x' , bottom = False , top = True , labelbottom = False , labeltop = True )
@@ -135,13 +148,14 @@ def plot_confusion_matrix(confusion_matrix,
135148 ) if not np .isnan (confusion_matrix [i , j ]) else - 1 ),
136149 ha = 'center' ,
137150 va = 'center' ,
138- color = 'w ' ,
139- size = 7 )
151+ color = 'k ' ,
152+ size = 20 )
140153
141154 ax .set_ylim (len (confusion_matrix ) - 0.5 , - 0.5 ) # matplotlib>3.1.1
142155
143156 fig .tight_layout ()
144157 if save_dir is not None :
158+ mkdir_or_exist (save_dir )
145159 plt .savefig (
146160 os .path .join (save_dir , 'confusion_matrix.png' ), format = 'png' )
147161 if show :
@@ -155,25 +169,24 @@ def main():
155169 if args .cfg_options is not None :
156170 cfg .merge_from_dict (args .cfg_options )
157171
158- results = load (args .prediction_path )
172+ results = []
173+ for img in sorted (os .listdir (args .prediction_path )):
174+ img = os .path .join (args .prediction_path , img )
175+ image = Image .open (img )
176+ image = np .copy (image )
177+ results .append (image )
159178
160179 assert isinstance (results , list )
161180 if isinstance (results [0 ], np .ndarray ):
162181 pass
163182 else :
164183 raise TypeError ('invalid type of prediction results' )
165184
166- if isinstance (cfg .data .test , dict ):
167- cfg .data .test .test_mode = True
168- elif isinstance (cfg .data .test , list ):
169- for ds_cfg in cfg .data .test :
170- ds_cfg .test_mode = True
171-
172- dataset = build_dataset (cfg .data .test )
185+ dataset = DATASETS .build (cfg .test_dataloader .dataset )
173186 confusion_matrix = calculate_confusion_matrix (dataset , results )
174187 plot_confusion_matrix (
175188 confusion_matrix ,
176- dataset .CLASSES ,
189+ dataset .METAINFO [ 'classes' ] ,
177190 save_dir = args .save_dir ,
178191 show = args .show ,
179192 title = args .title ,
0 commit comments