11import torch
22from torch .autograd .function import Function
33from torch ._thnn import type2backend
4+ from torch .autograd .function import once_differentiable
45
56from . import _all_functions
67
78
89class Embedding (Function ):
910
10- def __init__ (self , padding_idx , max_norm , norm_type , scale_grad_by_freq ,
11- sparse = False ):
12- super (Embedding , self ).__init__ ()
13- self .padding_idx = padding_idx
14- self .max_norm = max_norm
15- self .norm_type = norm_type
16- self .scale_grad_by_freq = scale_grad_by_freq
17- self ._indices = None
18- self .sparse = sparse
19-
20- def _renorm (self , indices , weight ):
11+ @staticmethod
12+ def _renorm (ctx , indices , weight , max_norm , norm_type ):
2113 if indices .dim () == 2 :
2214 indices = indices .clone ().view (- 1 )
2315
24- self ._backend .LookupTable_renorm (
25- self ._backend .library_state ,
16+ ctx ._backend .LookupTable_renorm (
17+ ctx ._backend .library_state ,
2618 indices ,
2719 weight ,
28- self . max_norm ,
29- self . norm_type
20+ max_norm ,
21+ norm_type
3022 )
3123
32- def forward (self , indices , weight ):
24+ @classmethod
25+ def forward (cls , ctx , indices , weight , padding_idx , max_norm , norm_type , scale_grad_by_freq ,
26+ sparse = False ):
27+
28+ ctx .padding_idx = padding_idx
29+ ctx .scale_grad_by_freq = scale_grad_by_freq
30+ ctx ._indices = None
31+ ctx .sparse = sparse
32+
3333 assert indices .dim () <= 2
34- assert not self .needs_input_grad [0 ], "Embedding doesn't " \
34+ assert not ctx .needs_input_grad [0 ], "Embedding doesn't " \
3535 "compute the gradient w.r.t. the indices"
3636
37- self ._backend = type2backend [type (weight )]
38- self ._weight_size = weight .size ()
37+ ctx ._backend = type2backend [type (weight )]
38+ ctx ._weight_size = weight .size ()
3939
4040 if not indices .is_contiguous ():
41- self ._indices = indices .contiguous ()
42- indices = self ._indices
41+ ctx ._indices = indices .contiguous ()
42+ indices = ctx ._indices
4343 else :
44- self .save_for_backward (indices )
44+ ctx .save_for_backward (indices )
4545
4646 output = weight .new ()
47- if self . max_norm is not None :
48- self ._renorm (indices , weight )
47+ if max_norm is not None :
48+ cls ._renorm (indices , weight , max_norm , norm_type )
4949
5050 if indices .dim () == 1 :
5151 output = torch .index_select (weight , 0 , indices )
@@ -55,14 +55,16 @@ def forward(self, indices, weight):
5555
5656 return output
5757
58- def backward (self , grad_output ):
59- if self ._indices is not None :
60- indices = self ._indices
58+ @staticmethod
59+ @once_differentiable
60+ def backward (ctx , grad_output ):
61+ if ctx ._indices is not None :
62+ indices = ctx ._indices
6163 else :
62- indices , = self .saved_tensors
64+ indices , = ctx .saved_tensors
6365
6466 grad_output = grad_output .contiguous ()
65- if not self .sparse :
67+ if not ctx .sparse :
6668 if indices .dim () == 2 :
6769 indices = indices .view (- 1 )
6870
@@ -75,17 +77,18 @@ def backward(self, grad_output):
7577 _count = torch .IntTensor ()
7678 _sorted = _indices = None
7779
78- grad_weight = grad_output .new (self ._weight_size ).zero_ ()
79- self ._backend .LookupTable_accGradParameters (
80- self ._backend .library_state ,
80+ grad_weight = grad_output .new (ctx ._weight_size ).zero_ ()
81+ # Doesn't support Variable grad_output
82+ ctx ._backend .LookupTable_accGradParameters (
83+ ctx ._backend .library_state ,
8184 indices ,
8285 grad_output ,
8386 grad_weight ,
8487 _count ,
8588 _sorted ,
8689 _indices ,
87- self .scale_grad_by_freq ,
88- self .padding_idx ,
90+ ctx .scale_grad_by_freq ,
91+ ctx .padding_idx ,
8992 1
9093 )
9194 else :
@@ -96,10 +99,10 @@ def backward(self, grad_output):
9699 SparseTensor = getattr (torch .sparse , tensor_type )
97100 grad_weight = SparseTensor (
98101 indices .view (1 , - 1 ),
99- grad_output .view (- 1 , self ._weight_size [1 ]),
100- self ._weight_size ,
102+ grad_output .view (- 1 , ctx ._weight_size [1 ]),
103+ ctx ._weight_size ,
101104 )
102- return None , grad_weight
105+ return None , grad_weight , None , None , None , None , None
103106
104107
105108_all_functions .append (Embedding )
0 commit comments