|
28 | 28 | # Change its value via `manual_variable_initialization(value)`.
|
29 | 29 | _MANUAL_VAR_INIT = False
|
30 | 30 |
|
| 31 | +# These two integers contain the tensorflow version for coping with API breaks. |
| 32 | +tf_major_version = int(tf.__version__.split('.')[0]) |
| 33 | +tf_minor_version = int(tf.__version__.split('.')[1]) |
| 34 | + |
31 | 35 |
|
32 | 36 | def clear_session():
|
33 | 37 | """Destroys the current TF graph and creates a new one.
|
@@ -240,9 +244,14 @@ def variable(value, dtype=None, name=None):
|
240 | 244 | sparse_coo = value.tocoo()
|
241 | 245 | indices = np.concatenate((np.expand_dims(sparse_coo.row, 1),
|
242 | 246 | np.expand_dims(sparse_coo.col, 1)), 1)
|
243 |
| - v = tf.SparseTensor(indices=indices, |
244 |
| - values=sparse_coo.data, |
245 |
| - shape=sparse_coo.shape) |
| 247 | + if tf_major_version >= 1: |
| 248 | + v = tf.SparseTensor(indices=indices, |
| 249 | + values=sparse_coo.data, |
| 250 | + dense_shape=sparse_coo.shape) |
| 251 | + else: |
| 252 | + v = tf.SparseTensor(indices=indices, |
| 253 | + values=sparse_coo.data, |
| 254 | + shape=sparse_coo.shape) |
246 | 255 | v._dims = len(sparse_coo.shape)
|
247 | 256 | v._keras_shape = sparse_coo.shape
|
248 | 257 | v._uses_learning_phase = False
|
@@ -1430,10 +1439,13 @@ def concatenate(tensors, axis=-1):
|
1430 | 1439 | if py_all([is_sparse(x) for x in tensors]):
|
1431 | 1440 | return tf.sparse_concat(axis, tensors)
|
1432 | 1441 | else:
|
1433 |
| - try: |
1434 |
| - return tf.concat_v2([to_dense(x) for x in tensors], axis) |
1435 |
| - except AttributeError: |
1436 |
| - return tf.concat(axis, [to_dense(x) for x in tensors]) |
| 1442 | + if tf_major_version >= 1: |
| 1443 | + return tf.concat([to_dense(x) for x in tensors], axis) |
| 1444 | + else: |
| 1445 | + try: |
| 1446 | + return tf.concat_v2([to_dense(x) for x in tensors], axis) |
| 1447 | + except AttributeError: |
| 1448 | + return tf.concat(axis, [to_dense(x) for x in tensors]) |
1437 | 1449 |
|
1438 | 1450 |
|
1439 | 1451 | def reshape(x, shape):
|
@@ -3061,8 +3073,12 @@ def ctc_decode(y_pred, input_length, greedy=True, beam_width=100,
|
3061 | 3073 | sequence_length=input_length, beam_width=beam_width,
|
3062 | 3074 | top_paths=top_paths)
|
3063 | 3075 |
|
3064 |
| - decoded_dense = [tf.sparse_to_dense(st.indices, st.shape, st.values, default_value=-1) |
3065 |
| - for st in decoded] |
| 3076 | + if tf_major_version >= 1: |
| 3077 | + decoded_dense = [tf.sparse_to_dense(st.indices, st.dense_shape, st.values, default_value=-1) |
| 3078 | + for st in decoded] |
| 3079 | + else: |
| 3080 | + decoded_dense = [tf.sparse_to_dense(st.indices, st.shape, st.values, default_value=-1) |
| 3081 | + for st in decoded] |
3066 | 3082 |
|
3067 | 3083 | return (decoded_dense, log_prob)
|
3068 | 3084 |
|
|
0 commit comments