@@ -634,8 +634,9 @@ def pairwise_distance(x1, x2, p=2, eps=1e-6):
634634 \Vert x \Vert _p := \left( \sum_{i=1}^n \vert x_i \vert ^ p \right) ^ {1/p}
635635
636636 Args:
637- x (Tensor): input tensor containing the two input batches
638- p (real): the norm degree. Default: 2
637+ x1: first input tensor
638+ x2: second input tensor
639+ p: the norm degree. Default: 2
639640
640641 Shape:
641642 - Input: :math:`(N, D)` where `D = vector dimension`
@@ -651,3 +652,55 @@ def pairwise_distance(x1, x2, p=2, eps=1e-6):
651652 diff = torch .abs (x1 - x2 )
652653 out = torch .pow (diff + eps , p ).sum (dim = 1 )
653654 return torch .pow (out , 1. / p )
655+
656+
657+ def triplet_margin_loss (anchor , positive , negative , margin = 1.0 , p = 2 , eps = 1e-6 , swap = False ):
658+ r"""Creates a criterion that measures the triplet loss given an input tensors x1, x2, x3
659+ and a margin with a value greater than 0.
660+ This is used for measuring a relative similarity between samples. A triplet is composed by
661+ `a`, `p` and `n`: anchor, positive examples and negative example respectively.
662+ The shape of all input variables should be :math:`(N, D)`.
663+
664+ The distance swap is described in detail in the paper `Learning shallow convolutional feature descriptors with
665+ triplet losses`_ by V. Balntas, E. Riba et al.
666+
667+ .. math::
668+ L(a, p, n) = \frac{1}{N} \left( \sum_{i=1}^N \max \{d(a_i, p_i) - d(a_i, n_i) + {\rm margin}, 0\} \right)
669+
670+ where :math: `d(x_i, y_i) = \| {\bf x}_i - {\bf y}_i \|_2^2`.
671+
672+ Args:
673+ anchor: anchor input tensor
674+ positive: positive input tensor
675+ negative: negative input tensor
676+ p: the norm degree. Default: 2
677+ eps: small epsilon value to avoid numerical issues
678+ swap: compute distance swap
679+
680+ Shape:
681+ - Input: :math:`(N, D)` where `D = vector dimension`
682+ - Output: :math:`(N, 1)`
683+
684+ >>> input1 = autograd.Variable(torch.randn(100, 128))
685+ >>> input2 = autograd.Variable(torch.randn(100, 128))
686+ >>> input3 = autograd.Variable(torch.randn(100, 128))
687+ >>> output = F.triplet_margin_loss(input1, input2, input3, p=2)
688+ >>> output.backward()
689+
690+ .. _Learning shallow convolutional feature descriptors with triplet losses:
691+ http://www.iis.ee.ic.ac.uk/%7Evbalnt/shallow_descr/TFeat_paper.pdf
692+ """
693+ assert anchor .size () == positive .size (), "Input sizes between positive and negative must be equal."
694+ assert anchor .size () == negative .size (), "Input sizes between anchor and negative must be equal."
695+ assert positive .size () == negative .size (), "Input sizes between positive and negative must be equal."
696+ assert anchor .dim () == 2 , "Inputd must be a 2D matrix."
697+ assert margin > 0.0 , 'Margin should be positive value.'
698+ d_p = pairwise_distance (anchor , positive , p , eps )
699+ d_n = pairwise_distance (anchor , negative , p , eps )
700+ if swap :
701+ d_s = pairwise_distance (positive , negative , p , eps )
702+ d_n = torch .min (d_n , d_s )
703+
704+ dist_hinge = torch .clamp (margin + d_p - d_n , min = 0.0 )
705+ loss = torch .mean (dist_hinge )
706+ return loss
0 commit comments