@@ -69,16 +69,31 @@ def forward(self, input, hx):
6969 )
7070 return func (input , self .all_weights , hx )
7171
72+ def __repr__ (self ):
73+ s = '{name}({input_size}, {hidden_size}'
74+ if self .num_layers != 1 :
75+ s += ', num_layers={num_layers}'
76+ if self .bias is not True :
77+ s += ', bias={bias}'
78+ if self .batch_first is not False :
79+ s += ', batch_first={batch_first}'
80+ if self .dropout != 0 :
81+ s += ', dropout={dropout}'
82+ if self .bidirectional is not False :
83+ s += ', bidirectional={bidirectional}'
84+ s += ')'
85+ return s .format (name = self .__class__ .__name__ , ** self .__dict__ )
86+
7287
7388class RNN (RNNBase ):
7489 r"""Applies a multi-layer Elman RNN with tanh or ReLU non-linearity to an input sequence.
7590
7691
7792 For each element in the input sequence, each layer computes the following
7893 function:
79-
94+
8095 .. math::
81-
96+
8297 h_t = \tanh(w_{ih} * x_t + b_{ih} + w_{hh} * h_{(t-1)} + b_{hh})
8398
8499 where :math:`h_t` is the hidden state at time `t`, and :math:`x_t` is the hidden
@@ -104,9 +119,9 @@ class RNN(RNNBase):
104119 - `h_n`: A (num_layers x batch x hidden_size) tensor containing the hidden state for k=seq_len
105120
106121 Attributes:
107- weight_ih_l[k]: the learnable input-hidden weights of the k-th layer,
122+ weight_ih_l[k]: the learnable input-hidden weights of the k-th layer,
108123 of shape `(input_size x hidden_size)`
109- weight_hh_l[k]: the learnable hidden-hidden weights of the k-th layer,
124+ weight_hh_l[k]: the learnable hidden-hidden weights of the k-th layer,
110125 of shape `(hidden_size x hidden_size)`
111126 bias_ih_l[k]: the learnable input-hidden bias of the k-th layer, of shape `(hidden_size)`
112127 bias_hh_l[k]: the learnable hidden-hidden bias of the k-th layer, of shape `(hidden_size)`
@@ -155,7 +170,7 @@ class LSTM(RNNBase):
155170
156171 where :math:`h_t` is the hidden state at time `t`, :math:`c_t` is the cell state at time `t`,
157172 :math:`x_t` is the hidden state of the previous layer at time `t` or :math:`input_t` for the first layer,
158- and :math:`i_t`, :math:`f_t`, :math:`g_t`, :math:`o_t` are the input, forget,
173+ and :math:`i_t`, :math:`f_t`, :math:`g_t`, :math:`o_t` are the input, forget,
159174 cell, and out gates, respectively.
160175
161176 Args:
@@ -249,7 +264,19 @@ def __init__(self, *args, **kwargs):
249264 super (GRU , self ).__init__ ('GRU' , * args , ** kwargs )
250265
251266
252- class RNNCell (Module ):
267+ class RNNCellBase (Module ):
268+
269+ def __repr__ (self ):
270+ s = '{name}({input_size}, {hidden_size}'
271+ if 'bias' in self .__dict__ and self .bias != True :
272+ s += ', bias={bias}}'
273+ if 'nonlinearity' in self .__dict__ and self .nonlinearity != "tanh" :
274+ s += ', nonlinearity={nonlinearity}'
275+ s += ')'
276+ return s .format (name = self .__class__ .__name__ , ** self .__dict__ )
277+
278+
279+ class RNNCell (RNNCellBase ):
253280 r"""An Elman RNN cell with tanh or ReLU non-linearity.
254281
255282 .. math::
@@ -325,7 +352,7 @@ def forward(self, input, hx):
325352 )
326353
327354
328- class LSTMCell (Module ):
355+ class LSTMCell (RNNCellBase ):
329356 r"""A long short-term memory (LSTM) cell.
330357
331358 .. math::
@@ -399,7 +426,7 @@ def forward(self, input, hx):
399426 )
400427
401428
402- class GRUCell (Module ):
429+ class GRUCell (RNNCellBase ):
403430 r"""A gated recurrent unit (GRU) cell
404431 .. math::
405432
0 commit comments