Skip to content

Commit 31894ca

Browse files
killeentsoumith
authored andcommitted
add support for advanced indexing with less than ndim indexers, ellipsis (pytorch#2144)
1 parent 95ccbf8 commit 31894ca

File tree

3 files changed

+95
-11
lines changed

3 files changed

+95
-11
lines changed

test/test_autograd.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,9 @@ def check_index(x, y, idx):
555555
check_index(x, y, ([slice(None), [2, 3]]))
556556
check_index(x, y, ([[2, 3], slice(None)]))
557557

558+
# advanced indexing, with less dim, or ellipsis
559+
check_index(x, y, ([0], ))
560+
558561
x = torch.arange(1, 49).view(4, 3, 4)
559562
y = Variable(x, requires_grad=True)
560563

@@ -570,6 +573,14 @@ def check_index(x, y, idx):
570573
check_index(x, y, (slice(None), [2, 1], slice(None)))
571574
check_index(x, y, ([2, 1], slice(None), slice(None)))
572575

576+
# advanced indexing, with less dim, or ellipsis
577+
check_index(x, y, ([0], ))
578+
check_index(x, y, ([0], slice(None)))
579+
check_index(x, y, ([0], Ellipsis))
580+
check_index(x, y, ([1, 2], [0, 1]))
581+
check_index(x, y, ([1, 2], [0, 1], Ellipsis))
582+
check_index(x, y, (Ellipsis, [1, 2], [0, 1]))
583+
573584
def test_indexing_duplicates(self):
574585
x = torch.arange(1, 17).view(4, 4)
575586
y = Variable(x, requires_grad=True)
@@ -1458,6 +1469,9 @@ class dont_convert(tuple):
14581469
(Index, (), (torch.rand(S, S, S), dont_convert([slice(None), [0, 3], slice(None)])), 'adv_index_mid'),
14591470
(Index, (), (torch.rand(S, S, S), dont_convert([[0, 3], slice(None), slice(None)])), 'adv_index_beg'),
14601471
(Index, (), (torch.rand(S, S, S), dont_convert([[0, 3], [1, 2], slice(None)])), 'adv_index_comb'),
1472+
(Index, (), (torch.rand(S, S, S), dont_convert([[0, 3], ])), 'adv_index_sub'),
1473+
(Index, (), (torch.rand(S, S, S), dont_convert([[0, 3], slice(None)])), 'adv_index_sub_2'),
1474+
(Index, (), (torch.rand(S, S, S), dont_convert([[0, 3], Ellipsis])), 'adv_index_sub_3'),
14611475
(View, (), (torch.rand(S, S, S), torch.Size([S * S, S]))),
14621476
(Expand, (), ((1, S, 1, S, 1), torch.Size([5, S, 5, S, 5]))),
14631477
(Expand, (), ((S, 1), torch.Size([S, S, S])), 'new_dim'),

test/test_torch.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2762,6 +2762,16 @@ def ri(indices):
27622762
self.assertEqual(strided[rows, columns],
27632763
torch.Tensor([[4, 6], [2, 3]]))
27642764

2765+
# Tests using less than the number of dims, and ellipsis
2766+
2767+
# reference is 1 2
2768+
# 3 4
2769+
# 5 6
2770+
reference = conv_fn(consec((3, 2)))
2771+
self.assertEqual(reference[ri([0, 2]), ], torch.Tensor([[1, 2], [5, 6]]))
2772+
self.assertEqual(reference[ri([1]), ...], torch.Tensor([[3, 4]]))
2773+
self.assertEqual(reference[..., ri([1])], torch.Tensor([[2], [4], [6]]))
2774+
27652775
if TEST_NUMPY:
27662776
# we use numpy to compare against, to verify that our advanced
27672777
# indexing semantics are the same, and also for ease of test
@@ -2864,6 +2874,19 @@ def get_set_tensor(indexed, indexer):
28642874
[[[0, 1], [2, 3]], [[0]], slice(None)],
28652875
[[[2, 1]], [[0, 3], [4, 4]], slice(None)],
28662876
[[[2]], [[0, 3], [4, 1]], slice(None)],
2877+
2878+
# less dim, ellipsis
2879+
[[0, 2], ],
2880+
[[0, 2], slice(None)],
2881+
[[0, 2], Ellipsis],
2882+
[[0, 2], slice(None), Ellipsis],
2883+
[[0, 2], Ellipsis, slice(None)],
2884+
[[0, 2], [1, 3]],
2885+
[[0, 2], [1, 3], Ellipsis],
2886+
[Ellipsis, [1, 3], [2, 3]],
2887+
[Ellipsis, [2, 3, 4]],
2888+
[Ellipsis, slice(None), [2, 3, 4]],
2889+
[slice(None), Ellipsis, [2, 3, 4]],
28672890
]
28682891

