Skip to content

Commit 9504246

Browse files
edgarribasoumith
authored andcommitted
add triplet margin loss (pytorch#1165)
1 parent 81cf3db commit 9504246

File tree

3 files changed

+116
-2
lines changed

3 files changed

+116
-2
lines changed

test/test_nn.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1828,6 +1828,20 @@ def test_pairwise_distance(self):
18281828
input2 = Variable(torch.randn(4, 4), requires_grad=True)
18291829
self.assertTrue(gradcheck(lambda x, y: F.pairwise_distance(x, y), (input1, input2)))
18301830

1831+
def test_triplet_margin_loss(self):
1832+
input1 = Variable(torch.randn(4, 4), requires_grad=True)
1833+
input2 = Variable(torch.randn(4, 4), requires_grad=True)
1834+
input3 = Variable(torch.randn(4, 4), requires_grad=True)
1835+
self.assertTrue(gradcheck(lambda x1, x2, x3: F.triplet_margin_loss(
1836+
x1, x2, x3), (input1, input2, input3)))
1837+
1838+
def test_triplet_margin_swap_loss(self):
1839+
input1 = Variable(torch.randn(4, 4), requires_grad=True)
1840+
input2 = Variable(torch.randn(4, 4), requires_grad=True)
1841+
input3 = Variable(torch.randn(4, 4), requires_grad=True)
1842+
self.assertTrue(gradcheck(lambda x1, x2, x3: F.triplet_margin_loss(
1843+
x1, x2, x3, swap=True), (input1, input2, input3)))
1844+
18311845

18321846
class TestNNInit(TestCase):
18331847
def setUp(self):

torch/nn/functional.py

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -634,8 +634,9 @@ def pairwise_distance(x1, x2, p=2, eps=1e-6):
634634
\Vert x \Vert _p := \left( \sum_{i=1}^n \vert x_i \vert ^ p \right) ^ {1/p}
635635
636636
Args:
637-
x (Tensor): input tensor containing the two input batches
638-
p (real): the norm degree. Default: 2
637+
x1: first input tensor
638+
x2: second input tensor
639+
p: the norm degree. Default: 2
639640
640641
Shape:
641642
- Input: :math:`(N, D)` where `D = vector dimension`
@@ -651,3 +652,55 @@ def pairwise_distance(x1, x2, p=2, eps=1e-6):
651652
diff = torch.abs(x1 - x2)
652653
out = torch.pow(diff + eps, p).sum(dim=1)
653654
return torch.pow(out, 1. / p)
655+
656+
657+
def triplet_margin_loss(anchor, positive, negative, margin=1.0, p=2, eps=1e-6, swap=False):
658+
r"""Creates a criterion that measures the triplet loss given an input tensors x1, x2, x3
659+
and a margin with a value greater than 0.
660+
This is used for measuring a relative similarity between samples. A triplet is composed by
661+
`a`, `p` and `n`: anchor, positive examples and negative example respectively.
662+
The shape of all input variables should be :math:`(N, D)`.
663+
664+
The distance swap is described in detail in the paper `Learning shallow convolutional feature descriptors with
665+
triplet losses`_ by V. Balntas, E. Riba et al.
666+
667+
.. math::
668+
L(a, p, n) = \frac{1}{N} \left( \sum_{i=1}^N \max \{d(a_i, p_i) - d(a_i, n_i) + {\rm margin}, 0\} \right)
669+
670+
where :math: `d(x_i, y_i) = \| {\bf x}_i - {\bf y}_i \|_2^2`.
671+
672+
Args:
673+
anchor: anchor input tensor
674+
positive: positive input tensor
675+
negative: negative input tensor
676+
p: the norm degree. Default: 2
677+
eps: small epsilon value to avoid numerical issues
678+
swap: compute distance swap
679+
680+
Shape:
681+
- Input: :math:`(N, D)` where `D = vector dimension`
682+
- Output: :math:`(N, 1)`
683+
684+
>>> input1 = autograd.Variable(torch.randn(100, 128))
685+
>>> input2 = autograd.Variable(torch.randn(100, 128))
686+
>>> input3 = autograd.Variable(torch.randn(100, 128))
687+
>>> output = F.triplet_margin_loss(input1, input2, input3, p=2)
688+
>>> output.backward()
689+
690+
.. _Learning shallow convolutional feature descriptors with triplet losses:
691+
http://www.iis.ee.ic.ac.uk/%7Evbalnt/shallow_descr/TFeat_paper.pdf
692+
"""
693+
assert anchor.size() == positive.size(), "Input sizes between positive and negative must be equal."
694+
assert anchor.size() == negative.size(), "Input sizes between anchor and negative must be equal."
695+
assert positive.size() == negative.size(), "Input sizes between positive and negative must be equal."
696+
assert anchor.dim() == 2, "Inputd must be a 2D matrix."
697+
assert margin > 0.0, 'Margin should be positive value.'
698+
d_p = pairwise_distance(anchor, positive, p, eps)
699+
d_n = pairwise_distance(anchor, negative, p, eps)
700+
if swap:
701+
d_s = pairwise_distance(positive, negative, p, eps)
702+
d_n = torch.min(d_n, d_s)
703+
704+
dist_hinge = torch.clamp(margin + d_p - d_n, min=0.0)
705+
loss = torch.mean(dist_hinge)
706+
return loss

torch/nn/modules/loss.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,53 @@ def forward(self, input, target):
427427
self.margin, weight=self.weight)(input, target)
428428

