Skip to content

Commit 99b4cb2

Browse files
authored
Merge pull request stanfordnlp#1211 from Anindyadeep/premai/vectorizer
Prem AI Vectorizer support
2 parents deff8ec + a8034f6 commit 99b4cb2

File tree

1 file changed

+88
-30
lines changed

1 file changed

+88
-30
lines changed

dsp/modules/sentence_vectorizer.py

Lines changed: 88 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,15 @@
66

77

88
class BaseSentenceVectorizer(abc.ABC):
9-
'''
9+
"""
1010
Base Class for Vectorizers. The main purpose is to vectorize text (doc/query)
1111
for ANN/KNN indexes. `__call__` method takes `List[Example]` as a single input, then extracts
1212
`field_to_vectorize` from every Example and convert them into embeddings.
1313
You can customize extraction logic in the `_extract_text_from_examples` method.
14-
'''
14+
"""
15+
1516
# embeddings will be computed based on the string in this attribute of Example object
16-
field_to_vectorize = 'text_to_vectorize'
17+
field_to_vectorize = "text_to_vectorize"
1718

1819
def __init__(self) -> None:
1920
pass
@@ -24,28 +25,29 @@ def __call__(self, inp_examples: List["Example"]) -> np.ndarray:
2425

2526
def _extract_text_from_examples(self, inp_examples: List) -> List[str]:
2627
if isinstance(inp_examples[0], str):
27-
return inp_examples
28+
return inp_examples
2829
return [" ".join([example[key] for key in example._input_keys]) for example in inp_examples]
2930

3031

3132
class SentenceTransformersVectorizer(BaseSentenceVectorizer):
32-
'''
33+
"""
3334
Vectorizer based on `SentenceTransformers` models. You can pick any model from this link:
3435
https://huggingface.co/models?library=sentence-transformers
3536
More details about models:
3637
https://www.sbert.net/docs/pretrained_models.html
37-
'''
38+
"""
39+
3840
def __init__(
3941
self,
40-
model_name_or_path: str = 'all-MiniLM-L6-v2',
42+
model_name_or_path: str = "all-MiniLM-L6-v2",
4143
vectorize_bs: int = 256,
4244
max_gpu_devices: int = 1,
4345
normalize_embeddings: bool = False,
4446
):
4547
# this isn't a good practice, but with top-level import the whole DSP
4648
# module import will be slow (>5 sec), because SentenceTransformer is doing
4749
# it's directory/file-related magic under the hood :(
48-
50+
4951
try:
5052
from sentence_transformers import SentenceTransformer
5153
except ImportError:
@@ -55,9 +57,9 @@ def __init__(
5557
"or simply run `pip install sentence-transformers",
5658
)
5759
from dsp.utils.ann_utils import determine_devices
58-
60+
5961
self.num_devices, self.is_gpu = determine_devices(max_gpu_devices)
60-
self.proxy_device = 'cuda' if self.is_gpu else 'cpu'
62+
self.proxy_device = "cuda" if self.is_gpu else "cpu"
6163

6264
self.model = SentenceTransformer(model_name_or_path, device=self.proxy_device)
6365

@@ -93,42 +95,42 @@ def __call__(self, inp_examples: List) -> np.ndarray:
9395

9496

9597
class NaiveGetFieldVectorizer(BaseSentenceVectorizer):
96-
'''
97-
If embeddings were precomputed, then we could just extract them from the proper field
98+
"""
99+
If embeddings were precomputed, then we could just extract them from the proper field
98100
(set by `field_with_embedding`) from each `Example`.
99-
'''
100-
def __init__(self, field_with_embedding: str = 'vectorized'):
101+
"""
102+
103+
def __init__(self, field_with_embedding: str = "vectorized"):
101104
self.field_with_embedding = field_with_embedding
102105

103106
def __call__(self, inp_examples: List["Example"]) -> np.ndarray:
104-
embeddings = [
105-
getattr(cur_example, self.field_with_embedding).reshape(1, -1)
106-
for cur_example in inp_examples
107-
]
107+
embeddings = [getattr(cur_example, self.field_with_embedding).reshape(1, -1) for cur_example in inp_examples]
108108
embeddings = np.concatenate(embeddings, axis=0).astype(np.float32)
109109
return embeddings
110110

111111

