@@ -45,21 +45,24 @@ def backward(ctx, grad_output):
4545 pad_l , pad_r , pad_t , pad_b = ctx .pad
4646
4747 grad_input = Variable (grad_output .data .new (ctx .input_size ).zero_ ())
48- grad_input_slices = [slice (0 , x + 1 ,) for x in ctx .input_size ]
48+ grad_input_slices = [slice (0 , x ,) for x in ctx .input_size ]
4949
5050 def narrow_slice (dim , start , length ):
5151 grad_input_slices [dim ] = (slice (grad_input_slices [dim ].start + start ,
52- grad_input_slices [dim ].start + start + length + 1 ))
52+ grad_input_slices [dim ].start + start + length ))
53+
54+ def slice_length (dim ):
55+ return grad_input_slices [dim ].stop - grad_input_slices [dim ].start
5356
5457 # crop grad_input if necessary
5558 if pad_t < 0 :
56- narrow_slice (2 , - pad_t , grad_input . size (2 ) + pad_t )
59+ narrow_slice (2 , - pad_t , slice_length (2 ) + pad_t )
5760 if pad_b < 0 :
58- narrow_slice (2 , 0 , grad_input . size (2 ) + pad_b )
61+ narrow_slice (2 , 0 , slice_length (2 ) + pad_b )
5962 if pad_l < 0 :
60- narrow_slice (3 , - pad_l , grad_input . size (3 ) + pad_l )
63+ narrow_slice (3 , - pad_l , slice_length (3 ) + pad_l )
6164 if pad_r < 0 :
62- narrow_slice (3 , 0 , grad_input . size (3 ) + pad_r )
65+ narrow_slice (3 , 0 , slice_length (3 ) + pad_r )
6366
6467 # crop grad_output if necessary
6568 cg_output = grad_output
0 commit comments