Skip to content

Commit 6620cc5

Browse files
authored
Update metrics.py
1 parent a2fe223 commit 6620cc5

File tree

1 file changed

+57
-29
lines changed

1 file changed

+57
-29
lines changed

src/patchcore/metrics.py

Lines changed: 57 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,76 +1,104 @@
11
"""Anomaly metrics."""
22
import numpy as np
33
from sklearn import metrics
4-
4+
from sklearn.metrics import average_precision_score
55

66
def compute_imagewise_retrieval_metrics(
77
anomaly_prediction_weights, anomaly_ground_truth_labels
88
):
9-
"""
10-
Computes retrieval statistics (AUROC, FPR, TPR).
11-
12-
Args:
13-
anomaly_prediction_weights: [np.array or list] [N] Assignment weights
14-
per image. Higher indicates higher
15-
probability of being an anomaly.
16-
anomaly_ground_truth_labels: [np.array or list] [N] Binary labels - 1
17-
if image is an anomaly, 0 if not.
18-
"""
199
fpr, tpr, thresholds = metrics.roc_curve(
2010
anomaly_ground_truth_labels, anomaly_prediction_weights
2111
)
2212
auroc = metrics.roc_auc_score(
2313
anomaly_ground_truth_labels, anomaly_prediction_weights
2414
)
25-
return {"auroc": auroc, "fpr": fpr, "tpr": tpr, "threshold": thresholds}
26-
15+
# 新增图像 AP 计算
16+
image_ap = average_precision_score(
17+
anomaly_ground_truth_labels, anomaly_prediction_weights
18+
)
19+
return {
20+
"auroc": auroc,
21+
"image_ap": image_ap, # 新增
22+
"fpr": fpr,
23+
"tpr": tpr,
24+
"threshold": thresholds
25+
}
2726

28-
def compute_pixelwise_retrieval_metrics(anomaly_segmentations, ground_truth_masks):
27+
def compute_pro_score(anomaly_segmentations, ground_truth_masks):
2928
"""
30-
Computes pixel-wise statistics (AUROC, FPR, TPR) for anomaly segmentations
31-
and ground truth segmentation masks.
32-
33-
Args:
34-
anomaly_segmentations: [list of np.arrays or np.array] [NxHxW] Contains
35-
generated segmentation masks.
36-
ground_truth_masks: [list of np.arrays or np.array] [NxHxW] Contains
37-
predefined ground truth segmentation masks
29+
Compute PRO score (Per-Region Overlap) for anomaly segmentation.
3830
"""
3931
if isinstance(anomaly_segmentations, list):
4032
anomaly_segmentations = np.stack(anomaly_segmentations)
4133
if isinstance(ground_truth_masks, list):
4234
ground_truth_masks = np.stack(ground_truth_masks)
4335

36+
pro_scores = []
37+
for seg, mask in zip(anomaly_segmentations, ground_truth_masks):
38+
mask = mask.astype(np.bool_)
39+
seg = (seg > 0.5).astype(np.bool_) # 假设分割结果已归一化到 [0,1]
40+
41+
# 计算每个区域的 PRO
42+
labeled_mask, num_regions = ndimage.label(mask)
43+
for region_id in range(1, num_regions + 1):
44+
region_mask = (labeled_mask == region_id)
45+
overlap = np.logical_and(region_mask, seg)
46+
region_area = np.sum(region_mask)
47+
if region_area == 0:
48+
continue
49+
pro_score = np.sum(overlap) / region_area
50+
pro_scores.append(pro_score)
51+
52+
return np.mean(pro_scores) if pro_scores else 0.0
53+
54+
55+
def compute_pixelwise_retrieval_metrics(anomaly_segmentations, ground_truth_masks):
56+
if isinstance(anomaly_segmentations, list):
57+
anomaly_segmentations = np.stack(anomaly_segmentations)
58+
if isinstance(ground_truth_masks, list):
59+
ground_truth_masks = np.stack(ground_truth_masks)
60+
4461
flat_anomaly_segmentations = anomaly_segmentations.ravel()
45-
flat_ground_truth_masks = ground_truth_masks.ravel()
62+
flat_ground_truth_masks = ground_truth_masks.ravel().astype(np.int32)
4663

64+
# 计算像素级 AUROC
4765
fpr, tpr, thresholds = metrics.roc_curve(
48-
flat_ground_truth_masks.astype(int), flat_anomaly_segmentations
66+
flat_ground_truth_masks, flat_anomaly_segmentations
4967
)
5068
auroc = metrics.roc_auc_score(
51-
flat_ground_truth_masks.astype(int), flat_anomaly_segmentations
69+
flat_ground_truth_masks, flat_anomaly_segmentations
70+
)
71+
72+
# 计算像素级 AP
73+
pixel_ap = average_precision_score(
74+
flat_ground_truth_masks, flat_anomaly_segmentations
5275
)
5376

77+
# 计算 PRO 指标
78+
pro_score = compute_pro_score(anomaly_segmentations, ground_truth_masks)
79+
80+
# 其他原有逻辑(如最优阈值)保持不变
5481
precision, recall, thresholds = metrics.precision_recall_curve(
55-
flat_ground_truth_masks.astype(int), flat_anomaly_segmentations
82+
flat_ground_truth_masks, flat_anomaly_segmentations
5683
)
5784
F1_scores = np.divide(
5885
2 * precision * recall,
5986
precision + recall,
6087
out=np.zeros_like(precision),
6188
where=(precision + recall) != 0,
6289
)
63-
6490
optimal_threshold = thresholds[np.argmax(F1_scores)]
65-
predictions = (flat_anomaly_segmentations >= optimal_threshold).astype(int)
91+
predictions = (flat_anomaly_segmentations >= optimal_threshold).astype(np.int32)
6692
fpr_optim = np.mean(predictions > flat_ground_truth_masks)
6793
fnr_optim = np.mean(predictions < flat_ground_truth_masks)
6894

6995
return {
7096
"auroc": auroc,
97+
"pixel_ap": pixel_ap, # 新增
98+
"pro_score": pro_score, # 新增
7199
"fpr": fpr,
72100
"tpr": tpr,
73101
"optimal_threshold": optimal_threshold,
74102
"optimal_fpr": fpr_optim,
75-
"optimal_fnr": fnr_optim,
103+
"optimal_fnr": fnr_optim
76104
}

0 commit comments

Comments
 (0)