We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 0f458ee commit 6b84dc2Copy full SHA for 6b84dc2
docs/source/nn.rst
@@ -160,7 +160,7 @@ Pooling Layers
160
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
161
162
.. autoclass:: AdaptiveMaxPool2d
163
- :members:
+ :members:
164
165
:hidden:`AdaptiveAvgPool1d`
166
@@ -174,7 +174,7 @@ Pooling Layers
174
.. autoclass:: AdaptiveAvgPool2d
175
:members:
176
177
-
+
178
Non-linear Activations
179
----------------------------------
180
@@ -682,7 +682,7 @@ Pooling functions
682
683
.. autofunction:: adaptive_avg_pool2d
684
685
686
Non-linear activation functions
687
-------------------------------
688
@@ -814,6 +814,11 @@ Distance functions
814
815
.. autofunction:: pairwise_distance
816
817
+:hidden:`cosine_similarity`
818
+~~~~~~~~~~~~~~~~~~~~~~~~~~~
819
820
+.. autofunction:: cosine_similarity
821
822
823
Loss functions
824
--------------
test/test_nn.py
@@ -2063,6 +2063,16 @@ def test_triplet_margin_swap_loss(self):
2063
self.assertTrue(gradcheck(lambda x1, x2, x3: F.triplet_margin_loss(
2064
x1, x2, x3, swap=True), (input1, input2, input3)))
2065
2066
+ def test_cosine_similarity(self):
2067
+ input1 = Variable(torch.randn(4, 4), requires_grad=True)
2068
+ input2 = Variable(torch.randn(4, 4), requires_grad=True)
2069
+ self.assertTrue(gradcheck(lambda x, y: F.cosine_similarity(x, y), (input1, input2)))
2070
2071
+ input1 = Variable(torch.randn(4, 5, 6), requires_grad=True)
2072
+ input2 = Variable(torch.randn(4, 5, 6), requires_grad=True)
2073
+ self.assertTrue(gradcheck(lambda x, y: F.cosine_similarity(x, y, dim=0), (input1, input2)))
2074
+ self.assertTrue(gradcheck(lambda x, y: F.cosine_similarity(x, y, dim=-1), (input1, input2)))
2075
2076
def test_bilinear(self):
2077
module = nn.Bilinear(10, 10, 8)
2078
module2 = legacy.Bilinear(10, 10, 8)
torch/nn/functional.py
@@ -694,6 +694,25 @@ def pairwise_distance(x1, x2, p=2, eps=1e-6):
694
return torch.pow(out, 1. / p)
695
696
697
+def cosine_similarity(x1, x2, dim=1, eps=1e-8):
698
+ r"""Returns cosine similarity between x1 and x2, computed along dim.
699
700
+ Args:
701
+ x1 (Variable): First input.
702
+ x2 (Variable): Second input (of size matching x1).
703
+ dim (int, optional): Dimension of vectors. Default: 1
704
+ eps (float, optional): Small value to avoid division by zero. Default: 1e-8
705
706
+ Shape:
707
+ - Input: :math:`(\ast_1, D, \ast_2)` where D is at position `dim`.
708
+ - Output: :math:`(\ast_1, \ast_2)` where 1 is at position `dim`.
709
+ """
710
+ w12 = torch.sum(x1 * x2, dim)
711
+ w1 = torch.norm(x1, 2, dim)
712
+ w2 = torch.norm(x2, 2, dim)
713
+ return (w12 / (w1 * w2).clamp(min=eps)).squeeze()
714
715
716
def triplet_margin_loss(anchor, positive, negative, margin=1.0, p=2, eps=1e-6, swap=False):
717
r"""Creates a criterion that measures the triplet loss given an input tensors x1, x2, x3
718
and a margin with a value greater than 0.
0 commit comments