Skip to content

Commit a45ad7c

Browse files
killeentsoumith
authored andcommitted
Advanced Indexing Part 1 -- Purely Integer Array Indexing
1 parent f09027b commit a45ad7c

File tree

7 files changed

+791
-4
lines changed

7 files changed

+791
-4
lines changed

test/test_autograd.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,10 @@ def check_index(idx):
543543
check_index(torch.LongTensor([0, 2]))
544544
check_index(torch.rand(4, 4).bernoulli().byte())
545545
check_index((Ellipsis, slice(2, None)))
546+
check_index(([0], [0]))
547+
check_index(([1, 2, 3], [0]))
548+
check_index(([1, 2], [2, 1]))
549+
check_index(([[1, 2], [3, 0]], [[0, 1], [2, 3]]))
546550

547551
def test_indexing_duplicates(self):
548552
x = torch.arange(1, 17).view(4, 4)
@@ -555,6 +559,29 @@ def test_indexing_duplicates(self):
555559
expected_grad[i] += 1
556560
self.assertEqual(y.grad.data, expected_grad)
557561

562+
# with advanced indexing
563+
x = torch.arange(1, 17).view(4, 4)
564+
y = Variable(x, requires_grad=True)
565+
566+
idx = [[1, 1, 3, 2, 1, 2], [0]]
567+
y[idx].sum().backward()
568+
expected_grad = torch.zeros(4, 4)
569+
for i in idx[0]:
570+
for j in idx[1]:
571+
expected_grad[i][j] += 1
572+
573+
self.assertEqual(y.grad.data, expected_grad)
574+
575+
x = torch.arange(1, 17).view(4, 4)
576+
y = Variable(x, requires_grad=True)
577+
idx = [[[1, 2], [0, 0]], [[0, 1], [1, 1]]]
578+
y[idx].sum().backward()
579+
expected_grad = torch.Tensor([[0, 2, 0, 0],
580+
[1, 0, 0, 0],
581+
[0, 1, 0, 0],
582+
[0, 0, 0, 0]])
583+
self.assertEqual(y.grad.data, expected_grad)
584+
558585
def test_basic_op_grad_fallback(self):
559586
"""Grad output might need to be reshaped to match the second argument."""
560587
x = Variable(torch.randn(4, 6), requires_grad=True)
@@ -793,8 +820,12 @@ def test_setitem(self):
793820
self._test_setitem((5, 5), 1)
794821
self._test_setitem((5,), 1)
795822
self._test_setitem((1,), 0)
823+
self._test_setitem((10,), [[0, 4, 2]])
824+
self._test_setitem((5, 5), [[0, 4], [2, 2]])
796825
self._test_setitem_tensor((5, 5), 3)
826+
self._test_setitem_tensor((5, 5), [[0, 1], [1, 0]])
797827
self._test_setitem_tensor((5,), 3)
828+
self._test_setitem_tensor((5,), [[0, 1, 2, 3]])
798829

799830
def test_setitem_mask(self):
800831
mask = torch.ByteTensor(5, 5).bernoulli_()
@@ -1345,6 +1376,8 @@ class dont_convert(tuple):
13451376
(Index, (), (torch.rand(S, S, S), dont_convert([1, 2]))),
13461377
(Index, (), (torch.rand(S, S, S), slice(0, 3)), 'slice'),
13471378
(Index, (), (torch.rand(S, S, S), dont_convert([slice(0, 3), 1])), 'slice_index'),
1379+
(Index, (), (torch.rand(S, S, S), dont_convert([[0, 2, 3], [1, 3, 3], [0, 0, 2]])), 'adv_index'),
1380+
(Index, (), (torch.rand(S, S, S), dont_convert([[0, 0, 3], [1, 1, 3], [0, 0, 2]])), 'adv_index_dup'),
13481381
(View, (), (torch.rand(S, S, S), torch.Size([S * S, S]))),
13491382
(Expand, (), ((1, S, 1, S, 1), torch.Size([5, S, 5, S, 5]))),
13501383
(Expand, (), ((S, 1), torch.Size([S, S, S])), 'new_dim'),

