Skip to content

Commit 593449a

Browse files
authored
Add mAP metric for object detection (#795)
* added mAP metric * update 'bbox' and 'category_id' keys to 'boxes' and 'labels' * mAP is now succesfully returned when calling metrics_logger.results() * change > to >= * clarification to iou description * retrieve patch ID from adversarial_datasets.py rather than hardcoding * remove TODO * clearer variable name * rename mAP fn; save AP's by class to results dict as well as mAP scalar * black formatting
1 parent 1cd0f1e commit 593449a

File tree

3 files changed

+230
-5
lines changed

3 files changed

+230
-5
lines changed

armory/utils/config_schema.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@
234234
"mean_image_circle_patch_diameter",
235235
"max_image_circle_patch_diameter",
236236
"word_error_rate",
237+
"object_detection_AP_per_class",
237238
"object_detection_class_recall",
238239
"object_detection_class_precision"
239240
]

armory/utils/metrics.py

Lines changed: 228 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import time
1212
from contextlib import contextmanager
1313
import io
14-
from collections import defaultdict
14+
from collections import defaultdict, Counter
1515

1616
import cProfile
1717
import pstats
@@ -431,6 +431,201 @@ def _check_object_detection_input(y, y_pred):
431431
)
432432

433433

434+
def _intersection_over_union(box_1, box_2):
435+
"""
436+
Assumes format of [y1, x1, y2, x2] or [x1, y1, x2, y2]
437+
"""
438+
assert box_1[2] >= box_1[0]
439+
assert box_2[2] >= box_2[0]
440+
assert box_1[3] >= box_1[1]
441+
assert box_2[3] >= box_2[1]
442+
443+
if sum([mean < 1 and mean >= 0 for mean in [box_1.mean(), box_2.mean()]]) == 1:
444+
logger.warning(
445+
"One set of boxes appears to be normalized while the other is not"
446+
)
447+
448+
# Determine coordinates of intersection box
449+
x_left = max(box_1[1], box_2[1])
450+
x_right = min(box_1[3], box_2[3])
451+
y_top = max(box_1[0], box_2[0])
452+
y_bottom = min(box_1[2], box_2[2])
453+
454+
intersect_area = max(0, x_right - x_left) * max(0, y_bottom - y_top)
455+
if intersect_area == 0:
456+
return 0
457+
458+
box_1_area = (box_1[3] - box_1[1]) * (box_1[2] - box_1[0])
459+
box_2_area = (box_2[3] - box_2[1]) * (box_2[2] - box_2[0])
460+
461+
iou = intersect_area / (box_1_area + box_2_area - intersect_area)
462+
assert iou >= 0
463+
assert iou <= 1
464+
return iou
465+
466+
467+
def object_detection_AP_per_class(list_of_ys, list_of_y_preds):
468+
"""
469+
Mean average precision for object detection. This function returns a dictionary
470+
mapping each class to the average precision (AP) for the class. The mAP can be computed
471+
by taking the mean of the AP's across all classes.
472+
473+
This metric is computed over all evaluation samples, rather than on a per-sample basis.
474+
"""
475+
476+
IOU_THRESHOLD = 0.5
477+
# Precision will be computed at recall points of 0, 0.1, 0.2, ..., 1
478+
RECALL_POINTS = np.linspace(0, 1, 11)
479+
480+
# Converting all boxes to a list of dicts (a list for predicted boxes, and a
481+
# separate list for ground truth boxes), where each dict corresponds to a box and
482+
# has the following keys "img_idx", "label", "box", as well as "score" for predicted boxes
483+
pred_boxes_list = []
484+
gt_boxes_list = []
485+
for img_idx, (y, y_pred) in enumerate(zip(list_of_ys, list_of_y_preds)):
486+
for gt_box_idx in range(len(y["labels"][0].flatten())):
487+
label = y["labels"][0][gt_box_idx]
488+
box = y["boxes"][0][gt_box_idx]
489+
490+
gt_box_dict = {"img_idx": img_idx, "label": label, "box": box}
491+
gt_boxes_list.append(gt_box_dict)
492+
493+
for pred_box_idx in range(len(y_pred["labels"].flatten())):
494+
label = y_pred["labels"][pred_box_idx]
495+
box = y_pred["boxes"][pred_box_idx]
496+
score = y_pred["scores"][pred_box_idx]
497+
498+
pred_box_dict = {
499+
"img_idx": img_idx,
500+
"label": label,
501+
"box": box,
502+
"score": score,
503+
}
504+
pred_boxes_list.append(pred_box_dict)
505+
506+
# Union of (1) the set of all true classes and (2) the set of all predicted classes
507+
set_of_class_ids = set([i["label"] for i in gt_boxes_list]) | set(
508+
[i["label"] for i in pred_boxes_list]
509+
)
510+
511+
# Remove the class ID that corresponds to a physical adversarial patch in APRICOT
512+
# dataset, if present
513+
set_of_class_ids.discard(ADV_PATCH_MAGIC_NUMBER_LABEL_ID)
514+
515+
# Initialize dict that will store AP for each class
516+
average_precisions_by_class = {}
517+
518+
# Compute AP for each class
519+
for class_id in set_of_class_ids:
520+
521+
# Buiild lists that contain all the predicted/ground-truth boxes with a
522+
# label of class_id
523+
class_predicted_boxes = []
524+
class_gt_boxes = []
525+
for pred_box in pred_boxes_list:
526+
if pred_box["label"] == class_id:
527+
class_predicted_boxes.append(pred_box)
528+
for gt_box in gt_boxes_list:
529+
if gt_box["label"] == class_id:
530+
class_gt_boxes.append(gt_box)
531+
532+
# Determine how many gt boxes (of class_id) there are in each image
533+
num_gt_boxes_per_img = Counter([gt["img_idx"] for gt in class_gt_boxes])
534+
535+
# Initialize dict where we'll keep track of whether a gt box has been matched to a
536+
# prediction yet. This is necessary because if multiple predicted boxes of class_id
537+
# overlap with a single gt box, only one of the predicted boxes can be considered a
538+
# true positive
539+
img_idx_to_gtboxismatched_array = {}
540+
for img_idx, num_gt_boxes in num_gt_boxes_per_img.items():
541+
img_idx_to_gtboxismatched_array[img_idx] = np.zeros(num_gt_boxes)
542+
543+
# Sort all predicted boxes (of class_id) by descending confidence
544+
class_predicted_boxes.sort(key=lambda x: x["score"], reverse=True)
545+
546+
# Initialize arrays. Once filled in, true_positives[i] indicates (with a 1 or 0)
547+
# whether the ith predicted box (of class_id) is a true positive. Likewise for
548+
# false_positives array
549+
true_positives = np.zeros(len(class_predicted_boxes))
550+
false_positives = np.zeros(len(class_predicted_boxes))
551+
552+
# Iterating over all predicted boxes of class_id
553+
for pred_idx, pred_box in enumerate(class_predicted_boxes):
554+
# Only compare gt boxes from the same image as the predicted box
555+
gt_boxes_from_same_img = [
556+
gt_box
557+
for gt_box in class_gt_boxes
558+
if gt_box["img_idx"] == pred_box["img_idx"]
559+
]
560+
561+
# If there are no gt boxes in the predicted box's image that have the predicted class
562+
if len(gt_boxes_from_same_img) == 0:
563+
false_positives[pred_idx] = 1
564+
continue
565+
566+
# Iterate over all gt boxes (of class_id) from the same image as the predicted box, d
567+
# etermining which gt box has the highest iou with the predicted box
568+
highest_iou = 0
569+
for gt_idx, gt_box in enumerate(gt_boxes_from_same_img):
570+
iou = _intersection_over_union(pred_box["box"], gt_box["box"])
571+
if iou >= highest_iou:
572+
highest_iou = iou
573+
highest_iou_gt_idx = gt_idx
574+
575+
if highest_iou > IOU_THRESHOLD:
576+
# If the gt box has not yet been covered
577+
if (
578+
img_idx_to_gtboxismatched_array[pred_box["img_idx"]][
579+
highest_iou_gt_idx
580+
]
581+
== 0
582+
):
583+
true_positives[pred_idx] = 1
584+
585+
# Record that we've now covered this gt box. Any subsequent
586+
# pred boxes that overlap with it are considered false positives
587+
img_idx_to_gtboxismatched_array[pred_box["img_idx"]][
588+
highest_iou_gt_idx
589+
] = 1
590+
else:
591+
# This gt box was already covered previously (i.e a different predicted
592+
# box was deemed a true positive after overlapping with this gt box)
593+
false_positives[pred_idx] = 1
594+
else:
595+
false_positives[pred_idx] = 1
596+
597+
# Cumulative sums of false/true positives across all predictions which were sorted by
598+
# descending confidence
599+
tp_cumulative_sum = np.cumsum(true_positives)
600+
fp_cumulative_sum = np.cumsum(false_positives)
601+
602+
# Total number of gt boxes with a label of class_id
603+
total_gt_boxes = len(class_gt_boxes)
604+
605+
recalls = tp_cumulative_sum / (total_gt_boxes + 1e-6)
606+
precisions = tp_cumulative_sum / (tp_cumulative_sum + fp_cumulative_sum + 1e-6)
607+
608+
interpolated_precisions = np.zeros(len(RECALL_POINTS))
609+
# Interpolate the precision at each recall level by taking the max precision for which
610+
# the corresponding recall exceeds the recall point
611+
# See http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.157.5766&rep=rep1&type=pdf
612+
for i, recall_point in enumerate(RECALL_POINTS):
613+
precisions_points = precisions[np.where(recalls >= recall_point)]
614+
# If there's no cutoff at which the recall > recall_point
615+
if len(precisions_points) == 0:
616+
interpolated_precisions[i] = 0
617+
else:
618+
interpolated_precisions[i] = max(precisions_points)
619+
620+
# Compute mean precision across the different recall levels
621+
average_precision = interpolated_precisions.mean()
622+
average_precisions_by_class[int(class_id)] = np.around(
623+
average_precision, decimals=2
624+
)
625+
626+
return average_precisions_by_class
627+
628+
434629
SUPPORTED_METRICS = {
435630
"categorical_accuracy": categorical_accuracy,
436631
"top_n_categorical_accuracy": top_n_categorical_accuracy,
@@ -449,6 +644,7 @@ def _check_object_detection_input(y, y_pred):
449644
"mars_mean_l2": mars_mean_l2,
450645
"mars_mean_patch": mars_mean_patch,
451646
"word_error_rate": word_error_rate,
647+
"object_detection_AP_per_class": object_detection_AP_per_class,
452648
"object_detection_class_precision": object_detection_class_precision,
453649
"object_detection_class_recall": object_detection_class_recall,
454650
}
@@ -503,6 +699,7 @@ def __init__(self, name, function=None):
503699
raise ValueError(f"function must be callable or None, not {function}")
504700
self.name = name
505701
self._values = []
702+
self._inputs = []
506703

