Skip to content

Commit 01cb68e

Browse files
committed
Added mAP_lim.py which is mAP but limited to a max. k of 1023
1 parent 6a24965 commit 01cb68e

File tree

3 files changed

+38
-2
lines changed

3 files changed

+38
-2
lines changed

metrics/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from metrics import e_recall, dists, rho_spectrum
2-
from metrics import nmi, f1, mAP, mAP_c, mAP_1000
2+
from metrics import nmi, f1, mAP, mAP_c, mAP_1000, mAP_lim
33
import numpy as np
44
import faiss
55
import torch
@@ -17,6 +17,8 @@ def select(metricname, opt):
1717
return mAP.Metric()
1818
elif metricname=='mAP_c':
1919
return mAP_c.Metric()
20+
elif metricname=='mAP_lim':
21+
return mAP_lim.Metric()
2022
elif metricname=='mAP_1000':
2123
return mAP_1000.Metric()
2224
elif metricname=='f1':

metrics/mAP_lim.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import torch
2+
import numpy as np
3+
import faiss
4+
5+
6+
7+
class Metric():
8+
def __init__(self, **kwargs):
9+
self.requires = ['features', 'target_labels']
10+
self.name = 'mAP'
11+
12+
def __call__(self, target_labels, features):
13+
labels, freqs = np.unique(target_labels, return_counts=True)
14+
## Account for faiss-limit at k=1023
15+
R = min(1023,len(features))
16+
17+
faiss_search_index = faiss.IndexFlatL2(features.shape[-1])
18+
faiss_search_index.add(features)
19+
nearest_neighbours = faiss_search_index.search(features, int(R+1))[1][:,1:]
20+
21+
target_labels = target_labels.reshape(-1)
22+
nn_labels = target_labels[nearest_neighbours]
23+
24+
avg_r_precisions = []
25+
for label, freq in zip(labels, freqs):
26+
rows_with_label = np.where(target_labels==label)[0]
27+
for row in rows_with_label:
28+
n_recalled_samples = np.arange(1,R+1)
29+
target_label_occ_in_row = nn_labels[row,:]==label
30+
cumsum_target_label_freq_row = np.cumsum(target_label_occ_in_row)
31+
avg_r_pr_row = np.sum(cumsum_target_label_freq_row*target_label_occ_in_row/n_recalled_samples)/freq
32+
avg_r_precisions.append(avg_r_pr_row)
33+
34+
return np.mean(avg_r_precisions)

parameters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def basic_training_parameters(parser):
3838
##### Evaluation Parameters
3939
parser.add_argument('--no_train_metrics', action='store_true', help='Flag. If set, evaluation metrics are not computed for the training data. Saves a forward pass over the full training dataset.')
4040
parser.add_argument('--evaluate_on_gpu', action='store_true', help='Flag. If set, all metrics, when possible, are computed on the GPU (requires Faiss-GPU).')
41-
parser.add_argument('--evaluation_metrics', nargs='+', default=['e_recall@1', 'e_recall@2', 'e_recall@4', 'nmi', 'f1', 'mAP_1000', 'mAP_c', 'dists@intra', 'dists@inter', 'dists@intra_over_inter', 'rho_spectrum@0', 'rho_spectrum@-1', 'rho_spectrum@1', 'rho_spectrum@2', 'rho_spectrum@10'], type=str, help='Metrics to evaluate performance by.')
41+
parser.add_argument('--evaluation_metrics', nargs='+', default=['e_recall@1', 'e_recall@2', 'e_recall@4', 'nmi', 'f1', 'mAP_1000', 'mAP_lim', 'mAP_c', 'dists@intra', 'dists@inter', 'dists@intra_over_inter', 'rho_spectrum@0', 'rho_spectrum@-1', 'rho_spectrum@1', 'rho_spectrum@2', 'rho_spectrum@10'], type=str, help='Metrics to evaluate performance by.')
4242
parser.add_argument('--storage_metrics', nargs='+', default=['e_recall@1'], type=str, help='Improvement in these metrics on a dataset trigger checkpointing.')
4343
parser.add_argument('--evaltypes', nargs='+', default=['discriminative'], type=str, help='The network may produce multiple embeddings (ModuleDict, relevant for e.g. DiVA). If the key is listed here, the entry will be evaluated on the evaluation metrics.\
4444
Note: One may use Combined_embed1_embed2_..._embedn-w1-w1-...-wn to compute evaluation metrics on weighted (normalized) combinations.')

0 commit comments

Comments
 (0)