@@ -180,8 +180,14 @@ def testGradient(self):
180180class L2NormalizeTest (tf .test .TestCase ):
181181
182182 def _l2Normalize (self , x , dim ):
183- norm = np .apply_along_axis (np .linalg .norm , dim , x )
184- return x / np .expand_dims (norm , dim )
183+ if isinstance (dim , list ):
184+ norm = np .linalg .norm (x , axis = tuple (dim ))
185+ for d in dim :
186+ norm = np .expand_dims (norm , d )
187+ return x / norm
188+ else :
189+ norm = np .apply_along_axis (np .linalg .norm , dim , x )
190+ return x / np .expand_dims (norm , dim )
185191
186192 def testL2Normalize (self ):
187193 x_shape = [20 , 7 , 3 ]
@@ -194,6 +200,17 @@ def testL2Normalize(self):
194200 y_tf = tf .nn .l2_normalize (x_tf , dim )
195201 self .assertAllClose (y_np , y_tf .eval ())
196202
203+ def testL2NormalizeDimArray (self ):
204+ x_shape = [20 , 7 , 3 ]
205+ np .random .seed (1 )
206+ x_np = np .random .random_sample (x_shape ).astype (np .float32 )
207+ dim = [1 , 2 ]
208+ y_np = self ._l2Normalize (x_np , dim )
209+ with self .test_session ():
210+ x_tf = tf .constant (x_np , name = "x" )
211+ y_tf = tf .nn .l2_normalize (x_tf , dim )
212+ self .assertAllClose (y_np , y_tf .eval ())
213+
197214 def testL2NormalizeGradient (self ):
198215 x_shape = [20 , 7 , 3 ]
199216 np .random .seed (1 )
0 commit comments