Skip to content

Commit 88b4232

Browse files
Martin Raisonsoumith
authored andcommitted
spcadd, sparseMask, cadd, csub, cmul + tests
1 parent ec260fe commit 88b4232

File tree

21 files changed

+1121
-169
lines changed

21 files changed

+1121
-169
lines changed

test/common.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -128,11 +128,12 @@ def assertEqual(self, x, y, prec=None, message=''):
128128

129129
if torch.is_tensor(x) and torch.is_tensor(y):
130130
def assertTensorsEqual(a, b):
131-
max_err = 0
132131
super(TestCase, self).assertEqual(a.size(), b.size())
133-
for index in iter_indices(a):
134-
max_err = max(max_err, abs(a[index] - b[index]))
135-
self.assertLessEqual(max_err, prec, message)
132+
if a.numel() > 0:
133+
b = b.type_as(a)
134+
b = b.cuda(device=a.get_device()) if a.is_cuda else b.cpu()
135+
max_err = (a - b).abs().max()
136+
self.assertLessEqual(max_err, prec, message)
136137
self.assertEqual(x.is_sparse, y.is_sparse, message)
137138
if x.is_sparse:
138139
assertTensorsEqual(x.indices(), y.indices())
@@ -161,11 +162,12 @@ def assertNotEqual(self, x, y, prec=None, message=''):
161162
y = y.data
162163

163164
if torch.is_tensor(x) and torch.is_tensor(y):
164-
max_err = 0
165165
if x.size() != y.size():
166166
super(TestCase, self).assertNotEqual(x.size(), y.size())
167-
for index in iter_indices(x):
168-
max_err = max(max_err, abs(x[index] - y[index]))
167+
self.assertGreater(x.numel(), 0)
168+
y = y.type_as(x)
169+
y = y.cuda(device=x.get_device()) if x.is_cuda else y.cpu()
170+
max_err = (x - y).abs().max()
169171
self.assertGreaterEqual(max_err, prec, message)
170172
elif type(x) == str and type(y) == str:
171173
super(TestCase, self).assertNotEqual(x, y)

test/test_sparse.py

Lines changed: 120 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
type_triplets = [cpu_triplet]
1616
if torch.cuda.is_available():
1717
cuda_triplet = (
18-
torch.cuda.LongTensor,
18+
torch.cuda.IntTensor,
1919
torch.cuda.DoubleTensor,
2020
torch.cuda.sparse.DoubleTensor)
2121
type_triplets.append(cuda_triplet)
@@ -312,27 +312,33 @@ def test_shape(di, dj, dk):
312312
test_shape(3000, 64, 300)
313313

314314
def _test_spadd_shape(self, shape_i, shape_v=None):
315-
shape = shape_i + (shape_v or [])
316-
x, _, _ = self._gen_sparse(len(shape_i), 10, shape)
317-
y = torch.randn(*shape)
318-
r = random.random()
315+
for is_cuda in [False, True]:
316+
shape = shape_i + (shape_v or [])
317+
x, _, _ = self._gen_sparse(len(shape_i), 10, shape, is_cuda)
318+
y = torch.randn(*shape)
319+
if is_cuda:
320+
y = y.cuda()
321+
r = random.random()
319322

320-
expected = y + r * x.to_dense()
321-
res = torch.add(y, r, x)
323+
expected = y + r * x.to_dense()
324+
res = torch.add(y, r, x)
322325

323-
self.assertEqual(res, expected)
326+
self.assertEqual(res, expected)
324327

325-
# Non contiguous dense tensor
326-
s = list(shape)
327-
s[0] = shape[-1]
328-
s[-1] = shape[0]
329-
y = torch.randn(*s).transpose_(0, len(s) - 1)
330-
r = random.random()
328+
# Non contiguous dense tensor
329+
s = list(shape)
330+
s[0] = shape[-1]
331+
s[-1] = shape[0]
332+
y = torch.randn(*s)
333+
if is_cuda:
334+
y = y.cuda()
335+
y.transpose_(0, len(s) - 1)
336+
r = random.random()
331337

