Skip to content

Commit 0b00095

Browse files
csarofeensoumith
authored andcommitted
Split batchnorm eval test into cpu and cuda functions. (pytorch#2273)
1 parent 42328b7 commit 0b00095

File tree

1 file changed

+28
-25
lines changed

1 file changed

+28
-25
lines changed

test/test_nn.py

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1092,6 +1092,13 @@ def expected_output(dim):
10921092
indices.add_(1)
10931093
self.assertRaises(RuntimeError, lambda: output.backward(grad_output))
10941094

1095+
def test_batchnorm_eval(self):
1096+
self._test_batchnorm_eval()
1097+
1098+
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
1099+
def test_batchnorm_eval_cuda(self):
1100+
self._test_batchnorm_eval(torch.cuda.FloatTensor)
1101+
10951102
def test_MaxPool1d_indices(self):
10961103
self._test_maxpool_indices(1)
10971104

@@ -2434,31 +2441,27 @@ def test_batchnorm_raises_error_if_bias_is_not_same_size_as_input(self):
24342441
with self.assertRaises(RuntimeError):
24352442
F.batch_norm(input, running_mean, running_var, bias=Parameter(torch.rand(size)))
24362443

2437-
def test_batchnorm_eval(self):
2438-
types = (torch.FloatTensor,)
2439-
if TEST_CUDA:
2440-
types += (torch.cuda.FloatTensor,)
2441-
for tp in types:
2442-
module = nn.BatchNorm1d(3).type(tp)
2443-
module.eval()
2444-
2445-
data = Variable(torch.rand(4, 3).type(tp), requires_grad=True)
2446-
grad = torch.rand(4, 3).type(tp)
2447-
2448-
# 1st pass
2449-
res1 = module(data)
2450-
res1.backward(grad)
2451-
grad1 = data.grad.data.clone()
2452-
2453-
# 2nd pass
2454-
if data.grad is not None:
2455-
data.grad.data.zero_()
2456-
2457-
res2 = module(data)
2458-
res2.backward(grad)
2459-
grad2 = data.grad.data.clone()
2460-
self.assertEqual(res1, res2)
2461-
self.assertEqual(grad1, grad2)
2444+
def _test_batchnorm_eval(self, test_type=torch.FloatTensor):
2445+
module = nn.BatchNorm1d(3).type(test_type)
2446+
module.eval()
2447+
2448+
data = Variable(torch.rand(4, 3).type(test_type), requires_grad=True)
2449+
grad = torch.rand(4, 3).type(test_type)
2450+
2451+
# 1st pass
2452+
res1 = module(data)
2453+
res1.backward(grad)
2454+
grad1 = data.grad.data.clone()
2455+
2456+
# 2nd pass
2457+
if data.grad is not None:
2458+
data.grad.data.zero_()
2459+
2460+
res2 = module(data)
2461+
res2.backward(grad)
2462+
grad2 = data.grad.data.clone()
2463+
self.assertEqual(res1, res2)
2464+
self.assertEqual(grad1, grad2)
24622465

24632466
def test_pairwise_distance(self):
24642467
input1 = Variable(torch.randn(4, 4), requires_grad=True)

0 commit comments

Comments
 (0)