@@ -2496,6 +2496,251 @@ def test_index(self):
24962496 self .assertRaises (TypeError , lambda : reference [0.0 , ..., 0.0 :2.0 ])
24972497 self .assertRaises (TypeError , lambda : reference [0.0 , :, 0.0 ])
24982498
2499+ @staticmethod
2500+ def _test_advancedindex (self , conv_fn ):
2501+ # Tests for Integer Array Indexing, Part I - Purely integer array
2502+ # indexing
2503+
2504+ def consec (size , start = 1 ):
2505+ sequence = torch .ones (int (torch .Tensor (size ).prod (0 )[0 ])).cumsum (0 )
2506+ sequence .add_ (start - 1 )
2507+ return sequence .view (* size )
2508+
2509+ # pick a random valid indexer type
2510+ def ri (indices ):
2511+ choice = random .randint (0 , 2 )
2512+ if choice == 0 :
2513+ return torch .LongTensor (indices )
2514+ elif choice == 1 :
2515+ return list (indices )
2516+ else :
2517+ return tuple (indices )
2518+
2519+ # First, we will test indexing to generate return values
2520+
2521+ # Case 1: Purely Integer Array Indexing
2522+ reference = conv_fn (consec ((10 ,)))
2523+ self .assertEqual (reference [ri ([0 ]), ], consec ((1 ,)))
2524+ self .assertEqual (reference [ri ([3 ]), ], consec ((1 ,), 4 ))
2525+ self .assertEqual (reference [ri ([2 , 3 , 4 ]), ], consec ((3 ,), 3 ))
2526+ self .assertEqual (reference [ri ([0 , 2 , 4 ]), ], torch .Tensor ([1 , 3 , 5 ]))
2527+
2528+ # setting values
2529+ reference [ri ([0 ],), ] = - 1
2530+ self .assertEqual (reference [ri ([0 ]), ], torch .Tensor ([- 1 ]))
2531+ reference [ri ([2 , 3 , 4 ]), ] = 3
2532+ self .assertEqual (reference [ri ([2 , 3 , 4 ]), ], torch .Tensor ([3 , 3 , 3 ]))
2533+ reference [ri ([0 , 2 , 4 ]), ] = conv_fn (torch .Tensor ([5 , 4 , 3 ]))
2534+ self .assertEqual (reference [ri ([0 , 2 , 4 ]), ], torch .Tensor ([5 , 4 , 3 ]))
2535+
2536+ # Tensor with stride != 1
2537+
2538+ # strided is [1, 3, 5, 7]
2539+ reference = conv_fn (consec ((10 ,)))
2540+ strided = conv_fn (torch .Tensor ())
2541+ strided .set_ (reference .storage (), storage_offset = 0 ,
2542+ size = torch .Size ([4 ]), stride = [2 ])
2543+
2544+ self .assertEqual (strided [ri ([0 ]), ], torch .Tensor ([1 ]))
2545+ self .assertEqual (strided [ri ([3 ]), ], torch .Tensor ([7 ]))
2546+ self .assertEqual (strided [ri ([1 , 2 ]), ], torch .Tensor ([3 , 5 ]))
2547+ self .assertEqual (strided [ri ([[2 , 1 ], [0 , 3 ]]), ],
2548+ torch .Tensor ([[5 , 3 ], [1 , 7 ]]))
2549+
2550+ # stride is [4, 8]
2551+ strided = conv_fn (torch .Tensor ())
2552+ strided .set_ (reference .storage (), storage_offset = 4 ,
2553+ size = torch .Size ([2 ]), stride = [4 ])
2554+ self .assertEqual (strided [ri ([0 ]), ], torch .Tensor ([5 ]))
2555+ self .assertEqual (strided [ri ([1 ]), ], torch .Tensor ([9 ]))
2556+ self .assertEqual (strided [ri ([0 , 1 ]), ], torch .Tensor ([5 , 9 ]))
2557+ self .assertEqual (strided [ri ([[0 , 1 ], [1 , 0 ]]), ],
2558+ torch .Tensor ([[5 , 9 ], [9 , 5 ]]))
2559+
2560+ # reference is 1 2
2561+ # 3 4
2562+ # 5 6
2563+ reference = conv_fn (consec ((3 , 2 )))
2564+ self .assertEqual (reference [ri ([0 , 1 , 2 ]), ri ([0 ])], torch .Tensor ([1 , 3 , 5 ]))
2565+ self .assertEqual (reference [ri ([0 , 1 , 2 ]), ri ([1 ])], torch .Tensor ([2 , 4 , 6 ]))
2566+ self .assertEqual (reference [ri ([0 ]), ri ([0 ])], consec ((1 ,)))
2567+ self .assertEqual (reference [ri ([2 ]), ri ([1 ])], consec ((1 ,), 6 ))
2568+ self .assertEqual (reference [[ri ([0 , 0 ]), ri ([0 , 1 ])]], torch .Tensor ([1 , 2 ]))
2569+ self .assertEqual (reference [[ri ([0 , 1 , 1 , 0 , 2 ]), ri ([1 ])]],
2570+ torch .Tensor ([2 , 4 , 4 , 2 , 6 ]))
2571+ self .assertEqual (reference [[ri ([0 , 0 , 1 , 1 ]), ri ([0 , 1 , 0 , 0 ])]],
2572+ torch .Tensor ([1 , 2 , 3 , 3 ]))
2573+
2574+ rows = ri ([[0 , 0 ],
2575+ [1 , 2 ]])
2576+ columns = [0 ],
2577+ self .assertEqual (reference [rows , columns ], torch .Tensor ([[1 , 1 ],
2578+ [3 , 5 ]]))
2579+
2580+ rows = ri ([[0 , 0 ],
2581+ [1 , 2 ]])
2582+ columns = ri ([1 , 0 ])
2583+ self .assertEqual (reference [rows , columns ], torch .Tensor ([[2 , 1 ],
2584+ [4 , 5 ]]))
2585+ rows = ri ([[0 , 0 ],
2586+ [1 , 2 ]])
2587+ columns = ri ([[0 , 1 ],
2588+ [1 , 0 ]])
2589+ self .assertEqual (reference [rows , columns ], torch .Tensor ([[1 , 2 ],
2590+ [4 , 5 ]]))
2591+
2592+ # setting values
2593+ reference [ri ([0 ]), ri ([1 ])] = - 1
2594+ self .assertEqual (reference [ri ([0 ]), ri ([1 ])], torch .Tensor ([- 1 ]))
2595+ reference [ri ([0 , 1 , 2 ]), ri ([0 ])] = conv_fn (torch .Tensor ([- 1 , 2 , - 4 ]))
2596+ self .assertEqual (reference [ri ([0 , 1 , 2 ]), ri ([0 ])], torch .Tensor ([- 1 ,
2597+ 2 , - 4 ]))
2598+ reference [rows , columns ] = conv_fn (torch .Tensor ([[4 , 6 ], [2 , 3 ]]))
2599+ self .assertEqual (reference [rows , columns ],
2600+ torch .Tensor ([[4 , 6 ], [2 , 3 ]]))
2601+
2602+ # Verify still works with Tranposed (i.e. non-contiguous) Tensors
2603+
2604+ reference = conv_fn (torch .Tensor ([[0 , 1 , 2 , 3 ],
2605+ [4 , 5 , 6 , 7 ],
2606+ [8 , 9 , 10 , 11 ]])).t_ ()
2607+
2608+ # Tranposed: [[0, 4, 8],
2609+ # [1, 5, 9],
2610+ # [2, 6, 10],
2611+ # [3, 7, 11]]
2612+
2613+ self .assertEqual (reference [ri ([0 , 1 , 2 ]), ri ([0 ])], torch .Tensor ([0 , 1 ,
2614+ 2 ]))
2615+ self .assertEqual (reference [ri ([0 , 1 , 2 ]), ri ([1 ])], torch .Tensor ([4 , 5 ,
2616+ 6 ]))
2617+ self .assertEqual (reference [ri ([0 ]), ri ([0 ])], torch .Tensor ([0 ]))
2618+ self .assertEqual (reference [ri ([2 ]), ri ([1 ])], torch .Tensor ([6 ]))
2619+ self .assertEqual (reference [[ri ([0 , 0 ]), ri ([0 , 1 ])]], torch .Tensor ([0 , 4 ]))
2620+ self .assertEqual (reference [[ri ([0 , 1 , 1 , 0 , 3 ]), ri ([1 ])]],
2621+ torch .Tensor ([4 , 5 , 5 , 4 , 7 ]))
2622+ self .assertEqual (reference [[ri ([0 , 0 , 1 , 1 ]), ri ([0 , 1 , 0 , 0 ])]],
2623+ torch .Tensor ([0 , 4 , 1 , 1 ]))
2624+
2625+ rows = ri ([[0 , 0 ],
2626+ [1 , 2 ]])
2627+ columns = [0 ],
2628+ self .assertEqual (reference [rows , columns ], torch .Tensor ([[0 , 0 ],
2629+ [1 , 2 ]]))
2630+
2631+ rows = ri ([[0 , 0 ],
2632+ [1 , 2 ]])
2633+ columns = ri ([1 , 0 ])
2634+ self .assertEqual (reference [rows , columns ], torch .Tensor ([[4 , 0 ],
2635+ [5 , 2 ]]))
2636+ rows = ri ([[0 , 0 ],
2637+ [1 , 3 ]])
2638+ columns = ri ([[0 , 1 ],
2639+ [1 , 2 ]])
2640+ self .assertEqual (reference [rows , columns ], torch .Tensor ([[0 , 4 ],
2641+ [5 , 11 ]]))
2642+
2643+ # setting values
2644+ reference [ri ([0 ]), ri ([1 ])] = - 1
2645+ self .assertEqual (reference [ri ([0 ]), ri ([1 ])], torch .Tensor ([- 1 ]))
2646+ reference [ri ([0 , 1 , 2 ]), ri ([0 ])] = conv_fn (torch .Tensor ([- 1 , 2 , - 4 ]))
2647+ self .assertEqual (reference [ri ([0 , 1 , 2 ]), ri ([0 ])], torch .Tensor ([- 1 ,
2648+ 2 , - 4 ]))
2649+ reference [rows , columns ] = conv_fn (torch .Tensor ([[4 , 6 ], [2 , 3 ]]))
2650+ self .assertEqual (reference [rows , columns ],
2651+ torch .Tensor ([[4 , 6 ], [2 , 3 ]]))
2652+
2653+ # stride != 1
2654+
2655+ # strided is [[1 3 5 7],
2656+ # [9 11 13 15]]
2657+
2658+ reference = conv_fn (torch .arange (0 , 24 ).view (3 , 8 ))
2659+ strided = conv_fn (torch .Tensor ())
2660+ strided .set_ (reference .storage (), 1 , size = torch .Size ([2 , 4 ]),
2661+ stride = [8 , 2 ])
2662+
2663+ self .assertEqual (strided [ri ([0 , 1 ]), ri ([0 ])], torch .Tensor ([1 , 9 ]))
2664+ self .assertEqual (strided [ri ([0 , 1 ]), ri ([1 ])], torch .Tensor ([3 , 11 ]))
2665+ self .assertEqual (strided [ri ([0 ]), ri ([0 ])], torch .Tensor ([1 ]))
2666+ self .assertEqual (strided [ri ([1 ]), ri ([3 ])], torch .Tensor ([15 ]))
2667+ self .assertEqual (strided [[ri ([0 , 0 ]), ri ([0 , 3 ])]], torch .Tensor ([1 , 7 ]))
2668+ self .assertEqual (strided [[ri ([1 ]), ri ([0 , 1 , 1 , 0 , 3 ])]],
2669+ torch .Tensor ([9 , 11 , 11 , 9 , 15 ]))
2670+ self .assertEqual (strided [[ri ([0 , 0 , 1 , 1 ]), ri ([0 , 1 , 0 , 0 ])]],
2671+ torch .Tensor ([1 , 3 , 9 , 9 ]))
2672+
2673+ rows = ri ([[0 , 0 ],
2674+ [1 , 1 ]])
2675+ columns = [0 ],
2676+ self .assertEqual (strided [rows , columns ], torch .Tensor ([[1 , 1 ],
2677+ [9 , 9 ]]))
2678+
2679+ rows = ri ([[0 , 1 ],
2680+ [1 , 0 ]])
2681+ columns = ri ([1 , 2 ])
2682+ self .assertEqual (strided [rows , columns ], torch .Tensor ([[3 , 13 ],
2683+ [11 , 5 ]]))
2684+ rows = ri ([[0 , 0 ],
2685+ [1 , 1 ]])
2686+ columns = ri ([[0 , 1 ],
2687+ [1 , 2 ]])
2688+ self .assertEqual (strided [rows , columns ], torch .Tensor ([[1 , 3 ],
2689+ [11 , 13 ]]))
2690+
2691+ # setting values
2692+
2693+ # strided is [[10, 11],
2694+ # [17, 18]]
2695+
2696+ reference = conv_fn (torch .arange (0 , 24 ).view (3 , 8 ))
2697+ strided = conv_fn (torch .Tensor ())
2698+ strided .set_ (reference .storage (), 10 , size = torch .Size ([2 , 2 ]),
2699+ stride = [7 , 1 ])
2700+ self .assertEqual (strided [ri ([0 ]), ri ([1 ])], torch .Tensor ([11 ]))
2701+ strided [ri ([0 ]), ri ([1 ])] = - 1
2702+ self .assertEqual (strided [ri ([0 ]), ri ([1 ])], torch .Tensor ([- 1 ]))
2703+
2704+ reference = conv_fn (torch .arange (0 , 24 ).view (3 , 8 ))
2705+ strided = conv_fn (torch .Tensor ())
2706+ strided .set_ (reference .storage (), 10 , size = torch .Size ([2 , 2 ]),
2707+ stride = [7 , 1 ])
2708+ self .assertEqual (strided [ri ([0 , 1 ]), ri ([1 , 0 ])], torch .Tensor ([11 ,
2709+ 17 ]))
2710+ strided [ri ([0 , 1 ]), ri ([1 , 0 ])] = conv_fn (torch .Tensor ([- 1 , 2 ]))
2711+ self .assertEqual (strided [ri ([0 , 1 ]), ri ([1 , 0 ])], torch .Tensor ([- 1 ,
2712+ 2 ]))
2713+
2714+ reference = conv_fn (torch .arange (0 , 24 ).view (3 , 8 ))
2715+ strided = conv_fn (torch .Tensor ())
2716+ strided .set_ (reference .storage (), 10 , size = torch .Size ([2 , 2 ]),
2717+ stride = [7 , 1 ])
2718+
2719+ rows = ri ([[0 ],
2720+ [1 ]])
2721+ columns = ri ([[0 , 1 ],
2722+ [0 , 1 ]])
2723+ self .assertEqual (strided [rows , columns ],
2724+ torch .Tensor ([[10 , 11 ], [17 , 18 ]]))
2725+ strided [rows , columns ] = conv_fn (torch .Tensor ([[4 , 6 ], [2 , 3 ]]))
2726+ self .assertEqual (strided [rows , columns ],
2727+ torch .Tensor ([[4 , 6 ], [2 , 3 ]]))
2728+
2729+ # TODO: error raising tests
2730+
2731+ def test_advancedindex (self ):
2732+ self ._test_advancedindex (self , lambda x : x )
2733+
2734+ @staticmethod
2735+ def _test_advancedindex_big (self , conv_fn ):
2736+ reference = conv_fn (torch .arange (0 , 123344 ).int ())
2737+
2738+ self .assertEqual (reference [[0 , 123 , 44488 , 68807 , 123343 ], ],
2739+ torch .LongTensor ([0 , 123 , 44488 , 68807 , 123343 ]))
2740+
2741+ def test_advancedindex_big (self ):
2742+ self ._test_advancedindex_big (self , lambda x : x )
2743+
24992744 def test_newindex (self ):
25002745 reference = self ._consecutive ((3 , 3 , 3 ))
25012746 # This relies on __index__() being correct - but we have separate tests for that
0 commit comments