332-
expected = y + r * x.to_dense()
333-
res = torch.add(y, r, x)
338+
expected = y + r * x.to_dense()
339+
res = torch.add(y, r, x)
334340

335-
self.assertEqual(res, expected)
341+
self.assertEqual(res, expected)
336342

337343
def test_spadd(self):
338344
self._test_spadd_shape([5, 6])
@@ -347,49 +353,50 @@ def test_spadd_hybrid(self):
347353
self._test_spadd_shape([5, 5, 5, 5, 5, 5], [2])
348354

349355
def _test_basic_ops_shape(self, shape_i, shape_v=None):
350-
shape = shape_i + (shape_v or [])
351-
x1, _, _ = self._gen_sparse(len(shape_i), 9, shape)
352-
x2, _, _ = self._gen_sparse(len(shape_i), 12, shape)
353-
354-
y1 = x1 + x2
355-
y2 = x1.clone()
356-
y2.add_(x2)
357-
expected = x1.to_dense() + x2.to_dense()
358-
self.assertEqual(y1.to_dense(), expected)
359-
self.assertEqual(y2.to_dense(), expected)
360-
361-
y1 = x1 - x2
362-
y2 = x1.clone()
363-
y2.sub_(x2)
364-
expected = x1.to_dense() - x2.to_dense()
365-
self.assertEqual(y1.to_dense(), expected)
366-
self.assertEqual(y2.to_dense(), expected)
367-
368-
y1 = x1 * x2
369-
y2 = x1.clone()
370-
y2.mul_(x2)
371-
expected = x1.to_dense() * x2.to_dense()
372-
self.assertEqual(y1.to_dense(), expected)
373-
self.assertEqual(y2.to_dense(), expected)
374-
375-
y1 = x1 * 37.5
376-
y2 = x1.clone()
377-
y2.mul_(37.5)
378-
expected = x1.to_dense() * 37.5
379-
self.assertEqual(y1.to_dense(), expected)
380-
self.assertEqual(y2.to_dense(), expected)
381-
382-
y1 = x1 / 37.5
383-
y2 = x1.clone()
384-
y2.div_(37.5)
385-
expected = x1.to_dense() / 37.5
386-
self.assertEqual(y1.to_dense(), expected)
387-
self.assertEqual(y2.to_dense(), expected)
388-
389-
y = x1.clone()
390-
y.zero_()
391-
expected = torch.zeros(x1.size())
392-
self.assertEqual(y.to_dense(), expected)
356+
for is_cuda in [False, True]:
357+
shape = shape_i + (shape_v or [])
358+
x1, _, _ = self._gen_sparse(len(shape_i), 9, shape, is_cuda)
359+
x2, _, _ = self._gen_sparse(len(shape_i), 12, shape, is_cuda)
360+
361+
y1 = x1 + x2
362+
y2 = x1.clone()
363+
y2.add_(x2)
364+
expected = x1.to_dense() + x2.to_dense()
365+
self.assertEqual(y1.to_dense(), expected)
366+
self.assertEqual(y2.to_dense(), expected)
367+
368+
y1 = x1 - x2
369+
y2 = x1.clone()
370+
y2.sub_(x2)
371+
expected = x1.to_dense() - x2.to_dense()
372+
self.assertEqual(y1.to_dense(), expected)
373+
self.assertEqual(y2.to_dense(), expected)
374+
375+
y1 = x1 * x2
376+
y2 = x1.clone()
377+
y2.mul_(x2)
378+
expected = x1.to_dense() * x2.to_dense()
379+
self.assertEqual(y1.to_dense(), expected)
380+
self.assertEqual(y2.to_dense(), expected)
381+
382+
y1 = x1 * 37.5
383+
y2 = x1.clone()
384+
y2.mul_(37.5)
385+
expected = x1.to_dense() * 37.5
386+
self.assertEqual(y1.to_dense(), expected)
387+
self.assertEqual(y2.to_dense(), expected)
388+
389+
y1 = x1 / 37.5
390+
y2 = x1.clone()
391+
y2.div_(37.5)
392+
expected = x1.to_dense() / 37.5
393+
self.assertEqual(y1.to_dense(), expected)
394+
self.assertEqual(y2.to_dense(), expected)
395+
396+
y = x1.clone()
397+
y.zero_()
398+
expected = torch.zeros(x1.size())
399+
self.assertEqual(y.to_dense(), expected)
393400

