Skip to content

Commit 8cd59a0

Browse files
authored
Update metrics.py
1 parent 1cba672 commit 8cd59a0

File tree

1 file changed

+46
-18
lines changed

1 file changed

+46
-18
lines changed

src/patchcore/metrics.py

Lines changed: 46 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
1+
"""Anomaly metrics."""
12
import numpy as np
2-
from sklearn.metrics import average_precision_score
3+
from sklearn import metrics
34

4-
def compute_image_ap(anomaly_prediction_weights, anomaly_ground_truth_labels):
5+
6+
def compute_imagewise_retrieval_metrics(
7+
anomaly_prediction_weights, anomaly_ground_truth_labels
8+
):
59
"""
6-
Computes the average precision (AP) for image-wise anomaly detection.
10+
Computes retrieval statistics (AUROC, FPR, TPR).
711
812
Args:
913
anomaly_prediction_weights: [np.array or list] [N] Assignment weights
@@ -12,11 +16,19 @@ def compute_image_ap(anomaly_prediction_weights, anomaly_ground_truth_labels):
1216
anomaly_ground_truth_labels: [np.array or list] [N] Binary labels - 1
1317
if image is an anomaly, 0 if not.
1418
"""
15-
return average_precision_score(anomaly_ground_truth_labels, anomaly_prediction_weights)
19+
fpr, tpr, thresholds = metrics.roc_curve(
20+
anomaly_ground_truth_labels, anomaly_prediction_weights
21+
)
22+
auroc = metrics.roc_auc_score(
23+
anomaly_ground_truth_labels, anomaly_prediction_weights
24+
)
25+
return {"auroc": auroc, "fpr": fpr, "tpr": tpr, "threshold": thresholds}
26+
1627

17-
def compute_pixel_ap(anomaly_segmentations, ground_truth_masks):
28+
def compute_pixelwise_retrieval_metrics(anomaly_segmentations, ground_truth_masks):
1829
"""
19-
Computes the average precision (AP) for pixel-wise anomaly detection.
30+
Computes pixel-wise statistics (AUROC, FPR, TPR) for anomaly segmentations
31+
and ground truth segmentation masks.
2032
2133
Args:
2234
anomaly_segmentations: [list of np.arrays or np.array] [NxHxW] Contains
@@ -32,17 +44,33 @@ def compute_pixel_ap(anomaly_segmentations, ground_truth_masks):
3244
flat_anomaly_segmentations = anomaly_segmentations.ravel()
3345
flat_ground_truth_masks = ground_truth_masks.ravel()
3446

35-
return average_precision_score(flat_ground_truth_masks.astype(int), flat_anomaly_segmentations)
47+
fpr, tpr, thresholds = metrics.roc_curve(
48+
flat_ground_truth_masks.astype(int), flat_anomaly_segmentations
49+
)
50+
auroc = metrics.roc_auc_score(
51+
flat_ground_truth_masks.astype(int), flat_anomaly_segmentations
52+
)
3653

37-
def compute_pro(anomaly_segmentations, ground_truth_masks):
38-
"""
39-
Computes the PRO score. This is a placeholder function and needs to be implemented.
54+
precision, recall, thresholds = metrics.precision_recall_curve(
55+
flat_ground_truth_masks.astype(int), flat_anomaly_segmentations
56+
)
57+
F1_scores = np.divide(
58+
2 * precision * recall,
59+
precision + recall,
60+
out=np.zeros_like(precision),
61+
where=(precision + recall) != 0,
62+
)
4063

41-
Args:
42-
anomaly_segmentations: [list of np.arrays or np.array] [NxHxW] Contains
43-
generated segmentation masks.
44-
ground_truth_masks: [list of np.arrays or np.array] [NxHxW] Contains
45-
predefined ground truth segmentation masks
46-
"""
47-
# Placeholder implementation, replace with actual PRO calculation
48-
return np.random.rand()
64+
optimal_threshold = thresholds[np.argmax(F1_scores)]
65+
predictions = (flat_anomaly_segmentations >= optimal_threshold).astype(int)
66+
fpr_optim = np.mean(predictions > flat_ground_truth_masks)
67+
fnr_optim = np.mean(predictions < flat_ground_truth_masks)
68+
69+
return {
70+
"auroc": auroc,
71+
"fpr": fpr,
72+
"tpr": tpr,
73+
"optimal_threshold": optimal_threshold,
74+
"optimal_fpr": fpr_optim,
75+
"optimal_fnr": fnr_optim,
76+
}

0 commit comments

Comments
 (0)