Skip to content

Commit c7e8c54

Browse files
authored
update tensorflow api to 1.8
1 parent 6672f93 commit c7e8c54

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

modules.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ def multihead_attention(queries,
224224
# Causality = Future blinding
225225
if causality:
226226
diag_vals = tf.ones_like(outputs[0, :, :]) # (T_q, T_k)
227-
tril = tf.contrib.linalg.LinearOperatorTriL(diag_vals).to_dense() # (T_q, T_k)
227+
tril = tf.linalg.LinearOperatorLowerTriangular(diag_vals).to_dense() # (T_q, T_k)
228228
masks = tf.tile(tf.expand_dims(tril, 0), [tf.shape(outputs)[0], 1, 1]) # (h*N, T_q, T_k)
229229

230230
paddings = tf.ones_like(masks)*(-2**32+1)

0 commit comments

Comments
 (0)