Skip to content

Commit 977f964

Browse files
gchanansoumith
authored andcommitted
Fix ZeroPad2d backwards with negative pads.
1 parent 38b42e0 commit 977f964

File tree

2 files changed

+15
-6
lines changed

2 files changed

+15
-6
lines changed

test/test_nn.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3562,6 +3562,12 @@ def add_test(test):
35623562
constructor_args=((1, 2, 3, 4),),
35633563
input_size=(2, 3, 4, 4)
35643564
),
3565+
dict(
3566+
module_name='ZeroPad2d',
3567+
constructor_args=((-1, -1, -1, -2),),
3568+
input_size=(2, 3, 4, 4),
3569+
desc='negative_dims'
3570+
),
35653571
dict(
35663572
module_name='ConstantPad2d',
35673573
constructor_args=((1, 2, 3, 4), 2),

torch/nn/_functions/padding.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,21 +45,24 @@ def backward(ctx, grad_output):
4545
pad_l, pad_r, pad_t, pad_b = ctx.pad
4646

4747
grad_input = Variable(grad_output.data.new(ctx.input_size).zero_())
48-
grad_input_slices = [slice(0, x + 1,) for x in ctx.input_size]
48+
grad_input_slices = [slice(0, x,) for x in ctx.input_size]
4949

5050
def narrow_slice(dim, start, length):
5151
grad_input_slices[dim] = (slice(grad_input_slices[dim].start + start,
52-
grad_input_slices[dim].start + start + length + 1))
52+
grad_input_slices[dim].start + start + length))
53+
54+
def slice_length(dim):
55+
return grad_input_slices[dim].stop - grad_input_slices[dim].start
5356

5457
# crop grad_input if necessary
5558
if pad_t < 0:
56-
narrow_slice(2, -pad_t, grad_input.size(2) + pad_t)
59+
narrow_slice(2, -pad_t, slice_length(2) + pad_t)
5760
if pad_b < 0:
58-
narrow_slice(2, 0, grad_input.size(2) + pad_b)
61+
narrow_slice(2, 0, slice_length(2) + pad_b)
5962
if pad_l < 0:
60-
narrow_slice(3, -pad_l, grad_input.size(3) + pad_l)
63+
narrow_slice(3, -pad_l, slice_length(3) + pad_l)
6164
if pad_r < 0:
62-
narrow_slice(3, 0, grad_input.size(3) + pad_r)
65+
narrow_slice(3, 0, slice_length(3) + pad_r)
6366

6467
# crop grad_output if necessary
6568
cg_output = grad_output

0 commit comments

Comments
 (0)