Skip to content

Commit f0c7124

Browse files
albanDsoumith
authored andcommitted
Allow support for negative dimension argument for all functions
1 parent e7f5220 commit f0c7124

File tree

13 files changed

+402
-187
lines changed

13 files changed

+402
-187
lines changed

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,10 +154,11 @@ def run(self):
154154
from tools.cwrap.plugins.KwargsPlugin import KwargsPlugin
155155
from tools.cwrap.plugins.NullableArguments import NullableArguments
156156
from tools.cwrap.plugins.CuDNNPlugin import CuDNNPlugin
157+
from tools.cwrap.plugins.WrapDim import WrapDim
157158
thp_plugin = THPPlugin()
158159
cwrap('torch/csrc/generic/TensorMethods.cwrap', plugins=[
159160
BoolOption(), thp_plugin, AutoGPU(condition='IS_CUDA'),
160-
ArgcountSortPlugin(), KwargsPlugin()
161+
ArgcountSortPlugin(), KwargsPlugin(), WrapDim()
161162
])
162163
cwrap('torch/csrc/cudnn/cuDNN.cwrap', plugins=[
163164
CuDNNPlugin(), NullableArguments()

test/test_autograd.py

Lines changed: 138 additions & 110 deletions
Large diffs are not rendered by default.

test/test_cuda.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,12 +155,15 @@ def tmp(t):
155155
('fmod', small_3d, lambda t: [small_3d_positive(t)], 'tensor'),
156156
('chunk', medium_2d, lambda t: [4],),
157157
('chunk', medium_2d, lambda t: [4, 1], 'dim'),
158+
('chunk', medium_2d, lambda t: [4, -2], 'neg_dim'),
158159
('clamp', medium_2d_scaled, lambda t: [-1, 5],),
159160
('clone', medium_2d, lambda t: [],),
160161
('contiguous', medium_2d, lambda t: [],),
161162
('cross', new_t(M, 3, M), lambda t: [new_t(M, 3, M)(t)],),
162163
('cumprod', small_3d, lambda t: [1],),
164+
('cumprod', small_3d, lambda t: [-1], 'neg_dim'),
163165
('cumsum', small_3d, lambda t: [1],),
166+
('cumsum', small_3d, lambda t: [-1], 'neg_dim'),
164167
('dim', small_3d, lambda t: [],),
165168
('dist', small_2d, lambda t: [small_2d(t)],),
166169
('dist', small_2d, lambda t: [small_2d(t), 3], '3_norm'),
@@ -188,52 +191,72 @@ def tmp(t):
188191
# TODO: positive case
189192
('kthvalue', small_3d_unique, lambda t: [3],),
190193
('kthvalue', small_3d_unique, lambda t: [3, 1], 'dim'),
194+
('kthvalue', small_3d_unique, lambda t: [3, -1], 'neg_dim'),
191195
('lerp', small_3d, lambda t: [small_3d(t), 0.3],),
192196
('max', small_3d_unique, lambda t: [],),
193197
('max', small_3d_unique, lambda t: [1], 'dim'),
198+
('max', small_3d_unique, lambda t: [-1], 'neg_dim'),
194199
('max', medium_2d, lambda t: [medium_2d(t)], 'elementwise'),
195200
('min', small_3d_unique, lambda t: [],),
196201
('min', small_3d_unique, lambda t: [1], 'dim'),
202+
('min', small_3d_unique, lambda t: [-1], 'neg_dim'),
197203
('min', medium_2d, lambda t: [medium_2d(t)], 'elementwise'),
198204
('mean', small_3d, lambda t: [],),
205+
('mean', small_3d, lambda t: [-1], 'neg_dim'),
199206
('mean', small_3d, lambda t: [1], 'dim'),
200207
('mode', small_3d, lambda t: [],),
201208
('mode', small_3d, lambda t: [1], 'dim'),
209+
('mode', small_3d, lambda t: [-1], 'neg_dim'),
202210
('remainder', small_3d, lambda t: [3], 'value'),
203211
('remainder', small_3d, lambda t: [small_3d_positive(t)], 'tensor'),
204212
('std', small_3d, lambda t: [],),
205213
('std', small_3d, lambda t: [1], 'dim'),
214+
('std', small_3d, lambda t: [-1], 'neg_dim'),
206215
('var', small_3d, lambda t: [],),
207216
('var', small_3d, lambda t: [1], 'dim'),
217+
('var', small_3d, lambda t: [-1], 'neg_dim'),
208218
('ndimension', small_3d, lambda t: [],),
209219
('nelement', small_3d, lambda t: [],),
210220
('numel', small_3d, lambda t: [],),
211221
('narrow', small_3d, lambda t: [1, 3, 2],),
222+
('narrow', small_3d, lambda t: [-1, 3, 2], 'neg_dim'),
212223
('nonzero', small_3d, lambda t: [],),
213224
('norm', small_3d, lambda t: [],),
214225
('norm', small_3d, lambda t: [3], '3_norm'),
215226
('norm', small_3d, lambda t: [3, 0], '3_norm_dim'),
227+
('norm', small_3d, lambda t: [3, -2], '3_norm_neg_dim'),
216228
('ones', small_3d, lambda t: [1, 2, 3, 4, 5],),
217229
('permute', new_t(1, 2, 3, 4), lambda t: [2, 1, 3, 0],),
218230
('prod', small_2d_oneish, lambda t: [],),
219231
('prod', small_3d, lambda t: [1], 'dim'),
232+
('prod', small_3d, lambda t: [-1], 'neg_dim'),
220233
('sum', small_2d, lambda t: [],),
221234
('sum', small_3d, lambda t: [1], 'dim'),
235+
('sum', small_3d, lambda t: [-1], 'neg_dim'),
222236
('renorm', small_3d, lambda t: [2, 1, 1], '2_norm'),
237+
('renorm', small_3d, lambda t: [2, -1, 1], '2_norm_neg_dim'),
223238
('renorm', small_3d, lambda t: [1.5, 1, 1], '1_5_norm'),
224239
('repeat', small_2d, lambda t: [2, 2, 2],),
225240
('size', new_t(1, 2, 3, 4), lambda t: [],),
241+
('size', new_t(1, 2, 3, 4), lambda t: [1], 'dim'),
242+
('size', new_t(1, 2, 3, 4), lambda t: [-2], 'neg_dim'),
226243
('sort', small_3d_unique, lambda t: [],),
227244
('sort', small_3d_unique, lambda t: [1], 'dim'),
245+
('sort', small_3d_unique, lambda t: [-1], 'neg_dim'),
228246
('sort', small_3d_unique, lambda t: [1, True], 'dim_descending'),
247+
('sort', small_3d_unique, lambda t: [-1, True], 'neg_dim_descending'),
229248
('split', small_3d, lambda t: [2],),
230249
('split', small_3d, lambda t: [2, 1], 'dim'),
250+
('split', small_3d, lambda t: [2, -3], 'neg_dim'),
231251
('squeeze', new_t(1, 2, 1, 4), lambda t: [],),
232252
('squeeze', new_t(1, 2, 1, 4), lambda t: [2], 'dim'),
253+
('squeeze', new_t(1, 2, 1, 4), lambda t: [-2], 'neg_dim'),
233254
('t', new_t(1, 2), lambda t: [],),
234255
('transpose', new_t(1, 2, 3, 4), lambda t: [1, 2],),
256+
('transpose', new_t(1, 2, 3, 4), lambda t: [-1, -2], 'neg_dim'),
235257
('to_list', small_3d, lambda t: [],),
236258
('topk', small_3d, lambda t: [2, 1, False, True], 'dim_sort'),
259+
('topk', small_3d, lambda t: [2, -1, False, True], 'neg_dim_sort'),
237260
('topk', small_3d, lambda t: [2, 1, True, True], 'dim_desc_sort'),
238261
('trace', medium_2d, lambda t: [],),
239262
('tril', medium_2d, lambda t: [],),
@@ -243,6 +266,7 @@ def tmp(t):
243266
('triu', medium_2d, lambda t: [2], 'positive'),
244267
('triu', medium_2d, lambda t: [-2], 'negative'),
245268
('unsqueeze', new_t(2, 3, 4), lambda t: [2],),
269+
('unsqueeze', new_t(2, 3, 4), lambda t: [-2], 'neg_dim'),
246270
('view', small_3d, lambda t: [100, 10],),
247271
('view_as', small_3d, lambda t: [t(100, 10)],),
248272
('zero', small_3d, lambda t: [],),
@@ -467,6 +491,9 @@ def test_scatter_cpu(self):
467491
def test_scatter_cpu_dim(self):
468492
self._test_scatter(torch.randn(4, 4), dim=1)
469493

494+
def test_scatter_cpu_neg_dim(self):
495+
self._test_scatter(torch.randn(4, 4), dim=-2)
496+
470497
def test_scatter_cpu_sizes(self):
471498
self._test_scatter(torch.randn(6, 4), chunk_sizes=(2, 4))
472499

@@ -476,6 +503,9 @@ def test_scatter_gpu(self):
476503
def test_scatter_gpu_dim(self):
477504
self._test_scatter(torch.randn(4, 4).cuda(), dim=1)
478505

506+
def test_scatter_gpu_neg_dim(self):
507+
self._test_scatter(torch.randn(4, 4).cuda(), dim=-2)
508+
479509
def test_scatter_gpu_sizes(self):
480510
self._test_scatter(torch.randn(6, 4).cuda(), chunk_sizes=(2, 4))
481511

test/test_torch.py

Lines changed: 100 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import math
44
import random
5+
import copy
56
import torch
67
import torch.cuda
78
import tempfile
@@ -3132,23 +3133,107 @@ def test_Size(self):
31323133
self.assertIsInstance(x[:-1], torch.Size)
31333134
self.assertIsInstance(x + x, torch.Size)
31343135

3135-
def test_transpose_neg(self):
3136-
x = torch.randn(10, 20, 30)
3137-
ndim = 3
3136+
# Functions to test negative dimension wrapping
3137+
METHOD = 1
3138+
INPLACE_METHOD = 2
3139+
FUNCTIONAL = 4
3140+
DIM_ARG = None
31383141

3139-
for i, j in combinations(range(ndim), 2):
3140-
a = x.transpose(i, j)
3141-
b = x.transpose(i - ndim, j - ndim)
3142-
self.assertEqual(a, b)
31433142

3144-
a = torch.transpose(x, i, j)
3145-
b = torch.transpose(x, i - ndim, j - ndim)
3146-
self.assertEqual(a, b)
3147-
3148-
a = x.clone()
3149-
x.transpose_(i, j)
3150-
x.transpose_(i - ndim, j - ndim)
3151-
self.assertEqual(a, x)
3143+
def make_neg_dim_test(name, tensor_arg, arg_constr, types, extra_dim=0):
3144+
def neg_dim_test(self):
3145+
if isinstance(tensor_arg, list):
3146+
assert METHOD not in types and INPLACE_METHOD not in types
3147+
x = [torch.randn(arg) for arg in tensor_arg]
3148+
ndim = len(tensor_arg[-1])
3149+
else:
3150+
x = torch.randn(*tensor_arg)
3151+
ndim = len(tensor_arg)
3152+
ndim += extra_dim
3153+
3154+
n_dim_to_test = sum(map(lambda e: e is DIM_ARG, arg_constr()))
3155+
3156+
for dims_val in combinations(range(ndim), n_dim_to_test):
3157+
arg = arg_constr()
3158+
arg_neg = copy.deepcopy(arg)
3159+
idx = 0
3160+
for i, v in enumerate(arg):
3161+
if v is DIM_ARG:
3162+
arg[i] = dims_val[idx]
3163+
arg_neg[i] = dims_val[idx] - ndim
3164+
idx += 1
3165+
3166+
if METHOD in types:
3167+
a = getattr(x, name)(*arg)
3168+
b = getattr(x, name)(*arg_neg)
3169+
self.assertEqual(a, b)
3170+
3171+
if INPLACE_METHOD in types:
3172+
a = x.clone()
3173+
getattr(a, name + '_')(*arg)
3174+
b = x.clone()
3175+
getattr(b, name + '_')(*arg_neg)
3176+
self.assertEqual(a, b)
3177+
3178+
if FUNCTIONAL in types:
3179+
a = getattr(torch, name)(x, *arg)
3180+
b = getattr(torch, name)(x, *arg_neg)
3181+
self.assertEqual(a, b)
3182+
3183+
return neg_dim_test
3184+
3185+
3186+
def idx_tensor(size, max_val):
3187+
return torch.LongTensor(*size).random_(0, max_val - 1)
3188+
3189+
neg_dim_tests = [
3190+
('narrow', (10, 20, 30), lambda: [DIM_ARG, 0, 5], [METHOD]),
3191+
('transpose', (10, 20, 30), lambda: [DIM_ARG, DIM_ARG], [METHOD, INPLACE_METHOD, FUNCTIONAL]),
3192+
('size', (10, 20, 30), lambda: [DIM_ARG], [METHOD]),
3193+
('cat', [(2, 3, 4), (2, 3, 4)], lambda: [DIM_ARG], [FUNCTIONAL]),
3194+
('chunk', (10, 20, 30), lambda: [5, DIM_ARG], [METHOD, FUNCTIONAL]),
3195+
('gather', (10, 20), lambda: [DIM_ARG, idx_tensor((10, 20), 10)], [METHOD, FUNCTIONAL]),
3196+
('index_select', (10, 10), lambda: [DIM_ARG, idx_tensor((10,), 10)], [METHOD, FUNCTIONAL]),
3197+
('split', (10, 20), lambda: [5, DIM_ARG], [METHOD, FUNCTIONAL]),
3198+
('squeeze', (10, 1, 20, 1), lambda: [DIM_ARG], [METHOD, INPLACE_METHOD, FUNCTIONAL]),
3199+
('stack', [(2, 3, 4), (2, 3, 4)], lambda: [DIM_ARG], [FUNCTIONAL]),
3200+
('unbind', (2, 3, 4), lambda: [DIM_ARG], [FUNCTIONAL]),
3201+
('unsqueeze', (10, 20), lambda: [DIM_ARG], [METHOD, INPLACE_METHOD, FUNCTIONAL], 1),
3202+
('cumprod', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
3203+
('cumsum', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
3204+
('mean', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
3205+
('median', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
3206+
('mode', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
3207+
('norm', (10, 20), lambda: [2, DIM_ARG], [METHOD, FUNCTIONAL]),
3208+
('prod', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
3209+
('std', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
3210+
('sum', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
3211+
('var', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
3212+
('kthvalue', (10, 20), lambda: [3, DIM_ARG], [METHOD, FUNCTIONAL]),
3213+
('max', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
3214+
('min', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
3215+
('sort', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]),
3216+
('topk', (10, 20), lambda: [5, DIM_ARG], [METHOD, FUNCTIONAL]),
3217+
('renorm', (10, 20), lambda: [2, DIM_ARG, 1], [METHOD, INPLACE_METHOD, FUNCTIONAL]),
3218+
('index_add', (10, 10), lambda: [DIM_ARG, idx_tensor((10,), 10), torch.randn(10, 10)], [INPLACE_METHOD]),
3219+
('index_copy', (10, 10), lambda: [DIM_ARG, idx_tensor((10,), 10), torch.randn(10, 10)], [INPLACE_METHOD]),
3220+
('index_fill', (10, 10), lambda: [DIM_ARG, idx_tensor((10,), 10), 12], [INPLACE_METHOD]),
3221+
('scatter', (10, 10), lambda: [DIM_ARG, idx_tensor((10, 10), 10), torch.randn(10, 10)], [INPLACE_METHOD]),
3222+
('select', (10, 20), lambda: [DIM_ARG, 3], [METHOD]),
3223+
('unfold', (10, 20), lambda: [DIM_ARG, 5, 2], [METHOD]),
3224+
]
3225+
3226+
for decl in neg_dim_tests:
3227+
if len(decl) == 4:
3228+
name, tensor_arg, arg_constr, types = decl
3229+
extra_dim = 0
3230+
elif len(decl) == 5:
3231+
name, tensor_arg, arg_constr, types, extra_dim = decl
3232+
3233+
test_name = 'test_' + name + '_neg_dim'
3234+
3235+
assert not hasattr(TestTorch, test_name), "Duplicated test name: " + test_name
3236+
setattr(TestTorch, test_name, make_neg_dim_test(name, tensor_arg, arg_constr, types, extra_dim))
31523237

31533238
if __name__ == '__main__':
31543239
run_tests()

tools/cwrap/cwrap.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def build_option_args(self, arguments, arg_unpack):
198198
arguments = self.get_assign_args(arguments)
199199
for arg, unpack in zip(arguments, arg_unpack):
200200
if arg['type'] == 'CONSTANT':
201-
call_arg.append(str(arg['name']))
201+
call_arg.append(unpack)
202202
else:
203203
var_name = "arg_" + str(arg.get('assign_name', arg['name']))
204204
res = self.ARG_ASSIGN_TEMPLATE.substitute(

tools/cwrap/plugins/WrapDim.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from . import CWrapPlugin
2+
from string import Template
3+
4+
5+
class WrapDim(CWrapPlugin):
6+
7+
NDIM_TEMPLATE = Template(
8+
"""${arg_tensor}->nDimension""")
9+
10+
CODE_TEMPLATE = Template(
11+
"""THPUtils_assert(${arg_dim} >= -(${ndim}) && ${arg_dim} < (${ndim}),
12+
"dimension out of range (expected to be in range of [%d, %d], but got %d)",
13+
-(${ndim}), (${ndim})-1, ${arg_dim});
14+
if (${arg_dim} < 0) ${arg_dim} += (${ndim});""")
15+
16+
def initialize(self, cwrap):
17+
self.cwrap = cwrap
18+
19+
def process_option_code_template(self, template, option):
20+
new_code = []
21+
for i, arg in enumerate(option['arguments']):
22+
if 'wrap_dim' not in arg:
23+
continue
24+
25+
params = arg.get('wrap_dim').split("+")
26+
arg_tensor = params[0]
27+
28+
arg_tensor = "arg_" + arg_tensor
29+
arg_dim = "arg_" + arg.get('assign_name', arg['name'])
30+
31+
params[0] = self.NDIM_TEMPLATE.substitute(arg_tensor=arg_tensor)
32+
ndim = "+".join(params)
33+
34+
new_code.append(self.CODE_TEMPLATE.substitute(
35+
arg_dim=arg_dim,
36+
ndim=ndim))
37+
new_code.append("")
38+
39+
template = new_code + template
40+
return template

tools/cwrap/plugins/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,4 @@ def process_pre_arg_assign(self, template, option):
6565
from .AutoGPU import AutoGPU
6666
from .CuDNNPlugin import CuDNNPlugin
6767
from .GenericNN import GenericNN
68+
from .WrapDim import WrapDim

torch/_torch_docs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4297,6 +4297,8 @@
42974297
42984298
The returned tensor shares the same underlying data with this tensor.
42994299
4300+
A negative dim value can be used and will correspond to :math:`dim + input.dim() + 1`
4301+
43004302
Args:
43014303
input (Tensor): the input `Tensor`
43024304
dim (int): The index at which to insert the singleton dimension

torch/autograd/_functions/reduce.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,9 @@ def backward(self, grad_output):
6060
return grad_input
6161
else:
6262
input, output = self.saved_tensors
63+
dim = self.dim if self.dim >= 0 else self.dim + input.dim()
6364
zero_mask = input == 0
64-
slice_zero_count = zero_mask.sum(self.dim)
65+
slice_zero_count = zero_mask.sum(dim)
6566
total_zeros = slice_zero_count.sum()
6667
grad_input = grad_output.mul(output).expand_as(input).div(input)
6768
if total_zeros == 0:
@@ -71,17 +72,21 @@ def backward(self, grad_output):
7172
grad_input[some_zeros] = 0
7273

7374
single_zero_idx = slice_zero_count.eq(1).nonzero()
75+
76+
if len(single_zero_idx) == 0:
77+
return grad_input
78+
7479
for idx in single_zero_idx:
7580
idx_tuple = tuple(idx.cpu())
76-
input_idx_tuple = idx_tuple[:self.dim] + (slice(0, None),) + idx_tuple[self.dim + 1:]
81+
input_idx_tuple = idx_tuple[:dim] + (slice(0, None),) + idx_tuple[dim + 1:]
7782

7883
# slice_mask and input_copy are 1D
7984
slice_mask = zero_mask[input_idx_tuple]
8085
input_copy = input[input_idx_tuple].clone()
8186
zero_idx = slice_mask.nonzero()[0, 0]
8287
input_copy[zero_idx] = 1.
8388

84-
grad_idx_tuple = idx_tuple[:self.dim] + (zero_idx,) + idx_tuple[self.dim + 1:]
89+
grad_idx_tuple = idx_tuple[:dim] + (zero_idx,) + idx_tuple[dim + 1:]
8590
grad_input[grad_idx_tuple] = grad_output[idx_tuple] * input_copy.prod()
8691

8792
return grad_input

torch/autograd/variable.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -661,10 +661,12 @@ def transpose(self, dim1, dim2):
661661
return Transpose(dim1, dim2)(self)
662662

663663
def select(self, dim, _index):
664+
dim = dim if dim >= 0 else dim + self.dim()
664665
index = tuple(slice(None, None) for _ in range(dim)) + (_index,)
665666
return Index(index)(self)
666667

667668
def narrow(self, dim, start_index, length):
669+
dim = dim if dim >= 0 else dim + self.dim()
668670
index = tuple(slice(None, None) for _ in range(dim)) + \
669671
(slice(start_index, start_index + length),)
670672

0 commit comments

Comments
 (0)