11
11
import time
12
12
from contextlib import contextmanager
13
13
import io
14
- from collections import defaultdict
14
+ from collections import defaultdict , Counter
15
15
16
16
import cProfile
17
17
import pstats
@@ -431,6 +431,201 @@ def _check_object_detection_input(y, y_pred):
431
431
)
432
432
433
433
434
+ def _intersection_over_union (box_1 , box_2 ):
435
+ """
436
+ Assumes format of [y1, x1, y2, x2] or [x1, y1, x2, y2]
437
+ """
438
+ assert box_1 [2 ] >= box_1 [0 ]
439
+ assert box_2 [2 ] >= box_2 [0 ]
440
+ assert box_1 [3 ] >= box_1 [1 ]
441
+ assert box_2 [3 ] >= box_2 [1 ]
442
+
443
+ if sum ([mean < 1 and mean >= 0 for mean in [box_1 .mean (), box_2 .mean ()]]) == 1 :
444
+ logger .warning (
445
+ "One set of boxes appears to be normalized while the other is not"
446
+ )
447
+
448
+ # Determine coordinates of intersection box
449
+ x_left = max (box_1 [1 ], box_2 [1 ])
450
+ x_right = min (box_1 [3 ], box_2 [3 ])
451
+ y_top = max (box_1 [0 ], box_2 [0 ])
452
+ y_bottom = min (box_1 [2 ], box_2 [2 ])
453
+
454
+ intersect_area = max (0 , x_right - x_left ) * max (0 , y_bottom - y_top )
455
+ if intersect_area == 0 :
456
+ return 0
457
+
458
+ box_1_area = (box_1 [3 ] - box_1 [1 ]) * (box_1 [2 ] - box_1 [0 ])
459
+ box_2_area = (box_2 [3 ] - box_2 [1 ]) * (box_2 [2 ] - box_2 [0 ])
460
+
461
+ iou = intersect_area / (box_1_area + box_2_area - intersect_area )
462
+ assert iou >= 0
463
+ assert iou <= 1
464
+ return iou
465
+
466
+
467
+ def object_detection_AP_per_class (list_of_ys , list_of_y_preds ):
468
+ """
469
+ Mean average precision for object detection. This function returns a dictionary
470
+ mapping each class to the average precision (AP) for the class. The mAP can be computed
471
+ by taking the mean of the AP's across all classes.
472
+
473
+ This metric is computed over all evaluation samples, rather than on a per-sample basis.
474
+ """
475
+
476
+ IOU_THRESHOLD = 0.5
477
+ # Precision will be computed at recall points of 0, 0.1, 0.2, ..., 1
478
+ RECALL_POINTS = np .linspace (0 , 1 , 11 )
479
+
480
+ # Converting all boxes to a list of dicts (a list for predicted boxes, and a
481
+ # separate list for ground truth boxes), where each dict corresponds to a box and
482
+ # has the following keys "img_idx", "label", "box", as well as "score" for predicted boxes
483
+ pred_boxes_list = []
484
+ gt_boxes_list = []
485
+ for img_idx , (y , y_pred ) in enumerate (zip (list_of_ys , list_of_y_preds )):
486
+ for gt_box_idx in range (len (y ["labels" ][0 ].flatten ())):
487
+ label = y ["labels" ][0 ][gt_box_idx ]
488
+ box = y ["boxes" ][0 ][gt_box_idx ]
489
+
490
+ gt_box_dict = {"img_idx" : img_idx , "label" : label , "box" : box }
491
+ gt_boxes_list .append (gt_box_dict )
492
+
493
+ for pred_box_idx in range (len (y_pred ["labels" ].flatten ())):
494
+ label = y_pred ["labels" ][pred_box_idx ]
495
+ box = y_pred ["boxes" ][pred_box_idx ]
496
+ score = y_pred ["scores" ][pred_box_idx ]
497
+
498
+ pred_box_dict = {
499
+ "img_idx" : img_idx ,
500
+ "label" : label ,
501
+ "box" : box ,
502
+ "score" : score ,
503
+ }
504
+ pred_boxes_list .append (pred_box_dict )
505
+
506
+ # Union of (1) the set of all true classes and (2) the set of all predicted classes
507
+ set_of_class_ids = set ([i ["label" ] for i in gt_boxes_list ]) | set (
508
+ [i ["label" ] for i in pred_boxes_list ]
509
+ )
510
+
511
+ # Remove the class ID that corresponds to a physical adversarial patch in APRICOT
512
+ # dataset, if present
513
+ set_of_class_ids .discard (ADV_PATCH_MAGIC_NUMBER_LABEL_ID )
514
+
515
+ # Initialize dict that will store AP for each class
516
+ average_precisions_by_class = {}
517
+
518
+ # Compute AP for each class
519
+ for class_id in set_of_class_ids :
520
+
521
+ # Buiild lists that contain all the predicted/ground-truth boxes with a
522
+ # label of class_id
523
+ class_predicted_boxes = []
524
+ class_gt_boxes = []
525
+ for pred_box in pred_boxes_list :
526
+ if pred_box ["label" ] == class_id :
527
+ class_predicted_boxes .append (pred_box )
528
+ for gt_box in gt_boxes_list :
529
+ if gt_box ["label" ] == class_id :
530
+ class_gt_boxes .append (gt_box )
531
+
532
+ # Determine how many gt boxes (of class_id) there are in each image
533
+ num_gt_boxes_per_img = Counter ([gt ["img_idx" ] for gt in class_gt_boxes ])
534
+
535
+ # Initialize dict where we'll keep track of whether a gt box has been matched to a
536
+ # prediction yet. This is necessary because if multiple predicted boxes of class_id
537
+ # overlap with a single gt box, only one of the predicted boxes can be considered a
538
+ # true positive
539
+ img_idx_to_gtboxismatched_array = {}
540
+ for img_idx , num_gt_boxes in num_gt_boxes_per_img .items ():
541
+ img_idx_to_gtboxismatched_array [img_idx ] = np .zeros (num_gt_boxes )
542
+
543
+ # Sort all predicted boxes (of class_id) by descending confidence
544
+ class_predicted_boxes .sort (key = lambda x : x ["score" ], reverse = True )
545
+
546
+ # Initialize arrays. Once filled in, true_positives[i] indicates (with a 1 or 0)
547
+ # whether the ith predicted box (of class_id) is a true positive. Likewise for
548
+ # false_positives array
549
+ true_positives = np .zeros (len (class_predicted_boxes ))
550
+ false_positives = np .zeros (len (class_predicted_boxes ))
551
+
552
+ # Iterating over all predicted boxes of class_id
553
+ for pred_idx , pred_box in enumerate (class_predicted_boxes ):
554
+ # Only compare gt boxes from the same image as the predicted box
555
+ gt_boxes_from_same_img = [
556
+ gt_box
557
+ for gt_box in class_gt_boxes
558
+ if gt_box ["img_idx" ] == pred_box ["img_idx" ]
559
+ ]
560
+
561
+ # If there are no gt boxes in the predicted box's image that have the predicted class
562
+ if len (gt_boxes_from_same_img ) == 0 :
563
+ false_positives [pred_idx ] = 1
564
+ continue
565
+
566
+ # Iterate over all gt boxes (of class_id) from the same image as the predicted box, d
567
+ # etermining which gt box has the highest iou with the predicted box
568
+ highest_iou = 0
569
+ for gt_idx , gt_box in enumerate (gt_boxes_from_same_img ):
570
+ iou = _intersection_over_union (pred_box ["box" ], gt_box ["box" ])
571
+ if iou >= highest_iou :
572
+ highest_iou = iou
573
+ highest_iou_gt_idx = gt_idx
574
+
575
+ if highest_iou > IOU_THRESHOLD :
576
+ # If the gt box has not yet been covered
577
+ if (
578
+ img_idx_to_gtboxismatched_array [pred_box ["img_idx" ]][
579
+ highest_iou_gt_idx
580
+ ]
581
+ == 0
582
+ ):
583
+ true_positives [pred_idx ] = 1
584
+
585
+ # Record that we've now covered this gt box. Any subsequent
586
+ # pred boxes that overlap with it are considered false positives
587
+ img_idx_to_gtboxismatched_array [pred_box ["img_idx" ]][
588
+ highest_iou_gt_idx
589
+ ] = 1
590
+ else :
591
+ # This gt box was already covered previously (i.e a different predicted
592
+ # box was deemed a true positive after overlapping with this gt box)
593
+ false_positives [pred_idx ] = 1
594
+ else :
595
+ false_positives [pred_idx ] = 1
596
+
597
+ # Cumulative sums of false/true positives across all predictions which were sorted by
598
+ # descending confidence
599
+ tp_cumulative_sum = np .cumsum (true_positives )
600
+ fp_cumulative_sum = np .cumsum (false_positives )
601
+
602
+ # Total number of gt boxes with a label of class_id
603
+ total_gt_boxes = len (class_gt_boxes )
604
+
605
+ recalls = tp_cumulative_sum / (total_gt_boxes + 1e-6 )
606
+ precisions = tp_cumulative_sum / (tp_cumulative_sum + fp_cumulative_sum + 1e-6 )
607
+
608
+ interpolated_precisions = np .zeros (len (RECALL_POINTS ))
609
+ # Interpolate the precision at each recall level by taking the max precision for which
610
+ # the corresponding recall exceeds the recall point
611
+ # See http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.157.5766&rep=rep1&type=pdf
612
+ for i , recall_point in enumerate (RECALL_POINTS ):
613
+ precisions_points = precisions [np .where (recalls >= recall_point )]
614
+ # If there's no cutoff at which the recall > recall_point
615
+ if len (precisions_points ) == 0 :
616
+ interpolated_precisions [i ] = 0
617
+ else :
618
+ interpolated_precisions [i ] = max (precisions_points )
619
+
620
+ # Compute mean precision across the different recall levels
621
+ average_precision = interpolated_precisions .mean ()
622
+ average_precisions_by_class [int (class_id )] = np .around (
623
+ average_precision , decimals = 2
624
+ )
625
+
626
+ return average_precisions_by_class
627
+
628
+
434
629
SUPPORTED_METRICS = {
435
630
"categorical_accuracy" : categorical_accuracy ,
436
631
"top_n_categorical_accuracy" : top_n_categorical_accuracy ,
@@ -449,6 +644,7 @@ def _check_object_detection_input(y, y_pred):
449
644
"mars_mean_l2" : mars_mean_l2 ,
450
645
"mars_mean_patch" : mars_mean_patch ,
451
646
"word_error_rate" : word_error_rate ,
647
+ "object_detection_AP_per_class" : object_detection_AP_per_class ,
452
648
"object_detection_class_precision" : object_detection_class_precision ,
453
649
"object_detection_class_recall" : object_detection_class_recall ,
454
650
}
@@ -503,6 +699,7 @@ def __init__(self, name, function=None):
503
699
raise ValueError (f"function must be callable or None, not { function } " )
504
700
self .name = name
505
701
self ._values = []
702
+ self ._inputs = []
506
703
507
704
def clear (self ):
508
705
self ._values .clear ()
@@ -523,6 +720,9 @@ def values(self):
523
720
def mean (self ):
524
721
return sum (float (x ) for x in self ._values ) / len (self ._values )
525
722
723
+ def append_inputs (self , * args ):
724
+ self ._inputs .append (args )
725
+
526
726
def total_wer (self ):
527
727
# checks if all values are tuples from the WER metric
528
728
if all (isinstance (wer_tuple , tuple ) for wer_tuple in self ._values ):
@@ -535,6 +735,12 @@ def total_wer(self):
535
735
else :
536
736
raise ValueError ("total_wer() only for WER metric" )
537
737
738
+ def AP_per_class (self ):
739
+ # Computed at once across all samples
740
+ y_s = [i [0 ] for i in self ._inputs ]
741
+ y_preds = [i [1 ] for i in self ._inputs ]
742
+ return object_detection_AP_per_class (y_s , y_preds )
743
+
538
744
539
745
class MetricsLogger :
540
746
"""
@@ -550,6 +756,7 @@ def __init__(
550
756
profiler_type = None ,
551
757
computational_resource_dict = None ,
552
758
skip_benign = None ,
759
+ ** kwargs ,
553
760
):
554
761
"""
555
762
task - single metric or list of metrics
@@ -565,7 +772,7 @@ def __init__(
565
772
self .computational_resource_dict = {}
566
773
if not self .means and not self .full :
567
774
logger .warning (
568
- "No metric results will be produced. "
775
+ "No per-sample metric results will be produced. "
569
776
"To change this, set 'means' or 'record_metric_per_sample' to True."
570
777
)
571
778
if not self .tasks and not self .perturbations and not self .adversarial_tasks :
@@ -598,7 +805,10 @@ def clear(self):
598
805
def update_task (self , y , y_pred , adversarial = False ):
599
806
tasks = self .adversarial_tasks if adversarial else self .tasks
600
807
for metric in tasks :
601
- metric .append (y , y_pred )
808
+ if metric .name == "object_detection_AP_per_class" :
809
+ metric .append_inputs (y , y_pred [0 ])
810
+ else :
811
+ metric .append (y , y_pred )
602
812
603
813
def update_perturbation (self , x , x_adv ):
604
814
for metric in self .perturbations :
@@ -624,6 +834,13 @@ def log_task(self, adversarial=False, targeted=False):
624
834
f"Word error rate on { task_type } examples: "
625
835
f"{ metric .total_wer ():.2%} "
626
836
)
837
+ elif metric .name == "object_detection_AP_per_class" :
838
+ average_precision_by_class = metric .AP_per_class ()
839
+ logger .info (
840
+ f"object_detection_mAP on { task_type } examples: "
841
+ f"{ np .fromiter (average_precision_by_class .values (), dtype = float ).mean ():.2%} ."
842
+ f" AP by class ID: { average_precision_by_class } "
843
+ )
627
844
else :
628
845
logger .info (
629
846
f"Average { metric .name } on { task_type } test examples: "
@@ -641,6 +858,14 @@ def results(self):
641
858
(self .perturbations , "perturbation" ),
642
859
]:
643
860
for metric in metrics :
861
+ if metric .name == "object_detection_AP_per_class" :
862
+ average_precision_by_class = metric .AP_per_class ()
863
+ results [f"{ prefix } _object_detection_mAP" ] = np .fromiter (
864
+ average_precision_by_class .values (), dtype = float
865
+ ).mean ()
866
+ results [f"{ prefix } _{ metric .name } " ] = average_precision_by_class
867
+ continue
868
+
644
869
if self .full :
645
870
results [f"{ prefix } _{ metric .name } " ] = metric .values ()
646
871
if self .means :
0 commit comments