Skip to content

Commit 038a8ce

Browse files
authored
Update common.py
Test
1 parent fcaa92f commit 038a8ce

File tree

1 file changed

+108
-0
lines changed

1 file changed

+108
-0
lines changed

src/patchcore/common.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import scipy.ndimage as ndimage
1010
import torch
1111
import torch.nn.functional as F
12+
from sklearn.metrics import average_precision_score
1213

1314

1415
class FaissNN(object):
@@ -391,3 +392,110 @@ def load(self, load_folder: str, prepend: str = "") -> None:
391392
self.detection_features = self._load(
392393
self._detection_file(load_folder, prepend)
393394
)
395+
396+
397+
def compute_pro_score(anomaly_segmentations, ground_truth_masks):
398+
"""
399+
Computes the PRO score for anomaly segmentations and ground truth segmentation masks.
400+
401+
Args:
402+
anomaly_segmentations: [list of np.arrays or np.array] [NxHxW] Contains
403+
generated segmentation masks.
404+
ground_truth_masks: [list of np.arrays or np.array] [NxHxW] Contains
405+
predefined ground truth segmentation masks
406+
"""
407+
if isinstance(anomaly_segmentations, list):
408+
anomaly_segmentations = np.stack(anomaly_segmentations)
409+
if isinstance(ground_truth_masks, list):
410+
ground_truth_masks = np.stack(ground_truth_masks)
411+
412+
num_images = anomaly_segmentations.shape[0]
413+
pro_scores = []
414+
415+
for i in range(num_images):
416+
anomaly_mask = anomaly_segmentations[i]
417+
gt_mask = ground_truth_masks[i]
418+
419+
# 找到所有的真实异常区域
420+
unique_labels = np.unique(gt_mask)
421+
unique_labels = unique_labels[unique_labels > 0]
422+
423+
region_pro_scores = []
424+
for label in unique_labels:
425+
region_mask = (gt_mask == label).astype(np.float32)
426+
overlap = np.sum(anomaly_mask * region_mask)
427+
region_area = np.sum(region_mask)
428+
region_pro = overlap / region_area
429+
region_pro_scores.append(region_pro)
430+
431+
if len(region_pro_scores) > 0:
432+
pro_scores.append(np.mean(region_pro_scores))
433+
434+
if len(pro_scores) > 0:
435+
return np.mean(pro_scores)
436+
else:
437+
return 0.0
438+
439+
440+
def compute_imagewise_retrieval_metrics(
441+
anomaly_prediction_weights, anomaly_ground_truth_labels
442+
):
443+
from sklearn import metrics
444+
fpr, tpr, thresholds = metrics.roc_curve(
445+
anomaly_ground_truth_labels, anomaly_prediction_weights
446+
)
447+
auroc = metrics.roc_auc_score(
448+
anomaly_ground_truth_labels, anomaly_prediction_weights
449+
)
450+
image_ap = average_precision_score(anomaly_ground_truth_labels, anomaly_prediction_weights)
451+
452+
return {"auroc": auroc, "fpr": fpr, "tpr": tpr, "threshold": thresholds, "image_ap": image_ap}
453+
454+
455+
def compute_pixelwise_retrieval_metrics(anomaly_segmentations, ground_truth_masks):
456+
from sklearn import metrics
457+
if isinstance(anomaly_segmentations, list):
458+
anomaly_segmentations = np.stack(anomaly_segmentations)
459+
if isinstance(ground_truth_masks, list):
460+
ground_truth_masks = np.stack(ground_truth_masks)
461+
462+
flat_anomaly_segmentations = anomaly_segmentations.ravel()
463+
flat_ground_truth_masks = ground_truth_masks.ravel()
464+
465+
fpr, tpr, thresholds = metrics.roc_curve(
466+
flat_ground_truth_masks.astype(int), flat_anomaly_segmentations
467+
)
468+
auroc = metrics.roc_auc_score(
469+
flat_ground_truth_masks.astype(int), flat_anomaly_segmentations
470+
)
471+
472+
precision, recall, thresholds = metrics.precision_recall_curve(
473+
flat_ground_truth_masks.astype(int), flat_anomaly_segmentations
474+
)
475+
F1_scores = np.divide(
476+
2 * precision * recall,
477+
precision + recall,
478+
out=np.zeros_like(precision),
479+
where=(precision + recall) != 0,
480+
)
481+
482+
optimal_threshold = thresholds[np.argmax(F1_scores)]
483+
predictions = (flat_anomaly_segmentations >= optimal_threshold).astype(int)
484+
fpr_optim = np.mean(predictions > flat_ground_truth_masks)
485+
fnr_optim = np.mean(predictions < flat_ground_truth_masks)
486+
487+
# 计算PRO指标
488+
pro_score = compute_pro_score(anomaly_segmentations, ground_truth_masks)
489+
# 计算像素级AP指标
490+
pixel_ap = average_precision_score(flat_ground_truth_masks.astype(int), flat_anomaly_segmentations)
491+
492+
return {
493+
"auroc": auroc,
494+
"fpr": fpr,
495+
"tpr": tpr,
496+
"optimal_threshold": optimal_threshold,
497+
"optimal_fpr": fpr_optim,
498+
"optimal_fnr": fnr_optim,
499+
"pro_score": pro_score,
500+
"pixel_ap": pixel_ap
501+
}

0 commit comments

Comments
 (0)