@@ -356,14 +356,40 @@ def test_fmod(self):
356356 self .assertEqual (res1 , res2 )
357357
358358 def test_remainder (self ):
359+ # Check the Floating point case
359360 m1 = torch .Tensor (10 , 10 ).uniform_ (- 10. , 10. )
360361 res1 = m1 .clone ()
361- q = 2.1
362- res1 [:, 3 ].remainder_ (q )
363362 res2 = m1 .clone ()
364- for i in range (m1 .size (0 )):
365- res2 [i , 3 ] = res2 [i , 3 ] % q
363+ qs = torch .range (- 5.1 , 4.1 )
364+ # Check the case where the divisor is a simple float
365+ for col_idx , q in enumerate (qs ):
366+ # Reference
367+ for i in range (m1 .size (0 )):
368+ res2 [i , col_idx ] = res2 [i , col_idx ] % q
369+ # To test
370+ res1 [:, col_idx ].remainder_ (q )
366371 self .assertEqual (res1 , res2 )
372+ # Check the case where the divisor is a tensor
373+ res1 = m1 .clone ()
374+ res1 .remainder_ (qs .unsqueeze (0 ).expand_as (res1 ))
375+ self .assertEqual (res1 , res2 )
376+
377+ # Check the LongTensor case
378+ long_m1 = torch .LongTensor (10 , 10 ).random_ (- 10 , 10 )
379+ long_res1 = long_m1 .clone ()
380+ long_res2 = long_m1 .clone ()
381+ long_qs = torch .range (- 5 , 4 ).long ()
382+ long_qs [5 ] = 5 # Can't handle the divisor=0 case
383+ for col_idx , long_q in enumerate (long_qs ):
384+ # Reference
385+ for i in range (long_m1 .size (0 )):
386+ long_res2 [i , col_idx ] = long_res2 [i , col_idx ] % long_q
387+ # To test
388+ long_res1 [:, col_idx ].remainder_ (long_q )
389+ self .assertEqual (long_res1 , long_res2 )
390+ # Divisor is a tensor case
391+ long_res1 = long_m1 .clone ()
392+ long_res1 .remainder_ (long_qs .unsqueeze (0 ).expand_as (long_res1 ))
367393
368394 def test_mm (self ):
369395 # helper function
0 commit comments