Skip to content

Commit a6343d0

Browse files
Ali Yahyatensorflower-gardener
authored andcommitted
Currently, calling read_variable_op or assign_variable_op with a dtype that does not match the underlying type of the variable fails at runtime instead of at graph construction time. This CL changes that.
PiperOrigin-RevId: 166527195
1 parent 51642db commit a6343d0

File tree

2 files changed

+46
-9
lines changed

2 files changed

+46
-9
lines changed

tensorflow/core/kernels/resource_variable_ops.cc

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@ REGISTER_RESOURCE_HANDLE_KERNEL(Var);
4141
template <typename Device, typename T>
4242
class ReadVariableOp : public OpKernel {
4343
public:
44-
explicit ReadVariableOp(OpKernelConstruction* c) : OpKernel(c) {}
44+
explicit ReadVariableOp(OpKernelConstruction* c) : OpKernel(c) {
45+
OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_));
46+
}
4547

4648
void Compute(OpKernelContext* ctx) override {
4749
Var* variable = nullptr;
@@ -63,8 +65,16 @@ class ReadVariableOp : public OpKernel {
6365
ctx->allocate_output(0, variable->tensor()->shape(), &out));
6466
functor::DenseUpdate<Device, T, ASSIGN> copy_functor;
6567
const Tensor& t = *variable->tensor();
68+
OP_REQUIRES(
69+
ctx, dtype_ == t.dtype(),
70+
errors::InvalidArgument(
71+
"Trying to read variable with wrong dtype. Expected ",
72+
DataTypeString(dtype_), " got ", DataTypeString(t.dtype())));
6673
copy_functor(ctx->eigen_device<Device>(), out->flat<T>(), t.flat<T>());
6774
}
75+
76+
private:
77+
DataType dtype_;
6878
};
6979

7080
// TODO(apassos) register for the GPU as well.
@@ -223,6 +233,12 @@ class AssignVariableOp : public OpKernel {
223233
}));
224234
core::ScopedUnref s(variable);
225235

236+
OP_REQUIRES(context, variable->tensor()->dtype() == dtype_,
237+
errors::InvalidArgument(
238+
"Trying to assign variable with wrong dtype. Expected ",
239+
DataTypeString(variable->tensor()->dtype()), " got ",
240+
DataTypeString(dtype_)));
241+
226242
// TODO(apassos): holding a lock and copying is unnecessary if we are the
227243
// last user of the value tensor. This should essentially always be the
228244
// case, yet the refcount is usually 2 instead of 1. Figure out what needs

tensorflow/python/kernel_tests/resource_variable_ops_test.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,27 @@ def testHandleDtypeShapeMatch(self):
5353
0,
5454
dtype=dtypes.int32)).run()
5555

56+
def testReadVariableDtypeMismatch(self):
57+
with context.eager_mode():
58+
handle = resource_variable_ops.var_handle_op(
59+
dtype=dtypes.int32, shape=[1], name="foo")
60+
with self.assertRaisesRegexp(errors.InvalidArgumentError,
61+
"Trying to read variable with wrong dtype. "
62+
"Expected float got int32."):
63+
_ = resource_variable_ops.read_variable_op(handle, dtype=dtypes.float32)
64+
65+
def testAssignVariableDtypeMismatch(self):
66+
with context.eager_mode():
67+
handle = resource_variable_ops.var_handle_op(
68+
dtype=dtypes.int32, shape=[1], name="foo")
69+
resource_variable_ops.assign_variable_op(
70+
handle, constant_op.constant([1]))
71+
with self.assertRaisesRegexp(errors.InvalidArgumentError,
72+
"Trying to assign variable with wrong "
73+
"dtype. Expected int32 got float."):
74+
resource_variable_ops.assign_variable_op(
75+
handle, constant_op.constant([1.], dtype=dtypes.float32))
76+
5677
@test_util.run_in_graph_and_eager_modes()
5778
def testDtypeSurvivesIdentity(self):
5879
handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[])
@@ -188,7 +209,7 @@ def testSparseRead(self):
188209
with self.test_session():
189210
init_value = np.reshape(np.arange(np.power(4, 3)), (4, 4, 4))
190211
v = resource_variable_ops.ResourceVariable(
191-
constant_op.constant(init_value, dtype=dtypes.int32))
212+
constant_op.constant(init_value, dtype=dtypes.int32), name="var0")
192213
self.evaluate(variables.global_variables_initializer())
193214

194215
value = self.evaluate(v.sparse_read([0, 3, 1, 2]))
@@ -294,18 +315,18 @@ def testSharedName(self):
294315
@test_util.run_in_graph_and_eager_modes()
295316
def testSharedNameWithNamescope(self):
296317
with ops.name_scope("foo"):
297-
v = resource_variable_ops.ResourceVariable(300.0, name="var1")
318+
v = resource_variable_ops.ResourceVariable(300.0, name="var3")
298319
self.evaluate(variables.global_variables_initializer())
299320

300321
w = resource_variable_ops.var_handle_op(
301-
dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="foo/var1")
322+
dtype=v.dtype.base_dtype, shape=v.get_shape(), shared_name="foo/var3")
302323
w_read = resource_variable_ops.read_variable_op(w, v.dtype.base_dtype)
303324
self.assertEqual(300.0, self.evaluate(w_read))
304325

305326
@test_util.run_in_graph_and_eager_modes()
306327
def testShape(self):
307328
v = resource_variable_ops.ResourceVariable(
308-
name="var1", initial_value=array_ops.ones(shape=[10, 20, 35]))
329+
name="var4", initial_value=array_ops.ones(shape=[10, 20, 35]))
309330
self.assertEqual("(10, 20, 35)", str(v.shape))
310331
self.assertEqual("(10, 20, 35)", str(v.get_shape()))
311332
self.assertEqual("(10, 20, 35)", str(v.value().shape))
@@ -343,13 +364,13 @@ def testVariableEager(self):
343364
constraint = lambda x: x
344365
with ops.name_scope("foo"):
345366
v = resource_variable_ops.ResourceVariable(
346-
name="var1",
367+
name="var5",
347368
initial_value=init,
348369
caching_device="cpu:0",
349370
constraint=constraint)
350371
# Test properties
351372
self.assertEqual(dtypes.int32, v.dtype)
352-
self.assertEqual("foo/var1:0", v.name)
373+
self.assertEqual("foo/var5:0", v.name)
353374
self.assertAllEqual([10, 20, 35], v.shape.as_list())
354375
self.assertAllEqual(init.device, v.device)
355376
self.assertTrue(isinstance(v.handle, ops.EagerTensor))
@@ -360,8 +381,8 @@ def testVariableEager(self):
360381
# Callable init.
361382
callable_init = lambda: init * 2
362383
v2 = resource_variable_ops.ResourceVariable(
363-
initial_value=callable_init, name="v2")
364-
self.assertEqual("v2:0", v2.name)
384+
initial_value=callable_init, name="var6")
385+
self.assertEqual("var6:0", v2.name)
365386
self.assertAllEqual(2 * init.numpy(), v2.read_value().numpy())
366387

367388
# Test assign_add.

0 commit comments

Comments
 (0)