1
+ # -*- coding: utf-8 -*-
2
+ # Author: Lawlite
3
+ # Date: 2018/10/20
4
+ # Associate Blog: http://lawlite.me/2018/10/16/Triplet-Loss原理及其实现/#more
5
+ # License: MIT
6
+ import tensorflow as tf
7
+
8
+ def _pairwise_distance (embeddings , squared = False ):
9
+ '''
10
+ 计算两两embedding的距离
11
+ ------------------------------------------
12
+ Args:
13
+ embedding: 特征向量, 大小(batch_size, vector_size)
14
+ squared: 是否距离的平方,即欧式距离
15
+
16
+ Returns:
17
+ distances: 两两embeddings的距离矩阵,大小 (batch_size, batch_size)
18
+ '''
19
+ # 矩阵相乘,得到(batch_size, batch_size),因为计算欧式距离|a-b|^2 = a^2 -2ab + b^2,
20
+ # 其中 ab 可以用矩阵乘表示
21
+ dot_product = tf .matmul (embeddings , tf .transpose (embeddings ))
22
+ # dot_product对角线部分就是 每个embedding的平方
23
+ square_norm = tf .diag_part (dot_product )
24
+ # |a-b|^2 = a^2 - 2ab + b^2
25
+ # tf.expand_dims(square_norm, axis=1)是(batch_size, 1)大小的矩阵,减去 (batch_size, batch_size)大小的矩阵,相当于每一列操作
26
+ distances = tf .expand_dims (square_norm , axis = 1 ) - 2.0 * dot_product + tf .expand_dims (square_norm , axis = 0 )
27
+ distances = tf .maximum (distances , 0.0 ) # 小于0的距离置为0
28
+ if not squared : # 如果不平方,就开根号,但是注意有0元素,所以0的位置加上 1e*-16
29
+ mask = tf .to_float (tf .equal (dot_product , 0.0 ))
30
+ distances = distances + mask * 1e-16
31
+ distances = tf .sqrt (distances )
32
+ distances = distances * (1.0 - mask ) # 0的部分仍然置为0
33
+ return distances
34
+
35
+
36
+ def _get_triplet_mask (labels ):
37
+ '''
38
+ 得到一个3D的mask [a, p, n], 对应triplet(a, p, n)是valid的位置是True
39
+ ----------------------------------
40
+ Args:
41
+ labels: 对应训练数据的labels, shape = (batch_size,)
42
+
43
+ Returns:
44
+ mask: 3D,shape = (batch_size, batch_size, batch_size)
45
+
46
+ '''
47
+ indices_equal = tf .cast (tf .eye (tf .shape (labels )[0 ]), tf .bool )
48
+ indices_not_equal = tf .logical_not (indices_equal )
49
+ i_not_equal_j = tf .expand_dims (indices_not_equal , 2 )
50
+ i_not_equal_k = tf .expand_dims (indices_not_equal , 1 )
51
+ j_not_equal_k = tf .expand_dims (indices_not_equal , 0 )
52
+
53
+ distinct_indices = tf .logical_and (tf .logical_and (i_not_equal_j , i_not_equal_k ), j_not_equal_k )
54
+
55
+ label_equal = tf .equal (tf .expand_dims (labels , 0 ), tf .expand_dims (labels , 1 ))
56
+ i_equal_j = tf .expand_dims (label_equal , 2 )
57
+ i_equal_k = tf .expand_dims (label_equal , 1 )
58
+ valid_labels = tf .logical_and (i_equal_j , tf .logical_not (i_equal_k ))
59
+ mask = tf .logical_and (distinct_indices , valid_labels )
60
+ return mask
61
+
62
+ def batch_all_triplet_loss (labels , embeddings , margin , squared = False ):
63
+ pairwise_dis = _pairwise_distance (embeddings , squared = squared )
64
+ anchor_positive_dist = tf .exoand_dims (pairwise_dis , 2 )
65
+ assert anchor_positive_dist .shape [2 ] == 1 , "{}" .format (anchor_positive_dist .shape )
66
+ anchor_negative_dist = tf .expand_dims (pairwise_dis , 1 )
67
+ assert anchor_negative_dist .shape [1 ] == 1 , "{}" .format (anchor_negative_dist .shape )
68
+
69
+ triplet_loss = anchor_positive_dist - anchor_negative_dist + margin
70
+ mask = _get_triplet_mask (labels )
71
+ mask = tf .to_float (mask )
72
+ triplet_loss = tf .multiply (mask , triplet_loss )
73
+ triplet_loss = tf .maximum (triplet_loss , 0.0 )
74
+ valid_triplets = tf .to_float (tf .greater (triplet_loss , 1e-16 ))
75
+ num_positive_triplets = tf .reduce_sum (valid_triplets )
76
+ num_valid_triplets = tf .reduce_sum (mask )
77
+ fraction_postive_triplets = num_positive_triplets / (num_valid_triplets + 1e-16 )
78
+ triplet_loss = tf .reduce_sum (triplet_loss ) / (num_positive_triplets + 1e-16 )
79
+ return triplet_loss , fraction_postive_triplets
0 commit comments