Skip to content

Commit 04cad40

Browse files
committed
Only rely on CudnnRNNV3 in TF
1 parent b01bb52 commit 04cad40

File tree

1 file changed

+71
-87
lines changed
  • keras/backend/tensorflow

1 file changed

+71
-87
lines changed

keras/backend/tensorflow/rnn.py

Lines changed: 71 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -682,13 +682,20 @@ def _cudnn_gru(
682682
_assert_valid_mask(mask)
683683
sequence_lengths = _compute_sequence_length_from_mask(mask, time_major)
684684
else:
685-
sequence_lengths = None
685+
if time_major:
686+
batch_dim = tf.shape(inputs)[1]
687+
max_sequence_length = tf.shape(inputs)[0]
688+
else:
689+
batch_dim = tf.shape(inputs)[0]
690+
max_sequence_length = tf.shape(inputs)[1]
691+
sequence_lengths = tf.fill([batch_dim], max_sequence_length)
686692

687693
if not time_major and sequence_lengths is None:
688694
inputs = tf.transpose(inputs, perm=(1, 0, 2))
689695
seq_axis, batch_axis = (0, 1)
690696
else:
691697
seq_axis, batch_axis = (0, 1) if time_major else (1, 0)
698+
692699
# For init_h, cuDNN expects one more dim of num_layers before or after batch
693700
# dim for time major or batch major inputs respectively
694701
init_h = tf.expand_dims(initial_state, axis=seq_axis)
@@ -719,49 +726,36 @@ def _cudnn_gru(
719726
transpose_weights=True,
720727
)
721728

722-
if sequence_lengths is not None:
723-
if go_backwards:
724-
# Three reversals are required. E.g.,
725-
# normal input = [1, 2, 3, 0, 0] # where 0 need to be masked
726-
# reversed_input_to_cudnn = [3, 2, 1, 0, 0]
727-
# output_from_cudnn = [6, 5, 4, 0, 0]
728-
# expected_output = [0, 0, 6, 5 ,4]
729-
inputs = tf.reverse_sequence(
730-
inputs,
731-
sequence_lengths,
732-
seq_axis=seq_axis,
733-
batch_axis=batch_axis,
734-
)
735-
outputs, h, _, _, _ = tf.raw_ops.CudnnRNNV3(
736-
input=inputs,
737-
input_h=init_h,
738-
input_c=0,
739-
params=params,
740-
is_training=True,
741-
rnn_mode="gru",
742-
sequence_lengths=sequence_lengths,
743-
time_major=time_major,
729+
if go_backwards:
730+
# Three reversals are required. E.g.,
731+
# normal input = [1, 2, 3, 0, 0] # where 0 need to be masked
732+
# reversed_input_to_cudnn = [3, 2, 1, 0, 0]
733+
# output_from_cudnn = [6, 5, 4, 0, 0]
734+
# expected_output = [0, 0, 6, 5 ,4]
735+
inputs = tf.reverse_sequence(
736+
inputs,
737+
sequence_lengths,
738+
seq_axis=seq_axis,
739+
batch_axis=batch_axis,
744740
)
745-
if go_backwards:
746-
outputs = tf.reverse_sequence(
747-
outputs,
748-
sequence_lengths,
749-
seq_axis=seq_axis,
750-
batch_axis=batch_axis,
751-
)
752-
outputs = tf.reverse(outputs, axis=[seq_axis])
753-
else:
754-
if go_backwards:
755-
# Reverse axis 0 since the input is already convert to time major.
756-
inputs = tf.reverse(inputs, axis=[0])
757-
outputs, h, _, _ = tf.raw_ops.CudnnRNN(
758-
input=inputs,
759-
input_h=init_h,
760-
input_c=0,
761-
params=params,
762-
is_training=True,
763-
rnn_mode="gru",
741+
outputs, h, _, _, _ = tf.raw_ops.CudnnRNNV3(
742+
input=inputs,
743+
input_h=init_h,
744+
input_c=0,
745+
params=params,
746+
is_training=True,
747+
rnn_mode="gru",
748+
sequence_lengths=sequence_lengths,
749+
time_major=time_major,
750+
)
751+
if go_backwards:
752+
outputs = tf.reverse_sequence(
753+
outputs,
754+
sequence_lengths,
755+
seq_axis=seq_axis,
756+
batch_axis=batch_axis,
764757
)
758+
outputs = tf.reverse(outputs, axis=[seq_axis])
765759

766760
last_output = outputs[-1]
767761
if not time_major and sequence_lengths is None and return_sequences:
@@ -880,7 +874,13 @@ def _cudnn_lstm(
880874
_assert_valid_mask(mask)
881875
sequence_lengths = _compute_sequence_length_from_mask(mask, time_major)
882876
else:
883-
sequence_lengths = None
877+
if time_major:
878+
batch_dim = tf.shape(inputs)[1]
879+
max_sequence_length = tf.shape(inputs)[0]
880+
else:
881+
batch_dim = tf.shape(inputs)[0]
882+
max_sequence_length = tf.shape(inputs)[1]
883+
sequence_lengths = tf.fill([batch_dim], max_sequence_length)
884884

885885
if not time_major and sequence_lengths is None:
886886
inputs = tf.transpose(inputs, perm=(1, 0, 2))
@@ -918,52 +918,36 @@ def _cudnn_lstm(
918918
transpose_weights=True,
919919
)
920920

921-
if sequence_lengths is not None:
922-
if go_backwards:
923-
# Three reversals are required. E.g.,
924-
# normal input = [1, 2, 3, 0, 0] # where 0 need to be masked
925-
# reversed_input_to_cudnn = [3, 2, 1, 0, 0]
926-
# output_from_cudnn = [6, 5, 4, 0, 0]
927-
# expected_output = [0, 0, 6, 5 ,4]
928-
inputs = tf.reverse_sequence(
929-
inputs,
930-
sequence_lengths,
931-
seq_axis=seq_axis,
932-
batch_axis=batch_axis,
933-
)
934-
outputs, h, c, _, _ = tf.raw_ops.CudnnRNNV3(
935-
input=inputs,
936-
input_h=init_h,
937-
input_c=init_c,
938-
params=params,
939-
is_training=True,
940-
rnn_mode="lstm",
941-
sequence_lengths=sequence_lengths,
942-
time_major=time_major,
921+
if go_backwards:
922+
# Three reversals are required. E.g.,
923+
# normal input = [1, 2, 3, 0, 0] # where 0 need to be masked
924+
# reversed_input_to_cudnn = [3, 2, 1, 0, 0]
925+
# output_from_cudnn = [6, 5, 4, 0, 0]
926+
# expected_output = [0, 0, 6, 5 ,4]
927+
inputs = tf.reverse_sequence(
928+
inputs,
929+
sequence_lengths,
930+
seq_axis=seq_axis,
931+
batch_axis=batch_axis,
943932
)
944-
if go_backwards:
945-
outputs = tf.reverse_sequence(
946-
outputs,
947-
sequence_lengths,
948-
seq_axis=seq_axis,
949-
batch_axis=batch_axis,
950-
)
951-
outputs = tf.reverse(outputs, axis=[seq_axis])
952-
else:
953-
# # Fill the array with shape [batch] with value of max timesteps.
954-
# sequence_length = array_ops.fill([array_ops.shape(inputs)[1]],
955-
# array_ops.shape(inputs)[0])
956-
if go_backwards:
957-
# Reverse axis 0 since the input is already convert to time major.
958-
inputs = tf.reverse(inputs, axis=[0])
959-
outputs, h, c, _ = tf.raw_ops.CudnnRNN(
960-
input=inputs,
961-
input_h=init_h,
962-
input_c=init_c,
963-
params=params,
964-
is_training=True,
965-
rnn_mode="lstm",
933+
outputs, h, c, _, _ = tf.raw_ops.CudnnRNNV3(
934+
input=inputs,
935+
input_h=init_h,
936+
input_c=init_c,
937+
params=params,
938+
is_training=True,
939+
rnn_mode="lstm",
940+
sequence_lengths=sequence_lengths,
941+
time_major=time_major,
942+
)
943+
if go_backwards:
944+
outputs = tf.reverse_sequence(
945+
outputs,
946+
sequence_lengths,
947+
seq_axis=seq_axis,
948+
batch_axis=batch_axis,
966949
)
950+
outputs = tf.reverse(outputs, axis=[seq_axis])
967951

968952
last_output = outputs[-1]
969953
if not time_major and sequence_lengths is None and return_sequences:

0 commit comments

Comments
 (0)