Skip to content

Commit 018af47

Browse files
authored
TF convnets perf improvement: fused BN + native bias_add for NCHW. (keras-team#8785)
* TF convnets perf improvement: fused BN + native bias_add for NCHW. * Add docstrings * Skip some theano tests * Fix typo
1 parent 3897383 commit 018af47

File tree

2 files changed

+181
-32
lines changed

2 files changed

+181
-32
lines changed

keras/backend/tensorflow_backend.py

Lines changed: 128 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1684,9 +1684,9 @@ def cos(x):
16841684
return tf.cos(x)
16851685

16861686

1687-
def normalize_batch_in_training(x, gamma, beta,
1688-
reduction_axes, epsilon=1e-3):
1689-
"""Computes mean and std for batch then apply batch_normalization on batch.
1687+
def _regular_normalize_batch_in_training(x, gamma, beta,
1688+
reduction_axes, epsilon=1e-3):
1689+
"""Non-fused version of `normalize_batch_in_training`.
16901690
16911691
# Arguments
16921692
x: Input tensor or variable.
@@ -1701,36 +1701,131 @@ def normalize_batch_in_training(x, gamma, beta,
17011701
"""
17021702
mean, var = tf.nn.moments(x, reduction_axes,
17031703
shift=None, name=None, keep_dims=False)
1704-
if sorted(reduction_axes) == list(range(ndim(x)))[:-1]:
1705-
normed = tf.nn.batch_normalization(x, mean, var,
1706-
beta, gamma,
1707-
epsilon)
1708-
else:
1709-
# need broadcasting
1710-
target_shape = []
1711-
for axis in range(ndim(x)):
1712-
if axis in reduction_axes:
1713-
target_shape.append(1)
1714-
else:
1715-
target_shape.append(tf.shape(x)[axis])
1716-
target_shape = tf.stack(target_shape)
1704+
normed = tf.nn.batch_normalization(x, mean, var,
1705+
beta, gamma,
1706+
epsilon)
1707+
return normed, mean, var
17171708

1718-
broadcast_mean = tf.reshape(mean, target_shape)
1719-
broadcast_var = tf.reshape(var, target_shape)
1720-
if gamma is None:
1721-
broadcast_gamma = None
1722-
else:
1723-
broadcast_gamma = tf.reshape(gamma, target_shape)
1724-
if beta is None:
1725-
broadcast_beta = None
1709+
1710+
def _broadcast_normalize_batch_in_training(x, gamma, beta,
1711+
reduction_axes, epsilon=1e-3):
1712+
"""Non-fused, broadcast version of `normalize_batch_in_training`.
1713+
1714+
# Arguments
1715+
x: Input tensor or variable.
1716+
gamma: Tensor by which to scale the input.
1717+
beta: Tensor with which to center the input.
1718+
reduction_axes: iterable of integers,
1719+
axes over which to normalize.
1720+
epsilon: Fuzz factor.
1721+
1722+
# Returns
1723+
A tuple length of 3, `(normalized_tensor, mean, variance)`.
1724+
"""
1725+
mean, var = tf.nn.moments(x, reduction_axes,
1726+
shift=None, name=None, keep_dims=False)
1727+
target_shape = []
1728+
for axis in range(ndim(x)):
1729+
if axis in reduction_axes:
1730+
target_shape.append(1)
17261731
else:
1727-
broadcast_beta = tf.reshape(beta, target_shape)
1728-
normed = tf.nn.batch_normalization(x, broadcast_mean, broadcast_var,
1729-
broadcast_beta, broadcast_gamma,
1730-
epsilon)
1732+
target_shape.append(tf.shape(x)[axis])
1733+
target_shape = tf.stack(target_shape)
1734+
1735+
broadcast_mean = tf.reshape(mean, target_shape)
1736+
broadcast_var = tf.reshape(var, target_shape)
1737+
if gamma is None:
1738+
broadcast_gamma = None
1739+
else:
1740+
broadcast_gamma = tf.reshape(gamma, target_shape)
1741+
if beta is None:
1742+
broadcast_beta = None
1743+
else:
1744+
broadcast_beta = tf.reshape(beta, target_shape)
1745+
1746+
normed = tf.nn.batch_normalization(
1747+
x,
1748+
broadcast_mean,
1749+
broadcast_var,
1750+
broadcast_beta,
1751+
broadcast_gamma,
1752+
epsilon)
17311753
return normed, mean, var
17321754

17331755

1756+
def _fused_normalize_batch_in_training(x, gamma, beta, reduction_axes,
1757+
epsilon=1e-3):
1758+
"""Fused version of `normalize_batch_in_training`.
1759+
1760+
# Arguments
1761+
x: Input tensor or variable.
1762+
gamma: Tensor by which to scale the input.
1763+
beta: Tensor with which to center the input.
1764+
reduction_axes: iterable of integers,
1765+
axes over which to normalize.
1766+
epsilon: Fuzz factor.
1767+
1768+
# Returns
1769+
A tuple length of 3, `(normalized_tensor, mean, variance)`.
1770+
"""
1771+
if list(reduction_axes) == [0, 1, 2]:
1772+
normalization_axis = 3
1773+
tf_data_format = 'NHWC'
1774+
else:
1775+
normalization_axis = 1
1776+
tf_data_format = 'NCHW'
1777+
1778+
if gamma is None:
1779+
gamma = tf.constant(1.0,
1780+
dtype=x.dtype,
1781+
shape=[x.get_shape()[normalization_axis]])
1782+
if beta is None:
1783+
beta = tf.constant(0.0,
1784+
dtype=x.dtype,
1785+
shape=[x.get_shape()[normalization_axis]])
1786+
1787+
return tf.nn.fused_batch_norm(
1788+
x,
1789+
gamma,
1790+
beta,
1791+
epsilon=epsilon,
1792+
data_format=tf_data_format)
1793+
1794+
1795+
def normalize_batch_in_training(x, gamma, beta,
1796+
reduction_axes, epsilon=1e-3):
1797+
"""Computes mean and std for batch then apply batch_normalization on batch.
1798+
1799+
# Arguments
1800+
x: Input tensor or variable.
1801+
gamma: Tensor by which to scale the input.
1802+
beta: Tensor with which to center the input.
1803+
reduction_axes: iterable of integers,
1804+
axes over which to normalize.
1805+
epsilon: Fuzz factor.
1806+
1807+
# Returns
1808+
A tuple length of 3, `(normalized_tensor, mean, variance)`.
1809+
"""
1810+
if ndim(x) == 4 and list(reduction_axes) in [[0, 1, 2], [0, 2, 3]]:
1811+
if not _has_nchw_support() and list(reduction_axes) == [0, 2, 3]:
1812+
return _broadcast_normalize_batch_in_training(x, gamma, beta,
1813+
reduction_axes,
1814+
epsilon=epsilon)
1815+
return _fused_normalize_batch_in_training(
1816+
x, gamma, beta, reduction_axes,
1817+
epsilon=epsilon)
1818+
else:
1819+
if sorted(reduction_axes) == list(range(ndim(x)))[:-1]:
1820+
return _regular_normalize_batch_in_training(x, gamma, beta,
1821+
reduction_axes,
1822+
epsilon=epsilon)
1823+
else:
1824+
return _broadcast_normalize_batch_in_training(x, gamma, beta,
1825+
reduction_axes,
1826+
epsilon=epsilon)
1827+
1828+
17341829
def batch_normalization(x, mean, var, beta, gamma, epsilon=1e-3):
17351830
"""Applies batch normalization on x given mean, var, beta and gamma.
17361831
@@ -3573,7 +3668,11 @@ def bias_add(x, bias, data_format=None):
35733668
elif ndim(x) == 4:
35743669
if data_format == 'channels_first':
35753670
if len(bias_shape) == 1:
3576-
x += reshape(bias, (1, bias_shape[0], 1, 1))
3671+
if _has_nchw_support():
3672+
x = tf.nn.bias_add(x, bias,
3673+
data_format='NCHW')
3674+
else:
3675+
x += reshape(bias, (1, bias_shape[0], 1, 1))
35773676
else:
35783677
x += reshape(bias, (1, bias_shape[2]) + bias_shape[:2])
35793678
elif data_format == 'channels_last':

tests/keras/layers/normalization_test.py

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from numpy.testing import assert_allclose
44

55
from keras.layers import Input
6+
from keras import regularizers
67
from keras.utils.test_utils import layer_test, keras_test
78
from keras.layers import normalization
89
from keras.models import Sequential, Model
@@ -16,23 +17,35 @@
1617

1718
@keras_test
1819
def test_basic_batchnorm():
19-
from keras import regularizers
2020
layer_test(normalization.BatchNormalization,
2121
kwargs={'momentum': 0.9,
2222
'epsilon': 0.1,
2323
'gamma_regularizer': regularizers.l2(0.01),
2424
'beta_regularizer': regularizers.l2(0.01)},
2525
input_shape=(3, 4, 2))
26+
layer_test(normalization.BatchNormalization,
27+
kwargs={'momentum': 0.9,
28+
'epsilon': 0.1,
29+
'axis': 1},
30+
input_shape=(3, 4, 2))
2631
layer_test(normalization.BatchNormalization,
2732
kwargs={'gamma_initializer': 'ones',
2833
'beta_initializer': 'ones',
2934
'moving_mean_initializer': 'zeros',
3035
'moving_variance_initializer': 'ones'},
31-
input_shape=(3, 4, 2))
36+
input_shape=(3, 4, 2, 4))
37+
if K.backend() != 'theano':
38+
layer_test(normalization.BatchNormalization,
39+
kwargs={'momentum': 0.9,
40+
'epsilon': 0.1,
41+
'axis': 1,
42+
'scale': False,
43+
'center': False},
44+
input_shape=(3, 4, 2, 4))
3245

3346

3447
@keras_test
35-
def test_batchnorm_correctness():
48+
def test_batchnorm_correctness_1d():
3649
model = Sequential()
3750
norm = normalization.BatchNormalization(input_shape=(10,), momentum=0.8)
3851
model.add(norm)
@@ -49,6 +62,24 @@ def test_batchnorm_correctness():
4962
assert_allclose(out.std(), 1.0, atol=1e-1)
5063

5164

65+
@keras_test
66+
def test_batchnorm_correctness_2d():
67+
model = Sequential()
68+
norm = normalization.BatchNormalization(axis=1, input_shape=(10, 6), momentum=0.8)
69+
model.add(norm)
70+
model.compile(loss='mse', optimizer='sgd')
71+
72+
# centered on 5.0, variance 10.0
73+
x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 10, 6))
74+
model.fit(x, x, epochs=4, verbose=0)
75+
out = model.predict(x)
76+
out -= np.reshape(K.eval(norm.beta), (1, 10, 1))
77+
out /= np.reshape(K.eval(norm.gamma), (1, 10, 1))
78+
79+
assert_allclose(out.mean(axis=(0, 2)), 0.0, atol=1e-1)
80+
assert_allclose(out.std(axis=(0, 2)), 1.0, atol=1e-1)
81+
82+
5283
@keras_test
5384
def test_batchnorm_training_argument():
5485
bn1 = normalization.BatchNormalization(input_shape=(10,))
@@ -106,6 +137,25 @@ def test_batchnorm_convnet():
106137
assert_allclose(np.std(out, axis=(0, 2, 3)), 1.0, atol=1e-1)
107138

108139

140+
@keras_test
141+
@pytest.mark.skipif((K.backend() == 'theano'),
142+
reason='Bug with theano backend')
143+
def test_batchnorm_convnet_no_center_no_scale():
144+
model = Sequential()
145+
norm = normalization.BatchNormalization(axis=-1, center=False, scale=False,
146+
input_shape=(3, 4, 4), momentum=0.8)
147+
model.add(norm)
148+
model.compile(loss='mse', optimizer='sgd')
149+
150+
# centered on 5.0, variance 10.0
151+
x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 3, 4, 4))
152+
model.fit(x, x, epochs=4, verbose=0)
153+
out = model.predict(x)
154+
155+
assert_allclose(np.mean(out, axis=(0, 2, 3)), 0.0, atol=1e-1)
156+
assert_allclose(np.std(out, axis=(0, 2, 3)), 1.0, atol=1e-1)
157+
158+
109159
@keras_test
110160
def test_shared_batchnorm():
111161
'''Test that a BN layer can be shared

0 commit comments

Comments
 (0)