Skip to content

Commit 518864a

Browse files
adamlerersoumith
authored andcommitted
Fix bug in legacy NN updateGradParameters (pytorch#714)
1 parent d9dccfd commit 518864a

File tree

2 files changed

+21
-6
lines changed

2 files changed

+21
-6
lines changed

test/test_legacy_nn.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1154,6 +1154,15 @@ def test_NarrowTable(self):
11541154
module.__repr__()
11551155
str(module)
11561156

1157+
def test_accUpdateGradParameters(self):
1158+
module = nn.LookupTable(5, 3)
1159+
module.weight.fill_(2)
1160+
input = torch.LongTensor([1, 3])
1161+
output = module.updateOutput(input)
1162+
module.backwardUpdate(input, output, 0.1)
1163+
self.assertEqual(module.weight[0, 0], 2)
1164+
self.assertEqual(module.weight[3, 0], 1.8)
1165+
11571166
def _build_net(self):
11581167
return (nn.Sequential()
11591168
.add(nn.Concat(0)

torch/legacy/nn/Module.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,19 @@ def accGradParameters(self, input, gradOutput, scale=1):
4949
pass
5050

5151
def accUpdateGradParameters(self, input, gradOutput, lr):
52-
gradWeight = self.gradWeight
53-
gradBias = self.gradBias
54-
self.gradWeight = self.weight
55-
self.gradBias = self.bias
52+
has_weight = hasattr(self, 'weight') and self.weight is not None
53+
has_bias = hasattr(self, 'bias') and self.bias is not None
54+
if has_weight:
55+
gradWeight = self.gradWeight
56+
self.gradWeight = self.weight
57+
if has_bias:
58+
gradBias = self.gradBias
59+
self.gradBias = self.bias
5660
self.accGradParameters(input, gradOutput, -lr)
57-
self.gradWeight = gradWeight
58-
self.gradBias = gradBias
61+
if has_weight:
62+
self.gradWeight = gradWeight
63+
if has_bias:
64+
self.gradBias = gradBias
5965

6066
def sharedAccUpdateGradParameters(self, input, gradOutput, lr):
6167
if self.parameters():

0 commit comments

Comments
 (0)