1- """Anomaly metrics."""
21import numpy as np
3- from sklearn import metrics
2+ from sklearn . metrics import average_precision_score
43
5-
6- def compute_imagewise_retrieval_metrics (
7- anomaly_prediction_weights , anomaly_ground_truth_labels
8- ):
4+ def compute_image_ap (anomaly_prediction_weights , anomaly_ground_truth_labels ):
95 """
10- Computes retrieval statistics (AUROC, FPR, TPR) .
6+ Computes the average precision (AP) for image-wise anomaly detection .
117
128 Args:
139 anomaly_prediction_weights: [np.array or list] [N] Assignment weights
@@ -16,19 +12,11 @@ def compute_imagewise_retrieval_metrics(
1612 anomaly_ground_truth_labels: [np.array or list] [N] Binary labels - 1
1713 if image is an anomaly, 0 if not.
1814 """
19- fpr , tpr , thresholds = metrics .roc_curve (
20- anomaly_ground_truth_labels , anomaly_prediction_weights
21- )
22- auroc = metrics .roc_auc_score (
23- anomaly_ground_truth_labels , anomaly_prediction_weights
24- )
25- return {"auroc" : auroc , "fpr" : fpr , "tpr" : tpr , "threshold" : thresholds }
26-
15+ return average_precision_score (anomaly_ground_truth_labels , anomaly_prediction_weights )
2716
28- def compute_pixelwise_retrieval_metrics (anomaly_segmentations , ground_truth_masks ):
17+ def compute_pixel_ap (anomaly_segmentations , ground_truth_masks ):
2918 """
30- Computes pixel-wise statistics (AUROC, FPR, TPR) for anomaly segmentations
31- and ground truth segmentation masks.
19+ Computes the average precision (AP) for pixel-wise anomaly detection.
3220
3321 Args:
3422 anomaly_segmentations: [list of np.arrays or np.array] [NxHxW] Contains
@@ -44,33 +32,17 @@ def compute_pixelwise_retrieval_metrics(anomaly_segmentations, ground_truth_mask
4432 flat_anomaly_segmentations = anomaly_segmentations .ravel ()
4533 flat_ground_truth_masks = ground_truth_masks .ravel ()
4634
47- fpr , tpr , thresholds = metrics .roc_curve (
48- flat_ground_truth_masks .astype (int ), flat_anomaly_segmentations
49- )
50- auroc = metrics .roc_auc_score (
51- flat_ground_truth_masks .astype (int ), flat_anomaly_segmentations
52- )
53-
54- precision , recall , thresholds = metrics .precision_recall_curve (
55- flat_ground_truth_masks .astype (int ), flat_anomaly_segmentations
56- )
57- F1_scores = np .divide (
58- 2 * precision * recall ,
59- precision + recall ,
60- out = np .zeros_like (precision ),
61- where = (precision + recall ) != 0 ,
62- )
35+ return average_precision_score (flat_ground_truth_masks .astype (int ), flat_anomaly_segmentations )
6336
64- optimal_threshold = thresholds [np .argmax (F1_scores )]
65- predictions = (flat_anomaly_segmentations >= optimal_threshold ).astype (int )
66- fpr_optim = np .mean (predictions > flat_ground_truth_masks )
67- fnr_optim = np .mean (predictions < flat_ground_truth_masks )
37+ def compute_pro (anomaly_segmentations , ground_truth_masks ):
38+ """
39+ Computes the PRO score. This is a placeholder function and needs to be implemented.
6840
69- return {
70- "auroc" : auroc ,
71- "fpr" : fpr ,
72- "tpr" : tpr ,
73- "optimal_threshold" : optimal_threshold ,
74- "optimal_fpr" : fpr_optim ,
75- "optimal_fnr" : fnr_optim ,
76- }
41+ Args:
42+ anomaly_segmentations: [list of np.arrays or np.array] [NxHxW] Contains
43+ generated segmentation masks.
44+ ground_truth_masks: [list of np.arrays or np.array] [NxHxW] Contains
45+ predefined ground truth segmentation masks
46+ """
47+ # Placeholder implementation, replace with actual PRO calculation
48+ return np . random . rand ()
0 commit comments