Skip to content

Commit 8889f9c

Browse files
authored
[fbsync] Replace TransformerEncoder in torchtext with better transformer (#1703)
1 parent 7bc0071 commit 8889f9c

File tree

1 file changed

+54
-26
lines changed

1 file changed

+54
-26
lines changed

torchtext/models/roberta/modules.py

Lines changed: 54 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -110,19 +110,17 @@ def __init__(
110110
super().__init__()
111111
self.padding_idx = padding_idx
112112
self.token_embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx)
113-
self.layers = nn.ModuleList(
114-
[
115-
TransformerEncoderLayer(
116-
embedding_dim=embedding_dim,
117-
num_attention_heads=num_attention_heads,
118-
ffn_dimension=ffn_dimension,
119-
dropout=dropout,
120-
normalize_before=normalize_before,
121-
scaling=scaling,
122-
)
123-
for _ in range(num_encoder_layers)
124-
]
113+
ffn_dimension = ffn_dimension or 4 * embedding_dim
114+
layer = torch.nn.TransformerEncoderLayer(
115+
d_model=embedding_dim,
116+
nhead=num_attention_heads,
117+
dim_feedforward=ffn_dimension,
118+
dropout=dropout,
119+
activation="gelu",
120+
batch_first=True,
121+
norm_first=normalize_before,
125122
)
123+
self.layers = torch.nn.TransformerEncoder(encoder_layer=layer, num_layers=num_encoder_layers)
126124
self.positional_embedding = PositionalEmbedding(max_seq_len, embedding_dim, padding_idx)
127125
self.embedding_layer_norm = nn.LayerNorm(embedding_dim)
128126
self.dropout = nn.Dropout(dropout)
@@ -153,27 +151,57 @@ def forward(
153151

154152
padded_embedded = embedded * (1 - padding_mask.unsqueeze(-1).type_as(embedded))
155153

156-
encoded = padded_embedded.transpose(0, 1)
157-
158154
if self.return_all_layers:
159-
states = [encoded]
160-
161-
for layer in self.layers:
155+
encoded = padded_embedded
156+
# B x T x C
157+
# Then transpose back to T x B x C
158+
states = [encoded.transpose(1, 0)]
159+
for layer in self.layers.layers:
162160
encoded = layer(encoded, padding_mask, attn_mask)
163-
states.append(encoded)
164-
161+
encoded_t = encoded.transpose(1, 0)
162+
states.append(encoded_t)
165163
if self.normalize_before:
166164
for i, state in enumerate(states):
167165
states[i] = self.embedding_layer_norm(state)
168-
169-
# states are returned as T x B x C
170166
return states
171167
else:
172-
for layer in self.layers:
173-
encoded = layer(encoded, padding_mask, attn_mask)
174-
168+
# B x T x C
169+
# Then transpose back to T x B x C
170+
encoded = self.layers(padded_embedded).transpose(1, 0)
175171
if self.normalize_before:
176172
encoded = self.embedding_layer_norm(encoded)
177-
178-
# states are returned as T x B x C
179173
return encoded
174+
175+
def _load_from_state_dict(
176+
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
177+
):
178+
better_to_old_names = {
179+
"self_attn.in_proj_weight": "attention.input_projection.weight",
180+
"self_attn.in_proj_bias": "attention.input_projection.bias",
181+
"self_attn.out_proj.weight": "attention.output_projection.weight",
182+
"self_attn.out_proj.bias": "attention.output_projection.bias",
183+
"linear1.weight": "residual_mlp.mlp.0.weight",
184+
"linear1.bias": "residual_mlp.mlp.0.bias",
185+
"linear2.weight": "residual_mlp.mlp.3.weight",
186+
"linear2.bias": "residual_mlp.mlp.3.bias",
187+
"norm1.weight": "attention_layer_norm.weight",
188+
"norm1.bias": "attention_layer_norm.bias",
189+
"norm2.weight": "final_layer_norm.weight",
190+
"norm2.bias": "final_layer_norm.bias",
191+
}
192+
for i in range(self.layers.num_layers):
193+
for better, old in better_to_old_names.items():
194+
better_name = prefix + "layers.layers.{}.".format(i) + better
195+
old_name = prefix + "layers.{}.".format(i) + old
196+
if old_name in state_dict:
197+
state_dict[better_name] = state_dict[old_name]
198+
state_dict.pop(old_name)
199+
elif better_name in state_dict:
200+
# Do nothing
201+
pass
202+
elif strict:
203+
missing_keys.append(better_name)
204+
205+
super(TransformerEncoder, self)._load_from_state_dict(
206+
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
207+
)

0 commit comments

Comments
 (0)