Skip to content

Commit 3f5786a

Browse files
authored
feat: add dimensionality during loading (MinishLab#216)
1 parent 4cc6148 commit 3f5786a

File tree

2 files changed

+46
-0
lines changed

2 files changed

+46
-0
lines changed

model2vec/model.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ def from_pretrained(
150150
path: PathLike,
151151
token: str | None = None,
152152
normalize: bool | None = None,
153+
dimensionality: int | None = None,
153154
) -> StaticModel:
154155
"""
155156
Load a StaticModel from a local path or huggingface hub path.
@@ -159,12 +160,27 @@ def from_pretrained(
159160
:param path: The path to load your static model from.
160161
:param token: The huggingface token to use.
161162
:param normalize: Whether to normalize the embeddings.
163+
:param dimensionality: The dimensionality of the model. If this is None, use the dimensionality of the model.
164+
This is useful if you want to load a model with a lower dimensionality.
165+
Note that this only applies if you have trained your model using mrl or PCA.
162166
:return: A StaticModel
167+
:raises: ValueError if the dimensionality is greater than the model dimensionality.
163168
"""
164169
from model2vec.hf_utils import load_pretrained
165170

166171
embeddings, tokenizer, config, metadata = load_pretrained(path, token=token, from_sentence_transformers=False)
167172

173+
if dimensionality is not None:
174+
if dimensionality > embeddings.shape[1]:
175+
raise ValueError(
176+
f"Dimensionality {dimensionality} is greater than the model dimensionality {embeddings.shape[1]}"
177+
)
178+
embeddings = embeddings[:, :dimensionality]
179+
if config.get("apply_pca", None) is None:
180+
logger.warning(
181+
"You are reducing the dimensionality of the model, but we can't find a pca key in the model config. This might not work as expected."
182+
)
183+
168184
return cls(
169185
embeddings,
170186
tokenizer,

tests/test_model.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,36 @@ def test_load_pretrained(
182182
assert loaded_model.config == mock_config
183183

184184

185+
def test_load_pretrained_dim(
186+
tmp_path: Path, mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: dict[str, str]
187+
) -> None:
188+
"""Test loading a pretrained model after saving it."""
189+
# Save the model to a temporary path
190+
model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, config=mock_config)
191+
save_path = tmp_path / "saved_model"
192+
model.save_pretrained(save_path)
193+
194+
# Load the model back from the same path
195+
loaded_model = StaticModel.from_pretrained(save_path, dimensionality=2)
196+
197+
# Assert that the loaded model has the same properties as the original one
198+
np.testing.assert_array_equal(loaded_model.embedding, mock_vectors[:, :2])
199+
assert loaded_model.tokenizer.get_vocab() == mock_tokenizer.get_vocab()
200+
assert loaded_model.config == mock_config
201+
202+
# Load the model back from the same path
203+
loaded_model = StaticModel.from_pretrained(save_path, dimensionality=None)
204+
205+
# Assert that the loaded model has the same properties as the original one
206+
np.testing.assert_array_equal(loaded_model.embedding, mock_vectors)
207+
assert loaded_model.tokenizer.get_vocab() == mock_tokenizer.get_vocab()
208+
assert loaded_model.config == mock_config
209+
210+
# Load the model back from the same path
211+
with pytest.raises(ValueError):
212+
StaticModel.from_pretrained(save_path, dimensionality=3000)
213+
214+
185215
def test_initialize_normalize(mock_vectors: np.ndarray, mock_tokenizer: Tokenizer) -> None:
186216
"""Tests whether the normalization initialization is correct."""
187217
model = StaticModel(mock_vectors, mock_tokenizer, {}, normalize=None)

0 commit comments

Comments
 (0)