507704
def clear(self):
508705
self._values.clear()
@@ -523,6 +720,9 @@ def values(self):
523720
def mean(self):
524721
return sum(float(x) for x in self._values) / len(self._values)
525722

723+
def append_inputs(self, *args):
724+
self._inputs.append(args)
725+
526726
def total_wer(self):
527727
# checks if all values are tuples from the WER metric
528728
if all(isinstance(wer_tuple, tuple) for wer_tuple in self._values):
@@ -535,6 +735,12 @@ def total_wer(self):
535735
else:
536736
raise ValueError("total_wer() only for WER metric")
537737

738+
def AP_per_class(self):
739+
# Computed at once across all samples
740+
y_s = [i[0] for i in self._inputs]
741+
y_preds = [i[1] for i in self._inputs]
742+
return object_detection_AP_per_class(y_s, y_preds)
743+
538744

539745
class MetricsLogger:
540746
"""
@@ -550,6 +756,7 @@ def __init__(
550756
profiler_type=None,
551757
computational_resource_dict=None,
552758
skip_benign=None,
759+
**kwargs,
553760
):
554761
"""
555762
task - single metric or list of metrics
@@ -565,7 +772,7 @@ def __init__(
565772
self.computational_resource_dict = {}
566773
if not self.means and not self.full:
567774
logger.warning(
568-
"No metric results will be produced. "
775+
"No per-sample metric results will be produced. "
569776
"To change this, set 'means' or 'record_metric_per_sample' to True."
570777
)
571778
if not self.tasks and not self.perturbations and not self.adversarial_tasks:
@@ -598,7 +805,10 @@ def clear(self):
598805
def update_task(self, y, y_pred, adversarial=False):
599806
tasks = self.adversarial_tasks if adversarial else self.tasks
600807
for metric in tasks:
601-
metric.append(y, y_pred)
808+
if metric.name == "object_detection_AP_per_class":
809+
metric.append_inputs(y, y_pred[0])
810+
else:
811+
metric.append(y, y_pred)
602812

603813
def update_perturbation(self, x, x_adv):
604814
for metric in self.perturbations:
@@ -624,6 +834,13 @@ def log_task(self, adversarial=False, targeted=False):
624834
f"Word error rate on {task_type} examples: "
625835
f"{metric.total_wer():.2%}"
626836
)
837+
elif metric.name == "object_detection_AP_per_class":
838+
average_precision_by_class = metric.AP_per_class()
839+
logger.info(
840+
f"object_detection_mAP on {task_type} examples: "
841+
f"{np.fromiter(average_precision_by_class.values(), dtype=float).mean():.2%}."
842+
f" AP by class ID: {average_precision_by_class}"
843+
)
627844
else:
628845
logger.info(
629846
f"Average {metric.name} on {task_type} test examples: "
@@ -641,6 +858,14 @@ def results(self):
641858
(self.perturbations, "perturbation"),
642859
]:
643860
for metric in metrics:
861+
if metric.name == "object_detection_AP_per_class":
862+
average_precision_by_class = metric.AP_per_class()
863+
results[f"{prefix}_object_detection_mAP"] = np.fromiter(
864+
average_precision_by_class.values(), dtype=float
865+
).mean()
866+
results[f"{prefix}_{metric.name}"] = average_precision_by_class
867+
continue
868+
644869
if self.full:
645870
results[f"{prefix}_{metric.name}"] = metric.values()
646871
if self.means:

scenario_configs/xview_frcnn.json

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,7 @@
2929
"perturbation": "image_circle_patch_diameter",
3030
"record_metric_per_sample": false,
3131
"task": [
32-
"object_detection_class_precision",
33-
"object_detection_class_recall"
32+
"object_detection_AP_per_class"
3433
]
3534
},
3635
"model": {

0 commit comments

Comments
 (0)