@@ -53,6 +53,27 @@ def testHandleDtypeShapeMatch(self):
5353 0 ,
5454 dtype = dtypes .int32 )).run ()
5555
56+ def testReadVariableDtypeMismatch (self ):
57+ with context .eager_mode ():
58+ handle = resource_variable_ops .var_handle_op (
59+ dtype = dtypes .int32 , shape = [1 ], name = "foo" )
60+ with self .assertRaisesRegexp (errors .InvalidArgumentError ,
61+ "Trying to read variable with wrong dtype. "
62+ "Expected float got int32." ):
63+ _ = resource_variable_ops .read_variable_op (handle , dtype = dtypes .float32 )
64+
65+ def testAssignVariableDtypeMismatch (self ):
66+ with context .eager_mode ():
67+ handle = resource_variable_ops .var_handle_op (
68+ dtype = dtypes .int32 , shape = [1 ], name = "foo" )
69+ resource_variable_ops .assign_variable_op (
70+ handle , constant_op .constant ([1 ]))
71+ with self .assertRaisesRegexp (errors .InvalidArgumentError ,
72+ "Trying to assign variable with wrong "
73+ "dtype. Expected int32 got float." ):
74+ resource_variable_ops .assign_variable_op (
75+ handle , constant_op .constant ([1. ], dtype = dtypes .float32 ))
76+
5677 @test_util .run_in_graph_and_eager_modes ()
5778 def testDtypeSurvivesIdentity (self ):
5879 handle = resource_variable_ops .var_handle_op (dtype = dtypes .int32 , shape = [])
@@ -188,7 +209,7 @@ def testSparseRead(self):
188209 with self .test_session ():
189210 init_value = np .reshape (np .arange (np .power (4 , 3 )), (4 , 4 , 4 ))
190211 v = resource_variable_ops .ResourceVariable (
191- constant_op .constant (init_value , dtype = dtypes .int32 ))
212+ constant_op .constant (init_value , dtype = dtypes .int32 ), name = "var0" )
192213 self .evaluate (variables .global_variables_initializer ())
193214
194215 value = self .evaluate (v .sparse_read ([0 , 3 , 1 , 2 ]))
@@ -294,18 +315,18 @@ def testSharedName(self):
294315 @test_util .run_in_graph_and_eager_modes ()
295316 def testSharedNameWithNamescope (self ):
296317 with ops .name_scope ("foo" ):
297- v = resource_variable_ops .ResourceVariable (300.0 , name = "var1 " )
318+ v = resource_variable_ops .ResourceVariable (300.0 , name = "var3 " )
298319 self .evaluate (variables .global_variables_initializer ())
299320
300321 w = resource_variable_ops .var_handle_op (
301- dtype = v .dtype .base_dtype , shape = v .get_shape (), shared_name = "foo/var1 " )
322+ dtype = v .dtype .base_dtype , shape = v .get_shape (), shared_name = "foo/var3 " )
302323 w_read = resource_variable_ops .read_variable_op (w , v .dtype .base_dtype )
303324 self .assertEqual (300.0 , self .evaluate (w_read ))
304325
305326 @test_util .run_in_graph_and_eager_modes ()
306327 def testShape (self ):
307328 v = resource_variable_ops .ResourceVariable (
308- name = "var1 " , initial_value = array_ops .ones (shape = [10 , 20 , 35 ]))
329+ name = "var4 " , initial_value = array_ops .ones (shape = [10 , 20 , 35 ]))
309330 self .assertEqual ("(10, 20, 35)" , str (v .shape ))
310331 self .assertEqual ("(10, 20, 35)" , str (v .get_shape ()))
311332 self .assertEqual ("(10, 20, 35)" , str (v .value ().shape ))
@@ -343,13 +364,13 @@ def testVariableEager(self):
343364 constraint = lambda x : x
344365 with ops .name_scope ("foo" ):
345366 v = resource_variable_ops .ResourceVariable (
346- name = "var1 " ,
367+ name = "var5 " ,
347368 initial_value = init ,
348369 caching_device = "cpu:0" ,
349370 constraint = constraint )
350371 # Test properties
351372 self .assertEqual (dtypes .int32 , v .dtype )
352- self .assertEqual ("foo/var1 :0" , v .name )
373+ self .assertEqual ("foo/var5 :0" , v .name )
353374 self .assertAllEqual ([10 , 20 , 35 ], v .shape .as_list ())
354375 self .assertAllEqual (init .device , v .device )
355376 self .assertTrue (isinstance (v .handle , ops .EagerTensor ))
@@ -360,8 +381,8 @@ def testVariableEager(self):
360381 # Callable init.
361382 callable_init = lambda : init * 2
362383 v2 = resource_variable_ops .ResourceVariable (
363- initial_value = callable_init , name = "v2 " )
364- self .assertEqual ("v2 :0" , v2 .name )
384+ initial_value = callable_init , name = "var6 " )
385+ self .assertEqual ("var6 :0" , v2 .name )
365386 self .assertAllEqual (2 * init .numpy (), v2 .read_value ().numpy ())
366387
367388 # Test assign_add.
0 commit comments