Skip to content

Commit edd41d8

Browse files
authored
BatchNorm fallback to THNN when eps < CUDNN_BN_MIN_EPSILON (pytorch#1742)
1 parent 352f8b2 commit edd41d8

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

torch/csrc/autograd/functions/batch_normalization.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ namespace torch { namespace autograd {
1717

1818
using thpp::Tensor;
1919

20+
#ifndef CUDNN_BN_MIN_EPSILON
21+
#define CUDNN_BN_MIN_EPSILON 0
22+
#endif
23+
2024
auto BatchNormForward::apply(const variable_list& inputs) -> variable_list {
2125
check_input_variables("BatchNorm", inputs, 3, 1);
2226

@@ -41,7 +45,7 @@ auto BatchNormForward::apply(const variable_list& inputs) -> variable_list {
4145
std::unique_ptr<Tensor> save_std(output->newTensor());
4246
save_std->resizeAs(*running_var);
4347

44-
if (use_cudnn) {
48+
if (use_cudnn && eps >= CUDNN_BN_MIN_EPSILON) {
4549
#ifdef WITH_CUDNN
4650
torch::cudnn::cudnn_batch_norm_forward(
4751
state,
@@ -123,7 +127,7 @@ auto BatchNormBackward::apply(const variable_list& grad_outputs) -> variable_lis
123127
}
124128
}
125129

126-
if (use_cudnn) {
130+
if (use_cudnn && eps >= CUDNN_BN_MIN_EPSILON) {
127131
#ifdef WITH_CUDNN
128132
torch::cudnn::cudnn_batch_norm_backward(
129133
state,

0 commit comments

Comments
 (0)