9
9
from huggingface_hub import model_info
10
10
from sklearn .decomposition import PCA
11
11
from tokenizers import Tokenizer
12
- from tokenizers .models import BPE , Unigram
13
12
from transformers import AutoModel , AutoTokenizer , PreTrainedModel , PreTrainedTokenizerFast
14
13
15
- from model2vec .distill .inference import (
16
- create_output_embeddings_from_model ,
17
- create_output_embeddings_from_model_and_tokens ,
18
- )
19
- from model2vec .distill .tokenizer import add_tokens , preprocess_vocabulary , remove_tokens
14
+ from model2vec .distill .inference import create_embeddings
15
+ from model2vec .distill .tokenizer import replace_vocabulary
20
16
from model2vec .distill .utils import select_optimal_device
21
17
from model2vec .model import StaticModel
22
18
@@ -71,52 +67,41 @@ def distill_from_model(
71
67
:return: A StaticModel
72
68
73
69
"""
74
- sif_coefficient = _validate_parameters (tokenizer , vocabulary , apply_zipf , sif_coefficient , use_subword )
70
+ backend_tokenizer = tokenizer .backend_tokenizer
71
+ sif_coefficient , token_remove_regex = _validate_parameters (
72
+ vocabulary , apply_zipf , sif_coefficient , use_subword , token_remove_pattern
73
+ )
74
+
75
+ if vocabulary is None :
76
+ vocabulary = []
75
77
76
78
device = select_optimal_device (device )
77
79
# Make a base list of tokens.
78
- tokens : list [str ] = []
79
- if use_subword :
80
- # Create the subword embeddings.
81
- tokens , embeddings = create_output_embeddings_from_model (model = model , tokenizer = tokenizer , device = device )
82
- new_tokenizer , embeddings = _remove_tokens_and_embeddings (tokenizer , token_remove_pattern , tokens , embeddings )
83
- else :
84
- # We need to keep the unk token in the tokenizer.
85
- unk_token = tokenizer .backend_tokenizer .model .unk_token
86
- # Remove all tokens except the UNK token.
87
- new_tokenizer = remove_tokens (tokenizer .backend_tokenizer , list (set (tokenizer .get_vocab ()) - {unk_token }))
88
- # We need to set embeddings to None because we don't know the dimensions of the embeddings yet.
89
- embeddings = None
90
-
91
- if vocabulary :
92
- # Preprocess the vocabulary with the original tokenizer.
93
- preprocessed_vocabulary = preprocess_vocabulary (tokenizer .backend_tokenizer , vocabulary )
94
- n_tokens_before = len (preprocessed_vocabulary )
95
- # Clean the vocabulary by removing duplicate tokens and tokens that are in the subword vocabulary.
96
- cleaned_vocabulary = _clean_vocabulary (preprocessed_vocabulary , tokens )
97
- n_tokens_after = len (cleaned_vocabulary )
98
- logger .info (
99
- f"Adding { n_tokens_after } tokens to the vocabulary. Removed { n_tokens_before - n_tokens_after } tokens during preprocessing."
100
- )
101
- # Only create embeddings if we have tokens to add.
102
- if cleaned_vocabulary :
103
- # Create the embeddings.
104
- _ , token_embeddings = create_output_embeddings_from_model_and_tokens (
105
- model = model ,
106
- tokenizer = tokenizer ,
107
- tokens = cleaned_vocabulary ,
108
- device = device ,
109
- )
80
+ subword_vocab : dict [str , int ] = tokenizer .get_vocab ()
81
+ subword_tokens : list [str ] = [k for k , _ in sorted (subword_vocab .items (), key = lambda x : x [1 ])]
82
+
83
+ n_tokens_before = len (vocabulary )
84
+ # Clean the vocabulary by removing duplicate tokens and tokens that are in the subword vocabulary.
85
+ cleaned_vocabulary = _clean_vocabulary (tokenizer .backend_tokenizer , vocabulary , subword_tokens )
86
+ n_tokens_after = len (cleaned_vocabulary )
87
+ logger .info (
88
+ f"Adding { n_tokens_after } tokens to the vocabulary. Removed { n_tokens_before - n_tokens_after } tokens during preprocessing."
89
+ )
110
90
111
- # If we don't have subword tokens, we still need to create
112
- # some embeddings for [UNK] and some other special tokens.
113
- if embeddings is None :
114
- embeddings = np .zeros ((new_tokenizer .get_vocab_size (), token_embeddings .shape [1 ]))
115
- embeddings = np .concatenate ([embeddings , token_embeddings ], axis = 0 )
116
- # Add the cleaned vocabulary to the tokenizer.
117
- new_tokenizer = add_tokens (new_tokenizer , cleaned_vocabulary )
118
- else :
119
- logger .warning ("Didn't create any token embeddings as all tokens were duplicates or empty." )
91
+ # Create the embeddings.
92
+ all_tokens , embeddings = create_embeddings (
93
+ model = model ,
94
+ tokenizer = tokenizer ,
95
+ tokens = cleaned_vocabulary ,
96
+ device = device ,
97
+ use_subword = use_subword ,
98
+ token_remove_regex = token_remove_regex ,
99
+ )
100
+
101
+ unk_token = tokenizer .special_tokens_map .get ("unk_token" )
102
+ pad_token = tokenizer .special_tokens_map .get ("pad_token" )
103
+ # Add the cleaned vocabulary to the tokenizer.
104
+ backend_tokenizer = replace_vocabulary (backend_tokenizer , all_tokens , unk_token = unk_token , pad_token = pad_token )
120
105
121
106
# Post process the embeddings by applying PCA and Zipf weighting.
122
107
embeddings = _post_process_embeddings (np .asarray (embeddings ), pca_dims , sif_coefficient = sif_coefficient )
@@ -150,7 +135,7 @@ def distill_from_model(
150
135
151
136
return StaticModel (
152
137
vectors = embeddings ,
153
- tokenizer = new_tokenizer ,
138
+ tokenizer = backend_tokenizer ,
154
139
config = config ,
155
140
base_model_name = model_name ,
156
141
language = language ,
@@ -159,22 +144,22 @@ def distill_from_model(
159
144
160
145
161
146
def _validate_parameters (
162
- tokenizer : PreTrainedTokenizerFast ,
163
147
vocabulary : list [str ] | None ,
164
148
apply_zipf : bool | None ,
165
149
sif_coefficient : float | None ,
166
150
use_subword : bool ,
167
- ) -> float | None :
151
+ token_remove_pattern : str | None ,
152
+ ) -> tuple [float | None , re .Pattern | None ]:
168
153
"""
169
154
Validate the parameters passed to the distillation function.
170
155
171
- :param tokenizer: The tokenizer to use.
172
156
:param vocabulary: The vocabulary to use.
173
157
:param apply_zipf: DEPRECATED: This parameter used to control whether Zipf is applied.
174
158
Zipf weighting is now controlled by the sif_coefficient parameter. If this is set to None, no weighting is applied.
175
159
:param sif_coefficient: The SIF coefficient to use. If this is None, no weighting is applied.
176
160
Should be a value >= 0 and < 1.0. A value of 1e-4 is a good default.
177
161
:param use_subword: Whether to keep subword tokens in the vocabulary. If this is False, you must pass a vocabulary, and the returned tokenizer will only detect full words.
162
+ :param token_remove_pattern: If this is set to a string, we compile this into a regex. Any tokens that conform to this regex pattern will be removed from the vocabulary.
178
163
:return: The SIF coefficient to use.
179
164
:raises: ValueError if the PCA dimension is larger than the number of dimensions in the embeddings.
180
165
:raises: ValueError if the vocabulary contains duplicate tokens.
@@ -204,49 +189,14 @@ def _validate_parameters(
204
189
"You must pass a vocabulary if you don't use subword tokens. Either pass a vocabulary, or set use_subword to True."
205
190
)
206
191
207
- if vocabulary and isinstance (tokenizer .backend_tokenizer .model , (BPE , Unigram )):
208
- raise ValueError (
209
- "You passed a vocabulary, but the model you are using does not use a WordPiece tokenizer. "
210
- "This is not supported yet."
211
- "Feel free to open an issue if this is a blocker: https://github.com/MinishLab/model2vec/issues"
212
- )
213
-
214
- return sif_coefficient
215
-
216
-
217
- def _remove_tokens_and_embeddings (
218
- tokenizer : PreTrainedTokenizerFast , token_remove_pattern : str | None , tokens : list [str ], embeddings : np .ndarray
219
- ) -> tuple [Tokenizer , np .ndarray ]:
220
- if not token_remove_pattern :
221
- return tokenizer .backend_tokenizer , embeddings
222
-
223
- try :
224
- token_regex = re .compile (token_remove_pattern )
225
- except re .error as e :
226
- raise ValueError (f"Invalid regex pattern: { token_remove_pattern } " ) from e
227
- # Remove any unused tokens from the tokenizer and embeddings.
228
- wrong_tokens = [x for x in tokens if token_regex .match (x )]
229
- vocab = tokenizer .get_vocab ()
230
- # Get the ids of the unused token.
231
- wrong_token_ids = [vocab [token ] for token in wrong_tokens ]
232
-
233
- if len (wrong_token_ids ) == len (vocab ):
234
- raise ValueError (
235
- "All tokens in the vocabulary are unused tokens. This will result in an empty tokenizer. "
236
- "Please provide a valid token removal pattern. The pattern is now: {token_remove_pattern}"
237
- )
238
-
239
- # Remove the unused tokens from the tokenizer.
240
- new_tokenizer = remove_tokens (tokenizer .backend_tokenizer , wrong_tokens )
241
- if new_tokenizer .get_vocab_size () == tokenizer .backend_tokenizer .get_vocab_size ():
242
- # This happens if we didn't remove any tokens.
243
- return new_tokenizer , embeddings
244
-
245
- # Remove the embeddings of the unused tokens.
246
- embeddings = np .delete (embeddings , wrong_token_ids , axis = 0 )
247
- logger .info (f"Removed { len (wrong_tokens )} unused tokens from the tokenizer and embeddings." )
192
+ token_remove_regex : re .Pattern | None = None
193
+ if token_remove_pattern is not None :
194
+ try :
195
+ token_remove_regex = re .compile (token_remove_pattern )
196
+ except re .error as e :
197
+ raise ValueError (f"Couldn't compile the regex pattern: { e } " )
248
198
249
- return new_tokenizer , embeddings
199
+ return sif_coefficient , token_remove_regex
250
200
251
201
252
202
def distill (
@@ -345,26 +295,40 @@ def _post_process_embeddings(
345
295
return embeddings
346
296
347
297
348
- def _clean_vocabulary (preprocessed_vocabulary : list [str ], added_tokens : list [str ]) -> list [str ]:
298
+ def _clean_vocabulary (tokenizer : Tokenizer , vocabulary : list [str ], added_tokens : list [str ]) -> list [str ]:
349
299
"""Cleans a vocabulary by removing duplicates and tokens that were already in the vocabulary."""
350
300
added_tokens_set = set (added_tokens )
351
301
seen_tokens = set ()
352
302
cleaned_vocabulary = []
353
303
n_empty = 0
354
304
n_duplicates = 0
355
- for token in preprocessed_vocabulary :
305
+ n_multiword = 0
306
+ for token in vocabulary :
307
+ if tokenizer .normalizer is not None :
308
+ token = tokenizer .normalizer .normalize_str (token )
309
+
356
310
if not token :
357
311
n_empty += 1
358
312
continue
359
313
if token in seen_tokens or token in added_tokens_set :
360
314
n_duplicates += 1
361
315
continue
316
+
317
+ pre_tokenizer = tokenizer .pre_tokenizer
318
+ if pre_tokenizer is not None :
319
+ pretokenized_tokens = pre_tokenizer .pre_tokenize_str (token )
320
+ if len (pretokenized_tokens ) != 1 :
321
+ n_multiword += 1
322
+ continue
323
+
362
324
seen_tokens .add (token )
363
325
cleaned_vocabulary .append (token )
364
326
365
327
if n_duplicates :
366
328
logger .warning (f"Removed { n_duplicates } duplicate tokens." )
367
329
if n_empty :
368
330
logger .warning (f"Removed { n_empty } empty tokens." )
331
+ if n_multiword :
332
+ logger .warning (f"Removed { n_multiword } multiword tokens." )
369
333
370
334
return cleaned_vocabulary
0 commit comments