@@ -1092,6 +1092,13 @@ def expected_output(dim):
10921092 indices .add_ (1 )
10931093 self .assertRaises (RuntimeError , lambda : output .backward (grad_output ))
10941094
1095+ def test_batchnorm_eval (self ):
1096+ self ._test_batchnorm_eval ()
1097+
1098+ @unittest .skipIf (not TEST_CUDA , "CUDA unavailable" )
1099+ def test_batchnorm_eval_cuda (self ):
1100+ self ._test_batchnorm_eval (torch .cuda .FloatTensor )
1101+
10951102 def test_MaxPool1d_indices (self ):
10961103 self ._test_maxpool_indices (1 )
10971104
@@ -2434,31 +2441,27 @@ def test_batchnorm_raises_error_if_bias_is_not_same_size_as_input(self):
24342441 with self .assertRaises (RuntimeError ):
24352442 F .batch_norm (input , running_mean , running_var , bias = Parameter (torch .rand (size )))
24362443
2437- def test_batchnorm_eval (self ):
2438- types = (torch .FloatTensor ,)
2439- if TEST_CUDA :
2440- types += (torch .cuda .FloatTensor ,)
2441- for tp in types :
2442- module = nn .BatchNorm1d (3 ).type (tp )
2443- module .eval ()
2444-
2445- data = Variable (torch .rand (4 , 3 ).type (tp ), requires_grad = True )
2446- grad = torch .rand (4 , 3 ).type (tp )
2447-
2448- # 1st pass
2449- res1 = module (data )
2450- res1 .backward (grad )
2451- grad1 = data .grad .data .clone ()
2452-
2453- # 2nd pass
2454- if data .grad is not None :
2455- data .grad .data .zero_ ()
2456-
2457- res2 = module (data )
2458- res2 .backward (grad )
2459- grad2 = data .grad .data .clone ()
2460- self .assertEqual (res1 , res2 )
2461- self .assertEqual (grad1 , grad2 )
2444+ def _test_batchnorm_eval (self , test_type = torch .FloatTensor ):
2445+ module = nn .BatchNorm1d (3 ).type (test_type )
2446+ module .eval ()
2447+
2448+ data = Variable (torch .rand (4 , 3 ).type (test_type ), requires_grad = True )
2449+ grad = torch .rand (4 , 3 ).type (test_type )
2450+
2451+ # 1st pass
2452+ res1 = module (data )
2453+ res1 .backward (grad )
2454+ grad1 = data .grad .data .clone ()
2455+
2456+ # 2nd pass
2457+ if data .grad is not None :
2458+ data .grad .data .zero_ ()
2459+
2460+ res2 = module (data )
2461+ res2 .backward (grad )
2462+ grad2 = data .grad .data .clone ()
2463+ self .assertEqual (res1 , res2 )
2464+ self .assertEqual (grad1 , grad2 )
24622465
24632466 def test_pairwise_distance (self ):
24642467 input1 = Variable (torch .randn (4 , 4 ), requires_grad = True )
0 commit comments