394401
def test_basic_ops(self):
395402
self._test_basic_ops_shape([5, 6])
@@ -403,6 +410,59 @@ def test_basic_ops_hybrid(self):
403410
self._test_basic_ops_shape([50, 30, 20], [2])
404411
self._test_basic_ops_shape([5, 5, 5, 5, 5, 5], [2])
405412

413+
def _test_sparse_mask_shape(self, shape_i, shape_v=None):
414+
for is_cuda in [False, True]:
415+
shape = shape_i + (shape_v or [])
416+
x1, _, _ = self._gen_sparse(len(shape_i), 9, shape, is_cuda)
417+
x2, _, _ = self._gen_sparse(len(shape_i), 12, shape, is_cuda)
418+
419+
y1 = x1 + x2
420+
y2 = x1.clone()
421+
y2.add_(x2)
422+
expected = x1.to_dense() + x2.to_dense()
423+
self.assertEqual(y1.to_dense(), expected)
424+
self.assertEqual(y2.to_dense(), expected)
425+
426+
def test_sparse_mask(self):
427+
for IndexTensor, ValueTensor, SparseTensor in type_triplets:
428+
i = IndexTensor([
429+
[1, 3, 3, 0, 4],
430+
[2, 1, 1, 2, 3],
431+
])
432+
v = ValueTensor([1, 2, 3, 4, 5])
433+
x = SparseTensor(i, v, torch.Size([5, 4]))
434+
dense = ValueTensor([
435+
[1, 2, 3, 4],
436+
[5, 6, 7, 8],
437+
[9, 10, 11, 12],
438+
[13, 14, 15, 16],
439+
[17, 18, 19, 20],
440+
])
441+
exp_v = ValueTensor([7, 14, 14, 3, 20])
442+
expected = SparseTensor(i, exp_v, torch.Size([5, 4]))
443+
res = dense.sparse_mask(x)
444+
self.assertEqual(res, expected)
445+
446+
def test_sparse_mask_hybrid(self):
447+
for IndexTensor, ValueTensor, SparseTensor in type_triplets:
448+
i = IndexTensor([
449+
[1, 3, 3, 0, 4],
450+
[2, 1, 1, 2, 3],
451+
])
452+
v = ValueTensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6]])
453+
x = SparseTensor(i, v, torch.Size([5, 4, 2]))
454+
dense = ValueTensor([
455+
[[1, 3], [2, 2], [3, 3], [4, 2]],
456+
[[5, 7], [6, 7], [7, 9], [8, 9]],
457+
[[9, 2], [10, 4], [11, 1], [12, 3]],
458+
[[13, 5], [14, 1], [15, 1], [16, 6]],
459+
[[17, 7], [18, 2], [19, 7], [20, 1]],
460+
])
461+
exp_v = ValueTensor([[7, 9], [14, 1], [14, 1], [3, 3], [20, 1]])
462+
expected = SparseTensor(i, exp_v, torch.Size([5, 4, 2]))
463+
res = dense.sparse_mask(x)
464+
self.assertEqual(res, expected)
465+
406466

407467
if __name__ == '__main__':
408468
run_tests()

tools/cwrap/plugins/THPPlugin.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ class THPPlugin(CWrapPlugin):
1919

2020
'THCudaTensor*': Template('((THCPFloatTensor*)$arg)->cdata'),
2121
'THCudaDoubleTensor*': Template('((THCPDoubleTensor*)$arg)->cdata'),
22+
'THCudaIntTensor*': Template('((THCPIntTensor*)$arg)->cdata'),
2223
'THCudaLongTensor*': Template('((THCPLongTensor*)$arg)->cdata'),
2324

2425
'THSFloatTensor*': Template('((THSPFloatTensor*)$arg)->cdata'),
@@ -56,6 +57,7 @@ class THPPlugin(CWrapPlugin):
5657

