Skip to content

Commit 533b54c

Browse files
authored
tests: introduce model cache to tests (#573)
* tests: introduce model cache to tests * fix: fix not cached model deletion * new: do not run CI tests on mac os and windows on python 3.10-3.12 * fix: lowercase cache keys, bm25 caching * tests: do not run parallel processing on all cpus in sparse text embed * fix: fix models to cache names, do not run parallel=0 * fix: fix sparse embedding tests * fix: bm42 language by lower case model name
1 parent ba1f605 commit 533b54c

11 files changed

+563
-405
lines changed

.github/workflows/python-tests.yml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,20 @@ jobs:
2424
- ubuntu-latest
2525
- macos-latest
2626
- windows-latest
27+
exclude:
28+
# Exclude 3.10–3.12 for macOS and Windows
29+
- os: macos-latest
30+
python-version: '3.10.x'
31+
- os: macos-latest
32+
python-version: '3.11.x'
33+
- os: macos-latest
34+
python-version: '3.12.x'
35+
- os: windows-latest
36+
python-version: '3.10.x'
37+
- os: windows-latest
38+
python-version: '3.11.x'
39+
- os: windows-latest
40+
python-version: '3.12.x'
2741

2842
runs-on: ${{ matrix.os }}
2943

fastembed/sparse/bm42.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,17 @@
3131
),
3232
]
3333

34-
MODEL_TO_LANGUAGE = {
34+
35+
_MODEL_TO_LANGUAGE = {
3536
"Qdrant/bm42-all-minilm-l6-v2-attentions": "english",
3637
}
38+
MODEL_TO_LANGUAGE = {
39+
model_name.lower(): language for model_name, language in _MODEL_TO_LANGUAGE.items()
40+
}
41+
42+
43+
def get_language_by_model_name(model_name: str) -> str:
44+
return MODEL_TO_LANGUAGE[model_name.lower()]
3745

3846

3947
class Bm42(SparseTextEmbeddingBase, OnnxTextModel[SparseEmbedding]):
@@ -124,7 +132,7 @@ def __init__(
124132
self.special_tokens_ids: set[int] = set()
125133
self.punctuation = set(string.punctuation)
126134
self.stopwords = set(self._load_stopwords(self._model_dir))
127-
self.stemmer = SnowballStemmer(MODEL_TO_LANGUAGE[model_name])
135+
self.stemmer = SnowballStemmer(get_language_by_model_name(self.model_name))
128136
self.alpha = alpha
129137

130138
if not self.lazy_load:

fastembed/sparse/minicoil.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,16 @@
4646
),
4747
]
4848