112112
class CohereVectorizer(BaseSentenceVectorizer):
113-
'''
113+
"""
114114
This vectorizer uses the Cohere API to convert texts to embeddings.
115115
More about the available models: https://docs.cohere.com/reference/embed
116116
`api_key` should be passed as an argument and can be retrieved
117117
from https://dashboard.cohere.com/api-keys
118-
'''
118+
"""
119+
119120
def __init__(
120121
self,
121122
api_key: str,
122-
model: str = 'embed-english-v3.0',
123+
model: str = "embed-english-v3.0",
123124
embed_batch_size: int = 96,
124-
embedding_type: str = 'search_document', # for details check Cohere embed docs
125+
embedding_type: str = "search_document", # for details check Cohere embed docs
125126
):
126127
self.model = model
127128
self.embed_batch_size = embed_batch_size
128129
self.embedding_type = embedding_type
129130

130131
import cohere
131-
self.client = cohere.Client(api_key, client_name='dspy')
132+
133+
self.client = cohere.Client(api_key, client_name="dspy")
132134

133135
def __call__(self, inp_examples: List["Example"]) -> np.ndarray:
134136
text_to_vectorize = self._extract_text_from_examples(inp_examples)
@@ -139,7 +141,7 @@ def __call__(self, inp_examples: List["Example"]) -> np.ndarray:
139141
for cur_batch_idx in range(n_batches):
140142
start_idx = cur_batch_idx * self.embed_batch_size
141143
end_idx = (cur_batch_idx + 1) * self.embed_batch_size
142-
cur_batch = text_to_vectorize[start_idx: end_idx]
144+
cur_batch = text_to_vectorize[start_idx:end_idx]
143145

144146
response = self.client.embed(
145147
texts=cur_batch,
@@ -160,14 +162,15 @@ def __call__(self, inp_examples: List["Example"]) -> np.ndarray:
160162

161163

162164
class OpenAIVectorizer(BaseSentenceVectorizer):
163-
'''
165+
"""
164166
This vectorizer uses OpenAI API to convert texts to embeddings. Changing `model` is not
165167
recommended. More about the model: https://openai.com/blog/new-and-improved-embedding-model/
166168
`api_key` should be passed as an argument or as env variable (`OPENAI_API_KEY`).
167-
'''
169+
"""
170+
168171
def __init__(
169172
self,
170-
model: str = 'text-embedding-ada-002',
173+
model: str = "text-embedding-ada-002",
171174
embed_batch_size: int = 1024,
172175
api_key: Optional[str] = None,
173176
):
@@ -191,19 +194,20 @@ def __call__(self, inp_examples: List["Example"]) -> np.ndarray:
191194
for cur_batch_idx in range(n_batches): # tqdm.tqdm?
192195
start_idx = cur_batch_idx * self.embed_batch_size
193196
end_idx = (cur_batch_idx + 1) * self.embed_batch_size
194-
cur_batch = text_to_vectorize[start_idx: end_idx]
197+
cur_batch = text_to_vectorize[start_idx:end_idx]
195198
# OpenAI API call:
196199
response = self.Embedding.create(
197200
model=self.model,
198201
input=cur_batch,
199202
)
200203

201-
cur_batch_embeddings = [cur_obj['embedding'] for cur_obj in response['data']]
204+
cur_batch_embeddings = [cur_obj["embedding"] for cur_obj in response["data"]]
202205
embeddings_list.extend(cur_batch_embeddings)
203206

204207
embeddings = np.array(embeddings_list, dtype=np.float32)
205208
return embeddings
206209

210+
207211
class FastEmbedVectorizer(BaseSentenceVectorizer):
208212
"""Sentence vectorizer implementaion using FastEmbed - https://qdrant.github.io/fastembed."""
209213

@@ -247,4 +251,58 @@ def __call__(self, inp_examples: List["Example"]) -> np.ndarray:
247251
texts_to_vectorize = self._extract_text_from_examples(inp_examples)
248252
embeddings = self._model.embed(texts_to_vectorize, batch_size=self._batch_size, parallel=self._parallel)
249253

250-
return np.array([embedding.tolist() for embedding in embeddings], dtype=np.float32)
254+
return np.array([embedding.tolist() for embedding in embeddings], dtype=np.float32)
255+
256+
257+
class PremAIVectorizer(BaseSentenceVectorizer):
258+
"""The PremAIVectorizer class utilizes the PremAI Embeddings API to convert text into embeddings.
259+
This vectorizer leverages various models provided by PremAI.
260+
261+
For detailed information on the supported models, visit: https://docs.premai.io/get-started/supported-models.
262+
263+
The `project_id` is a mandatory argument, while `api_key` and `model_name` are optional. The `api_key`
264+
can be supplied either as an argument or through an environment variable. By default, the `model_name`
265+
is set to "text-embedding-3-large", unless specified otherwise.
266+
267+
To learn more about getting started with PremAI, visit: https://docs.premai.io/introduction.
268+
"""
269+
270+
def __init__(
271+
self,
272+
project_id: str,
273+
api_key: Optional[str] = None,
274+
model_name: Optional[str] = "text-embedding-3-large",
275+
embed_batch_size: int = 32,
276+
):
277+
self.model_name, self.project_id = model_name, project_id
278+
self.embed_batch_size = embed_batch_size
279+
280+
try:
281+
from premai import Prem
282+
283+
from dsp.modules.premai import get_premai_api_key
284+
285+
api_key = get_premai_api_key(api_key=api_key)
286+
self.client = Prem(api_key=api_key)
287+
except ImportError as error:
288+
raise ImportError("Please install premai package using: pip install premai") from error
289+
290+
def __call__(self, inp_examples: List["Example"]) -> np.ndarray:
291+
text_to_vectorize = self._extract_text_from_examples(inp_examples)
292+
embedding_list = []
293+
294+
n_batches = (len(text_to_vectorize) - 1) // self.embed_batch_size + 1
295+
for cur_batch_idx in range(n_batches):
296+
start_idx = cur_batch_idx * self.embed_batch_size
297+
end_idx = (cur_batch_idx + 1) * self.embed_batch_size
298+
current_batch = text_to_vectorize[start_idx:end_idx]
299+
embeddings = self.client.embeddings.create(
300+
project_id=self.project_id,
301+
model=self.model_name,
302+
input=current_batch,
303+
).data
304+
current_batch_embeddings = [embedding.embedding for embedding in embeddings]
305+
embedding_list.extend(current_batch_embeddings)
306+
307+
embeddings = np.array(embedding_list, dtype=np.float32)
308+
return embeddings

0 commit comments

Comments
 (0)