Skip to content

Commit 6b84dc2

Browse files
apaszkesoumith
authored andcommitted
Add F.cosine_similarity (pytorch#1502)
1 parent 0f458ee commit 6b84dc2

File tree

3 files changed

+37
-3
lines changed

3 files changed

+37
-3
lines changed

docs/source/nn.rst

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ Pooling Layers
160160
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
161161

162162
.. autoclass:: AdaptiveMaxPool2d
163-
:members:
163+
:members:
164164

165165
:hidden:`AdaptiveAvgPool1d`
166166
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -174,7 +174,7 @@ Pooling Layers
174174
.. autoclass:: AdaptiveAvgPool2d
175175
:members:
176176

177-
177+
178178
Non-linear Activations
179179
----------------------------------
180180

@@ -682,7 +682,7 @@ Pooling functions
682682

683683
.. autofunction:: adaptive_avg_pool2d
684684

685-
685+
686686
Non-linear activation functions
687687
-------------------------------
688688

@@ -814,6 +814,11 @@ Distance functions
814814

815815
.. autofunction:: pairwise_distance
816816

817+
:hidden:`cosine_similarity`
818+
~~~~~~~~~~~~~~~~~~~~~~~~~~~
819+
820+
.. autofunction:: cosine_similarity
821+
817822

818823
Loss functions
819824
--------------

test/test_nn.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2063,6 +2063,16 @@ def test_triplet_margin_swap_loss(self):
20632063
self.assertTrue(gradcheck(lambda x1, x2, x3: F.triplet_margin_loss(
20642064
x1, x2, x3, swap=True), (input1, input2, input3)))
20652065

2066+
def test_cosine_similarity(self):
2067+
input1 = Variable(torch.randn(4, 4), requires_grad=True)
2068+
input2 = Variable(torch.randn(4, 4), requires_grad=True)
2069+
self.assertTrue(gradcheck(lambda x, y: F.cosine_similarity(x, y), (input1, input2)))
2070+
2071+
input1 = Variable(torch.randn(4, 5, 6), requires_grad=True)
2072+
input2 = Variable(torch.randn(4, 5, 6), requires_grad=True)
2073+
self.assertTrue(gradcheck(lambda x, y: F.cosine_similarity(x, y, dim=0), (input1, input2)))
2074+
self.assertTrue(gradcheck(lambda x, y: F.cosine_similarity(x, y, dim=-1), (input1, input2)))
2075+
20662076
def test_bilinear(self):
20672077
module = nn.Bilinear(10, 10, 8)
20682078
module2 = legacy.Bilinear(10, 10, 8)

torch/nn/functional.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -694,6 +694,25 @@ def pairwise_distance(x1, x2, p=2, eps=1e-6):
694694
return torch.pow(out, 1. / p)
695695

696696

697+
def cosine_similarity(x1, x2, dim=1, eps=1e-8):
698+
r"""Returns cosine similarity between x1 and x2, computed along dim.
699+
700+
Args:
701+
x1 (Variable): First input.
702+
x2 (Variable): Second input (of size matching x1).
703+
dim (int, optional): Dimension of vectors. Default: 1
704+
eps (float, optional): Small value to avoid division by zero. Default: 1e-8
705+
706+
Shape:
707+
- Input: :math:`(\ast_1, D, \ast_2)` where D is at position `dim`.
708+
- Output: :math:`(\ast_1, \ast_2)` where 1 is at position `dim`.
709+
"""
710+
w12 = torch.sum(x1 * x2, dim)
711+
w1 = torch.norm(x1, 2, dim)
712+
w2 = torch.norm(x2, 2, dim)
713+
return (w12 / (w1 * w2).clamp(min=eps)).squeeze()
714+
715+
697716
def triplet_margin_loss(anchor, positive, negative, margin=1.0, p=2, eps=1e-6, swap=False):
698717
r"""Creates a criterion that measures the triplet loss given an input tensors x1, x2, x3
699718
and a margin with a value greater than 0.

0 commit comments

Comments
 (0)