Skip to content

Commit 0de7fc6

Browse files
committed
add codes for triplet loss
1 parent 9d8282d commit 0de7fc6

File tree

5 files changed

+164
-0
lines changed

5 files changed

+164
-0
lines changed

.DS_Store

2 KB
Binary file not shown.

code/.DS_Store

4 KB
Binary file not shown.

code/triplet-loss/.DS_Store

6 KB
Binary file not shown.

code/triplet-loss/triplet_loss.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# -*- coding: utf-8 -*-
2+
# Author: Lawlite
3+
# Date: 2018/10/20
4+
# Associate Blog: http://lawlite.me/2018/10/16/Triplet-Loss原理及其实现/#more
5+
# License: MIT
6+
import tensorflow as tf
7+
8+
def _pairwise_distance(embeddings, squared=False):
9+
'''
10+
计算两两embedding的距离
11+
------------------------------------------
12+
Args:
13+
embedding: 特征向量, 大小(batch_size, vector_size)
14+
squared: 是否距离的平方,即欧式距离
15+
16+
Returns:
17+
distances: 两两embeddings的距离矩阵,大小 (batch_size, batch_size)
18+
'''
19+
# 矩阵相乘,得到(batch_size, batch_size),因为计算欧式距离|a-b|^2 = a^2 -2ab + b^2,
20+
# 其中 ab 可以用矩阵乘表示
21+
dot_product = tf.matmul(embeddings, tf.transpose(embeddings))
22+
# dot_product对角线部分就是 每个embedding的平方
23+
square_norm = tf.diag_part(dot_product)
24+
# |a-b|^2 = a^2 - 2ab + b^2
25+
# tf.expand_dims(square_norm, axis=1)是(batch_size, 1)大小的矩阵,减去 (batch_size, batch_size)大小的矩阵,相当于每一列操作
26+
distances = tf.expand_dims(square_norm, axis=1) - 2.0 * dot_product + tf.expand_dims(square_norm, axis=0)
27+
distances = tf.maximum(distances, 0.0) # 小于0的距离置为0
28+
if not squared: # 如果不平方,就开根号,但是注意有0元素,所以0的位置加上 1e*-16
29+
mask = tf.to_float(tf.equal(dot_product, 0.0))
30+
distances = distances + mask * 1e-16
31+
distances = tf.sqrt(distances)
32+
distances = distances * (1.0 - mask) # 0的部分仍然置为0
33+
return distances
34+
35+
36+
def _get_triplet_mask(labels):
37+
'''
38+
得到一个3D的mask [a, p, n], 对应triplet(a, p, n)是valid的位置是True
39+
----------------------------------
40+
Args:
41+
labels: 对应训练数据的labels, shape = (batch_size,)
42+
43+
Returns:
44+
mask: 3D,shape = (batch_size, batch_size, batch_size)
45+
46+
'''
47+
indices_equal = tf.cast(tf.eye(tf.shape(labels)[0]), tf.bool)
48+
indices_not_equal = tf.logical_not(indices_equal)
49+
i_not_equal_j = tf.expand_dims(indices_not_equal, 2)
50+
i_not_equal_k = tf.expand_dims(indices_not_equal, 1)
51+
j_not_equal_k = tf.expand_dims(indices_not_equal, 0)
52+
53+
distinct_indices = tf.logical_and(tf.logical_and(i_not_equal_j, i_not_equal_k), j_not_equal_k)
54+
55+
label_equal = tf.equal(tf.expand_dims(labels, 0), tf.expand_dims(labels, 1))
56+
i_equal_j = tf.expand_dims(label_equal, 2)
57+
i_equal_k = tf.expand_dims(label_equal, 1)
58+
valid_labels = tf.logical_and(i_equal_j, tf.logical_not(i_equal_k))
59+
mask = tf.logical_and(distinct_indices, valid_labels)
60+
return mask
61+
62+
def batch_all_triplet_loss(labels, embeddings, margin, squared=False):
63+
pairwise_dis = _pairwise_distance(embeddings, squared=squared)
64+
anchor_positive_dist = tf.exoand_dims(pairwise_dis, 2)
65+
assert anchor_positive_dist.shape[2] == 1, "{}".format(anchor_positive_dist.shape)
66+
anchor_negative_dist = tf.expand_dims(pairwise_dis, 1)
67+
assert anchor_negative_dist.shape[1] == 1, "{}".format(anchor_negative_dist.shape)
68+
69+
triplet_loss = anchor_positive_dist - anchor_negative_dist + margin
70+
mask = _get_triplet_mask(labels)
71+
mask = tf.to_float(mask)
72+
triplet_loss = tf.multiply(mask, triplet_loss)
73+
triplet_loss = tf.maximum(triplet_loss, 0.0)
74+
valid_triplets = tf.to_float(tf.greater(triplet_loss, 1e-16))
75+
num_positive_triplets = tf.reduce_sum(valid_triplets)
76+
num_valid_triplets = tf.reduce_sum(mask)
77+
fraction_postive_triplets = num_positive_triplets / (num_valid_triplets + 1e-16)
78+
triplet_loss = tf.reduce_sum(triplet_loss) / (num_positive_triplets + 1e-16)
79+
return triplet_loss, fraction_postive_triplets

