11import os
2+ from contextlib import contextmanager
23
34import numpy as np
45import pytest
78from tests .utils import delete_model_cache
89
910
10- @pytest .mark .parametrize ("model_name" , ["Qdrant/bm42-all-minilm-l6-v2-attentions" , "Qdrant/bm25" ])
11- def test_attention_embeddings (model_name : str ) -> None :
12- is_ci = os .getenv ("CI" )
13- model = SparseTextEmbedding (model_name = model_name )
14-
15- output = list (
16- model .query_embed (
17- [
18- "I must not fear. Fear is the mind-killer." ,
19- ]
20- )
21- )
22-
23- assert len (output ) == 1
24-
25- for result in output :
26- assert len (result .indices ) == len (result .values )
27- assert np .allclose (result .values , np .ones (len (result .values )))
28-
29- quotes = [
30- "I must not fear. Fear is the mind-killer." ,
31- "All animals are equal, but some animals are more equal than others." ,
32- "It was a pleasure to burn." ,
33- "The sky above the port was the color of television, tuned to a dead channel." ,
34- "In the beginning, the universe was created."
35- " This has made a lot of people very angry and been widely regarded as a bad move." ,
36- "It's a truth universally acknowledged that a zombie in possession of brains must be in want of more brains." ,
37- "War is peace. Freedom is slavery. Ignorance is strength." ,
38- "We're not in Infinity; we're in the suburbs." ,
39- "I was a thousand times more evil than thou!" ,
40- "History is merely a list of surprises... It can only prepare us to be surprised yet again." ,
41- "." , # Empty string
42- ]
43-
44- output = list (model .embed (quotes ))
45-
46- assert len (output ) == len (quotes )
47-
48- for result in output [:- 1 ]:
49- assert len (result .indices ) == len (result .values )
50- assert len (result .indices ) > 0
51-
52- assert len (output [- 1 ].indices ) == 0
53-
54- # Test support for unknown languages
55- output = list (
56- model .query_embed (
57- [
58- "привет мир!" ,
59- ]
60- )
61- )
11+ _MODELS_TO_CACHE = ("Qdrant/bm42-all-minilm-l6-v2-attentions" , "Qdrant/bm25" )
12+ MODELS_TO_CACHE = tuple ([x .lower () for x in _MODELS_TO_CACHE ])
6213
63- assert len (output ) == 1
6414
65- for result in output :
66- assert len (result .indices ) == len (result .values )
67- assert len (result .indices ) == 2
15+ @pytest .fixture (scope = "module" )
16+ def model_cache ():
17+ is_ci = os .getenv ("CI" )
18+ cache = {}
19+
20+ @contextmanager
21+ def get_model (model_name : str ):
22+ lowercase_model_name = model_name .lower ()
23+ if lowercase_model_name not in cache :
24+ cache [lowercase_model_name ] = SparseTextEmbedding (lowercase_model_name )
25+ yield cache [lowercase_model_name ]
26+ if lowercase_model_name not in MODELS_TO_CACHE :
27+ print ("deleting model" )
28+ model_inst = cache .pop (lowercase_model_name )
29+ if is_ci :
30+ delete_model_cache (model_inst .model ._model_dir )
31+ del model_inst
32+
33+ yield get_model
6834
6935 if is_ci :
70- delete_model_cache (model .model ._model_dir )
36+ for name , model in cache .items ():
37+ delete_model_cache (model .model ._model_dir )
38+ cache .clear ()
7139
7240
7341@pytest .mark .parametrize ("model_name" , ["Qdrant/bm42-all-minilm-l6-v2-attentions" , "Qdrant/bm25" ])
74- def test_parallel_processing (model_name : str ) -> None :
75- is_ci = os .getenv ("CI" )
42+ def test_attention_embeddings (model_cache , model_name : str ) -> None :
43+ with model_cache (model_name ) as model :
44+ output = list (
45+ model .query_embed (
46+ [
47+ "I must not fear. Fear is the mind-killer." ,
48+ ]
49+ )
50+ )
7651
77- model = SparseTextEmbedding (model_name = model_name )
52+ assert len (output ) == 1
53+
54+ for result in output :
55+ assert len (result .indices ) == len (result .values )
56+ assert np .allclose (result .values , np .ones (len (result .values )))
57+
58+ quotes = [
59+ "I must not fear. Fear is the mind-killer." ,
60+ "All animals are equal, but some animals are more equal than others." ,
61+ "It was a pleasure to burn." ,
62+ "The sky above the port was the color of television, tuned to a dead channel." ,
63+ "In the beginning, the universe was created."
64+ " This has made a lot of people very angry and been widely regarded as a bad move." ,
65+ "It's a truth universally acknowledged that a zombie in possession of brains must be in want of more brains." ,
66+ "War is peace. Freedom is slavery. Ignorance is strength." ,
67+ "We're not in Infinity; we're in the suburbs." ,
68+ "I was a thousand times more evil than thou!" ,
69+ "History is merely a list of surprises... It can only prepare us to be surprised yet again." ,
70+ "." , # Empty string
71+ ]
72+
73+ output = list (model .embed (quotes ))
74+
75+ assert len (output ) == len (quotes )
76+
77+ for result in output [:- 1 ]:
78+ assert len (result .indices ) == len (result .values )
79+ assert len (result .indices ) > 0
80+
81+ assert len (output [- 1 ].indices ) == 0
82+
83+ # Test support for unknown languages
84+ output = list (
85+ model .query_embed (
86+ [
87+ "привет мир!" ,
88+ ]
89+ )
90+ )
7891
79- docs = ["hello world" , "attention embedding" , "Mangez-vous vraiment des grenouilles?" ] * 100
80- embeddings = list (model .embed (docs , batch_size = 10 , parallel = 2 ))
92+ assert len (output ) == 1
8193
82- embeddings_2 = list (model .embed (docs , batch_size = 10 , parallel = None ))
94+ for result in output :
95+ assert len (result .indices ) == len (result .values )
96+ assert len (result .indices ) == 2
8397
84- embeddings_3 = list (model .embed (docs , batch_size = 10 , parallel = 0 ))
8598
86- assert len (embeddings ) == len (docs )
99+ @pytest .mark .parametrize ("model_name" , ["Qdrant/bm42-all-minilm-l6-v2-attentions" , "Qdrant/bm25" ])
100+ def test_parallel_processing (model_cache , model_name : str ) -> None :
101+ with model_cache (model_name ) as model :
102+ docs = [
103+ "hello world" ,
104+ "attention embedding" ,
105+ "Mangez-vous vraiment des grenouilles?" ,
106+ ] * 100
107+ embeddings = list (model .embed (docs , batch_size = 10 , parallel = 2 ))
87108
88- for emb_1 , emb_2 , emb_3 in zip (embeddings , embeddings_2 , embeddings_3 ):
89- assert np .allclose (emb_1 .indices , emb_2 .indices )
90- assert np .allclose (emb_1 .indices , emb_3 .indices )
91- assert np .allclose (emb_1 .values , emb_2 .values )
92- assert np .allclose (emb_1 .values , emb_3 .values )
109+ embeddings_2 = list (model .embed (docs , batch_size = 10 , parallel = None ))
93110
94- if is_ci :
95- delete_model_cache (model .model ._model_dir )
111+ embeddings_3 = list (model .embed (docs , batch_size = 10 , parallel = 0 ))
96112
113+ assert len (embeddings ) == len (docs )
114+
115+ for emb_1 , emb_2 , emb_3 in zip (embeddings , embeddings_2 , embeddings_3 ):
116+ assert np .allclose (emb_1 .indices , emb_2 .indices )
117+ assert np .allclose (emb_1 .indices , emb_3 .indices )
118+ assert np .allclose (emb_1 .values , emb_2 .values )
119+ assert np .allclose (emb_1 .values , emb_3 .values )
97120
98- @pytest .mark .parametrize ("model_name" , ["Qdrant/bm25" ])
99- def test_multilanguage (model_name : str ) -> None :
100- is_ci = os .getenv ("CI" )
101121
122+ @pytest .mark .parametrize ("model_name" , ["Qdrant/bm25" ])
123+ def test_multilanguage (model_cache , model_name : str ) -> None :
102124 docs = ["Mangez-vous vraiment des grenouilles?" , "Je suis au lit" ]
103125
104126 model = SparseTextEmbedding (model_name = model_name , language = "french" )
@@ -109,39 +131,30 @@ def test_multilanguage(model_name: str) -> None:
109131 assert embeddings [1 ].values .shape == (1 ,)
110132 assert embeddings [1 ].indices .shape == (1 ,)
111133
112- model = SparseTextEmbedding (model_name = model_name , language = "english" )
113- embeddings = list (model .embed (docs ))[:2 ]
114- assert embeddings [0 ].values .shape == (5 ,)
115- assert embeddings [0 ].indices .shape == (5 ,)
134+ with model_cache (model_name ) as model : # language = "english"
135+ embeddings = list (model .embed (docs ))[:2 ]
136+ assert embeddings [0 ].values .shape == (5 ,)
137+ assert embeddings [0 ].indices .shape == (5 ,)
116138
117- assert embeddings [1 ].values .shape == (4 ,)
118- assert embeddings [1 ].indices .shape == (4 ,)
119-
120- if is_ci :
121- delete_model_cache (model .model ._model_dir )
139+ assert embeddings [1 ].values .shape == (4 ,)
140+ assert embeddings [1 ].indices .shape == (4 ,)
122141
123142
124143@pytest .mark .parametrize ("model_name" , ["Qdrant/bm25" ])
125- def test_special_characters (model_name : str ) -> None :
126- is_ci = os .getenv ("CI" )
127-
128- docs = [
129- "Über den größten Flüssen Österreichs äußern sich Experten häufig: Öko-Systeme müssen geschützt werden!" ,
130- "L'élève français s'écrie : « Où est mon crayon ? J'ai besoin de finir cet exercice avant la récréation!" ,
131- "Într-o zi însorită, Ștefan și Ioana au mâncat mămăligă cu brânză și au băut țuică la cabană." ,
132- "Üzgün öğretmen öğrencilere seslendi: Lütfen gürültü yapmayın, sınavınızı bitirmeye çalışıyorum!" ,
133- "Ο Ξενοφών είπε: «Ψάχνω για ένα ωραίο δώρο για τη γιαγιά μου. Ίσως ένα φυτό ή ένα βιβλίο;»" ,
134- "Hola! ¿Cómo estás? Estoy muy emocionado por el cumpleaños de mi hermano, ¡va a ser increíble! También quiero comprar un pastel de chocolate con fresas y un regalo especial: un libro titulado «Cien años de soledad" ,
135- ]
136-
137- model = SparseTextEmbedding (model_name = model_name , language = "english" )
138- embeddings = list (model .embed (docs ))
139- for idx , shape in enumerate ([14 , 18 , 15 , 10 , 15 ]):
140- assert embeddings [idx ].values .shape == (shape ,)
141- assert embeddings [idx ].indices .shape == (shape ,)
142-
143- if is_ci :
144- delete_model_cache (model .model ._model_dir )
144+ def test_special_characters (model_cache , model_name : str ) -> None :
145+ with model_cache (model_name ) as model :
146+ docs = [
147+ "Über den größten Flüssen Österreichs äußern sich Experten häufig: Öko-Systeme müssen geschützt werden!" ,
148+ "L'élève français s'écrie : « Où est mon crayon ? J'ai besoin de finir cet exercice avant la récréation!" ,
149+ "Într-o zi însorită, Ștefan și Ioana au mâncat mămăligă cu brânză și au băut țuică la cabană." ,
150+ "Üzgün öğretmen öğrencilere seslendi: Lütfen gürültü yapmayın, sınavınızı bitirmeye çalışıyorum!" ,
151+ "Ο Ξενοφών είπε: «Ψάχνω για ένα ωραίο δώρο για τη γιαγιά μου. Ίσως ένα φυτό ή ένα βιβλίο;»" ,
152+ "Hola! ¿Cómo estás? Estoy muy emocionado por el cumpleaños de mi hermano, ¡va a ser increíble! También quiero comprar un pastel de chocolate con fresas y un regalo especial: un libro titulado «Cien años de soledad" ,
153+ ]
154+ embeddings = list (model .embed (docs ))
155+ for idx , shape in enumerate ([14 , 18 , 15 , 10 , 15 ]):
156+ assert embeddings [idx ].values .shape == (shape ,)
157+ assert embeddings [idx ].indices .shape == (shape ,)
145158
146159
147160@pytest .mark .parametrize ("model_name" , ["Qdrant/bm42-all-minilm-l6-v2-attentions" ])
0 commit comments