@@ -35,12 +35,15 @@ def _map_list(self, l1, l2, f):
3535 del l1 [i ]
3636 return l1
3737
38- def _backward (self , method , input , gradOutput , scale ):
38+ def _backward (self , method , input , gradOutput , scale = 1 ):
3939 isTable = isinstance (input , list )
4040 wasTable = isinstance (self .gradInput , list )
4141 if isTable :
4242 for i , module in enumerate (self .modules ):
43- currentGradInput = getattr (module , method )(input , gradOutput [i ], scale )
43+ if method == 'updateGradInput' :
44+ currentGradInput = module .updateGradInput (input , gradOutput [i ])
45+ elif method == 'backward' :
46+ currentGradInput = module .backward (input , gradOutput [i ], scale )
4447 if not isinstance (currentGradInput , list ):
4548 raise RuntimeError ("currentGradInput is not a table!" )
4649
@@ -68,7 +71,10 @@ def fn(l, i, v):
6871 else :
6972 self .gradInput = self .gradInput if not wasTable else input .clone ()
7073 for i , module in enumerate (self .modules ):
71- currentGradInput = getattr (module , method )(input , gradOutput [i ], scale )
74+ if method == 'updateGradInput' :
75+ currentGradInput = module .updateGradInput (input , gradOutput [i ])
76+ elif method == 'backward' :
77+ currentGradInput = module .backward (input , gradOutput [i ], scale )
7278 if i == 0 :
7379 self .gradInput .resize_as_ (currentGradInput ).copy_ (currentGradInput )
7480 else :
0 commit comments