Skip to content

Commit db2dba8

Browse files
authored
Update metrics.py
it will be okay for adding the pixel_ap,img_ap,pixel_pro
1 parent b2b72c5 commit db2dba8

File tree

1 file changed

+32
-38
lines changed

1 file changed

+32
-38
lines changed

src/patchcore/metrics.py

Lines changed: 32 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import numpy as np
22
from sklearn import metrics
33
from sklearn.metrics import average_precision_score
4-
import scipy.ndimage as ndimage
5-
from tqdm import tqdm
4+
import scipy.ndimage as ndimage
65

76
def compute_imagewise_retrieval_metrics(
87
anomaly_prediction_weights, anomaly_ground_truth_labels
@@ -19,7 +18,7 @@ def compute_imagewise_retrieval_metrics(
1918
)
2019
return {
2120
"auroc": auroc,
22-
"image_ap": image_ap,
21+
"image_ap": image_ap, # 新增
2322
"fpr": fpr,
2423
"tpr": tpr,
2524
"threshold": thresholds
@@ -36,77 +35,72 @@ def compute_pro_score(anomaly_segmentations, ground_truth_masks, threshold=0.5):
3635

3736
for i in range(num_images):
3837
anomaly_mask = (anomaly_segmentations[i] > threshold).astype(np.float32)
39-
gt_mask = ground_truth_masks[i].astype(np.float32)
38+
gt_mask = ground_truth_masks[i].astype(np.float32)
4039

4140
labeled_mask, num_regions = ndimage.label(gt_mask)
4241
region_pro_scores = []
4342
for region_id in range(1, num_regions + 1):
4443
region_mask = (labeled_mask == region_id).astype(np.float32)
45-
overlap = np.sum(anomaly_mask * region_mask)
44+
overlap = np.sum(anomaly_mask * region_mask)
4645
region_area = np.sum(region_mask)
4746
if region_area == 0:
4847
continue
4948
pro_score = overlap / region_area
5049
region_pro_scores.append(pro_score)
51-
50+
5251
if region_pro_scores:
5352
pro_scores.append(np.mean(region_pro_scores))
5453

5554
return np.mean(pro_scores) if pro_scores else 0.0
5655

57-
def compute_pro(predictions, flat_ground_truth_masks, original_shape):
58-
num_images = original_shape[0]
59-
height = original_shape[1]
60-
width = original_shape[2]
61-
predictions = predictions.reshape((num_images, height, width))
62-
flat_ground_truth_masks = flat_ground_truth_masks.reshape((num_images, height, width))
63-
return compute_pro_score(predictions, flat_ground_truth_masks)
64-
65-
def compute_pixelwise_retrieval_metrics(anomaly_segmentations, ground_truth_masks):
56+
def compute_pixelwise_retrieval_metrics(anomaly_segmentations, ground_truth_masks, anomaly_ratio=0.1):
6657
if isinstance(anomaly_segmentations, list):
6758
anomaly_segmentations = np.stack(anomaly_segmentations)
6859
if isinstance(ground_truth_masks, list):
6960
ground_truth_masks = np.stack(ground_truth_masks)
7061

7162
flat_anomaly_segmentations = anomaly_segmentations.ravel()
72-
flat_ground_truth_masks = ground_truth_masks.ravel()
63+
flat_ground_truth_masks = ground_truth_masks.ravel().astype(np.int32)
7364

7465
fpr, tpr, thresholds = metrics.roc_curve(
75-
flat_ground_truth_masks.astype(int), flat_anomaly_segmentations
66+
flat_ground_truth_masks, flat_anomaly_segmentations
7667
)
7768
auroc = metrics.roc_auc_score(
78-
flat_ground_truth_masks.astype(int), flat_anomaly_segmentations
69+
flat_ground_truth_masks, flat_anomaly_segmentations
7970
)
8071

81-
# 确保 flat_ground_truth_masks 是二元标签
82-
flat_ground_truth_masks_binary = flat_ground_truth_masks.astype(int)
8372
pixel_ap = average_precision_score(
84-
flat_ground_truth_masks_binary, flat_anomaly_segmentations
73+
flat_ground_truth_masks, flat_anomaly_segmentations
8574
)
8675

87-
thresholds = np.linspace(0, 1, 100)
88-
pros = []
89-
original_shape = anomaly_segmentations.shape
90-
91-
for threshold in tqdm(thresholds, desc="Computing PRO for different thresholds"):
92-
predictions = (flat_anomaly_segmentations >= threshold).astype(int)
93-
pro = compute_pro(predictions, flat_ground_truth_masks, original_shape)
94-
pros.append(pro)
76+
# 自适应阈值选择
77+
threshold = np.quantile(flat_anomaly_segmentations, 1 - anomaly_ratio)
78+
predictions = (flat_anomaly_segmentations >= threshold).astype(int)
9579

96-
best_threshold = thresholds[np.argmax(pros)]
97-
best_pro = np.max(pros)
80+
pro_score = compute_pro_score(predictions.reshape(anomaly_segmentations.shape), ground_truth_masks, threshold=0.5)
9881

99-
predictions = (flat_anomaly_segmentations >= best_threshold).astype(int)
100-
fpr_optim = np.mean(predictions > flat_ground_truth_masks)
101-
fnr_optim = np.mean(predictions < flat_ground_truth_masks)
82+
precision, recall, thresholds = metrics.precision_recall_curve(
83+
flat_ground_truth_masks, flat_anomaly_segmentations
84+
)
85+
F1_scores = np.divide(
86+
2 * precision * recall,
87+
precision + recall,
88+
out=np.zeros_like(precision),
89+
where=(precision + recall) != 0,
90+
)
91+
optimal_threshold = thresholds[np.argmax(F1_scores)]
92+
predictions_optimal = (flat_anomaly_segmentations >= optimal_threshold).astype(np.int32)
93+
fpr_optim = np.mean(predictions_optimal > flat_ground_truth_masks)
94+
fnr_optim = np.mean(predictions_optimal < flat_ground_truth_masks)
10295

10396
return {
10497
"auroc": auroc,
105-
"pixel_ap": pixel_ap,
106-
"pro_score": best_pro,
98+
"pixel_ap": pixel_ap,
99+
"pro_score": pro_score,
107100
"fpr": fpr,
108101
"tpr": tpr,
109-
"optimal_threshold": best_threshold,
102+
"optimal_threshold": optimal_threshold,
110103
"optimal_fpr": fpr_optim,
111-
"optimal_fnr": fnr_optim
104+
"optimal_fnr": fnr_optim,
105+
"adaptive_threshold": threshold
112106
}

0 commit comments

Comments
 (0)