28692892
for indexer in indices_to_test:
@@ -2917,6 +2940,25 @@ def get_set_tensor(indexed, indexer):
29172940
[[0], [4], [1, 3, 4], slice(None)],
29182941
[[1], [0, 2, 3], [1], slice(None)],
29192942
[[[1, 2], [1, 2]], [[0, 1], [2, 3]], [[2, 3], [3, 5]], slice(None)],
2943+
2944+
# less dim, ellipsis
2945+
[Ellipsis, [0, 3, 4]],
2946+
[Ellipsis, slice(None), [0, 3, 4]],
2947+
[Ellipsis, slice(None), slice(None), [0, 3, 4]],
2948+
[slice(None), Ellipsis, [0, 3, 4]],
2949+
[slice(None), slice(None), Ellipsis, [0, 3, 4]],
2950+
[slice(None), [0, 2, 3], [1, 3, 4]],
2951+
[slice(None), [0, 2, 3], [1, 3, 4], Ellipsis],
2952+
[Ellipsis, [0, 2, 3], [1, 3, 4], slice(None)],
2953+
[[0], [1, 2, 4]],
2954+
[[0], [1, 2, 4], slice(None)],
2955+
[[0], [1, 2, 4], Ellipsis],
2956+
[[0], [1, 2, 4], Ellipsis, slice(None)],
2957+
[[1], ],
2958+
[[0, 2, 1], [3], [4]],
2959+
[[0, 2, 1], [3], [4], slice(None)],
2960+
[[0, 2, 1], [3], [4], Ellipsis],
2961+
[Ellipsis, [0, 2, 1], [3], [4]],
29202962
]
29212963

29222964
for indexer in indices_to_test:

