@@ -1684,9 +1684,9 @@ def cos(x):
16841684 return tf .cos (x )
16851685
16861686
1687- def normalize_batch_in_training (x , gamma , beta ,
1688- reduction_axes , epsilon = 1e-3 ):
1689- """Computes mean and std for batch then apply batch_normalization on batch .
1687+ def _regular_normalize_batch_in_training (x , gamma , beta ,
1688+ reduction_axes , epsilon = 1e-3 ):
1689+ """Non-fused version of `normalize_batch_in_training` .
16901690
16911691 # Arguments
16921692 x: Input tensor or variable.
@@ -1701,36 +1701,131 @@ def normalize_batch_in_training(x, gamma, beta,
17011701 """
17021702 mean , var = tf .nn .moments (x , reduction_axes ,
17031703 shift = None , name = None , keep_dims = False )
1704- if sorted (reduction_axes ) == list (range (ndim (x )))[:- 1 ]:
1705- normed = tf .nn .batch_normalization (x , mean , var ,
1706- beta , gamma ,
1707- epsilon )
1708- else :
1709- # need broadcasting
1710- target_shape = []
1711- for axis in range (ndim (x )):
1712- if axis in reduction_axes :
1713- target_shape .append (1 )
1714- else :
1715- target_shape .append (tf .shape (x )[axis ])
1716- target_shape = tf .stack (target_shape )
1704+ normed = tf .nn .batch_normalization (x , mean , var ,
1705+ beta , gamma ,
1706+ epsilon )
1707+ return normed , mean , var
17171708
1718- broadcast_mean = tf .reshape (mean , target_shape )
1719- broadcast_var = tf .reshape (var , target_shape )
1720- if gamma is None :
1721- broadcast_gamma = None
1722- else :
1723- broadcast_gamma = tf .reshape (gamma , target_shape )
1724- if beta is None :
1725- broadcast_beta = None
1709+
1710+ def _broadcast_normalize_batch_in_training (x , gamma , beta ,
1711+ reduction_axes , epsilon = 1e-3 ):
1712+ """Non-fused, broadcast version of `normalize_batch_in_training`.
1713+
1714+ # Arguments
1715+ x: Input tensor or variable.
1716+ gamma: Tensor by which to scale the input.
1717+ beta: Tensor with which to center the input.
1718+ reduction_axes: iterable of integers,
1719+ axes over which to normalize.
1720+ epsilon: Fuzz factor.
1721+
1722+ # Returns
1723+ A tuple length of 3, `(normalized_tensor, mean, variance)`.
1724+ """
1725+ mean , var = tf .nn .moments (x , reduction_axes ,
1726+ shift = None , name = None , keep_dims = False )
1727+ target_shape = []
1728+ for axis in range (ndim (x )):
1729+ if axis in reduction_axes :
1730+ target_shape .append (1 )
17261731 else :
1727- broadcast_beta = tf .reshape (beta , target_shape )
1728- normed = tf .nn .batch_normalization (x , broadcast_mean , broadcast_var ,
1729- broadcast_beta , broadcast_gamma ,
1730- epsilon )
1732+ target_shape .append (tf .shape (x )[axis ])
1733+ target_shape = tf .stack (target_shape )
1734+
1735+ broadcast_mean = tf .reshape (mean , target_shape )
1736+ broadcast_var = tf .reshape (var , target_shape )
1737+ if gamma is None :
1738+ broadcast_gamma = None
1739+ else :
1740+ broadcast_gamma = tf .reshape (gamma , target_shape )
1741+ if beta is None :
1742+ broadcast_beta = None
1743+ else :
1744+ broadcast_beta = tf .reshape (beta , target_shape )
1745+
1746+ normed = tf .nn .batch_normalization (
1747+ x ,
1748+ broadcast_mean ,
1749+ broadcast_var ,
1750+ broadcast_beta ,
1751+ broadcast_gamma ,
1752+ epsilon )
17311753 return normed , mean , var
17321754
17331755
1756+ def _fused_normalize_batch_in_training (x , gamma , beta , reduction_axes ,
1757+ epsilon = 1e-3 ):
1758+ """Fused version of `normalize_batch_in_training`.
1759+
1760+ # Arguments
1761+ x: Input tensor or variable.
1762+ gamma: Tensor by which to scale the input.
1763+ beta: Tensor with which to center the input.
1764+ reduction_axes: iterable of integers,
1765+ axes over which to normalize.
1766+ epsilon: Fuzz factor.
1767+
1768+ # Returns
1769+ A tuple length of 3, `(normalized_tensor, mean, variance)`.
1770+ """
1771+ if list (reduction_axes ) == [0 , 1 , 2 ]:
1772+ normalization_axis = 3
1773+ tf_data_format = 'NHWC'
1774+ else :
1775+ normalization_axis = 1
1776+ tf_data_format = 'NCHW'
1777+
1778+ if gamma is None :
1779+ gamma = tf .constant (1.0 ,
1780+ dtype = x .dtype ,
1781+ shape = [x .get_shape ()[normalization_axis ]])
1782+ if beta is None :
1783+ beta = tf .constant (0.0 ,
1784+ dtype = x .dtype ,
1785+ shape = [x .get_shape ()[normalization_axis ]])
1786+
1787+ return tf .nn .fused_batch_norm (
1788+ x ,
1789+ gamma ,
1790+ beta ,
1791+ epsilon = epsilon ,
1792+ data_format = tf_data_format )
1793+
1794+
1795+ def normalize_batch_in_training (x , gamma , beta ,
1796+ reduction_axes , epsilon = 1e-3 ):
1797+ """Computes mean and std for batch then apply batch_normalization on batch.
1798+
1799+ # Arguments
1800+ x: Input tensor or variable.
1801+ gamma: Tensor by which to scale the input.
1802+ beta: Tensor with which to center the input.
1803+ reduction_axes: iterable of integers,
1804+ axes over which to normalize.
1805+ epsilon: Fuzz factor.
1806+
1807+ # Returns
1808+ A tuple length of 3, `(normalized_tensor, mean, variance)`.
1809+ """
1810+ if ndim (x ) == 4 and list (reduction_axes ) in [[0 , 1 , 2 ], [0 , 2 , 3 ]]:
1811+ if not _has_nchw_support () and list (reduction_axes ) == [0 , 2 , 3 ]:
1812+ return _broadcast_normalize_batch_in_training (x , gamma , beta ,
1813+ reduction_axes ,
1814+ epsilon = epsilon )
1815+ return _fused_normalize_batch_in_training (
1816+ x , gamma , beta , reduction_axes ,
1817+ epsilon = epsilon )
1818+ else :
1819+ if sorted (reduction_axes ) == list (range (ndim (x )))[:- 1 ]:
1820+ return _regular_normalize_batch_in_training (x , gamma , beta ,
1821+ reduction_axes ,
1822+ epsilon = epsilon )
1823+ else :
1824+ return _broadcast_normalize_batch_in_training (x , gamma , beta ,
1825+ reduction_axes ,
1826+ epsilon = epsilon )
1827+
1828+
17341829def batch_normalization (x , mean , var , beta , gamma , epsilon = 1e-3 ):
17351830 """Applies batch normalization on x given mean, var, beta and gamma.
17361831
@@ -3573,7 +3668,11 @@ def bias_add(x, bias, data_format=None):
35733668 elif ndim (x ) == 4 :
35743669 if data_format == 'channels_first' :
35753670 if len (bias_shape ) == 1 :
3576- x += reshape (bias , (1 , bias_shape [0 ], 1 , 1 ))
3671+ if _has_nchw_support ():
3672+ x = tf .nn .bias_add (x , bias ,
3673+ data_format = 'NCHW' )
3674+ else :
3675+ x += reshape (bias , (1 , bias_shape [0 ], 1 , 1 ))
35773676 else :
35783677 x += reshape (bias , (1 , bias_shape [2 ]) + bias_shape [:2 ])
35793678 elif data_format == 'channels_last' :
0 commit comments