Skip to content

Commit 79900e7

Browse files
Pringledstephantul
andauthored
feat: Added quantization for from_sentence_transformers (MinishLab#219)
* Added quantization for from_sentence_transformers * Updates * feat: remove flag argument (MinishLab#220) * feat: remove flag argument * fix typing * add future anns --------- Co-authored-by: Stephan Tulkens <[email protected]>
1 parent 4c69a68 commit 79900e7

File tree

3 files changed

+84
-29
lines changed

3 files changed

+84
-29
lines changed

model2vec/hf_utils.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def _create_model_card(
9090
library_name="model2vec",
9191
**kwargs,
9292
)
93-
model_card = ModelCard.from_template(model_card_data, template_path=full_path)
93+
model_card = ModelCard.from_template(model_card_data, template_path=str(full_path))
9494
model_card.save(folder_path / "README.md")
9595

9696

@@ -145,24 +145,32 @@ def load_pretrained(
145145

146146
else:
147147
logger.info("Folder does not exist locally, attempting to use huggingface hub.")
148-
embeddings_path = huggingface_hub.hf_hub_download(
149-
folder_or_repo_path.as_posix(), model_file, token=token, subfolder=subfolder
148+
embeddings_path = Path(
149+
huggingface_hub.hf_hub_download(
150+
folder_or_repo_path.as_posix(), model_file, token=token, subfolder=subfolder
151+
)
150152
)
151153

152154
try:
153-
readme_path = huggingface_hub.hf_hub_download(
154-
folder_or_repo_path.as_posix(), "README.md", token=token, subfolder=subfolder
155+
readme_path = Path(
156+
huggingface_hub.hf_hub_download(
157+
folder_or_repo_path.as_posix(), "README.md", token=token, subfolder=subfolder
158+
)
155159
)
156160
metadata = _get_metadata_from_readme(Path(readme_path))
157161
except huggingface_hub.utils.EntryNotFoundError:
158162
logger.info("No README found in the model folder. No model card loaded.")
159163
metadata = {}
160164

161-
config_path = huggingface_hub.hf_hub_download(
162-
folder_or_repo_path.as_posix(), config_name, token=token, subfolder=subfolder
165+
config_path = Path(
166+
huggingface_hub.hf_hub_download(
167+
folder_or_repo_path.as_posix(), config_name, token=token, subfolder=subfolder
168+
)
163169
)
164-
tokenizer_path = huggingface_hub.hf_hub_download(
165-
folder_or_repo_path.as_posix(), tokenizer_file, token=token, subfolder=subfolder
170+
tokenizer_path = Path(
171+
huggingface_hub.hf_hub_download(
172+
folder_or_repo_path.as_posix(), tokenizer_file, token=token, subfolder=subfolder
173+
)
166174
)
167175

168176
opened_tensor_file = cast(SafeOpenProtocol, safetensors.safe_open(embeddings_path, framework="numpy"))

model2vec/model.py

Lines changed: 39 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from tokenizers import Encoding, Tokenizer
1313
from tqdm import tqdm
1414

15-
from model2vec.quantization import DType, quantize_embeddings
15+
from model2vec.quantization import DType, quantize_and_reduce_dim
1616
from model2vec.utils import ProgressParallel, load_local_model
1717

1818
PathLike = Union[Path, str]
@@ -171,28 +171,22 @@ def from_pretrained(
171171
:param dimensionality: The dimensionality of the model. If this is None, use the dimensionality of the model.
172172
This is useful if you want to load a model with a lower dimensionality.
173173
Note that this only applies if you have trained your model using mrl or PCA.
174-
:return: A StaticModel
175-
:raises: ValueError if the dimensionality is greater than the model dimensionality.
174+
:return: A StaticModel.
176175
"""
177176
from model2vec.hf_utils import load_pretrained
178177

179178
embeddings, tokenizer, config, metadata = load_pretrained(
180-
path, token=token, from_sentence_transformers=False, subfolder=subfolder
179+
folder_or_repo_path=path,
180+
token=token,
181+
from_sentence_transformers=False,
182+
subfolder=subfolder,
181183
)
182184

183-
if quantize_to is not None:
184-
quantize_to = DType(quantize_to)
185-
embeddings = quantize_embeddings(embeddings, quantize_to)
186-
if dimensionality is not None:
187-
if dimensionality > embeddings.shape[1]:
188-
raise ValueError(
189-
f"Dimensionality {dimensionality} is greater than the model dimensionality {embeddings.shape[1]}"
190-
)
191-
embeddings = embeddings[:, :dimensionality]
192-
if config.get("apply_pca", None) is None:
193-
logger.warning(
194-
"You are reducing the dimensionality of the model, but we can't find a pca key in the model config. This might not work as expected."
195-
)
185+
embeddings = quantize_and_reduce_dim(
186+
embeddings=embeddings,
187+
quantize_to=quantize_to,
188+
dimensionality=dimensionality,
189+
)
196190

197191
return cls(
198192
embeddings,
@@ -209,6 +203,8 @@ def from_sentence_transformers(
209203
path: PathLike,
210204
token: str | None = None,
211205
normalize: bool | None = None,
206+
quantize_to: str | DType | None = None,
207+
dimensionality: int | None = None,
212208
) -> StaticModel:
213209
"""
214210
Load a StaticModel trained with sentence transformers from a local path or huggingface hub path.
@@ -218,13 +214,36 @@ def from_sentence_transformers(
218214
:param path: The path to load your static model from.
219215
:param token: The huggingface token to use.
220216
:param normalize: Whether to normalize the embeddings.
221-
:return: A StaticModel
217+
:param quantize_to: The dtype to quantize the model to. If None, no quantization is done.
218+
If a string is passed, it is converted to a DType.
219+
:param dimensionality: The dimensionality of the model. If this is None, use the dimensionality of the model.
220+
This is useful if you want to load a model with a lower dimensionality.
221+
Note that this only applies if you have trained your model using mrl or PCA.
222+
:return: A StaticModel.
222223
"""
223224
from model2vec.hf_utils import load_pretrained
224225

225-
embeddings, tokenizer, config, _ = load_pretrained(path, token=token, from_sentence_transformers=True)
226+
embeddings, tokenizer, config, metadata = load_pretrained(
227+
folder_or_repo_path=path,
228+
token=token,
229+
from_sentence_transformers=True,
230+
subfolder=None,
231+
)
232+
233+
embeddings = quantize_and_reduce_dim(
234+
embeddings=embeddings,
235+
quantize_to=quantize_to,
236+
dimensionality=dimensionality,
237+
)
226238

227-
return cls(embeddings, tokenizer, config, normalize=normalize, base_model_name=None, language=None)
239+
return cls(
240+
embeddings,
241+
tokenizer,
242+
config,
243+
normalize=normalize,
244+
base_model_name=metadata.get("base_model"),
245+
language=metadata.get("language"),
246+
)
228247

229248
def encode_as_sequence(
230249
self,

model2vec/quantization.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
from enum import Enum
24

35
import numpy as np
@@ -33,3 +35,29 @@ def quantize_embeddings(embeddings: np.ndarray, quantize_to: DType) -> np.ndarra
3335
return quantized
3436
else:
3537
raise ValueError("Not a valid enum member of DType.")
38+
39+
40+
def quantize_and_reduce_dim(
41+
embeddings: np.ndarray, quantize_to: str | DType | None, dimensionality: int | None
42+
) -> np.ndarray:
43+
"""
44+
Quantize embeddings to a datatype and reduce dimensionality.
45+
46+
:param embeddings: The embeddings to quantize and reduce, as a numpy array.
47+
:param quantize_to: The data type to quantize to. If None, no quantization is performed.
48+
:param dimensionality: The number of dimensions to keep. If None, no dimensionality reduction is performed.
49+
:return: The quantized and reduced embeddings.
50+
:raises ValueError: If the passed dimensionality is not None and greater than the model dimensionality.
51+
"""
52+
if quantize_to is not None:
53+
quantize_to = DType(quantize_to)
54+
embeddings = quantize_embeddings(embeddings, quantize_to)
55+
56+
if dimensionality is not None:
57+
if dimensionality > embeddings.shape[1]:
58+
raise ValueError(
59+
f"Dimensionality {dimensionality} is greater than the model dimensionality {embeddings.shape[1]}"
60+
)
61+
embeddings = embeddings[:, :dimensionality]
62+
63+
return embeddings

0 commit comments

Comments
 (0)