torch/csrc/generic/Tensor.cpp

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -531,29 +531,31 @@ static bool THPTensor_(_checkAdvancedIndexing)(THPTensor *indexed, PyObject *arg
531531
//
532532
// 1. "Basic Integer Array Indexing" the integer-array indexing strategy
533533
// where we have ndim sequence/LongTensor arguments
534-
// 2. Combining Advanced Indexing with ":", with the limitation that
534+
// 2. Combining Advanced Indexing with ":", or "..." , with the limitation that
535535
// the advanced indexing dimensions must be adjacent, i.e.:
536536
//
537537
// x[:, :, [1,2], [3,4], :] --> valid
538+
// x[[1,2], [3,4]] --> valid
539+
// x[[1,2], [3,4], ...] --> valid
538540
// x[:, [1,2], :, [3,4], :] --> not valid
539541

540542
// Verification, Step #1 -- ndim sequencers
541543
if (THPTensor_(_checkBasicIntegerArrayIndexing)(indexed, arg)) return true;
542544

543545
// Verification, Step #2 -- at least one sequencer, all the rest are
544-
// ':', can be less than ndim indexers, all sequencers adjacent
546+
// ':' and/or a single '...', can be less than ndim indexers, all sequencers
547+
// adjacent
545548

546549
long ndim = THTensor_(nDimension)(LIBRARY_STATE indexed->cdata);
547-
// TODO: should this be == ndim? --> for now, yes, but to support
548-
// other things, no
549-
if (PySequence_Check(arg) && PySequence_Size(arg) == ndim) {
550+
if (PySequence_Check(arg) && PySequence_Size(arg) <= ndim) {
550551
THPObjectPtr fast = THPObjectPtr(PySequence_Fast(arg, NULL));
551552

552553
bool sequenceFound = false;
553-
bool nonColonFound = false;
554+
bool nonColonEllipsisFound = false;
555+
bool ellipsisFound = false;
554556
Py_ssize_t lastSeqDim = -1;
555557

556-
for (Py_ssize_t i = 0; i < ndim; ++i) {
558+
for (Py_ssize_t i = 0; i < PySequence_Fast_GET_SIZE(fast.get()); ++i) {
557559
PyObject *item = PySequence_Fast_GET_ITEM(fast.get(), i);
558560
if (THPIndexTensor_Check(item) || PySequence_Check(item)) {
559561
sequenceFound = true;
@@ -573,17 +575,25 @@ static bool THPTensor_(_checkAdvancedIndexing)(THPTensor *indexed, PyObject *arg
573575
Py_ssize_t start, end, length, step;
574576
if (THPUtils_parseSlice(item, dimSize, &start, &end, &step, &length)) {
575577
if (start != 0 || end != dimSize || step != 1 || length != dimSize) {
576-
nonColonFound = true;
578+
nonColonEllipsisFound = true;
577579
break;
578580
}
579581
}
580582
continue;
581583
}
582-
nonColonFound = true;
584+
if (Py_TYPE(item) == &PyEllipsis_Type) {
585+
if (ellipsisFound) {
586+
// Can't have duplicate ellipsi
587+
return false;
588+
}
589+
ellipsisFound = true;
590+
continue;
591+
}
592+
nonColonEllipsisFound = true;
583593
break;
584594
}
585595

586-
return sequenceFound && (!nonColonFound);
596+
return sequenceFound && (!nonColonEllipsisFound);
587597
}
588598
return false;
589599

@@ -639,6 +649,7 @@ static bool THPTensor_(_convertToTensorIndexers)(
639649
// 1. A LongTensor
640650
// 2. A sequence that can be converted into a LongTensor
641651
// 3. A empty slice object (i.e. ':')
652+
// 4. An Ellipsis (i.e. '...')
642653
//
643654
// This function loops through all of the indexing elements. If we encounter
644655
// a LongTensor, we record the dimension at which it occurs. If we encounter
@@ -658,12 +669,29 @@ static bool THPTensor_(_convertToTensorIndexers)(
658669
std::vector<Py_ssize_t> indexingDims;
659670
std::vector<THPIndexTensor*>indexers;
660671

672+
// The indexing matches advanced indexing requirements. In the case that
673+
// the user has an Ellipsis, and/or less dimensions than are in the
674+
// Tensor being indexed, we "fill in" empty Slices to these dimensions
675+
// so that the the resulting advanced indexing code still works
676+
677+
678+
661679
// The top-level indexer should be a sequence, per the check above
662680
THPObjectPtr fast(PySequence_Fast(index, NULL));
663681
sequenceLength = PySequence_Fast_GET_SIZE(fast.get());
682+
int ellipsisOffset = 0;
664683

665684
for (Py_ssize_t i = 0; i < sequenceLength; ++i) {
666685
PyObject *item = PySequence_Fast_GET_ITEM(fast.get(), i);
686+
687+
// If this is an ellipsis, the all subsequent advanced indexing
688+
// objects "positions" should be shifted, e.g. if we have a 5D Tensor
689+
// x, and then x[..., [2, 3]], then the "position" of [2, 3] is 4
690+
if (Py_TYPE(item) == &PyEllipsis_Type) {
691+
ellipsisOffset = THTensor_(nDimension)(LIBRARY_STATE indexed) - sequenceLength;
692+
continue;
693+
}
694+
667695
if (!PySlice_Check(item)) {
668696
// Returns NULL upon conversion failure
669697
THPIndexTensor *indexer = (THPIndexTensor *)PyObject_CallFunctionObjArgs(
@@ -680,7 +708,7 @@ static bool THPTensor_(_convertToTensorIndexers)(
680708
}
681709
return false;
682710
}
683-
indexingDims.push_back(i);
711+
indexingDims.push_back(i + ellipsisOffset);
684712
indexers.push_back(indexer);
685713
}
686714
}

0 commit comments

Comments
 (0)