Skip to content

Commit b16a352

Browse files
bunelrsoumith
authored andcommitted
Fix remainder and cremainder for integer types
1 parent 4026593 commit b16a352

File tree

2 files changed

+32
-4
lines changed

2 files changed

+32
-4
lines changed

test/test_cuda.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,9 @@ def tmp(t):
208208
('mode', small_3d, lambda t: [1], 'dim'),
209209
('mode', small_3d, lambda t: [-1], 'neg_dim'),
210210
('remainder', small_3d, lambda t: [3], 'value'),
211+
('remainder', small_3d, lambda t: [-3], 'negative_value'),
211212
('remainder', small_3d, lambda t: [small_3d_positive(t)], 'tensor'),
213+
('remainder', small_3d, lambda t: [0 - small_3d_positive(t)], 'negative_tensor'),
212214
('std', small_3d, lambda t: [],),
213215
('std', small_3d, lambda t: [1], 'dim'),
214216
('std', small_3d, lambda t: [-1], 'neg_dim'),

test/test_torch.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -356,14 +356,40 @@ def test_fmod(self):
356356
self.assertEqual(res1, res2)
357357

358358
def test_remainder(self):
359+
# Check the Floating point case
359360
m1 = torch.Tensor(10, 10).uniform_(-10., 10.)
360361
res1 = m1.clone()
361-
q = 2.1
362-
res1[:, 3].remainder_(q)
363362
res2 = m1.clone()
364-
for i in range(m1.size(0)):
365-
res2[i, 3] = res2[i, 3] % q
363+
qs = torch.range(-5.1, 4.1)
364+
# Check the case where the divisor is a simple float
365+
for col_idx, q in enumerate(qs):
366+
# Reference
367+
for i in range(m1.size(0)):
368+
res2[i, col_idx] = res2[i, col_idx] % q
369+
# To test
370+
res1[:, col_idx].remainder_(q)
366371
self.assertEqual(res1, res2)
372+
# Check the case where the divisor is a tensor
373+
res1 = m1.clone()
374+
res1.remainder_(qs.unsqueeze(0).expand_as(res1))
375+
self.assertEqual(res1, res2)
376+
377+
# Check the LongTensor case
378+
long_m1 = torch.LongTensor(10, 10).random_(-10, 10)
379+
long_res1 = long_m1.clone()
380+
long_res2 = long_m1.clone()
381+
long_qs = torch.range(-5, 4).long()
382+
long_qs[5] = 5 # Can't handle the divisor=0 case
383+
for col_idx, long_q in enumerate(long_qs):
384+
# Reference
385+
for i in range(long_m1.size(0)):
386+
long_res2[i, col_idx] = long_res2[i, col_idx] % long_q
387+
# To test
388+
long_res1[:, col_idx].remainder_(long_q)
389+
self.assertEqual(long_res1, long_res2)
390+
# Divisor is a tensor case
391+
long_res1 = long_m1.clone()
392+
long_res1.remainder_(long_qs.unsqueeze(0).expand_as(long_res1))
367393

368394
def test_mm(self):
369395
# helper function

0 commit comments

Comments
 (0)