@@ -112,7 +112,7 @@ class RNN(RNNBase):
112112
113113 Inputs: input, h_0
114114 - `input`: A (seq_len x batch x input_size) tensor containing the features of the input sequence.
115- - `h_0`: A (num_layers x batch x hidden_size) tensor containing the initial hidden state for each element in the batch.
115+ - `h_0`: A (( num_layers * num_directions) x batch x hidden_size) tensor containing the initial hidden state for each element in the batch.
116116
117117 Outputs: output, h_n
118118 - `output`: A (seq_len x batch x hidden_size) tensor containing the output features (h_k) from the last layer of the RNN, for each k
@@ -184,8 +184,8 @@ class LSTM(RNNBase):
184184
185185 Inputs: `input, (h_0, c_0)`
186186 - `input` : A (seq_len x batch x input_size) tensor containing the features of the input sequence.
187- - `h_0` : A (num_layers x batch x hidden_size) tensor containing the initial hidden state for each element in the batch.
188- - `c_0` : A (num_layers x batch x hidden_size) tensor containing the initial cell state for each element in the batch.
187+ - `h_0` : A (( num_layers * num_directions) x batch x hidden_size) tensor containing the initial hidden state for each element in the batch.
188+ - `c_0` : A (( num_layers * num_directions) x batch x hidden_size) tensor containing the initial cell state for each element in the batch.
189189
190190 Outputs: output, (h_n, c_n)
191191 - `output` : A (seq_len x batch x hidden_size) tensor containing the output features `(h_t)` from the last layer of the RNN, for each t
@@ -241,7 +241,7 @@ class GRU(RNNBase):
241241
242242 Inputs: `input, h_0`
243243 - `input` : A `(seq_len x batch x input_size)` tensor containing the features of the input sequence.
244- - `h_0` : A `(num_layers x batch x hidden_size)` tensor containing the initial hidden state for each element in the batch.
244+ - `h_0` : A `(( num_layers * num_directions) x batch x hidden_size)` tensor containing the initial hidden state for each element in the batch.
245245
246246 Outputs: `output, h_n`
247247 - `output` : A `(seq_len x batch x hidden_size)` tensor containing the output features `(h_t)` from the last layer of the RNN, for each t
0 commit comments