5758
'THCudaTensor*': Template('(PyObject*)Py_TYPE($arg) == THCPFloatTensorClass'),
5859
'THCudaDoubleTensor*': Template('(PyObject*)Py_TYPE($arg) == THCPDoubleTensorClass'),
60+
'THCudaIntTensor*': Template('(PyObject*)Py_TYPE($arg) == THCPIntTensorClass'),
5961
'THCudaLongTensor*': Template('(PyObject*)Py_TYPE($arg) == THCPLongTensorClass'),
6062

6163
'THSDoubleTensor*': Template('(PyObject*)Py_TYPE($arg) == THSPDoubleTensorClass'),
@@ -86,8 +88,10 @@ class THPPlugin(CWrapPlugin):
8688
RETURN_WRAPPER = {
8789
'THTensor*': Template('return THPTensor_(New)($result);'),
8890
'THSTensor*': Template('return THSPTensor_(New)($result);'),
91+
'THIndexTensor*': Template('return THPIndexTensor_(New)($result);'),
8992
'THLongTensor*': Template('return THPLongTensor_New($result);'),
9093
'THLongStorage*': Template('return THPLongStorage_New($result);'),
94+
'THCudaIntTensor*': Template('return THCPIntTensor_New($result);'),
9195
'THCudaLongTensor*': Template('return THCPLongTensor_New($result);'),
9296
# TODO: make it smarter - it should return python long if result doesn't fit into an int
9397
'long': Template('return PyInt_FromLong($result);'),
@@ -174,6 +178,7 @@ def _allocate(typename, tmpl, cuda_tmpl=None, sparse=False):
174178
'THDoubleTensor*': '" THPModuleStr "DoubleTensor',
175179
'THCudaTensor*': 'torch.cuda.FloatTensor',
176180
'THCudaDoubleTensor*': 'torch.cuda.DoubleTensor',
181+
'THCudaIntTensor*': 'torch.cuda.IntTensor',
177182
'THCudaLongTensor*': 'torch.cuda.LongTensor',
178183
'THSize*': 'torch.Size',
179184
'THStride*': 'tuple',

torch/_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def _cuda(self, device=None, async=False):
5757
with torch.cuda.device(device):
5858
if self.is_sparse:
5959
new_type = getattr(torch.cuda.sparse, self.__class__.__name__)
60-
indices = self.indices().cuda(device, async)
60+
indices = self.indices().cuda(device, async).int()
6161
values = self.values().cuda(device, async)
6262
return new_type(indices, values, self.size())
6363
else:

torch/csrc/generic/SparseTensor.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ static void THSPTensor_(dealloc)(THSPTensor* self)
2626
static PyObject * THSPTensor_(pynew)(PyTypeObject *type, PyObject *args, PyObject *kwargs)
2727
{
2828
#ifdef THC_GENERIC_FILE
29-
#define THPIndexTensor_Check THCPLongTensor_Check
30-
#define THPIndexTensor THCPLongTensor
31-
#define THIndexTensor THCudaLongTensor
29+
#define THPIndexTensor_Check THCPIntTensor_Check
30+
#define THPIndexTensor THCPIntTensor
31+
#define THIndexTensor THCudaIntTensor
3232
#else
3333
#define THPIndexTensor_Check THPLongTensor_Check
3434
#define THPIndexTensor THPLongTensor

torch/csrc/generic/methods/SparseTensor.cwrap

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,11 @@ PyObject * THSPTensor_(size)(PyObject *self, PyObject *args, PyObject *kwargs)
8383
name: indices
8484
defined_if: "IS_CUDA"
8585
sparse: yes
86-
return: THCudaLongTensor*
86+
return: THCudaIntTensor*
8787
arguments:
8888
- THSTensor* self
8989
]]
9090

91-
9291
[[
9392
name: values
9493
sparse: yes
@@ -372,3 +371,13 @@ PyObject * THSPTensor_(size)(PyObject *self, PyObject *args, PyObject *kwargs)
372371
- THTensor* self
373372
- THSTensor* mask
374373
]]
374+
375+
[[
376+
name: getDevice
377+
sparse: yes
378+
python_name: get_device
379+
defined_if: IS_CUDA
380+
return: long
381+
arguments:
382+
- THSTensor* self
383+
]]

0 commit comments

Comments
 (0)