Skip to content

Commit 1031d67

Browse files
szagoruykoapaszke
authored andcommitted
legacy fixes (pytorch#287)
1 parent 0d7d29f commit 1031d67

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

torch/legacy/nn/ConcatTable.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,15 @@ def _map_list(self, l1, l2, f):
3535
del l1[i]
3636
return l1
3737

38-
def _backward(self, method, input, gradOutput, scale):
38+
def _backward(self, method, input, gradOutput, scale=1):
3939
isTable = isinstance(input, list)
4040
wasTable = isinstance(self.gradInput, list)
4141
if isTable:
4242
for i, module in enumerate(self.modules):
43-
currentGradInput = getattr(module, method)(input, gradOutput[i], scale)
43+
if method == 'updateGradInput':
44+
currentGradInput = module.updateGradInput(input, gradOutput[i])
45+
elif method == 'backward':
46+
currentGradInput = module.backward(input, gradOutput[i], scale)
4447
if not isinstance(currentGradInput, list):
4548
raise RuntimeError("currentGradInput is not a table!")
4649

@@ -68,7 +71,10 @@ def fn(l, i, v):
6871
else:
6972
self.gradInput = self.gradInput if not wasTable else input.clone()
7073
for i, module in enumerate(self.modules):
71-
currentGradInput = getattr(module, method)(input, gradOutput[i], scale)
74+
if method == 'updateGradInput':
75+
currentGradInput = module.updateGradInput(input, gradOutput[i])
76+
elif method == 'backward':
77+
currentGradInput = module.backward(input, gradOutput[i], scale)
7278
if i == 0:
7379
self.gradInput.resize_as_(currentGradInput).copy_(currentGradInput)
7480
else:

torch/utils/serialization/read_lua_file.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,7 @@ def ensure_type(obj, type_map):
344344
ensure_attr('ClassNLLCriterion', 'weights')
345345
ensure_attr('ParallelCriterion', 'repeatTarget')
346346
ensure_attr('MultiMarginCriterion', 'weights')
347+
ensure_attr('SpatialConvolution', 'bias', 'finput', 'fgradInput', 'gradWeight', 'gradBias')
347348
attr_map('ReLU', {'val': 'value'})
348349
attr_map('Threshold', {'val': 'value'})
349350
attr_map('Unsqueeze', {'pos': 'dim'})

0 commit comments

Comments
 (0)