Skip to content

Commit c4120f3

Browse files
killeentsoumith
authored andcommitted
move to model with cuda indexing tensors for cuda tensor adv indexing
1 parent 8b42308 commit c4120f3

File tree

2 files changed

+49
-72
lines changed

2 files changed

+49
-72
lines changed

test/test_torch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2546,7 +2546,7 @@ def consec(size, start=1):
25462546
def ri(indices):
25472547
choice = random.randint(0, 2)
25482548
if choice == 0:
2549-
return torch.LongTensor(indices)
2549+
return conv_fn(torch.LongTensor(indices))
25502550
elif choice == 1:
25512551
return list(indices)
25522552
else:

torch/csrc/generic/Tensor.cpp

Lines changed: 48 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -423,16 +423,19 @@ static PyObject * THPTensor_(pynew)(PyTypeObject *type, PyObject *args, PyObject
423423
#define THIndexTensor_(NAME) TH_CONCAT_2(THCudaLongTensor_,NAME)
424424
#define THPIndexTensor THCPLongTensor
425425
#define THPIndexTensor_Check THCPLongTensor_Check
426+
#define THPIndexTensorClass THCPLongTensorClass
426427
#elif defined(THD_GENERIC_FILE)
427428
#define THIndexTensor THDLongTensor
428429
#define THIndexTensor_(NAME) TH_CONCAT_2(THDLongTensor_,NAME)
429430
#define THPIndexTensor THDPLongTensor
430431
#define THPIndexTensor_Check THDPLongTensor_Check
432+
#define THPIndexTensorClass THDPLongTensorClass
431433
#else
432434
#define THIndexTensor THLongTensor
433435
#define THIndexTensor_(NAME) TH_CONCAT_2(THLongTensor_,NAME)
434436
#define THPIndexTensor THPLongTensor
435437
#define THPIndexTensor_Check THPLongTensor_Check
438+
#define THPIndexTensorClass THPLongTensorClass
436439
#endif
437440

438441
static bool THPTensor_(_indexOnce)(PyObject *index, int &indexed_dim,
@@ -514,7 +517,7 @@ static bool THPTensor_(_checkBasicIntegerArrayIndexing)(THPTensor *indexed, PyOb
514517
THPObjectPtr fast = THPObjectPtr(PySequence_Fast(arg, NULL));
515518
for (Py_ssize_t i = 0; i < ndim; ++i) {
516519
PyObject *item = PySequence_Fast_GET_ITEM(fast.get(), i);
517-
if (!THPLongTensor_Check(item) && !PySequence_Check(item)) {
520+
if (!THPIndexTensor_Check(item) && !PySequence_Check(item)) {
518521
return false;
519522
}
520523
}
@@ -629,7 +632,7 @@ static bool THPTensor_(_convertToTensorIndexers)(
629632
PyObject *index,
630633
THTensorPtr& indexed,
631634
Py_ssize_t& sequenceLength,
632-
std::unordered_map<Py_ssize_t, THLongTensorPtr>& broadcasted) {
635+
std::unordered_map<Py_ssize_t, THPPointer<THIndexTensor>>& broadcasted) {
633636

634637
// At the top-level, each indexing element must be one of 3 things:
635638
//
@@ -647,13 +650,13 @@ static bool THPTensor_(_convertToTensorIndexers)(
647650
// output map, with the dimension of the original tensor as the key.
648651

649652
// Indexes all indexing Tensors (pre-broadcast) by which dimension they occurred.
650-
// Because we rely upon the THPLongTensor constructor to handle sequence -> tensor
653+
// Because we rely upon the THPIndexTensor constructor to handle sequence -> tensor
651654
// conversions, we store THPTensors rather than THTensors. We use an ordered map
652655
// to maintain the order of Tensors via dimension. Because this is limited to
653656
// ndim(Tensor), it should always be small + fast.
654657

655658
std::vector<Py_ssize_t> indexingDims;
656-
std::vector<THPLongTensor*>indexers;
659+
std::vector<THPIndexTensor*>indexers;
657660

658661
// The top-level indexer should be a sequence, per the check above
659662
THPObjectPtr fast(PySequence_Fast(index, NULL));
@@ -663,16 +666,16 @@ static bool THPTensor_(_convertToTensorIndexers)(
663666
PyObject *item = PySequence_Fast_GET_ITEM(fast.get(), i);
664667
if (!PySlice_Check(item)) {
665668
// Returns NULL upon conversion failure
666-
THPLongTensor *indexer = (THPLongTensor *)PyObject_CallFunctionObjArgs(
667-
THPLongTensorClass, PySequence_Fast_GET_ITEM(fast.get(), i), NULL);
669+
THPIndexTensor *indexer = (THPIndexTensor *)PyObject_CallFunctionObjArgs(
670+
THPIndexTensorClass, PySequence_Fast_GET_ITEM(fast.get(), i), NULL);
668671
if (!indexer) {
669672
PyErr_Format(PyExc_IndexError,
670673
"When performing advanced indexing the indexing objects must be LongTensors or "
671674
"convertible to LongTensors");
672675

673676
// Clean up Indexers
674677
for (auto& idx : indexers) {
675-
THLongTensor_free(idx->cdata);
678+
THIndexTensor_(free)(LIBRARY_STATE idx->cdata);
676679
Py_DECREF(idx);
677680
}
678681
return false;
@@ -684,43 +687,46 @@ static bool THPTensor_(_convertToTensorIndexers)(
684687

685688
// Next, we need to verify that the Tensors are broadcastable. Keep these
686689
// as raw pointer vectors
687-
std::vector<THLongTensor*> maybeBroadcasted;
688-
std::vector<THLongTensor*> candidates;
690+
std::vector<THIndexTensor*> maybeBroadcasted;
691+
std::vector<THIndexTensor*> candidates;
689692

690693
// Extract the underlying Tensors for use in the expansion API call
691694
for (const auto& indexer : indexers) {
692-
maybeBroadcasted.emplace_back(THLongTensor_new());
695+
maybeBroadcasted.emplace_back(THIndexTensor_(new)(LIBRARY_STATE_NOARGS));
693696
// borrow the underlying Tensor from the indexer map
694697
candidates.emplace_back(indexer->cdata);
695698
}
696699

697700
// Broadcast/Expand indexing Tensors as necessary
698701
try {
699-
THLongTensor_expandNd(maybeBroadcasted.data(), candidates.data(), maybeBroadcasted.size());
702+
THIndexTensor_(expandNd)(LIBRARY_STATE maybeBroadcasted.data(), candidates.data(), maybeBroadcasted.size());
700703

701704
// Broadcast succeeded, place Broadcasted Tensors into output map by the index at
702705
// which they occurred, transferring ownership to that map object
703706
for (unsigned int i = 0; i < indexingDims.size(); ++i) {
704-
THLongTensorPtr owned(maybeBroadcasted[i]);
707+
THPPointer<THIndexTensor> owned(maybeBroadcasted[i]);
705708
broadcasted[indexingDims[i]] = std::move(owned);
706709
}
707710

708711
// Next, before doing any further work, we want to verify that all the indices
709-
// are in bounds at each advanced index dimension
712+
// are in bounds at each advanced index dimension. This occurs only on the CPU,
713+
// as point gets on CUDA Tensors would be slow. CUDA out of bounds errors
714+
// will trigger a device-side assert
710715

711-
ptrdiff_t nElement = THLongTensor_nElement(broadcasted.begin()->second.get());
716+
#if !defined(THC_GENERIC_FILE)
717+
ptrdiff_t nElement = THIndexTensor_(nElement)(LIBRARY_STATE broadcasted.begin()->second.get());
712718
THLongStoragePtr viewer(THLongStorage_newWithSize(1));
713719
THLongStorage_set(viewer.get(), 0, nElement);
714720
for (auto& dimBroadcast : broadcasted) {
715721
Py_ssize_t dim = dimBroadcast.first;
716722
long sizeAtDim = THTensor_(size)(LIBRARY_STATE indexed, dim);
717723

718724
// Need to make contiguous to view as 1D :/
719-
THLongTensorPtr contig(THLongTensor_newContiguous(dimBroadcast.second.get()));
725+
THPPointer<THIndexTensor> contig(THIndexTensor_(newContiguous)(LIBRARY_STATE dimBroadcast.second.get()));
720726

721727
// View as 1D + get1D makes me sad :(
722-
THLongTensorPtr flat(THLongTensor_newView(contig.get(), viewer));
723-
for (ptrdiff_t i = 0; i < THLongTensor_nElement(flat.get()); ++i) {
728+
THPPointer<THIndexTensor> flat(THIndexTensor_(newView)(LIBRARY_STATE contig.get(), viewer));
729+
for (ptrdiff_t i = 0; i < THIndexTensor_(nElement)(LIBRARY_STATE flat.get()); ++i) {
724730
long indexAtDim = THTensor_fastGet1d(flat.get(), i);
725731
if (indexAtDim >= sizeAtDim) {
726732
PyErr_Format(PyExc_IndexError, "index %lld from broadcast indexer is out of range "
@@ -729,41 +735,42 @@ static bool THPTensor_(_convertToTensorIndexers)(
729735

730736
// Clean up Indexers
731737
for (auto& idx : indexers) {
732-
THLongTensor_free(idx->cdata);
738+
THIndexTensor_(free)(LIBRARY_STATE idx->cdata);
733739
Py_DECREF(idx);
734740
}
735741

736742
return false;
737743
}
738744
}
739745
}
746+
#endif
740747
} catch (std::exception& e) {
741748
// Broadcasted failed, cleanup and return error. I'm not sure if there is a better
742749
// way to do this where we don't have to manually clean up the memory
743750
for (const auto& tensor : maybeBroadcasted) {
744-
THLongTensor_free(tensor);
751+
THIndexTensor_(free)(LIBRARY_STATE tensor);
745752
}
746753
PyErr_Format(PyExc_IndexError, "The advanced indexing objects could not be broadcast");
747754

748755
// Clean up Indexers
749756
for (auto& idx : indexers) {
750-
THLongTensor_free(idx->cdata);
757+
THIndexTensor_(free)(LIBRARY_STATE idx->cdata);
751758
Py_DECREF(idx);
752759
}
753760
return false;
754761
}
755762

756763
// Clean up Indexers
757764
for (auto& idx : indexers) {
758-
THLongTensor_free(idx->cdata);
765+
THIndexTensor_(free)(LIBRARY_STATE idx->cdata);
759766
Py_DECREF(idx);
760767
}
761768
return true;
762769
}
763770

764771
static inline long THPTensor_(_indexToOffset)(
765772
THTensorPtr& indexed,
766-
std::unordered_map<Py_ssize_t, THLongTensorPtr>& broadcasted,
773+
std::unordered_map<Py_ssize_t, THPPointer<THIndexTensor>>& broadcasted,
767774
ptrdiff_t index)
768775
{
769776
// We need to translate an "index" into a linear offset within the Tensor indexed.
@@ -832,7 +839,7 @@ static inline long THPTensor_(_indexToOffset)(
832839

833840
auto broadcast = broadcasted.find(i);
834841
if (broadcast != broadcasted.end()) {
835-
sizeAtDim = THLongTensor_nElement(broadcast->second.get());
842+
sizeAtDim = THIndexTensor_(nElement)(LIBRARY_STATE broadcast->second.get());
836843
indexAtDim = THTensor_fastGet1d(broadcast->second.get(), index % sizeAtDim);
837844

838845
if (i > 0 && broadcasted.find(i - 1) != broadcasted.end()) {
@@ -860,7 +867,7 @@ static inline long THPTensor_(_indexToOffset)(
860867
static THIndexTensor* THPTensor_(_calculateLinearIndices)(
861868
THTensorPtr& indexed,
862869
Py_ssize_t sequenceLength,
863-
std::unordered_map<Py_ssize_t, THLongTensorPtr>& broadcasted) {
870+
std::unordered_map<Py_ssize_t, THPPointer<THIndexTensor>>& broadcasted) {
864871

865872
// Get the number of indices to generate - this is the product of the size at each dimension,
866873
// that is not part of the advanced indexing, multiplied by the nElement of one of the broadcast
@@ -881,7 +888,7 @@ static THIndexTensor* THPTensor_(_calculateLinearIndices)(
881888
// --> total_size = 50
882889

883890
// TODO: should this be 1? what if there are no things to index? ????
884-
ptrdiff_t indexingElements = THLongTensor_nElement(broadcasted.begin()->second.get());
891+
ptrdiff_t indexingElements = THIndexTensor_(nElement)(LIBRARY_STATE broadcasted.begin()->second.get());
885892
for (Py_ssize_t i = 0; i < THTensor_(nDimension)(LIBRARY_STATE indexed.get()); ++i) {
886893
indexingElements *= broadcasted.find(i) != broadcasted.end() ?
887894
1 : THTensor_(size)(LIBRARY_STATE indexed.get(), i);
@@ -890,18 +897,18 @@ static THIndexTensor* THPTensor_(_calculateLinearIndices)(
890897
// The broadcasted advanced indexing tensor might not be one-dimensional, but we are
891898
// generating a vector of indices, so we need to view the indexer as 1D prior to getting
892899
// the value for the particular dimension.
893-
std::unordered_map<Py_ssize_t, THLongTensorPtr> flattenedBroadcasters;
900+
std::unordered_map<Py_ssize_t, THPPointer<THIndexTensor>> flattenedBroadcasters;
894901
THLongStorage *indexerSize = THLongStorage_newWithSize(1);
895902

896903
// All broadcast Tensors have the same number of elements
897-
ptrdiff_t dimIndexingElements = THLongTensor_nElement(broadcasted.begin()->second.get());
904+
ptrdiff_t dimIndexingElements = THIndexTensor_(nElement)(LIBRARY_STATE broadcasted.begin()->second.get());
898905
THLongStorage_set(indexerSize, 0, dimIndexingElements);
899906

900907
for (auto& broadcast : broadcasted) {
901-
THLongTensor *contig = THLongTensor_newContiguous(broadcast.second.get());
902-
THLongTensorPtr flat(THLongTensor_newView(contig, indexerSize));
908+
THIndexTensor *contig = THIndexTensor_(newContiguous)(LIBRARY_STATE broadcast.second.get());
909+
THPPointer<THIndexTensor> flat(THIndexTensor_(newView)(LIBRARY_STATE contig, indexerSize));
903910
flattenedBroadcasters[broadcast.first] = std::move(flat);
904-
THLongTensor_free(contig);
911+
THIndexTensor_(free)(LIBRARY_STATE contig);
905912
}
906913
THLongStorage_free(indexerSize);
907914

@@ -916,47 +923,17 @@ static THIndexTensor* THPTensor_(_calculateLinearIndices)(
916923
std::vector<THCudaLongTensor *> indexers(
917924
THTensor_(nDimension)(LIBRARY_STATE indexed.get()), NULL);
918925

919-
// Count the number of advanced indexers, and set the pointers to NULL for
920-
// those that are not advanced indexing dims
921-
unsigned int advancedIndexers = 0;
922926
for (int i = 0; i < THTensor_(nDimension)(LIBRARY_STATE indexed.get()); ++i) {
923927
if (flattenedBroadcasters.count(i) > 0) {
924-
++advancedIndexers;
925-
}
926-
}
927-
928-
// Allocate a single buffer to hold all of the indexing elements across all advanced
929-
// indexing dimensions
930-
THCudaLongTensor *broadcastIndicesChunk = THCudaLongTensor_newWithSize1d(
931-
LIBRARY_STATE dimIndexingElements * advancedIndexers);
932-
933-
// Copy the individual broadcast Tensors to the GPU
934-
unsigned int dimsHandled = 0;
935-
for (int i = 0; i < THTensor_(nDimension)(LIBRARY_STATE indexed.get()); ++i) {
936-
if (flattenedBroadcasters.count(i) > 0) {
937-
THCudaLongTensor *view = THCudaLongTensor_newWithStorage1d(
938-
LIBRARY_STATE
939-
THCudaLongTensor_storage(LIBRARY_STATE broadcastIndicesChunk),
940-
dimIndexingElements * dimsHandled,
941-
dimIndexingElements,
942-
1);
943-
THCudaLongTensor_copyAsyncCPU(LIBRARY_STATE view, flattenedBroadcasters[i].get());
944-
indexers[i] = view;
945-
++dimsHandled;
928+
indexers[i] = flattenedBroadcasters[i].get();
946929
}
947930
}
948931

949932
THTensor_(calculateAdvancedIndexingOffsets)(LIBRARY_STATE cudaIndices, indexed, baseOffset, indexers.data());
950933

951-
// Free the indexers
952-
for (auto ptr : indexers) {
953-
if (ptr != NULL) {
954-
THCudaLongTensor_free(LIBRARY_STATE ptr);
955-
}
956-
}
957934
return cudaIndices;
958935
#else
959-
THLongTensor *linearIndices = THLongTensor_newWithSize1d(indexingElements);
936+
THIndexTensor *linearIndices = THIndexTensor_(newWithSize1d)(LIBRARY_STATE indexingElements);
960937
long baseOffset = THTensor_(storageOffset)(LIBRARY_STATE indexed);
961938
for (ptrdiff_t i = 0; i < indexingElements; ++i) {
962939
long linearIdx = THPTensor_(_indexToOffset)(
@@ -970,7 +947,7 @@ static THIndexTensor* THPTensor_(_calculateLinearIndices)(
970947
static bool THPTensor_(_advancedIndexCommonInit)(
971948
PyObject *index,
972949
THTensorPtr &indexed,
973-
std::unordered_map<Py_ssize_t, THLongTensorPtr>& broadcasted,
950+
std::unordered_map<Py_ssize_t, THPPointer<THIndexTensor>>& broadcasted,
974951
THIndexTensor **linearIndices,
975952
THTensor **flattened) {
976953

@@ -1016,7 +993,7 @@ static void THPTensor_(_advancedIndexCommonCleanup)(
1016993

1017994
static bool THPTensor_(_advancedIndexGet)(PyObject *index, THTensorPtr &tresult)
1018995
{
1019-
std::unordered_map<Py_ssize_t, THLongTensorPtr> broadcasted;
996+
std::unordered_map<Py_ssize_t, THPPointer<THIndexTensor>> broadcasted;
1020997
THIndexTensor *linearIndices = NULL;
1021998
THTensor *flattened = NULL;
1022999
bool success = THPTensor_(_advancedIndexCommonInit)(
@@ -1063,15 +1040,15 @@ static bool THPTensor_(_advancedIndexGet)(PyObject *index, THTensorPtr &tresult)
10631040
if (baseDims == 0) {
10641041
auto iter = broadcasted.begin();
10651042
THTensor_(resizeNd)(LIBRARY_STATE result,
1066-
THLongTensor_nDimension(iter->second.get()),
1043+
THIndexTensor_(nDimension)(LIBRARY_STATE iter->second.get()),
10671044
iter->second.get()->size,
10681045
NULL);
10691046
} else {
10701047
// We have at least one dimension that is not part of advanced indexing. This
10711048
// implementation is pretty much shit, there might be a better way of doing this...
1072-
THLongTensor *broadcastShape = broadcasted.begin()->second.get();
1049+
THIndexTensor *broadcastShape = broadcasted.begin()->second.get();
10731050

1074-
int indexedDims = THLongTensor_nDimension(broadcastShape);
1051+
int indexedDims = THIndexTensor_(nDimension)(LIBRARY_STATE broadcastShape);
10751052
THLongStorage *outputShape = THLongStorage_newWithSize(baseDims + indexedDims);
10761053

10771054
int baseDimPtr = 0;
@@ -1085,7 +1062,7 @@ static bool THPTensor_(_advancedIndexGet)(PyObject *index, THTensorPtr &tresult)
10851062
++outputDimPtr;
10861063
} else if (!insertedSubspace) {
10871064
for (int dim = 0; dim < indexedDims; ++dim) {
1088-
outputShape->data[outputDimPtr] = THLongTensor_size(iter->second.get(), dim);
1065+
outputShape->data[outputDimPtr] = THIndexTensor_(size)(LIBRARY_STATE iter->second.get(), dim);
10891066
++outputDimPtr;
10901067
}
10911068
insertedSubspace = true;
@@ -1114,7 +1091,7 @@ static bool THPTensor_(_advancedIndexGet)(PyObject *index, THTensorPtr &tresult)
11141091

11151092
static bool THPTensor_(_advancedIndexSet)(PyObject *index, THTensorPtr &dest, PyObject *src)
11161093
{
1117-
std::unordered_map<Py_ssize_t, THLongTensorPtr> broadcasted;
1094+
std::unordered_map<Py_ssize_t, THPPointer<THIndexTensor>> broadcasted;
11181095
THIndexTensor *linearIndices = NULL;
11191096
THTensor *flattened = NULL;
11201097
bool success = THPTensor_(_advancedIndexCommonInit)(
@@ -1153,7 +1130,7 @@ static bool THPTensor_(_advancedIndexSet)(PyObject *index, THTensorPtr &dest, Py
11531130
}
11541131

11551132
static bool THPTensor_(_advancedIndexAdd)(PyObject *index, THTensorPtr &dest, THTensorPtr &src) {
1156-
std::unordered_map<Py_ssize_t, THLongTensorPtr> broadcasted;
1133+
std::unordered_map<Py_ssize_t, THPPointer<THIndexTensor>> broadcasted;
11571134
THIndexTensor *linearIndices = NULL;
11581135
THTensor *flattened = NULL;
11591136
bool success = THPTensor_(_advancedIndexCommonInit)(
@@ -1178,7 +1155,7 @@ static bool THPTensor_(_advancedIndexAdd)(PyObject *index, THTensorPtr &dest, TH
11781155
}
11791156

11801157
static bool THPTensor_(_advancedIndexSelect)(PyObject *index, THTensorPtr &dest, THTensorPtr &src) {
1181-
std::unordered_map<Py_ssize_t, THLongTensorPtr> broadcasted;
1158+
std::unordered_map<Py_ssize_t, THPPointer<THIndexTensor>> broadcasted;
11821159
THIndexTensor *linearIndices = NULL;
11831160
THTensor *flattened = NULL;
11841161
bool success = THPTensor_(_advancedIndexCommonInit)(

0 commit comments

Comments
 (0)