49-
MODEL_TO_LANGUAGE = {
49+
_MODEL_TO_LANGUAGE = {
5050
"Qdrant/minicoil-v1": "english",
5151
}
52+
MODEL_TO_LANGUAGE = {
53+
model_name.lower(): language for model_name, language in _MODEL_TO_LANGUAGE.items()
54+
}
55+
56+
57+
def get_language_by_model_name(model_name: str) -> str:
58+
return MODEL_TO_LANGUAGE[model_name.lower()]
5259

5360

5461
class MiniCOIL(SparseTextEmbeddingBase, OnnxTextModel[SparseEmbedding]):
@@ -156,7 +163,7 @@ def load_onnx_model(self) -> None:
156163
self.special_tokens_ids = set(self.special_token_to_id.values())
157164
self.stopwords = set(self._load_stopwords(self._model_dir))
158165

159-
stemmer = SnowballStemmer(MODEL_TO_LANGUAGE[self.model_name])
166+
stemmer = SnowballStemmer(get_language_by_model_name(self.model_name))
160167

161168
self.vocab_resolver = VocabResolver(
162169
tokenizer=VocabTokenizer(self.tokenizer),

tests/test_attention_embeddings.py

Lines changed: 117 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
from contextlib import contextmanager
23

34
import numpy as np
45
import pytest
@@ -7,98 +8,119 @@
78
from tests.utils import delete_model_cache
89

910

10-
@pytest.mark.parametrize("model_name", ["Qdrant/bm42-all-minilm-l6-v2-attentions", "Qdrant/bm25"])
11-
def test_attention_embeddings(model_name: str) -> None:
12-
is_ci = os.getenv("CI")
13-
model = SparseTextEmbedding(model_name=model_name)
14-
15-
output = list(
16-
model.query_embed(
17-
[
18-
"I must not fear. Fear is the mind-killer.",
19-
]
20-
)
21-
)
22-
23-
assert len(output) == 1
24-
25-
for result in output:
26-
assert len(result.indices) == len(result.values)
27-
assert np.allclose(result.values, np.ones(len(result.values)))
28-
29-
quotes = [
30-
"I must not fear. Fear is the mind-killer.",
31-
"All animals are equal, but some animals are more equal than others.",
32-
"It was a pleasure to burn.",
33-
"The sky above the port was the color of television, tuned to a dead channel.",
34-
"In the beginning, the universe was created."
35-
" This has made a lot of people very angry and been widely regarded as a bad move.",
36-
"It's a truth universally acknowledged that a zombie in possession of brains must be in want of more brains.",
37-
"War is peace. Freedom is slavery. Ignorance is strength.",
38-
"We're not in Infinity; we're in the suburbs.",
39-
"I was a thousand times more evil than thou!",
40-
"History is merely a list of surprises... It can only prepare us to be surprised yet again.",
41-
".", # Empty string
42-
]
43-
44-
output = list(model.embed(quotes))
45-
46-
assert len(output) == len(quotes)
47-
48-
for result in output[:-1]:
49-
assert len(result.indices) == len(result.values)
50-
assert len(result.indices) > 0
51-
52-
assert len(output[-1].indices) == 0
53-
54-
# Test support for unknown languages
55-
output = list(
56-
model.query_embed(
57-
[
58-
"привет мир!",
59-
]
60-
)
61-
)
11+
_MODELS_TO_CACHE = ("Qdrant/bm42-all-minilm-l6-v2-attentions", "Qdrant/bm25")
12+
MODELS_TO_CACHE = tuple([x.lower() for x in _MODELS_TO_CACHE])
6213

63-
assert len(output) == 1
6414

65-
for result in output:
66-
assert len(result.indices) == len(result.values)
67-
assert len(result.indices) == 2
15+
@pytest.fixture(scope="module")
16+
def model_cache():
17+
is_ci = os.getenv("CI")
18+
cache = {}
19+
20+
@contextmanager
21+
def get_model(model_name: str):
22+
lowercase_model_name = model_name.lower()
23+
if lowercase_model_name not in cache:
24+
cache[lowercase_model_name] = SparseTextEmbedding(lowercase_model_name)
25+
yield cache[lowercase_model_name]
26+
if lowercase_model_name not in MODELS_TO_CACHE:
27+
print("deleting model")
28+
model_inst = cache.pop(lowercase_model_name)
29+
if is_ci:
30+
delete_model_cache(model_inst.model._model_dir)
31+
del model_inst
32+
33+
yield get_model
6834

6935
if is_ci:
70-
delete_model_cache(model.model._model_dir)
36+
for name, model in cache.items():
37+
delete_model_cache(model.model._model_dir)
38+
cache.clear()
7139

7240

7341
@pytest.mark.parametrize("model_name", ["Qdrant/bm42-all-minilm-l6-v2-attentions", "Qdrant/bm25"])
74-
def test_parallel_processing(model_name: str) -> None:
75-
is_ci = os.getenv("CI")
42+
def test_attention_embeddings(model_cache, model_name: str) -> None:
43+
with model_cache(model_name) as model:
44+
output = list(
45+
model.query_embed(
46+
[
47+
"I must not fear. Fear is the mind-killer.",
48+
]
49+
)
50+
)
7651

77-
model = SparseTextEmbedding(model_name=model_name)
52+
assert len(output) == 1
53+
54+
for result in output:
55+
assert len(result.indices) == len(result.values)
56+
assert np.allclose(result.values, np.ones(len(result.values)))
57+
58+
quotes = [
59+
"I must not fear. Fear is the mind-killer.",
60+
"All animals are equal, but some animals are more equal than others.",
61+
"It was a pleasure to burn.",
62+
"The sky above the port was the color of television, tuned to a dead channel.",
63+
"In the beginning, the universe was created."
64+
" This has made a lot of people very angry and been widely regarded as a bad move.",
65+
"It's a truth universally acknowledged that a zombie in possession of brains must be in want of more brains.",
66+
"War is peace. Freedom is slavery. Ignorance is strength.",
67+
"We're not in Infinity; we're in the suburbs.",
68+
"I was a thousand times more evil than thou!",
69+
"History is merely a list of surprises... It can only prepare us to be surprised yet again.",
70+
".", # Empty string
71+
]
72+
73+
output = list(model.embed(quotes))
74+
75+
assert len(output) == len(quotes)
76+
77+
for result in output[:-1]:
78+
assert len(result.indices) == len(result.values)
79+
assert len(result.indices) > 0
80+
81+
assert len(output[-1].indices) == 0
82+
83+
# Test support for unknown languages
84+
output = list(
85+
model.query_embed(
86+
[
87+
"привет мир!",
88+
]
89+
)
90+
)
7891

79-
docs = ["hello world", "attention embedding", "Mangez-vous vraiment des grenouilles?"] * 100
80-
embeddings = list(model.embed(docs, batch_size=10, parallel=2))
92+
assert len(output) == 1
8193

82-
embeddings_2 = list(model.embed(docs, batch_size=10, parallel=None))
94+
for result in output:
95+
assert len(result.indices) == len(result.values)
96+
assert len(result.indices) == 2
8397

84-
embeddings_3 = list(model.embed(docs, batch_size=10, parallel=0))
8598

86-
assert len(embeddings) == len(docs)
99+
@pytest.mark.parametrize("model_name", ["Qdrant/bm42-all-minilm-l6-v2-attentions", "Qdrant/bm25"])
100+
def test_parallel_processing(model_cache, model_name: str) -> None:
101+
with model_cache(model_name) as model:
102+
docs = [
103+
"hello world",
104+
"attention embedding",
105+
"Mangez-vous vraiment des grenouilles?",
106+
] * 100
107+
embeddings = list(model.embed(docs, batch_size=10, parallel=2))
87108

88-
for emb_1, emb_2, emb_3 in zip(embeddings, embeddings_2, embeddings_3):
89-
assert np.allclose(emb_1.indices, emb_2.indices)
90-
assert np.allclose(emb_1.indices, emb_3.indices)
91-
assert np.allclose(emb_1.values, emb_2.values)
92-
assert np.allclose(emb_1.values, emb_3.values)
109+
embeddings_2 = list(model.embed(docs, batch_size=10, parallel=None))
93110

94-
if is_ci:
95-
delete_model_cache(model.model._model_dir)
111+
embeddings_3 = list(model.embed(docs, batch_size=10, parallel=0))
96112

113+
assert len(embeddings) == len(docs)
114+
115+
for emb_1, emb_2, emb_3 in zip(embeddings, embeddings_2, embeddings_3):
116+
assert np.allclose(emb_1.indices, emb_2.indices)
117+
assert np.allclose(emb_1.indices, emb_3.indices)
118+
assert np.allclose(emb_1.values, emb_2.values)
119+
assert np.allclose(emb_1.values, emb_3.values)
97120

98-
@pytest.mark.parametrize("model_name", ["Qdrant/bm25"])
99-
def test_multilanguage(model_name: str) -> None:
100-
is_ci = os.getenv("CI")
101121

122+
@pytest.mark.parametrize("model_name", ["Qdrant/bm25"])
123+
def test_multilanguage(model_cache, model_name: str) -> None:
102124
docs = ["Mangez-vous vraiment des grenouilles?", "Je suis au lit"]
103125

104126
model = SparseTextEmbedding(model_name=model_name, language="french")
@@ -109,39 +131,30 @@ def test_multilanguage(model_name: str) -> None:
109131
assert embeddings[1].values.shape == (1,)
110132
assert embeddings[1].indices.shape == (1,)
111133

112-
model = SparseTextEmbedding(model_name=model_name, language="english")
113-
embeddings = list(model.embed(docs))[:2]
114-
assert embeddings[0].values.shape == (5,)
115-
assert embeddings[0].indices.shape == (5,)
134+
with model_cache(model_name) as model: # language = "english"
135+
embeddings = list(model.embed(docs))[:2]
136+
assert embeddings[0].values.shape == (5,)
137+
assert embeddings[0].indices.shape == (5,)
116138

117-
assert embeddings[1].values.shape == (4,)
118-
assert embeddings[1].indices.shape == (4,)
119-
120-
if is_ci:
121-
delete_model_cache(model.model._model_dir)
139+
assert embeddings[1].values.shape == (4,)
140+
assert embeddings[1].indices.shape == (4,)
122141

123142

124143
@pytest.mark.parametrize("model_name", ["Qdrant/bm25"])
125-
def test_special_characters(model_name: str) -> None:
126-
is_ci = os.getenv("CI")
127-
128-
docs = [
129-
"Über den größten Flüssen Österreichs äußern sich Experten häufig: Öko-Systeme müssen geschützt werden!",
130-
"L'élève français s'écrie : « Où est mon crayon ? J'ai besoin de finir cet exercice avant la récréation!",
131-
"Într-o zi însorită, Ștefan și Ioana au mâncat mămăligă cu brânză și au băut țuică la cabană.",
132-
"Üzgün öğretmen öğrencilere seslendi: Lütfen gürültü yapmayın, sınavınızı bitirmeye çalışıyorum!",
133-
"Ο Ξενοφών είπε: «Ψάχνω για ένα ωραίο δώρο για τη γιαγιά μου. Ίσως ένα φυτό ή ένα βιβλίο;»",
134-
"Hola! ¿Cómo estás? Estoy muy emocionado por el cumpleaños de mi hermano, ¡va a ser increíble! También quiero comprar un pastel de chocolate con fresas y un regalo especial: un libro titulado «Cien años de soledad",
135-
]
136-
137-
model = SparseTextEmbedding(model_name=model_name, language="english")
138-
embeddings = list(model.embed(docs))
139-
for idx, shape in enumerate([14, 18, 15, 10, 15]):
140-
assert embeddings[idx].values.shape == (shape,)
141-
assert embeddings[idx].indices.shape == (shape,)
142-
143-
if is_ci:
144-
delete_model_cache(model.model._model_dir)
144+
def test_special_characters(model_cache, model_name: str) -> None:
145+
with model_cache(model_name) as model:
146+
docs = [
147+
"Über den größten Flüssen Österreichs äußern sich Experten häufig: Öko-Systeme müssen geschützt werden!",
148+
"L'élève français s'écrie : « Où est mon crayon ? J'ai besoin de finir cet exercice avant la récréation!",
149+
"Într-o zi însorită, Ștefan și Ioana au mâncat mămăligă cu brânză și au băut țuică la cabană.",
150+
"Üzgün öğretmen öğrencilere seslendi: Lütfen gürültü yapmayın, sınavınızı bitirmeye çalışıyorum!",
151+
"Ο Ξενοφών είπε: «Ψάχνω για ένα ωραίο δώρο για τη γιαγιά μου. Ίσως ένα φυτό ή ένα βιβλίο;»",
152+
"Hola! ¿Cómo estás? Estoy muy emocionado por el cumpleaños de mi hermano, ¡va a ser increíble! También quiero comprar un pastel de chocolate con fresas y un regalo especial: un libro titulado «Cien años de soledad",
153+
]
154+
embeddings = list(model.embed(docs))
155+
for idx, shape in enumerate([14, 18, 15, 10, 15]):
156+
assert embeddings[idx].values.shape == (shape,)
157+
assert embeddings[idx].indices.shape == (shape,)
145158

146159

147160
@pytest.mark.parametrize("model_name", ["Qdrant/bm42-all-minilm-l6-v2-attentions"])

tests/test_custom_models.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,13 @@ def test_text_custom_model():
7070
assert embeddings.shape == (2, dim)
7171

7272
assert np.allclose(embeddings[0, : canonical_vector.shape[0]], canonical_vector, atol=1e-3)
73+
7374
if is_ci:
7475
delete_model_cache(model.model._model_dir)
7576

77+
CustomTextEmbedding.SUPPORTED_MODELS.clear()
78+
CustomTextEmbedding.POSTPROCESSING_MAPPING.clear()
79+
7680

7781
def test_cross_encoder_custom_model():
7882
is_ci = os.getenv("CI")
@@ -110,6 +114,8 @@ def test_cross_encoder_custom_model():
110114
if is_ci:
111115
delete_model_cache(model.model._model_dir)
112116

117+
CustomTextCrossEncoder.SUPPORTED_MODELS.clear()
118+
113119

114120
def test_mock_add_custom_models():
115121
dim = 5
@@ -169,6 +175,9 @@ def test_mock_add_custom_models():
169175
)
170176
assert np.allclose(post_processed_output, expected_output[model_name], atol=1e-3)
171177

178+
CustomTextEmbedding.SUPPORTED_MODELS.clear()
179+
CustomTextEmbedding.POSTPROCESSING_MAPPING.clear()
180+
172181

173182
def test_do_not_add_existing_model():
174183
existing_base_model = "sentence-transformers/all-MiniLM-L6-v2"
@@ -203,6 +212,9 @@ def test_do_not_add_existing_model():
203212
size_in_gb=0.47,
204213
)
205214

215+
CustomTextEmbedding.SUPPORTED_MODELS.clear()
216+
CustomTextEmbedding.POSTPROCESSING_MAPPING.clear()
217+
206218

207219
def test_do_not_add_existing_cross_encoder():
208220
existing_base_model = "Xenova/ms-marco-MiniLM-L-6-v2"
@@ -227,3 +239,5 @@ def test_do_not_add_existing_cross_encoder():
227239
sources=ModelSource(hf=custom_model_name),
228240
size_in_gb=0.08,
229241
)
242+
243+
CustomTextCrossEncoder.SUPPORTED_MODELS.clear()

0 commit comments

Comments
 (0)