Skip to content

Commit e537023

Browse files
hughperkinssoumith
authored andcommitted
add functional embedding (pytorch#1987)
1 parent 09abaa2 commit e537023

File tree

2 files changed

+81
-0
lines changed

2 files changed

+81
-0
lines changed

test/test_nn.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -803,6 +803,20 @@ def test_embedding_padding_idx(self):
803803
self.assertEqual(output[0][0].sum().data[0], 0)
804804
self.assertEqual(output[1][2].sum().data[0], 0)
805805

806+
def test_embedding_functional(self):
807+
a = Variable(torch.LongTensor([
808+
[1, 3, 2],
809+
[0, 2, 1]
810+
]))
811+
embeddings = Variable(torch.rand(4, 3), requires_grad=True)
812+
813+
embed_old = torch.nn.Embedding(4, 3)
814+
embed_old.weight.data = embeddings.data
815+
res_old = embed_old(a)
816+
817+
res_F = F.embedding(a, embeddings)
818+
self.assertEqual(res_old, res_F)
819+
806820
def _test_EmbeddingBag(self, cuda, mode):
807821
# check a known test example
808822
es = nn.EmbeddingBag(5, 2, mode=mode)

torch/nn/functional.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,73 @@ def bilinear(input1, input2, weight, bias=None):
565565
return Bilinear.apply(input1, input2, weight, bias)
566566

567567

568+
def embedding(input, embedding_matrix,
569+
max_norm=None, norm_type=2, scale_grad_by_freq=False,
570+
sparse=False):
571+
r"""A simple lookup table that looks up embeddings in a fixed dictionary and size.
572+
573+
This module is often used to retrieve word embeddings using indices.
574+
The input to the module is a list of indices, and the embedding matrix,
575+
and the output is the corresponding word embeddings.
576+
577+
Args:
578+
input: tensor, containing indices into the embedding matrix
579+
embedding_matrix:
580+
Number of rows should correspond to the maximum possible index + 1,
581+
number of columns is the embedding size
582+
max_norm (float, optional): If given, will renormalize the embeddings to always have a norm lesser than this
583+
norm_type (float, optional): The p of the p-norm to compute for the max_norm option
584+
scale_grad_by_freq (boolean, optional): if given, this will scale gradients by the frequency of
585+
the words in the mini-batch.
586+
587+
Shape:
588+
- Input: LongTensor `(N, W)`, N = mini-batch, W = number of indices to extract per mini-batch
589+
- Embedding_matrix: FloatTensor `(V, embedding_dim)`, V = maximum index + 1, embedding_dim = embedding size
590+
- Output: `(N, W, embedding_dim)`
591+
592+
Examples::
593+
594+
>>> # a batch of 2 samples of 4 indices each
595+
>>> input = Variable(torch.LongTensor([[1,2,4,5],[4,3,2,9]]))
596+
>>> # an embedding matrix containing 10 tensors of size 3
597+
>>> embedding_matrix = Variable(torch.rand(10, 3))
598+
>>> torch.nn.functional.embedding(input, embedding_matrix)
599+
600+
Variable containing:
601+
(0 ,.,.) =
602+
-1.0822 1.2522 0.2434
603+
0.8393 -0.6062 -0.3348
604+
0.6597 0.0350 0.0837
605+
0.5521 0.9447 0.0498
606+
607+
(1 ,.,.) =
608+
0.6597 0.0350 0.0837
609+
-0.1527 0.0877 0.4260
610+
0.8393 -0.6062 -0.3348
611+
-0.8738 -0.9054 0.4281
612+
[torch.FloatTensor of size 2x4x3]
613+
614+
>>> # example with padding_idx
615+
>>> embedding_matrix = Variable(torch.rand(10, 3))
616+
>>> embedding_matrix[0].zero_()
617+
>>> input = Variable(torch.LongTensor([[0,2,0,5]]))
618+
>>> torch.nn.functional.embedding(input, embedding_matrix)
619+
620+
Variable containing:
621+
(0 ,.,.) =
622+
0.0000 0.0000 0.0000
623+
0.3452 0.4937 -0.9361
624+
0.0000 0.0000 0.0000
625+
0.0706 -2.1962 -0.6276
626+
[torch.FloatTensor of size 1x4x3]
627+
628+
"""
629+
return torch.nn.backends.thnn.backend.Embedding(
630+
-1, max_norm, norm_type,
631+
scale_grad_by_freq, sparse
632+
)(input, embedding_matrix)
633+
634+
568635
def batch_norm(input, running_mean, running_var, weight=None, bias=None,
569636
training=False, momentum=0.1, eps=1e-5):
570637
f = torch._C._functions.BatchNorm(running_mean, running_var, training, momentum, eps, torch.backends.cudnn.enabled)

0 commit comments

Comments
 (0)