Skip to content

Commit 3b7153d

Browse files
Vijay Vasudevantensorflower-gardener
authored andcommitted
Fix L2Normalize when passing a list of dims. Fixes tensorflow#3932.
Change: 130947456
1 parent a5bcbf1 commit 3b7153d

File tree

2 files changed

+22
-4
lines changed

2 files changed

+22
-4
lines changed

tensorflow/python/ops/nn.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -547,7 +547,8 @@ def l2_normalize(x, dim, epsilon=1e-12, name=None):
547547
548548
Args:
549549
x: A `Tensor`.
550-
dim: Dimension along which to normalize.
550+
dim: Dimension along which to normalize. A scalar or a vector of
551+
integers.
551552
epsilon: A lower bound value for the norm. Will use `sqrt(epsilon)` as the
552553
divisor if `norm < sqrt(epsilon)`.
553554
name: A name for this operation (optional).
@@ -557,7 +558,7 @@ def l2_normalize(x, dim, epsilon=1e-12, name=None):
557558
"""
558559
with ops.name_scope(name, "l2_normalize", [x]) as name:
559560
x = ops.convert_to_tensor(x, name="x")
560-
square_sum = math_ops.reduce_sum(math_ops.square(x), [dim], keep_dims=True)
561+
square_sum = math_ops.reduce_sum(math_ops.square(x), dim, keep_dims=True)
561562
x_inv_norm = math_ops.rsqrt(math_ops.maximum(square_sum, epsilon))
562563
return math_ops.mul(x, x_inv_norm, name=name)
563564

tensorflow/python/ops/nn_test.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,8 +180,14 @@ def testGradient(self):
180180
class L2NormalizeTest(tf.test.TestCase):
181181

182182
def _l2Normalize(self, x, dim):
183-
norm = np.apply_along_axis(np.linalg.norm, dim, x)
184-
return x / np.expand_dims(norm, dim)
183+
if isinstance(dim, list):
184+
norm = np.linalg.norm(x, axis=tuple(dim))
185+
for d in dim:
186+
norm = np.expand_dims(norm, d)
187+
return x / norm
188+
else:
189+
norm = np.apply_along_axis(np.linalg.norm, dim, x)
190+
return x / np.expand_dims(norm, dim)
185191

186192
def testL2Normalize(self):
187193
x_shape = [20, 7, 3]
@@ -194,6 +200,17 @@ def testL2Normalize(self):
194200
y_tf = tf.nn.l2_normalize(x_tf, dim)
195201
self.assertAllClose(y_np, y_tf.eval())
196202

203+
def testL2NormalizeDimArray(self):
204+
x_shape = [20, 7, 3]
205+
np.random.seed(1)
206+
x_np = np.random.random_sample(x_shape).astype(np.float32)
207+
dim = [1, 2]
208+
y_np = self._l2Normalize(x_np, dim)
209+
with self.test_session():
210+
x_tf = tf.constant(x_np, name="x")
211+
y_tf = tf.nn.l2_normalize(x_tf, dim)
212+
self.assertAllClose(y_np, y_tf.eval())
213+
197214
def testL2NormalizeGradient(self):
198215
x_shape = [20, 7, 3]
199216
np.random.seed(1)

0 commit comments

Comments
 (0)