Skip to content

Commit 1de4bf1

Browse files
t-vifchollet
authored andcommitted
achieve compatibility with tensorflow 1.0rc1 (keras-team#5296)
1 parent 7016e8f commit 1de4bf1

File tree

1 file changed

+25
-9
lines changed

1 file changed

+25
-9
lines changed

keras/backend/tensorflow_backend.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@
2828
# Change its value via `manual_variable_initialization(value)`.
2929
_MANUAL_VAR_INIT = False
3030

31+
# These two integers contain the tensorflow version for coping with API breaks.
32+
tf_major_version = int(tf.__version__.split('.')[0])
33+
tf_minor_version = int(tf.__version__.split('.')[1])
34+
3135

3236
def clear_session():
3337
"""Destroys the current TF graph and creates a new one.
@@ -240,9 +244,14 @@ def variable(value, dtype=None, name=None):
240244
sparse_coo = value.tocoo()
241245
indices = np.concatenate((np.expand_dims(sparse_coo.row, 1),
242246
np.expand_dims(sparse_coo.col, 1)), 1)
243-
v = tf.SparseTensor(indices=indices,
244-
values=sparse_coo.data,
245-
shape=sparse_coo.shape)
247+
if tf_major_version >= 1:
248+
v = tf.SparseTensor(indices=indices,
249+
values=sparse_coo.data,
250+
dense_shape=sparse_coo.shape)
251+
else:
252+
v = tf.SparseTensor(indices=indices,
253+
values=sparse_coo.data,
254+
shape=sparse_coo.shape)
246255
v._dims = len(sparse_coo.shape)
247256
v._keras_shape = sparse_coo.shape
248257
v._uses_learning_phase = False
@@ -1430,10 +1439,13 @@ def concatenate(tensors, axis=-1):
14301439
if py_all([is_sparse(x) for x in tensors]):
14311440
return tf.sparse_concat(axis, tensors)
14321441
else:
1433-
try:
1434-
return tf.concat_v2([to_dense(x) for x in tensors], axis)
1435-
except AttributeError:
1436-
return tf.concat(axis, [to_dense(x) for x in tensors])
1442+
if tf_major_version >= 1:
1443+
return tf.concat([to_dense(x) for x in tensors], axis)
1444+
else:
1445+
try:
1446+
return tf.concat_v2([to_dense(x) for x in tensors], axis)
1447+
except AttributeError:
1448+
return tf.concat(axis, [to_dense(x) for x in tensors])
14371449

14381450

14391451
def reshape(x, shape):
@@ -3061,8 +3073,12 @@ def ctc_decode(y_pred, input_length, greedy=True, beam_width=100,
30613073
sequence_length=input_length, beam_width=beam_width,
30623074
top_paths=top_paths)
30633075

3064-
decoded_dense = [tf.sparse_to_dense(st.indices, st.shape, st.values, default_value=-1)
3065-
for st in decoded]
3076+
if tf_major_version >= 1:
3077+
decoded_dense = [tf.sparse_to_dense(st.indices, st.dense_shape, st.values, default_value=-1)
3078+
for st in decoded]
3079+
else:
3080+
decoded_dense = [tf.sparse_to_dense(st.indices, st.shape, st.values, default_value=-1)
3081+
for st in decoded]
30663082

30673083
return (decoded_dense, log_prob)
30683084

0 commit comments

Comments
 (0)