@@ -75,6 +75,7 @@ def __init__(
7575 dropout : float = 0.0 ,
7676 word_dropout : float = 0.05 ,
7777 locked_dropout : float = 0.5 ,
78+ reproject_to : int = None ,
7879 train_initial_hidden_state : bool = False ,
7980 rnn_type : str = "LSTM" ,
8081 pickle_module : str = "pickle" ,
@@ -92,6 +93,7 @@ def __init__(
9293 :param rnn_layers: number of RNN layers
9394 :param dropout: dropout probability
9495 :param word_dropout: word dropout probability
96+ :param reproject_to: set this to control the dimensionality of the reprojection layer
9597 :param locked_dropout: locked dropout probability
9698 :param train_initial_hidden_state: if True, trains initial hidden state of RNN
9799 :param beta: Parameter for F-beta score for evaluation and training annealing
@@ -154,12 +156,16 @@ def __init__(
154156 if locked_dropout > 0.0 :
155157 self .locked_dropout = flair .nn .LockedDropout (locked_dropout )
156158
157- rnn_input_dim : int = self .embeddings .embedding_length
159+ embedding_dim : int = self .embeddings .embedding_length
158160
159- self .relearn_embeddings : bool = True
161+ # if no dimensionality for reprojection layer is set, reproject to equal dimension
162+ self .reproject_to = reproject_to
163+ if self .reproject_to is None : self .reproject_to = embedding_dim
164+ rnn_input_dim : int = self .reproject_to
160165
166+ self .relearn_embeddings : bool = True
161167 if self .relearn_embeddings :
162- self .embedding2nn = torch .nn .Linear (rnn_input_dim , rnn_input_dim )
168+ self .embedding2nn = torch .nn .Linear (embedding_dim , rnn_input_dim )
163169
164170 self .train_initial_hidden_state = train_initial_hidden_state
165171 self .bidirectional = True
@@ -237,6 +243,7 @@ def _get_state_dict(self):
237243 "rnn_type" : self .rnn_type ,
238244 "beta" : self .beta ,
239245 "weight_dict" : self .weight_dict ,
246+ "reproject_to" : self .reproject_to ,
240247 }
241248 return model_state
242249
@@ -260,6 +267,7 @@ def _init_model_with_state_dict(state):
260267 )
261268 beta = 1.0 if "beta" not in state .keys () else state ["beta" ]
262269 weights = None if "weight_dict" not in state .keys () else state ["weight_dict" ]
270+ reproject_to = None if "reproject_to" not in state .keys () else state ["reproject_to" ]
263271
264272 model = SequenceTagger (
265273 hidden_size = state ["hidden_size" ],
@@ -276,6 +284,7 @@ def _init_model_with_state_dict(state):
276284 rnn_type = rnn_type ,
277285 beta = beta ,
278286 loss_weights = weights ,
287+ reproject_to = reproject_to ,
279288 )
280289 model .load_state_dict (state ["state_dict" ])
281290 return model
@@ -1006,7 +1015,7 @@ def _fetch_model(model_name) -> str:
10061015 [hu_path , "release-de-pos-0" , "de-pos-ud-hdt-v0.5.pt" ]
10071016 )
10081017
1009- model_map ["de-pos-fine-grained " ] = "/" .join (
1018+ model_map ["de-pos-tweets " ] = "/" .join (
10101019 [
10111020 aws_resource_path_v04 ,
10121021 "POS-fine-grained-german-tweets" ,
@@ -1028,8 +1037,8 @@ def _fetch_model(model_name) -> str:
10281037 model_map ["nl-ner" ] = "/" .join (
10291038 [aws_resource_path_v04 , "NER-conll2002-dutch" , "nl-ner-conll02-v0.1.pt" ]
10301039 )
1031- model_map ["ml-pos" ] = "https://raw.githubusercontent.com/qburst/models-repository/master/FlairMalayalamModels/malayalam-upos -model.pt"
1032- model_map ["ml-xpos " ] = "https://raw.githubusercontent.com/qburst/models-repository/master/FlairMalayalamModels/malayalam-xpos -model.pt"
1040+ model_map ["ml-pos" ] = "https://raw.githubusercontent.com/qburst/models-repository/master/FlairMalayalamModels/malayalam-xpos -model.pt"
1041+ model_map ["ml-upos " ] = "https://raw.githubusercontent.com/qburst/models-repository/master/FlairMalayalamModels/malayalam-upos -model.pt"
10331042
10341043 cache_dir = Path ("models" )
10351044 if model_name in model_map :
0 commit comments