Skip to content

Commit 00410c4

Browse files
apaszkesoumith
authored andcommitted
Fix broken THNN groups in conv functions
1 parent 8b9276b commit 00410c4

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

test/test_nn.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1308,6 +1308,11 @@ def add_test(test):
13081308
input_size=(2, 4, 6, 5),
13091309
cudnn=True,
13101310
),
1311+
dict(
1312+
fullname='Conv2d_groups_thnn',
1313+
constructor=lambda: nn.Conv2d(4, 6, (3, 2), groups=2),
1314+
input_size=(2, 4, 6, 5),
1315+
),
13111316
dict(
13121317
module_name='ConvTranspose2d',
13131318
constructor_args=(3, 4, 3, (3, 2), 1, (1, 1)),

torch/nn/_functions/conv.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,8 @@ def _thnn(self, fn_name, input, weight, *args):
149149
res = []
150150
for g in range(self.groups):
151151
def group(tensor, dim=None):
152+
if tensor is None:
153+
return None
152154
if dim is None:
153155
dim = 0 if tensor.dim() == 1 else 1
154156
n = tensor.size(dim) // self.groups
@@ -158,7 +160,8 @@ def group(tensor, dim=None):
158160
grouped_args += [group(t) for t in args]
159161
res.append(impl[fn_name](self, self._bufs[g], *grouped_args))
160162
if fn_name == 'grad_params':
161-
return [torch.cat(t, 0) for t in zip(*res)]
163+
return [torch.cat(t, 0) if t[0] is not None else None
164+
for t in zip(*res)]
162165
else:
163166
return torch.cat(res, 1)
164167

0 commit comments

Comments
 (0)