code/triplet-loss/triplet_loss_np.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# -*- coding: utf-8 -*-
2+
# Author: Lawlite
3+
# Date: 2018/10/20
4+
# Associate Blog: http://lawlite.me/2018/10/16/Triplet-Loss原理及其实现/#more
5+
# License: MIT
6+
import numpy as np
7+
8+
9+
def test_pairwise_distances(squared = False):
10+
'''两两embedding的距离,比如第一行, 0和0距离为0, 0和1距离为8, 0和2距离为16 (注意开过根号)
11+
[[ 0. 8. 16.]
12+
[ 8. 0. 8.]
13+
[16. 8. 0.]]
14+
'''
15+
embeddings = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], dtype=np.float32)
16+
dot_product = np.dot(embeddings, np.transpose(embeddings))
17+
square_norm = np.diag(dot_product)
18+
distances = np.expand_dims(square_norm, axis=1) - 2.0*dot_product + np.expand_dims(square_norm, 0)
19+
mask = np.float32(np.equal(distances, 0.0))
20+
if not squared:
21+
distances = distances + mask * 1e-16
22+
distances = np.sqrt(distances)
23+
distances = distances * (1.0 - mask)
24+
print(distances)
25+
return distances
26+
27+
def test_get_triplet_mask(labels):
28+
'''
29+
valid (i, j, k)满足
30+
- i, j, k都不相等
31+
- labels[i] == labels[j] && labels[i] != labels[k]
32+
33+
array([[[False, False, False],
34+
[False, False, False],
35+
[False, True, False]],
36+
37+
[[False, False, False],
38+
[False, False, False],
39+
[False, False, False]],
40+
41+
[[False, True, False],
42+
[False, False, False],
43+
[False, False, False]]])
44+
'''
45+
indices_equal = np.cast[np.bool](np.eye(np.shape(labels)[0], dtype=np.int32))
46+
indices_not_equal = np.logical_not(indices_equal)
47+
i_not_equal_j = np.expand_dims(indices_not_equal, 2)
48+
i_not_equal_k = np.expand_dims(indices_not_equal, 1)
49+
j_not_equal_k = np.expand_dims(indices_not_equal, 0)
50+
51+
distinct_indices = np.logical_and(np.logical_and(i_not_equal_j, i_not_equal_k), j_not_equal_k)
52+
label_equal = np.equal(np.expand_dims(labels, 0), np.expand_dims(labels, 1))
53+
i_equal_j = np.expand_dims(label_equal, 2)
54+
i_equal_k = np.expand_dims(label_equal, 1)
55+
56+
valid_labels = np.logical_and(i_equal_j, np.logical_not(i_equal_k))
57+
mask = np.logical_and(valid_labels, distinct_indices)
58+
return mask
59+
60+
61+
def test_batch_all_triplet_loss(margin):
62+
labels = np.array([1, 0, 1]) # 比如1,3是正例,2是负例,这样计算出的loss应该是16-8 = 8
63+
pairwise_distances = test_pairwise_distances()
64+
anchor_positive = np.expand_dims(pairwise_distances, axis=2)
65+
anchor_negative = np.expand_dims(pairwise_distances, axis=1)
66+
triplet_loss = anchor_positive - anchor_negative + margin
67+
68+
mask = test_get_triplet_mask(labels)
69+
mask = np.cast[np.float32](mask)
70+
triplet_loss = np.multiply(mask, triplet_loss)
71+
triplet_loss = np.maximum(triplet_loss, 0.0)
72+
73+
valid_triplet_loss = np.cast[np.float32](np.greater(triplet_loss, 1e-16))
74+
num_positive_triplet = np.sum(valid_triplet_loss)
75+
num_valid_triplet_loss = np.sum(mask)
76+
fraction_positive_triplet = num_positive_triplet / (num_valid_triplet_loss + 1e-16)
77+
78+
triplet_loss = np.sum(triplet_loss) / (num_positive_triplet + 1e-16)
79+
return triplet_loss, fraction_positive_triplet
80+
81+
82+
83+
if __name__ == '__main__':
84+
# test_pairwise_distances()
85+
test_batch_all_triplet_loss(margin = 0.0)

0 commit comments

Comments
 (0)