Skip to content

Commit 5601dff

Browse files
authored
Update metrics.py
1 parent 990d01a commit 5601dff

File tree

1 file changed

+18
-46
lines changed

1 file changed

+18
-46
lines changed

src/patchcore/metrics.py

Lines changed: 18 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,9 @@
1-
"""Anomaly metrics."""
21
import numpy as np
3-
from sklearn import metrics
2+
from sklearn.metrics import average_precision_score
43

5-
6-
def compute_imagewise_retrieval_metrics(
7-
anomaly_prediction_weights, anomaly_ground_truth_labels
8-
):
4+
def compute_image_ap(anomaly_prediction_weights, anomaly_ground_truth_labels):
95
"""
10-
Computes retrieval statistics (AUROC, FPR, TPR).
6+
Computes the average precision (AP) for image-wise anomaly detection.
117
128
Args:
139
anomaly_prediction_weights: [np.array or list] [N] Assignment weights
@@ -16,19 +12,11 @@ def compute_imagewise_retrieval_metrics(
1612
anomaly_ground_truth_labels: [np.array or list] [N] Binary labels - 1
1713
if image is an anomaly, 0 if not.
1814
"""
19-
fpr, tpr, thresholds = metrics.roc_curve(
20-
anomaly_ground_truth_labels, anomaly_prediction_weights
21-
)
22-
auroc = metrics.roc_auc_score(
23-
anomaly_ground_truth_labels, anomaly_prediction_weights
24-
)
25-
return {"auroc": auroc, "fpr": fpr, "tpr": tpr, "threshold": thresholds}
26-
15+
return average_precision_score(anomaly_ground_truth_labels, anomaly_prediction_weights)
2716

28-
def compute_pixelwise_retrieval_metrics(anomaly_segmentations, ground_truth_masks):
17+
def compute_pixel_ap(anomaly_segmentations, ground_truth_masks):
2918
"""
30-
Computes pixel-wise statistics (AUROC, FPR, TPR) for anomaly segmentations
31-
and ground truth segmentation masks.
19+
Computes the average precision (AP) for pixel-wise anomaly detection.
3220
3321
Args:
3422
anomaly_segmentations: [list of np.arrays or np.array] [NxHxW] Contains
@@ -44,33 +32,17 @@ def compute_pixelwise_retrieval_metrics(anomaly_segmentations, ground_truth_mask
4432
flat_anomaly_segmentations = anomaly_segmentations.ravel()
4533
flat_ground_truth_masks = ground_truth_masks.ravel()
4634

47-
fpr, tpr, thresholds = metrics.roc_curve(
48-
flat_ground_truth_masks.astype(int), flat_anomaly_segmentations
49-
)
50-
auroc = metrics.roc_auc_score(
51-
flat_ground_truth_masks.astype(int), flat_anomaly_segmentations
52-
)
53-
54-
precision, recall, thresholds = metrics.precision_recall_curve(
55-
flat_ground_truth_masks.astype(int), flat_anomaly_segmentations
56-
)
57-
F1_scores = np.divide(
58-
2 * precision * recall,
59-
precision + recall,
60-
out=np.zeros_like(precision),
61-
where=(precision + recall) != 0,
62-
)
35+
return average_precision_score(flat_ground_truth_masks.astype(int), flat_anomaly_segmentations)
6336

64-
optimal_threshold = thresholds[np.argmax(F1_scores)]
65-
predictions = (flat_anomaly_segmentations >= optimal_threshold).astype(int)
66-
fpr_optim = np.mean(predictions > flat_ground_truth_masks)
67-
fnr_optim = np.mean(predictions < flat_ground_truth_masks)
37+
def compute_pro(anomaly_segmentations, ground_truth_masks):
38+
"""
39+
Computes the PRO score. This is a placeholder function and needs to be implemented.
6840
69-
return {
70-
"auroc": auroc,
71-
"fpr": fpr,
72-
"tpr": tpr,
73-
"optimal_threshold": optimal_threshold,
74-
"optimal_fpr": fpr_optim,
75-
"optimal_fnr": fnr_optim,
76-
}
41+
Args:
42+
anomaly_segmentations: [list of np.arrays or np.array] [NxHxW] Contains
43+
generated segmentation masks.
44+
ground_truth_masks: [list of np.arrays or np.array] [NxHxW] Contains
45+
predefined ground truth segmentation masks
46+
"""
47+
# Placeholder implementation, replace with actual PRO calculation
48+
return np.random.rand()

0 commit comments

Comments
 (0)