@@ -682,13 +682,20 @@ def _cudnn_gru(
682
682
_assert_valid_mask (mask )
683
683
sequence_lengths = _compute_sequence_length_from_mask (mask , time_major )
684
684
else :
685
- sequence_lengths = None
685
+ if time_major :
686
+ batch_dim = tf .shape (inputs )[1 ]
687
+ max_sequence_length = tf .shape (inputs )[0 ]
688
+ else :
689
+ batch_dim = tf .shape (inputs )[0 ]
690
+ max_sequence_length = tf .shape (inputs )[1 ]
691
+ sequence_lengths = tf .fill ([batch_dim ], max_sequence_length )
686
692
687
693
if not time_major and sequence_lengths is None :
688
694
inputs = tf .transpose (inputs , perm = (1 , 0 , 2 ))
689
695
seq_axis , batch_axis = (0 , 1 )
690
696
else :
691
697
seq_axis , batch_axis = (0 , 1 ) if time_major else (1 , 0 )
698
+
692
699
# For init_h, cuDNN expects one more dim of num_layers before or after batch
693
700
# dim for time major or batch major inputs respectively
694
701
init_h = tf .expand_dims (initial_state , axis = seq_axis )
@@ -719,49 +726,36 @@ def _cudnn_gru(
719
726
transpose_weights = True ,
720
727
)
721
728
722
- if sequence_lengths is not None :
723
- if go_backwards :
724
- # Three reversals are required. E.g.,
725
- # normal input = [1, 2, 3, 0, 0] # where 0 need to be masked
726
- # reversed_input_to_cudnn = [3, 2, 1, 0, 0]
727
- # output_from_cudnn = [6, 5, 4, 0, 0]
728
- # expected_output = [0, 0, 6, 5 ,4]
729
- inputs = tf .reverse_sequence (
730
- inputs ,
731
- sequence_lengths ,
732
- seq_axis = seq_axis ,
733
- batch_axis = batch_axis ,
734
- )
735
- outputs , h , _ , _ , _ = tf .raw_ops .CudnnRNNV3 (
736
- input = inputs ,
737
- input_h = init_h ,
738
- input_c = 0 ,
739
- params = params ,
740
- is_training = True ,
741
- rnn_mode = "gru" ,
742
- sequence_lengths = sequence_lengths ,
743
- time_major = time_major ,
729
+ if go_backwards :
730
+ # Three reversals are required. E.g.,
731
+ # normal input = [1, 2, 3, 0, 0] # where 0 need to be masked
732
+ # reversed_input_to_cudnn = [3, 2, 1, 0, 0]
733
+ # output_from_cudnn = [6, 5, 4, 0, 0]
734
+ # expected_output = [0, 0, 6, 5 ,4]
735
+ inputs = tf .reverse_sequence (
736
+ inputs ,
737
+ sequence_lengths ,
738
+ seq_axis = seq_axis ,
739
+ batch_axis = batch_axis ,
744
740
)
745
- if go_backwards :
746
- outputs = tf .reverse_sequence (
747
- outputs ,
748
- sequence_lengths ,
749
- seq_axis = seq_axis ,
750
- batch_axis = batch_axis ,
751
- )
752
- outputs = tf .reverse (outputs , axis = [seq_axis ])
753
- else :
754
- if go_backwards :
755
- # Reverse axis 0 since the input is already convert to time major.
756
- inputs = tf .reverse (inputs , axis = [0 ])
757
- outputs , h , _ , _ = tf .raw_ops .CudnnRNN (
758
- input = inputs ,
759
- input_h = init_h ,
760
- input_c = 0 ,
761
- params = params ,
762
- is_training = True ,
763
- rnn_mode = "gru" ,
741
+ outputs , h , _ , _ , _ = tf .raw_ops .CudnnRNNV3 (
742
+ input = inputs ,
743
+ input_h = init_h ,
744
+ input_c = 0 ,
745
+ params = params ,
746
+ is_training = True ,
747
+ rnn_mode = "gru" ,
748
+ sequence_lengths = sequence_lengths ,
749
+ time_major = time_major ,
750
+ )
751
+ if go_backwards :
752
+ outputs = tf .reverse_sequence (
753
+ outputs ,
754
+ sequence_lengths ,
755
+ seq_axis = seq_axis ,
756
+ batch_axis = batch_axis ,
764
757
)
758
+ outputs = tf .reverse (outputs , axis = [seq_axis ])
765
759
766
760
last_output = outputs [- 1 ]
767
761
if not time_major and sequence_lengths is None and return_sequences :
@@ -880,7 +874,13 @@ def _cudnn_lstm(
880
874
_assert_valid_mask (mask )
881
875
sequence_lengths = _compute_sequence_length_from_mask (mask , time_major )
882
876
else :
883
- sequence_lengths = None
877
+ if time_major :
878
+ batch_dim = tf .shape (inputs )[1 ]
879
+ max_sequence_length = tf .shape (inputs )[0 ]
880
+ else :
881
+ batch_dim = tf .shape (inputs )[0 ]
882
+ max_sequence_length = tf .shape (inputs )[1 ]
883
+ sequence_lengths = tf .fill ([batch_dim ], max_sequence_length )
884
884
885
885
if not time_major and sequence_lengths is None :
886
886
inputs = tf .transpose (inputs , perm = (1 , 0 , 2 ))
@@ -918,52 +918,36 @@ def _cudnn_lstm(
918
918
transpose_weights = True ,
919
919
)
920
920
921
- if sequence_lengths is not None :
922
- if go_backwards :
923
- # Three reversals are required. E.g.,
924
- # normal input = [1, 2, 3, 0, 0] # where 0 need to be masked
925
- # reversed_input_to_cudnn = [3, 2, 1, 0, 0]
926
- # output_from_cudnn = [6, 5, 4, 0, 0]
927
- # expected_output = [0, 0, 6, 5 ,4]
928
- inputs = tf .reverse_sequence (
929
- inputs ,
930
- sequence_lengths ,
931
- seq_axis = seq_axis ,
932
- batch_axis = batch_axis ,
933
- )
934
- outputs , h , c , _ , _ = tf .raw_ops .CudnnRNNV3 (
935
- input = inputs ,
936
- input_h = init_h ,
937
- input_c = init_c ,
938
- params = params ,
939
- is_training = True ,
940
- rnn_mode = "lstm" ,
941
- sequence_lengths = sequence_lengths ,
942
- time_major = time_major ,
921
+ if go_backwards :
922
+ # Three reversals are required. E.g.,
923
+ # normal input = [1, 2, 3, 0, 0] # where 0 need to be masked
924
+ # reversed_input_to_cudnn = [3, 2, 1, 0, 0]
925
+ # output_from_cudnn = [6, 5, 4, 0, 0]
926
+ # expected_output = [0, 0, 6, 5 ,4]
927
+ inputs = tf .reverse_sequence (
928
+ inputs ,
929
+ sequence_lengths ,
930
+ seq_axis = seq_axis ,
931
+ batch_axis = batch_axis ,
943
932
)
944
- if go_backwards :
945
- outputs = tf .reverse_sequence (
946
- outputs ,
947
- sequence_lengths ,
948
- seq_axis = seq_axis ,
949
- batch_axis = batch_axis ,
950
- )
951
- outputs = tf .reverse (outputs , axis = [seq_axis ])
952
- else :
953
- # # Fill the array with shape [batch] with value of max timesteps.
954
- # sequence_length = array_ops.fill([array_ops.shape(inputs)[1]],
955
- # array_ops.shape(inputs)[0])
956
- if go_backwards :
957
- # Reverse axis 0 since the input is already convert to time major.
958
- inputs = tf .reverse (inputs , axis = [0 ])
959
- outputs , h , c , _ = tf .raw_ops .CudnnRNN (
960
- input = inputs ,
961
- input_h = init_h ,
962
- input_c = init_c ,
963
- params = params ,
964
- is_training = True ,
965
- rnn_mode = "lstm" ,
933
+ outputs , h , c , _ , _ = tf .raw_ops .CudnnRNNV3 (
934
+ input = inputs ,
935
+ input_h = init_h ,
936
+ input_c = init_c ,
937
+ params = params ,
938
+ is_training = True ,
939
+ rnn_mode = "lstm" ,
940
+ sequence_lengths = sequence_lengths ,
941
+ time_major = time_major ,
942
+ )
943
+ if go_backwards :
944
+ outputs = tf .reverse_sequence (
945
+ outputs ,
946
+ sequence_lengths ,
947
+ seq_axis = seq_axis ,
948
+ batch_axis = batch_axis ,
966
949
)
950
+ outputs = tf .reverse (outputs , axis = [seq_axis ])
967
951
968
952
last_output = outputs [- 1 ]
969
953
if not time_major and sequence_lengths is None and return_sequences :
0 commit comments