Skip to content

Commit 3152be5

Browse files
adamlerersoumith
authored andcommitted
Add repr to RNNs and Embedding (pytorch#428)
1 parent 3a07228 commit 3152be5

File tree

2 files changed

+53
-11
lines changed

2 files changed

+53
-11
lines changed

torch/nn/modules/rnn.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -69,16 +69,31 @@ def forward(self, input, hx):
6969
)
7070
return func(input, self.all_weights, hx)
7171

72+
def __repr__(self):
73+
s = '{name}({input_size}, {hidden_size}'
74+
if self.num_layers != 1:
75+
s += ', num_layers={num_layers}'
76+
if self.bias is not True:
77+
s += ', bias={bias}'
78+
if self.batch_first is not False:
79+
s += ', batch_first={batch_first}'
80+
if self.dropout != 0:
81+
s += ', dropout={dropout}'
82+
if self.bidirectional is not False:
83+
s += ', bidirectional={bidirectional}'
84+
s += ')'
85+
return s.format(name=self.__class__.__name__, **self.__dict__)
86+
7287

7388
class RNN(RNNBase):
7489
r"""Applies a multi-layer Elman RNN with tanh or ReLU non-linearity to an input sequence.
7590
7691
7792
For each element in the input sequence, each layer computes the following
7893
function:
79-
94+
8095
.. math::
81-
96+
8297
h_t = \tanh(w_{ih} * x_t + b_{ih} + w_{hh} * h_{(t-1)} + b_{hh})
8398
8499
where :math:`h_t` is the hidden state at time `t`, and :math:`x_t` is the hidden
@@ -104,9 +119,9 @@ class RNN(RNNBase):
104119
- `h_n`: A (num_layers x batch x hidden_size) tensor containing the hidden state for k=seq_len
105120
106121
Attributes:
107-
weight_ih_l[k]: the learnable input-hidden weights of the k-th layer,
122+
weight_ih_l[k]: the learnable input-hidden weights of the k-th layer,
108123
of shape `(input_size x hidden_size)`
109-
weight_hh_l[k]: the learnable hidden-hidden weights of the k-th layer,
124+
weight_hh_l[k]: the learnable hidden-hidden weights of the k-th layer,
110125
of shape `(hidden_size x hidden_size)`
111126
bias_ih_l[k]: the learnable input-hidden bias of the k-th layer, of shape `(hidden_size)`
112127
bias_hh_l[k]: the learnable hidden-hidden bias of the k-th layer, of shape `(hidden_size)`
@@ -155,7 +170,7 @@ class LSTM(RNNBase):
155170
156171
where :math:`h_t` is the hidden state at time `t`, :math:`c_t` is the cell state at time `t`,
157172
:math:`x_t` is the hidden state of the previous layer at time `t` or :math:`input_t` for the first layer,
158-
and :math:`i_t`, :math:`f_t`, :math:`g_t`, :math:`o_t` are the input, forget,
173+
and :math:`i_t`, :math:`f_t`, :math:`g_t`, :math:`o_t` are the input, forget,
159174
cell, and out gates, respectively.
160175
161176
Args:
@@ -249,7 +264,19 @@ def __init__(self, *args, **kwargs):
249264
super(GRU, self).__init__('GRU', *args, **kwargs)
250265

251266

252-
class RNNCell(Module):
267+
class RNNCellBase(Module):
268+
269+
def __repr__(self):
270+
s = '{name}({input_size}, {hidden_size}'
271+
if 'bias' in self.__dict__ and self.bias != True:
272+
s += ', bias={bias}}'
273+
if 'nonlinearity' in self.__dict__ and self.nonlinearity != "tanh":
274+
s += ', nonlinearity={nonlinearity}'
275+
s += ')'
276+
return s.format(name=self.__class__.__name__, **self.__dict__)
277+
278+
279+
class RNNCell(RNNCellBase):
253280
r"""An Elman RNN cell with tanh or ReLU non-linearity.
254281
255282
.. math::
@@ -325,7 +352,7 @@ def forward(self, input, hx):
325352
)
326353

327354

328-
class LSTMCell(Module):
355+
class LSTMCell(RNNCellBase):
329356
r"""A long short-term memory (LSTM) cell.
330357
331358
.. math::
@@ -399,7 +426,7 @@ def forward(self, input, hx):
399426
)
400427

401428

402-
class GRUCell(Module):
429+
class GRUCell(RNNCellBase):
403430
r"""A gated recurrent unit (GRU) cell
404431
.. math::
405432

torch/nn/modules/sparse.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class Embedding(Module):
2121
2222
Attributes:
2323
weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim)
24-
24+
2525
Shape:
2626
- Input: LongTensor `(N, W)`, N = mini-batch, W = number of indices to extract per mini-batch
2727
- Output: `(N, W, embedding_dim)`
@@ -40,7 +40,7 @@ class Embedding(Module):
4040
0.8393 -0.6062 -0.3348
4141
0.6597 0.0350 0.0837
4242
0.5521 0.9447 0.0498
43-
43+
4444
(1 ,.,.) =
4545
0.6597 0.0350 0.0837
4646
-0.1527 0.0877 0.4260
@@ -60,7 +60,7 @@ class Embedding(Module):
6060
0.0000 0.0000 0.0000
6161
0.0706 -2.1962 -0.6276
6262
[torch.FloatTensor of size 1x4x3]
63-
63+
6464
"""
6565
def __init__(self, num_embeddings, embedding_dim, padding_idx=None,
6666
max_norm=None, norm_type=2, scale_grad_by_freq=False,
@@ -91,5 +91,20 @@ def forward(self, input):
9191
self.scale_grad_by_freq, self.sparse
9292
)(input, self.weight)
9393

94+
def __repr__(self):
95+
s = '{name}({num_embeddings}, {embedding_dim}'
96+
if self.padding_idx is not None:
97+
s += ', padding_idx={padding_idx}'
98+
if self.max_norm is not None:
99+
s += ', max_norm={max_norm}'
100+
if self.norm_type != 2:
101+
s += ', norm_type={norm_type}'
102+
if self.scale_grad_by_freq is not False:
103+
s += ', scale_grad_by_freq={scale_grad_by_freq}'
104+
if self.sparse is not False:
105+
s += ', sparse=True'
106+
s += ')'
107+
return s.format(name=self.__class__.__name__, **self.__dict__)
108+
94109

95110
# TODO: SparseLinear

0 commit comments

Comments
 (0)