Skip to content

Commit f3f4789

Browse files
ezyangsoumith
authored andcommitted
Convert Embedding to new style. (pytorch#1916)
Signed-off-by: Edward Z. Yang <[email protected]>
1 parent e537023 commit f3f4789

File tree

3 files changed

+46
-41
lines changed

3 files changed

+46
-41
lines changed

torch/nn/_functions/thnn/sparse.py

Lines changed: 40 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,51 @@
11
import torch
22
from torch.autograd.function import Function
33
from torch._thnn import type2backend
4+
from torch.autograd.function import once_differentiable
45

56
from . import _all_functions
67

78

89
class Embedding(Function):
910

10-
def __init__(self, padding_idx, max_norm, norm_type, scale_grad_by_freq,
11-
sparse=False):
12-
super(Embedding, self).__init__()
13-
self.padding_idx = padding_idx
14-
self.max_norm = max_norm
15-
self.norm_type = norm_type
16-
self.scale_grad_by_freq = scale_grad_by_freq
17-
self._indices = None
18-
self.sparse = sparse
19-
20-
def _renorm(self, indices, weight):
11+
@staticmethod
12+
def _renorm(ctx, indices, weight, max_norm, norm_type):
2113
if indices.dim() == 2:
2214
indices = indices.clone().view(-1)
2315

24-
self._backend.LookupTable_renorm(
25-
self._backend.library_state,
16+
ctx._backend.LookupTable_renorm(
17+
ctx._backend.library_state,
2618
indices,
2719
weight,
28-
self.max_norm,
29-
self.norm_type
20+
max_norm,
21+
norm_type
3022
)
3123

32-
def forward(self, indices, weight):
24+
@classmethod
25+
def forward(cls, ctx, indices, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq,
26+
sparse=False):
27+
28+
ctx.padding_idx = padding_idx
29+
ctx.scale_grad_by_freq = scale_grad_by_freq
30+
ctx._indices = None
31+
ctx.sparse = sparse
32+
3333
assert indices.dim() <= 2
34-
assert not self.needs_input_grad[0], "Embedding doesn't " \
34+
assert not ctx.needs_input_grad[0], "Embedding doesn't " \
3535
"compute the gradient w.r.t. the indices"
3636

37-
self._backend = type2backend[type(weight)]
38-
self._weight_size = weight.size()
37+
ctx._backend = type2backend[type(weight)]
38+
ctx._weight_size = weight.size()
3939

4040
if not indices.is_contiguous():
41-
self._indices = indices.contiguous()
42-
indices = self._indices
41+
ctx._indices = indices.contiguous()
42+
indices = ctx._indices
4343
else:
44-
self.save_for_backward(indices)
44+
ctx.save_for_backward(indices)
4545

4646
output = weight.new()
47-
if self.max_norm is not None:
48-
self._renorm(indices, weight)
47+
if max_norm is not None:
48+
cls._renorm(indices, weight, max_norm, norm_type)
4949

5050
if indices.dim() == 1:
5151
output = torch.index_select(weight, 0, indices)
@@ -55,14 +55,16 @@ def forward(self, indices, weight):
5555

5656
return output
5757

58-
def backward(self, grad_output):
59-
if self._indices is not None:
60-
indices = self._indices
58+
@staticmethod
59+
@once_differentiable
60+
def backward(ctx, grad_output):
61+
if ctx._indices is not None:
62+
indices = ctx._indices
6163
else:
62-
indices, = self.saved_tensors
64+
indices, = ctx.saved_tensors
6365

6466
grad_output = grad_output.contiguous()
65-
if not self.sparse:
67+
if not ctx.sparse:
6668
if indices.dim() == 2:
6769
indices = indices.view(-1)
6870

@@ -75,17 +77,18 @@ def backward(self, grad_output):
7577
_count = torch.IntTensor()
7678
_sorted = _indices = None
7779

78-
grad_weight = grad_output.new(self._weight_size).zero_()
79-
self._backend.LookupTable_accGradParameters(
80-
self._backend.library_state,
80+
grad_weight = grad_output.new(ctx._weight_size).zero_()
81+
# Doesn't support Variable grad_output
82+
ctx._backend.LookupTable_accGradParameters(
83+
ctx._backend.library_state,
8184
indices,
8285
grad_output,
8386
grad_weight,
8487
_count,
8588
_sorted,
8689
_indices,
87-
self.scale_grad_by_freq,
88-
self.padding_idx,
90+
ctx.scale_grad_by_freq,
91+
ctx.padding_idx,
8992
1
9093
)
9194
else:
@@ -96,10 +99,10 @@ def backward(self, grad_output):
9699
SparseTensor = getattr(torch.sparse, tensor_type)
97100
grad_weight = SparseTensor(
98101
indices.view(1, -1),
99-
grad_output.view(-1, self._weight_size[1]),
100-
self._weight_size,
102+
grad_output.view(-1, ctx._weight_size[1]),
103+
ctx._weight_size,
101104
)
102-
return None, grad_weight
105+
return None, grad_weight, None, None, None, None, None
103106

104107

105108
_all_functions.append(Embedding)

torch/nn/functional.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -626,10 +626,11 @@ def embedding(input, embedding_matrix,
626626
[torch.FloatTensor of size 1x4x3]
627627
628628
"""
629-
return torch.nn.backends.thnn.backend.Embedding(
629+
return torch.nn.backends.thnn.backend.Embedding.apply(
630+
input, embedding_matrix,
630631
-1, max_norm, norm_type,
631632
scale_grad_by_freq, sparse
632-
)(input, embedding_matrix)
633+
)
633634

634635

635636
def batch_norm(input, running_mean, running_var, weight=None, bias=None,

torch/nn/modules/sparse.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,11 @@ def forward(self, input):
8888
padding_idx = self.padding_idx
8989
if padding_idx is None:
9090
padding_idx = -1
91-
return self._backend.Embedding(
91+
return self._backend.Embedding.apply(
92+
input, self.weight,
9293
padding_idx, self.max_norm, self.norm_type,
9394
self.scale_grad_by_freq, self.sparse
94-
)(input, self.weight)
95+
)
9596

9697
def __repr__(self):
9798
s = '{name}({num_embeddings}, {embedding_dim}'

0 commit comments

Comments
 (0)