Skip to content

Commit b3db52f

Browse files
gchanansoumith
authored andcommitted
Support __neg__, .neg(), and neg_() for Long, Int, Short tensor types.
1 parent d19ee9c commit b3db52f

File tree

3 files changed

+35
-7
lines changed

3 files changed

+35
-7
lines changed

test/test_cuda.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,9 @@ def test_type_conversions_same_gpu(self):
455455
x = torch.randn(5, 5).cuda(1)
456456
self.assertEqual(x.int().get_device(), 1)
457457

458+
def test_neg(self):
459+
TestTorch._test_neg(self, lambda t: t.cuda())
460+
458461
def _test_broadcast(self, input):
459462
if torch.cuda.device_count() < 2:
460463
raise unittest.SkipTest("only one GPU detected")

test/test_torch.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -341,14 +341,33 @@ def test_csub(self):
341341
res_csub.sub_(scalar)
342342
self.assertEqual(res_add, res_csub)
343343

344-
def test_neg(self):
345-
a = torch.randn(100, 90)
346-
zeros = torch.Tensor().resize_as_(a).zero_()
344+
@staticmethod
345+
def _test_neg(self, cast):
346+
float_types = ['torch.DoubleTensor', 'torch.FloatTensor', 'torch.LongTensor']
347+
int_types = ['torch.IntTensor', 'torch.ShortTensor']
348+
349+
for t in float_types + int_types:
350+
if t in float_types:
351+
a = cast(torch.randn(100, 90).type(t))
352+
else:
353+
a = cast(torch.Tensor(100, 90).type(t).random_())
354+
zeros = cast(torch.Tensor().type(t)).resize_as_(a).zero_()
355+
356+
res_add = torch.add(zeros, -1, a)
357+
res_neg = a.clone()
358+
res_neg.neg_()
359+
self.assertEqual(res_neg, res_add)
360+
361+
# test out of place as well
362+
res_neg_out_place = a.clone().neg()
363+
self.assertEqual(res_neg_out_place, res_add)
347364

348-
res_add = torch.add(zeros, -1, a)
349-
res_neg = a.clone()
350-
res_neg.neg_()
351-
self.assertEqual(res_neg, res_add)
365+
# test via __neg__ operator
366+
res_neg_op = -a.clone()
367+
self.assertEqual(res_neg_op, res_add)
368+
369+
def test_neg(self):
370+
self._test_neg(self, lambda t: t)
352371

353372
def test_reciprocal(self):
354373
a = torch.randn(100, 89)

torch/csrc/generic/methods/TensorMath.cwrap

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -979,6 +979,9 @@
979979
name: neg
980980
types:
981981
- floating_point
982+
- Long
983+
- Int
984+
- Short
982985
backends:
983986
- CPU
984987
- CUDA
@@ -998,6 +1001,9 @@
9981001
name: neg_
9991002
types:
10001003
- floating_point
1004+
- Long
1005+
- Int
1006+
- Short
10011007
backends:
10021008
- CPU
10031009
- CUDA

0 commit comments

Comments
 (0)