12
12
from tokenizers import Encoding , Tokenizer
13
13
from tqdm import tqdm
14
14
15
- from model2vec .quantization import DType , quantize_embeddings
15
+ from model2vec .quantization import DType , quantize_and_reduce_dim
16
16
from model2vec .utils import ProgressParallel , load_local_model
17
17
18
18
PathLike = Union [Path , str ]
@@ -171,28 +171,22 @@ def from_pretrained(
171
171
:param dimensionality: The dimensionality of the model. If this is None, use the dimensionality of the model.
172
172
This is useful if you want to load a model with a lower dimensionality.
173
173
Note that this only applies if you have trained your model using mrl or PCA.
174
- :return: A StaticModel
175
- :raises: ValueError if the dimensionality is greater than the model dimensionality.
174
+ :return: A StaticModel.
176
175
"""
177
176
from model2vec .hf_utils import load_pretrained
178
177
179
178
embeddings , tokenizer , config , metadata = load_pretrained (
180
- path , token = token , from_sentence_transformers = False , subfolder = subfolder
179
+ folder_or_repo_path = path ,
180
+ token = token ,
181
+ from_sentence_transformers = False ,
182
+ subfolder = subfolder ,
181
183
)
182
184
183
- if quantize_to is not None :
184
- quantize_to = DType (quantize_to )
185
- embeddings = quantize_embeddings (embeddings , quantize_to )
186
- if dimensionality is not None :
187
- if dimensionality > embeddings .shape [1 ]:
188
- raise ValueError (
189
- f"Dimensionality { dimensionality } is greater than the model dimensionality { embeddings .shape [1 ]} "
190
- )
191
- embeddings = embeddings [:, :dimensionality ]
192
- if config .get ("apply_pca" , None ) is None :
193
- logger .warning (
194
- "You are reducing the dimensionality of the model, but we can't find a pca key in the model config. This might not work as expected."
195
- )
185
+ embeddings = quantize_and_reduce_dim (
186
+ embeddings = embeddings ,
187
+ quantize_to = quantize_to ,
188
+ dimensionality = dimensionality ,
189
+ )
196
190
197
191
return cls (
198
192
embeddings ,
@@ -209,6 +203,8 @@ def from_sentence_transformers(
209
203
path : PathLike ,
210
204
token : str | None = None ,
211
205
normalize : bool | None = None ,
206
+ quantize_to : str | DType | None = None ,
207
+ dimensionality : int | None = None ,
212
208
) -> StaticModel :
213
209
"""
214
210
Load a StaticModel trained with sentence transformers from a local path or huggingface hub path.
@@ -218,13 +214,36 @@ def from_sentence_transformers(
218
214
:param path: The path to load your static model from.
219
215
:param token: The huggingface token to use.
220
216
:param normalize: Whether to normalize the embeddings.
221
- :return: A StaticModel
217
+ :param quantize_to: The dtype to quantize the model to. If None, no quantization is done.
218
+ If a string is passed, it is converted to a DType.
219
+ :param dimensionality: The dimensionality of the model. If this is None, use the dimensionality of the model.
220
+ This is useful if you want to load a model with a lower dimensionality.
221
+ Note that this only applies if you have trained your model using mrl or PCA.
222
+ :return: A StaticModel.
222
223
"""
223
224
from model2vec .hf_utils import load_pretrained
224
225
225
- embeddings , tokenizer , config , _ = load_pretrained (path , token = token , from_sentence_transformers = True )
226
+ embeddings , tokenizer , config , metadata = load_pretrained (
227
+ folder_or_repo_path = path ,
228
+ token = token ,
229
+ from_sentence_transformers = True ,
230
+ subfolder = None ,
231
+ )
232
+
233
+ embeddings = quantize_and_reduce_dim (
234
+ embeddings = embeddings ,
235
+ quantize_to = quantize_to ,
236
+ dimensionality = dimensionality ,
237
+ )
226
238
227
- return cls (embeddings , tokenizer , config , normalize = normalize , base_model_name = None , language = None )
239
+ return cls (
240
+ embeddings ,
241
+ tokenizer ,
242
+ config ,
243
+ normalize = normalize ,
244
+ base_model_name = metadata .get ("base_model" ),
245
+ language = metadata .get ("language" ),
246
+ )
228
247
229
248
def encode_as_sequence (
230
249
self ,
0 commit comments