Skip to content

Commit 2d01f38

Browse files
szagoruykosoumith
authored andcommitted
fallback to nn batchnorm on backward-evaluate (pytorch#589)
1 parent f8d4f98 commit 2d01f38

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

torch/nn/_functions/batchnorm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ def backward(self, grad_output):
5454
if (len(self.needs_input_grad) > 1 and self.needs_input_grad[2]) or self.use_cudnn:
5555
grad_bias = bias.new(bias.size()).zero_()
5656

57-
if self.use_cudnn:
57+
if self.use_cudnn and self.training:
58+
# cudnn does not support backward in evaluate mode
5859
torch._C._cudnn_batch_norm_backward(
5960
input, grad_output, grad_input,
6061
grad_weight, grad_bias, weight,

0 commit comments

Comments
 (0)