test/test_cuda.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -856,6 +856,12 @@ def test_broadcast_fused_matmul(self):
856856
def test_broadcast_batched_matmul(self):
857857
TestTorch._test_broadcast_batched_matmul(self, lambda t: t.cuda())
858858

859+
def test_advancedindex(self):
860+
TestTorch._test_advancedindex(self, lambda t: t.cuda())
861+
862+
def test_advancedindex_big(self):
863+
TestTorch._test_advancedindex_big(self, lambda t: t.cuda())
864+
859865
def test_btrifact(self):
860866
TestTorch._test_btrifact(self, lambda t: t.cuda())
861867

test/test_torch.py

Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2496,6 +2496,251 @@ def test_index(self):
24962496
self.assertRaises(TypeError, lambda: reference[0.0, ..., 0.0:2.0])
24972497
self.assertRaises(TypeError, lambda: reference[0.0, :, 0.0])
24982498

2499+
@staticmethod
2500+
def _test_advancedindex(self, conv_fn):
2501+
# Tests for Integer Array Indexing, Part I - Purely integer array
2502+
# indexing
2503+
2504+
def consec(size, start=1):
2505+
sequence = torch.ones(int(torch.Tensor(size).prod(0)[0])).cumsum(0)
2506+
sequence.add_(start - 1)
2507+
return sequence.view(*size)
2508+
2509+
# pick a random valid indexer type
2510+
def ri(indices):
2511+
choice = random.randint(0, 2)
2512+
if choice == 0:
2513+
return torch.LongTensor(indices)
2514+
elif choice == 1:
2515+
return list(indices)
2516+
else:
2517+
return tuple(indices)
2518+
2519+
# First, we will test indexing to generate return values
2520+
2521+
# Case 1: Purely Integer Array Indexing
2522+
reference = conv_fn(consec((10,)))
2523+
self.assertEqual(reference[ri([0]), ], consec((1,)))
2524+
self.assertEqual(reference[ri([3]), ], consec((1,), 4))
2525+
self.assertEqual(reference[ri([2, 3, 4]), ], consec((3,), 3))
2526+
self.assertEqual(reference[ri([0, 2, 4]), ], torch.Tensor([1, 3, 5]))
2527+
2528+
# setting values
2529+
reference[ri([0],), ] = -1
2530+
self.assertEqual(reference[ri([0]), ], torch.Tensor([-1]))
2531+
reference[ri([2, 3, 4]), ] = 3
2532+
self.assertEqual(reference[ri([2, 3, 4]), ], torch.Tensor([3, 3, 3]))
2533+
reference[ri([0, 2, 4]), ] = conv_fn(torch.Tensor([5, 4, 3]))
2534+
self.assertEqual(reference[ri([0, 2, 4]), ], torch.Tensor([5, 4, 3]))
2535+
2536+
# Tensor with stride != 1
2537+
2538+
# strided is [1, 3, 5, 7]
2539+
reference = conv_fn(consec((10,)))
2540+
strided = conv_fn(torch.Tensor())
2541+
strided.set_(reference.storage(), storage_offset=0,
2542+
size=torch.Size([4]), stride=[2])
2543+
2544+
self.assertEqual(strided[ri([0]), ], torch.Tensor([1]))
2545+
self.assertEqual(strided[ri([3]), ], torch.Tensor([7]))
2546+
self.assertEqual(strided[ri([1, 2]), ], torch.Tensor([3, 5]))
2547+
self.assertEqual(strided[ri([[2, 1], [0, 3]]), ],
2548+
torch.Tensor([[5, 3], [1, 7]]))
2549+
2550+
# stride is [4, 8]
2551+
strided = conv_fn(torch.Tensor())
2552+
strided.set_(reference.storage(), storage_offset=4,
2553+
size=torch.Size([2]), stride=[4])
2554+
self.assertEqual(strided[ri([0]), ], torch.Tensor([5]))
2555+
self.assertEqual(strided[ri([1]), ], torch.Tensor([9]))
2556+
self.assertEqual(strided[ri([0, 1]), ], torch.Tensor([5, 9]))
2557+
self.assertEqual(strided[ri([[0, 1], [1, 0]]), ],
2558+
torch.Tensor([[5, 9], [9, 5]]))
2559+
2560+
# reference is 1 2
2561+
# 3 4
2562+
# 5 6
2563+
reference = conv_fn(consec((3, 2)))
2564+
self.assertEqual(reference[ri([0, 1, 2]), ri([0])], torch.Tensor([1, 3, 5]))
2565+
self.assertEqual(reference[ri([0, 1, 2]), ri([1])], torch.Tensor([2, 4, 6]))
2566+
self.assertEqual(reference[ri([0]), ri([0])], consec((1,)))
2567+
self.assertEqual(reference[ri([2]), ri([1])], consec((1,), 6))
2568+
self.assertEqual(reference[[ri([0, 0]), ri([0, 1])]], torch.Tensor([1, 2]))
2569+
self.assertEqual(reference[[ri([0, 1, 1, 0, 2]), ri([1])]],
2570+
torch.Tensor([2, 4, 4, 2, 6]))
2571+
self.assertEqual(reference[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]],
2572+
torch.Tensor([1, 2, 3, 3]))
2573+
2574+
rows = ri([[0, 0],
2575+
[1, 2]])
2576+
columns = [0],
2577+
self.assertEqual(reference[rows, columns], torch.Tensor([[1, 1],
2578+
[3, 5]]))
2579+
2580+
rows = ri([[0, 0],
2581+
[1, 2]])
2582+
columns = ri([1, 0])
2583+
self.assertEqual(reference[rows, columns], torch.Tensor([[2, 1],
2584+
[4, 5]]))
2585+
rows = ri([[0, 0],
2586+
[1, 2]])
2587+
columns = ri([[0, 1],
2588+
[1, 0]])
2589+
self.assertEqual(reference[rows, columns], torch.Tensor([[1, 2],
2590+
[4, 5]]))
2591+
2592+
# setting values
2593+
reference[ri([0]), ri([1])] = -1
2594+
self.assertEqual(reference[ri([0]), ri([1])], torch.Tensor([-1]))
2595+
reference[ri([0, 1, 2]), ri([0])] = conv_fn(torch.Tensor([-1, 2, -4]))
2596+
self.assertEqual(reference[ri([0, 1, 2]), ri([0])], torch.Tensor([-1,
2597+
2, -4]))
2598+
reference[rows, columns] = conv_fn(torch.Tensor([[4, 6], [2, 3]]))
2599+
self.assertEqual(reference[rows, columns],
2600+
torch.Tensor([[4, 6], [2, 3]]))
2601+
2602+
# Verify still works with Tranposed (i.e. non-contiguous) Tensors
2603+
2604+
reference = conv_fn(torch.Tensor([[0, 1, 2, 3],
2605+
[4, 5, 6, 7],
2606+
[8, 9, 10, 11]])).t_()
2607+
2608+
# Tranposed: [[0, 4, 8],
2609+
# [1, 5, 9],
2610+
# [2, 6, 10],
2611+
# [3, 7, 11]]
2612+
2613+
self.assertEqual(reference[ri([0, 1, 2]), ri([0])], torch.Tensor([0, 1,
2614+
2]))
2615+
self.assertEqual(reference[ri([0, 1, 2]), ri([1])], torch.Tensor([4, 5,
2616+
6]))
2617+
self.assertEqual(reference[ri([0]), ri([0])], torch.Tensor([0]))
2618+
self.assertEqual(reference[ri([2]), ri([1])], torch.Tensor([6]))
2619+
self.assertEqual(reference[[ri([0, 0]), ri([0, 1])]], torch.Tensor([0, 4]))
2620+
self.assertEqual(reference[[ri([0, 1, 1, 0, 3]), ri([1])]],
2621+
torch.Tensor([4, 5, 5, 4, 7]))
2622+
self.assertEqual(reference[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]],
2623+
torch.Tensor([0, 4, 1, 1]))
2624+
2625+
rows = ri([[0, 0],
2626+
[1, 2]])
2627+
columns = [0],
2628+
self.assertEqual(reference[rows, columns], torch.Tensor([[0, 0],
2629+
[1, 2]]))
2630+
2631+
rows = ri([[0, 0],
2632+
[1, 2]])
2633+
columns = ri([1, 0])
2634+
self.assertEqual(reference[rows, columns], torch.Tensor([[4, 0],
2635+
[5, 2]]))
2636+
rows = ri([[0, 0],
2637+
[1, 3]])
2638+
columns = ri([[0, 1],
2639+
[1, 2]])
2640+
self.assertEqual(reference[rows, columns], torch.Tensor([[0, 4],
2641+
[5, 11]]))
2642+
2643+
# setting values
2644+
reference[ri([0]), ri([1])] = -1
2645+
self.assertEqual(reference[ri([0]), ri([1])], torch.Tensor([-1]))
2646+
reference[ri([0, 1, 2]), ri([0])] = conv_fn(torch.Tensor([-1, 2, -4]))
2647+
self.assertEqual(reference[ri([0, 1, 2]), ri([0])], torch.Tensor([-1,
2648+
2, -4]))
2649+
reference[rows, columns] = conv_fn(torch.Tensor([[4, 6], [2, 3]]))
2650+
self.assertEqual(reference[rows, columns],
2651+
torch.Tensor([[4, 6], [2, 3]]))
2652+
2653+
# stride != 1
2654+
2655+
# strided is [[1 3 5 7],
2656+
# [9 11 13 15]]
2657+
2658+
reference = conv_fn(torch.arange(0, 24).view(3, 8))
2659+
strided = conv_fn(torch.Tensor())
2660+
strided.set_(reference.storage(), 1, size=torch.Size([2, 4]),
2661+
stride=[8, 2])
2662+
2663+
self.assertEqual(strided[ri([0, 1]), ri([0])], torch.Tensor([1, 9]))
2664+
self.assertEqual(strided[ri([0, 1]), ri([1])], torch.Tensor([3, 11]))
2665+
self.assertEqual(strided[ri([0]), ri([0])], torch.Tensor([1]))
2666+
self.assertEqual(strided[ri([1]), ri([3])], torch.Tensor([15]))
2667+
self.assertEqual(strided[[ri([0, 0]), ri([0, 3])]], torch.Tensor([1, 7]))
2668+
self.assertEqual(strided[[ri([1]), ri([0, 1, 1, 0, 3])]],
2669+
torch.Tensor([9, 11, 11, 9, 15]))
2670+
self.assertEqual(strided[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]],
2671+
torch.Tensor([1, 3, 9, 9]))
2672+
2673+
rows = ri([[0, 0],
2674+
[1, 1]])
2675+
columns = [0],
2676+
self.assertEqual(strided[rows, columns], torch.Tensor([[1, 1],
2677+
[9, 9]]))
2678+
2679+
rows = ri([[0, 1],
2680+
[1, 0]])
2681+
columns = ri([1, 2])
2682+
self.assertEqual(strided[rows, columns], torch.Tensor([[3, 13],
2683+
[11, 5]]))
2684+
rows = ri([[0, 0],
2685+
[1, 1]])
2686+
columns = ri([[0, 1],
2687+
[1, 2]])
2688+
self.assertEqual(strided[rows, columns], torch.Tensor([[1, 3],
2689+
[11, 13]]))
2690+
2691+
# setting values
2692+
2693+
# strided is [[10, 11],
2694+
# [17, 18]]
2695+
2696+
reference = conv_fn(torch.arange(0, 24).view(3, 8))
2697+
strided = conv_fn(torch.Tensor())
2698+
strided.set_(reference.storage(), 10, size=torch.Size([2, 2]),
2699+
stride=[7, 1])
2700+
self.assertEqual(strided[ri([0]), ri([1])], torch.Tensor([11]))
2701+
strided[ri([0]), ri([1])] = -1
2702+
self.assertEqual(strided[ri([0]), ri([1])], torch.Tensor([-1]))
2703+
2704+
reference = conv_fn(torch.arange(0, 24).view(3, 8))
2705+
strided = conv_fn(torch.Tensor())
2706+
strided.set_(reference.storage(), 10, size=torch.Size([2, 2]),
2707+
stride=[7, 1])
2708+
self.assertEqual(strided[ri([0, 1]), ri([1, 0])], torch.Tensor([11,
2709+
17]))
2710+
strided[ri([0, 1]), ri([1, 0])] = conv_fn(torch.Tensor([-1, 2]))
2711+
self.assertEqual(strided[ri([0, 1]), ri([1, 0])], torch.Tensor([-1,
2712+
2]))
2713+
2714+
reference = conv_fn(torch.arange(0, 24).view(3, 8))
2715+
strided = conv_fn(torch.Tensor())
2716+
strided.set_(reference.storage(), 10, size=torch.Size([2, 2]),
2717+
stride=[7, 1])
2718+
2719+
rows = ri([[0],
2720+
[1]])
2721+
columns = ri([[0, 1],
2722+
[0, 1]])
2723+
self.assertEqual(strided[rows, columns],
2724+
torch.Tensor([[10, 11], [17, 18]]))
2725+
strided[rows, columns] = conv_fn(torch.Tensor([[4, 6], [2, 3]]))
2726+
self.assertEqual(strided[rows, columns],
2727+
torch.Tensor([[4, 6], [2, 3]]))
2728+
2729+
# TODO: error raising tests
2730+
2731+
def test_advancedindex(self):
2732+
self._test_advancedindex(self, lambda x: x)
2733+
2734+
@staticmethod
2735+
def _test_advancedindex_big(self, conv_fn):
2736+
reference = conv_fn(torch.arange(0, 123344).int())
2737+
2738+
self.assertEqual(reference[[0, 123, 44488, 68807, 123343], ],
2739+
torch.LongTensor([0, 123, 44488, 68807, 123343]))
2740+
2741+
def test_advancedindex_big(self):
2742+
self._test_advancedindex_big(self, lambda x: x)
2743+
24992744
def test_newindex(self):
25002745
reference = self._consecutive((3, 3, 3))
25012746
# This relies on __index__() being correct - but we have separate tests for that