429429

430+
class TripletMarginLoss(Module):
431+
r"""Creates a criterion that measures the triplet loss given an input tensors x1, x2, x3
432+
and a margin with a value greater than 0.
433+
This is used for measuring a relative similarity between samples. A triplet is composed by
434+
`a`, `p` and `n`: anchor, positive examples and negative example respectively.
435+
The shape of all input variables should be :math:`(N, D)`.
436+
437+
The distance swap is described in detail in the paper `Learning shallow convolutional feature descriptors with
438+
triplet losses`_ by V. Balntas, E. Riba et al.
439+
440+
.. math::
441+
L(a, p, n) = \frac{1}{N} \left( \sum_{i=1}^N \max \{d(a_i, p_i) - d(a_i, n_i) + {\rm margin}, 0\} \right)
442+
443+
where :math: `d(x_i, y_i) = \| {\bf x}_i - {\bf y}_i \|_2^2`.
444+
445+
Args:
446+
anchor: anchor input tensor
447+
positive: positive input tensor
448+
negative: negative input tensor
449+
p: the norm degree. Default: 2
450+
451+
Shape:
452+
- Input: :math:`(N, D)` where `D = vector dimension`
453+
- Output: :math:`(N, 1)`
454+
455+
>>> triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2)
456+
>>> input1 = autograd.Variable(torch.randn(100, 128))
457+
>>> input2 = autograd.Variable(torch.randn(100, 128))
458+
>>> input3 = autograd.Variable(torch.randn(100, 128))
459+
>>> output = triplet_loss(input1, input2, input3)
460+
>>> output.backward()
461+
462+
.. _Learning shallow convolutional feature descriptors with triplet losses:
463+
http://www.iis.ee.ic.ac.uk/%7Evbalnt/shallow_descr/TFeat_paper.pdf
464+
"""
465+
466+
def __init__(self, margin=1.0, p=2, eps=1e-6, swap=False):
467+
super(TripletMarginLoss, self).__init__()
468+
self.margin = margin
469+
self.p = p
470+
self.eps = eps
471+
self.swap = swap
472+
473+
def forward(self, anchor, positive, negative):
474+
return F.triplet_margin_loss(anchor, positive, negative, self.margin,
475+
self.p, self.eps, self.swap)
476+
430477
# TODO: L1HingeEmbeddingCriterion
431478
# TODO: MSECriterion weight
432479
# TODO: ClassSimplexCriterion

0 commit comments

Comments
 (0)