Skip to content

Commit 1549491

Browse files
Akshay Moditensorflower-gardener
Akshay Modi
authored andcommitted
_autopacking_helper casts tensors when it needs to.
PiperOrigin-RevId: 210612986
1 parent 33b460d commit 1549491

File tree

2 files changed

+36
-14
lines changed

2 files changed

+36
-14
lines changed

tensorflow/python/kernel_tests/stack_op_test.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -277,24 +277,34 @@ def testDtype(self):
277277
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], dtype=dtypes.float64)
278278
self.assertEqual(dtypes.float64, t_2.dtype)
279279

280+
t_3 = ops.convert_to_tensor(
281+
[[0., 0., 0.],
282+
constant_op.constant([0., 0., 0.], dtype=dtypes.float64), [0., 0., 0.]
283+
],
284+
dtype=dtypes.float32)
285+
self.assertEqual(dtypes.float32, t_3.dtype)
286+
287+
t_4 = ops.convert_to_tensor(
288+
[constant_op.constant([0., 0., 0.], dtype=dtypes.float64)],
289+
dtype=dtypes.float32)
290+
self.assertEqual(dtypes.float32, t_4.dtype)
291+
280292
with self.assertRaises(TypeError):
281293
ops.convert_to_tensor([
282294
constant_op.constant(
283295
[0., 0., 0.], dtype=dtypes.float32), constant_op.constant(
284296
[0., 0., 0.], dtype=dtypes.float64), [0., 0., 0.]
285297
])
286298

287-
with self.assertRaises(TypeError):
288-
ops.convert_to_tensor(
289-
[[0., 0., 0.], constant_op.constant(
290-
[0., 0., 0.], dtype=dtypes.float64), [0., 0., 0.]],
291-
dtype=dtypes.float32)
299+
def testDtypeConversionWhenTensorDtypeMismatch(self):
300+
t_0 = ops.convert_to_tensor([0., 0., 0.])
301+
self.assertEqual(dtypes.float32, t_0.dtype)
292302

293-
with self.assertRaises(TypeError):
294-
ops.convert_to_tensor(
295-
[constant_op.constant(
296-
[0., 0., 0.], dtype=dtypes.float64)],
297-
dtype=dtypes.float32)
303+
t_1 = ops.convert_to_tensor([0, 0, 0])
304+
self.assertEqual(dtypes.int32, t_1.dtype)
305+
306+
t_2 = ops.convert_to_tensor([t_0, t_0, t_1], dtype=dtypes.float64)
307+
self.assertEqual(dtypes.float64, t_2.dtype)
298308

299309
def testPlaceholder(self):
300310
with self.test_session(use_gpu=True):

tensorflow/python/ops/array_ops.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from tensorflow.python.ops.gen_array_ops import *
4444
from tensorflow.python.ops.gen_array_ops import reverse_v2 as reverse # pylint: disable=unused-import
4545
from tensorflow.python.util import deprecation
46+
from tensorflow.python.util import nest
4647
from tensorflow.python.util.tf_export import tf_export
4748
# pylint: enable=wildcard-import
4849

@@ -948,6 +949,15 @@ def _get_dtype_from_nested_lists(list_or_tuple):
948949
return None
949950

950951

952+
def _cast_nested_seqs_to_dtype(dtype):
953+
def _maybe_cast(elem):
954+
if ops.is_dense_tensor_like(elem):
955+
if dtype != elem.dtype.base_dtype:
956+
elem = gen_math_ops.cast(elem, dtype)
957+
return elem
958+
return _maybe_cast
959+
960+
951961
def _autopacking_conversion_function(v, dtype=None, name=None, as_ref=False):
952962
"""Tensor conversion function that automatically packs arguments."""
953963
if as_ref:
@@ -957,9 +967,11 @@ def _autopacking_conversion_function(v, dtype=None, name=None, as_ref=False):
957967
# We did not find any tensor-like objects in the nested lists, so defer to
958968
# other conversion functions.
959969
return NotImplemented
960-
if dtype is not None and dtype != inferred_dtype:
961-
return NotImplemented
962-
return _autopacking_helper(v, inferred_dtype, name or "packed")
970+
if dtype is None:
971+
dtype = inferred_dtype
972+
elif dtype != inferred_dtype:
973+
v = nest.map_structure(_cast_nested_seqs_to_dtype(dtype), v)
974+
return _autopacking_helper(v, dtype, name or "packed")
963975

964976

965977
# pylint: enable=invalid-name
@@ -1715,7 +1727,7 @@ def placeholder(dtype, shape=None, name=None):
17151727
@compatibility(eager)
17161728
Placeholders are not compatible with eager execution.
17171729
@end_compatibility
1718-
1730+
17191731
Args:
17201732
dtype: The type of elements in the tensor to be fed.
17211733
shape: The shape of the tensor to be fed (optional). If the shape is not

0 commit comments

Comments
 (0)