torch/autograd/_functions/tensor.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,19 @@ def forward(ctx, i, index):
1414
ctx.input_size = i.size()
1515
ctx.index = index
1616
result = i.index(ctx.index)
17-
ctx.mark_shared_storage((i, result))
17+
ctx.advanced_indexing = i._check_advanced_indexing(index)
18+
if not ctx.advanced_indexing:
19+
ctx.mark_shared_storage((i, result))
1820
return result
1921

2022
@staticmethod
2123
def backward(ctx, grad_output):
22-
grad_input = Variable(grad_output.data.new(ctx.input_size).zero_())
23-
grad_input[ctx.index] = grad_output
24+
grad_input = grad_output.data.new(ctx.input_size).zero_()
25+
grad_input = Variable(grad_input)
26+
if ctx.advanced_indexing:
27+
grad_input._advanced_index_add(ctx.index, grad_output)
28+
else:
29+
grad_input[ctx.index] = grad_output
2430
return grad_input, None
2531

2632

@@ -195,6 +201,29 @@ def backward(ctx, grad_output):
195201
return grad_tensor1, None, None, grad_tensor2, None
196202

197203

204+
class AdvancedIndexAdd(InplaceFunction):
205+
206+
@staticmethod
207+
def forward(ctx, tensor1, adv_index, tensor2):
208+
assert not ctx.needs_input_grad[1]
209+
if ctx.needs_input_grad[2]:
210+
ctx.adv_index = adv_index
211+
ctx.mark_dirty(tensor1)
212+
return tensor1._advanced_index_add(adv_index, tensor2)
213+
214+
@staticmethod
215+
@once_differentiable
216+
def backward(ctx, grad_output):
217+
grad_tensor1 = grad_tensor2 = None
218+
219+
if ctx.needs_input_grad[0]:
220+
grad_tensor1 = grad_output
221+
222+
if ctx.needs_input_grad[2]:
223+
grad_tensor2 = grad_output._advanced_index_select(ctx.adv_index)
224+
return grad_tensor1, None, grad_tensor2
225+
226+
198227
class IndexCopy(InplaceFunction):
199228

200229
@staticmethod

torch/autograd/variable.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -627,6 +627,9 @@ def dist(self, tensor, p=2):
627627
def index_add(self, dim, index, tensor):
628628
return IndexAdd.apply(self, dim, index, tensor)
629629

630+
def _advanced_index_add(self, index, tensor):
631+
return AdvancedIndexAdd.apply(self, index, tensor)
632+
630633
def index_add_(self, dim, index, tensor):
631634
return IndexAdd.apply(self, dim, index, tensor, True)
632635

0 commit comments

Comments
 (0)