11import numpy as np
22from sklearn import metrics
33from sklearn .metrics import average_precision_score
4- import scipy .ndimage as ndimage
5- from tqdm import tqdm
4+ import scipy .ndimage as ndimage
65
76def compute_imagewise_retrieval_metrics (
87 anomaly_prediction_weights , anomaly_ground_truth_labels
@@ -19,7 +18,7 @@ def compute_imagewise_retrieval_metrics(
1918 )
2019 return {
2120 "auroc" : auroc ,
22- "image_ap" : image_ap ,
21+ "image_ap" : image_ap , # 新增
2322 "fpr" : fpr ,
2423 "tpr" : tpr ,
2524 "threshold" : thresholds
@@ -36,77 +35,72 @@ def compute_pro_score(anomaly_segmentations, ground_truth_masks, threshold=0.5):
3635
3736 for i in range (num_images ):
3837 anomaly_mask = (anomaly_segmentations [i ] > threshold ).astype (np .float32 )
39- gt_mask = ground_truth_masks [i ].astype (np .float32 )
38+ gt_mask = ground_truth_masks [i ].astype (np .float32 )
4039
4140 labeled_mask , num_regions = ndimage .label (gt_mask )
4241 region_pro_scores = []
4342 for region_id in range (1 , num_regions + 1 ):
4443 region_mask = (labeled_mask == region_id ).astype (np .float32 )
45- overlap = np .sum (anomaly_mask * region_mask )
44+ overlap = np .sum (anomaly_mask * region_mask )
4645 region_area = np .sum (region_mask )
4746 if region_area == 0 :
4847 continue
4948 pro_score = overlap / region_area
5049 region_pro_scores .append (pro_score )
51-
50+
5251 if region_pro_scores :
5352 pro_scores .append (np .mean (region_pro_scores ))
5453
5554 return np .mean (pro_scores ) if pro_scores else 0.0
5655
57- def compute_pro (predictions , flat_ground_truth_masks , original_shape ):
58- num_images = original_shape [0 ]
59- height = original_shape [1 ]
60- width = original_shape [2 ]
61- predictions = predictions .reshape ((num_images , height , width ))
62- flat_ground_truth_masks = flat_ground_truth_masks .reshape ((num_images , height , width ))
63- return compute_pro_score (predictions , flat_ground_truth_masks )
64-
65- def compute_pixelwise_retrieval_metrics (anomaly_segmentations , ground_truth_masks ):
56+ def compute_pixelwise_retrieval_metrics (anomaly_segmentations , ground_truth_masks , anomaly_ratio = 0.1 ):
6657 if isinstance (anomaly_segmentations , list ):
6758 anomaly_segmentations = np .stack (anomaly_segmentations )
6859 if isinstance (ground_truth_masks , list ):
6960 ground_truth_masks = np .stack (ground_truth_masks )
7061
7162 flat_anomaly_segmentations = anomaly_segmentations .ravel ()
72- flat_ground_truth_masks = ground_truth_masks .ravel ()
63+ flat_ground_truth_masks = ground_truth_masks .ravel (). astype ( np . int32 )
7364
7465 fpr , tpr , thresholds = metrics .roc_curve (
75- flat_ground_truth_masks . astype ( int ) , flat_anomaly_segmentations
66+ flat_ground_truth_masks , flat_anomaly_segmentations
7667 )
7768 auroc = metrics .roc_auc_score (
78- flat_ground_truth_masks . astype ( int ) , flat_anomaly_segmentations
69+ flat_ground_truth_masks , flat_anomaly_segmentations
7970 )
8071
81- # 确保 flat_ground_truth_masks 是二元标签
82- flat_ground_truth_masks_binary = flat_ground_truth_masks .astype (int )
8372 pixel_ap = average_precision_score (
84- flat_ground_truth_masks_binary , flat_anomaly_segmentations
73+ flat_ground_truth_masks , flat_anomaly_segmentations
8574 )
8675
87- thresholds = np .linspace (0 , 1 , 100 )
88- pros = []
89- original_shape = anomaly_segmentations .shape
90-
91- for threshold in tqdm (thresholds , desc = "Computing PRO for different thresholds" ):
92- predictions = (flat_anomaly_segmentations >= threshold ).astype (int )
93- pro = compute_pro (predictions , flat_ground_truth_masks , original_shape )
94- pros .append (pro )
76+ # 自适应阈值选择
77+ threshold = np .quantile (flat_anomaly_segmentations , 1 - anomaly_ratio )
78+ predictions = (flat_anomaly_segmentations >= threshold ).astype (int )
9579
96- best_threshold = thresholds [np .argmax (pros )]
97- best_pro = np .max (pros )
80+ pro_score = compute_pro_score (predictions .reshape (anomaly_segmentations .shape ), ground_truth_masks , threshold = 0.5 )
9881
99- predictions = (flat_anomaly_segmentations >= best_threshold ).astype (int )
100- fpr_optim = np .mean (predictions > flat_ground_truth_masks )
101- fnr_optim = np .mean (predictions < flat_ground_truth_masks )
82+ precision , recall , thresholds = metrics .precision_recall_curve (
83+ flat_ground_truth_masks , flat_anomaly_segmentations
84+ )
85+ F1_scores = np .divide (
86+ 2 * precision * recall ,
87+ precision + recall ,
88+ out = np .zeros_like (precision ),
89+ where = (precision + recall ) != 0 ,
90+ )
91+ optimal_threshold = thresholds [np .argmax (F1_scores )]
92+ predictions_optimal = (flat_anomaly_segmentations >= optimal_threshold ).astype (np .int32 )
93+ fpr_optim = np .mean (predictions_optimal > flat_ground_truth_masks )
94+ fnr_optim = np .mean (predictions_optimal < flat_ground_truth_masks )
10295
10396 return {
10497 "auroc" : auroc ,
105- "pixel_ap" : pixel_ap ,
106- "pro_score" : best_pro ,
98+ "pixel_ap" : pixel_ap ,
99+ "pro_score" : pro_score ,
107100 "fpr" : fpr ,
108101 "tpr" : tpr ,
109- "optimal_threshold" : best_threshold ,
102+ "optimal_threshold" : optimal_threshold ,
110103 "optimal_fpr" : fpr_optim ,
111- "optimal_fnr" : fnr_optim
104+ "optimal_fnr" : fnr_optim ,
105+ "adaptive_threshold" : threshold
112106 }
0 commit comments