Skip to content

Commit 844c3fa

Browse files
authored
rewrite backend (MinishLab#207)
* rewrite backend * fixes, clean-up * fix multiword, clean * add pad token * fix tests * Fix backend * comments/docstring
1 parent a00aaab commit 844c3fa

File tree

7 files changed

+321
-306
lines changed

7 files changed

+321
-306
lines changed

model2vec/distill/distillation.py

Lines changed: 60 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,10 @@
99
from huggingface_hub import model_info
1010
from sklearn.decomposition import PCA
1111
from tokenizers import Tokenizer
12-
from tokenizers.models import BPE, Unigram
1312
from transformers import AutoModel, AutoTokenizer, PreTrainedModel, PreTrainedTokenizerFast
1413

15-
from model2vec.distill.inference import (
16-
create_output_embeddings_from_model,
17-
create_output_embeddings_from_model_and_tokens,
18-
)
19-
from model2vec.distill.tokenizer import add_tokens, preprocess_vocabulary, remove_tokens
14+
from model2vec.distill.inference import create_embeddings
15+
from model2vec.distill.tokenizer import replace_vocabulary
2016
from model2vec.distill.utils import select_optimal_device
2117
from model2vec.model import StaticModel
2218

@@ -71,52 +67,41 @@ def distill_from_model(
7167
:return: A StaticModel
7268
7369
"""
74-
sif_coefficient = _validate_parameters(tokenizer, vocabulary, apply_zipf, sif_coefficient, use_subword)
70+
backend_tokenizer = tokenizer.backend_tokenizer
71+
sif_coefficient, token_remove_regex = _validate_parameters(
72+
vocabulary, apply_zipf, sif_coefficient, use_subword, token_remove_pattern
73+
)
74+
75+
if vocabulary is None:
76+
vocabulary = []
7577

7678
device = select_optimal_device(device)
7779
# Make a base list of tokens.
78-
tokens: list[str] = []
79-
if use_subword:
80-
# Create the subword embeddings.
81-
tokens, embeddings = create_output_embeddings_from_model(model=model, tokenizer=tokenizer, device=device)
82-
new_tokenizer, embeddings = _remove_tokens_and_embeddings(tokenizer, token_remove_pattern, tokens, embeddings)
83-
else:
84-
# We need to keep the unk token in the tokenizer.
85-
unk_token = tokenizer.backend_tokenizer.model.unk_token
86-
# Remove all tokens except the UNK token.
87-
new_tokenizer = remove_tokens(tokenizer.backend_tokenizer, list(set(tokenizer.get_vocab()) - {unk_token}))
88-
# We need to set embeddings to None because we don't know the dimensions of the embeddings yet.
89-
embeddings = None
90-
91-
if vocabulary:
92-
# Preprocess the vocabulary with the original tokenizer.
93-
preprocessed_vocabulary = preprocess_vocabulary(tokenizer.backend_tokenizer, vocabulary)
94-
n_tokens_before = len(preprocessed_vocabulary)
95-
# Clean the vocabulary by removing duplicate tokens and tokens that are in the subword vocabulary.
96-
cleaned_vocabulary = _clean_vocabulary(preprocessed_vocabulary, tokens)
97-
n_tokens_after = len(cleaned_vocabulary)
98-
logger.info(
99-
f"Adding {n_tokens_after} tokens to the vocabulary. Removed {n_tokens_before - n_tokens_after} tokens during preprocessing."
100-
)
101-
# Only create embeddings if we have tokens to add.
102-
if cleaned_vocabulary:
103-
# Create the embeddings.
104-
_, token_embeddings = create_output_embeddings_from_model_and_tokens(
105-
model=model,
106-
tokenizer=tokenizer,
107-
tokens=cleaned_vocabulary,
108-
device=device,
109-
)
80+
subword_vocab: dict[str, int] = tokenizer.get_vocab()
81+
subword_tokens: list[str] = [k for k, _ in sorted(subword_vocab.items(), key=lambda x: x[1])]
82+
83+
n_tokens_before = len(vocabulary)
84+
# Clean the vocabulary by removing duplicate tokens and tokens that are in the subword vocabulary.
85+
cleaned_vocabulary = _clean_vocabulary(tokenizer.backend_tokenizer, vocabulary, subword_tokens)
86+
n_tokens_after = len(cleaned_vocabulary)
87+
logger.info(
88+
f"Adding {n_tokens_after} tokens to the vocabulary. Removed {n_tokens_before - n_tokens_after} tokens during preprocessing."
89+
)
11090

111-
# If we don't have subword tokens, we still need to create
112-
# some embeddings for [UNK] and some other special tokens.
113-
if embeddings is None:
114-
embeddings = np.zeros((new_tokenizer.get_vocab_size(), token_embeddings.shape[1]))
115-
embeddings = np.concatenate([embeddings, token_embeddings], axis=0)
116-
# Add the cleaned vocabulary to the tokenizer.
117-
new_tokenizer = add_tokens(new_tokenizer, cleaned_vocabulary)
118-
else:
119-
logger.warning("Didn't create any token embeddings as all tokens were duplicates or empty.")
91+
# Create the embeddings.
92+
all_tokens, embeddings = create_embeddings(
93+
model=model,
94+
tokenizer=tokenizer,
95+
tokens=cleaned_vocabulary,
96+
device=device,
97+
use_subword=use_subword,
98+
token_remove_regex=token_remove_regex,
99+
)
100+
101+
unk_token = tokenizer.special_tokens_map.get("unk_token")
102+
pad_token = tokenizer.special_tokens_map.get("pad_token")
103+
# Add the cleaned vocabulary to the tokenizer.
104+
backend_tokenizer = replace_vocabulary(backend_tokenizer, all_tokens, unk_token=unk_token, pad_token=pad_token)
120105

121106
# Post process the embeddings by applying PCA and Zipf weighting.
122107
embeddings = _post_process_embeddings(np.asarray(embeddings), pca_dims, sif_coefficient=sif_coefficient)
@@ -150,7 +135,7 @@ def distill_from_model(
150135

151136
return StaticModel(
152137
vectors=embeddings,
153-
tokenizer=new_tokenizer,
138+
tokenizer=backend_tokenizer,
154139
config=config,
155140
base_model_name=model_name,
156141
language=language,
@@ -159,22 +144,22 @@ def distill_from_model(
159144

160145

161146
def _validate_parameters(
162-
tokenizer: PreTrainedTokenizerFast,
163147
vocabulary: list[str] | None,
164148
apply_zipf: bool | None,
165149
sif_coefficient: float | None,
166150
use_subword: bool,
167-
) -> float | None:
151+
token_remove_pattern: str | None,
152+
) -> tuple[float | None, re.Pattern | None]:
168153
"""
169154
Validate the parameters passed to the distillation function.
170155
171-
:param tokenizer: The tokenizer to use.
172156
:param vocabulary: The vocabulary to use.
173157
:param apply_zipf: DEPRECATED: This parameter used to control whether Zipf is applied.
174158
Zipf weighting is now controlled by the sif_coefficient parameter. If this is set to None, no weighting is applied.
175159
:param sif_coefficient: The SIF coefficient to use. If this is None, no weighting is applied.
176160
Should be a value >= 0 and < 1.0. A value of 1e-4 is a good default.
177161
:param use_subword: Whether to keep subword tokens in the vocabulary. If this is False, you must pass a vocabulary, and the returned tokenizer will only detect full words.
162+
:param token_remove_pattern: If this is set to a string, we compile this into a regex. Any tokens that conform to this regex pattern will be removed from the vocabulary.
178163
:return: The SIF coefficient to use.
179164
:raises: ValueError if the PCA dimension is larger than the number of dimensions in the embeddings.
180165
:raises: ValueError if the vocabulary contains duplicate tokens.
@@ -204,49 +189,14 @@ def _validate_parameters(
204189
"You must pass a vocabulary if you don't use subword tokens. Either pass a vocabulary, or set use_subword to True."
205190
)
206191

207-
if vocabulary and isinstance(tokenizer.backend_tokenizer.model, (BPE, Unigram)):
208-
raise ValueError(
209-
"You passed a vocabulary, but the model you are using does not use a WordPiece tokenizer. "
210-
"This is not supported yet."
211-
"Feel free to open an issue if this is a blocker: https://github.com/MinishLab/model2vec/issues"
212-
)
213-
214-
return sif_coefficient
215-
216-
217-
def _remove_tokens_and_embeddings(
218-
tokenizer: PreTrainedTokenizerFast, token_remove_pattern: str | None, tokens: list[str], embeddings: np.ndarray
219-
) -> tuple[Tokenizer, np.ndarray]:
220-
if not token_remove_pattern:
221-
return tokenizer.backend_tokenizer, embeddings
222-
223-
try:
224-
token_regex = re.compile(token_remove_pattern)
225-
except re.error as e:
226-
raise ValueError(f"Invalid regex pattern: {token_remove_pattern}") from e
227-
# Remove any unused tokens from the tokenizer and embeddings.
228-
wrong_tokens = [x for x in tokens if token_regex.match(x)]
229-
vocab = tokenizer.get_vocab()
230-
# Get the ids of the unused token.
231-
wrong_token_ids = [vocab[token] for token in wrong_tokens]
232-
233-
if len(wrong_token_ids) == len(vocab):
234-
raise ValueError(
235-
"All tokens in the vocabulary are unused tokens. This will result in an empty tokenizer. "
236-
"Please provide a valid token removal pattern. The pattern is now: {token_remove_pattern}"
237-
)
238-
239-
# Remove the unused tokens from the tokenizer.
240-
new_tokenizer = remove_tokens(tokenizer.backend_tokenizer, wrong_tokens)
241-
if new_tokenizer.get_vocab_size() == tokenizer.backend_tokenizer.get_vocab_size():
242-
# This happens if we didn't remove any tokens.
243-
return new_tokenizer, embeddings
244-
245-
# Remove the embeddings of the unused tokens.
246-
embeddings = np.delete(embeddings, wrong_token_ids, axis=0)
247-
logger.info(f"Removed {len(wrong_tokens)} unused tokens from the tokenizer and embeddings.")
192+
token_remove_regex: re.Pattern | None = None
193+
if token_remove_pattern is not None:
194+
try:
195+
token_remove_regex = re.compile(token_remove_pattern)
196+
except re.error as e:
197+
raise ValueError(f"Couldn't compile the regex pattern: {e}")
248198

249-
return new_tokenizer, embeddings
199+
return sif_coefficient, token_remove_regex
250200

251201

252202
def distill(
@@ -345,26 +295,40 @@ def _post_process_embeddings(
345295
return embeddings
346296

347297

348-
def _clean_vocabulary(preprocessed_vocabulary: list[str], added_tokens: list[str]) -> list[str]:
298+
def _clean_vocabulary(tokenizer: Tokenizer, vocabulary: list[str], added_tokens: list[str]) -> list[str]:
349299
"""Cleans a vocabulary by removing duplicates and tokens that were already in the vocabulary."""
350300
added_tokens_set = set(added_tokens)
351301
seen_tokens = set()
352302
cleaned_vocabulary = []
353303
n_empty = 0
354304
n_duplicates = 0
355-
for token in preprocessed_vocabulary:
305+
n_multiword = 0
306+
for token in vocabulary:
307+
if tokenizer.normalizer is not None:
308+
token = tokenizer.normalizer.normalize_str(token)
309+
356310
if not token:
357311
n_empty += 1
358312
continue
359313
if token in seen_tokens or token in added_tokens_set:
360314
n_duplicates += 1
361315
continue
316+
317+
pre_tokenizer = tokenizer.pre_tokenizer
318+
if pre_tokenizer is not None:
319+
pretokenized_tokens = pre_tokenizer.pre_tokenize_str(token)
320+
if len(pretokenized_tokens) != 1:
321+
n_multiword += 1
322+
continue
323+
362324
seen_tokens.add(token)
363325
cleaned_vocabulary.append(token)
364326

365327
if n_duplicates:
366328
logger.warning(f"Removed {n_duplicates} duplicate tokens.")
367329
if n_empty:
368330
logger.warning(f"Removed {n_empty} empty tokens.")
331+
if n_multiword:
332+
logger.warning(f"Removed {n_multiword} multiword tokens.")
369333

370334
return cleaned_vocabulary

0 commit comments

Comments
 (0)