@@ -110,19 +110,17 @@ def __init__(
110
110
super ().__init__ ()
111
111
self .padding_idx = padding_idx
112
112
self .token_embedding = nn .Embedding (vocab_size , embedding_dim , padding_idx )
113
- self .layers = nn .ModuleList (
114
- [
115
- TransformerEncoderLayer (
116
- embedding_dim = embedding_dim ,
117
- num_attention_heads = num_attention_heads ,
118
- ffn_dimension = ffn_dimension ,
119
- dropout = dropout ,
120
- normalize_before = normalize_before ,
121
- scaling = scaling ,
122
- )
123
- for _ in range (num_encoder_layers )
124
- ]
113
+ ffn_dimension = ffn_dimension or 4 * embedding_dim
114
+ layer = torch .nn .TransformerEncoderLayer (
115
+ d_model = embedding_dim ,
116
+ nhead = num_attention_heads ,
117
+ dim_feedforward = ffn_dimension ,
118
+ dropout = dropout ,
119
+ activation = "gelu" ,
120
+ batch_first = True ,
121
+ norm_first = normalize_before ,
125
122
)
123
+ self .layers = torch .nn .TransformerEncoder (encoder_layer = layer , num_layers = num_encoder_layers )
126
124
self .positional_embedding = PositionalEmbedding (max_seq_len , embedding_dim , padding_idx )
127
125
self .embedding_layer_norm = nn .LayerNorm (embedding_dim )
128
126
self .dropout = nn .Dropout (dropout )
@@ -153,27 +151,57 @@ def forward(
153
151
154
152
padded_embedded = embedded * (1 - padding_mask .unsqueeze (- 1 ).type_as (embedded ))
155
153
156
- encoded = padded_embedded .transpose (0 , 1 )
157
-
158
154
if self .return_all_layers :
159
- states = [encoded ]
160
-
161
- for layer in self .layers :
155
+ encoded = padded_embedded
156
+ # B x T x C
157
+ # Then transpose back to T x B x C
158
+ states = [encoded .transpose (1 , 0 )]
159
+ for layer in self .layers .layers :
162
160
encoded = layer (encoded , padding_mask , attn_mask )
163
- states . append ( encoded )
164
-
161
+ encoded_t = encoded . transpose ( 1 , 0 )
162
+ states . append ( encoded_t )
165
163
if self .normalize_before :
166
164
for i , state in enumerate (states ):
167
165
states [i ] = self .embedding_layer_norm (state )
168
-
169
- # states are returned as T x B x C
170
166
return states
171
167
else :
172
- for layer in self . layers :
173
- encoded = layer ( encoded , padding_mask , attn_mask )
174
-
168
+ # B x T x C
169
+ # Then transpose back to T x B x C
170
+ encoded = self . layers ( padded_embedded ). transpose ( 1 , 0 )
175
171
if self .normalize_before :
176
172
encoded = self .embedding_layer_norm (encoded )
177
-
178
- # states are returned as T x B x C
179
173
return encoded
174
+
175
+ def _load_from_state_dict (
176
+ self , state_dict , prefix , local_metadata , strict , missing_keys , unexpected_keys , error_msgs
177
+ ):
178
+ better_to_old_names = {
179
+ "self_attn.in_proj_weight" : "attention.input_projection.weight" ,
180
+ "self_attn.in_proj_bias" : "attention.input_projection.bias" ,
181
+ "self_attn.out_proj.weight" : "attention.output_projection.weight" ,
182
+ "self_attn.out_proj.bias" : "attention.output_projection.bias" ,
183
+ "linear1.weight" : "residual_mlp.mlp.0.weight" ,
184
+ "linear1.bias" : "residual_mlp.mlp.0.bias" ,
185
+ "linear2.weight" : "residual_mlp.mlp.3.weight" ,
186
+ "linear2.bias" : "residual_mlp.mlp.3.bias" ,
187
+ "norm1.weight" : "attention_layer_norm.weight" ,
188
+ "norm1.bias" : "attention_layer_norm.bias" ,
189
+ "norm2.weight" : "final_layer_norm.weight" ,
190
+ "norm2.bias" : "final_layer_norm.bias" ,
191
+ }
192
+ for i in range (self .layers .num_layers ):
193
+ for better , old in better_to_old_names .items ():
194
+ better_name = prefix + "layers.layers.{}." .format (i ) + better
195
+ old_name = prefix + "layers.{}." .format (i ) + old
196
+ if old_name in state_dict :
197
+ state_dict [better_name ] = state_dict [old_name ]
198
+ state_dict .pop (old_name )
199
+ elif better_name in state_dict :
200
+ # Do nothing
201
+ pass
202
+ elif strict :
203
+ missing_keys .append (better_name )
204
+
205
+ super (TransformerEncoder , self )._load_from_state_dict (
206
+ state_dict , prefix , local_metadata , strict , missing_keys , unexpected_keys , error_msgs
207
+ )
0 commit comments