Skip to content

Commit 833ded7

Browse files
Merge pull request stanfordnlp#791 from usamajamil43/feature/add-MyScale-to-retrieve
feature(dspy): Add MyScale in Retrieve
2 parents 300f0ed + a0463cb commit 833ded7

File tree

4 files changed

+311
-3
lines changed

4 files changed

+311
-3
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,10 @@ Or open our intro notebook in Google Colab: [<img align="center" src="https://co
7272

7373
By default, DSPy installs the latest `openai` from pip. However, if you install old version before OpenAI changed their API `openai~=0.28.1`, the library will use that just fine. Both are supported.
7474

75-
For the optional (alphabetically sorted) [Chromadb](https://github.com/chroma-core/chroma), [Groq](https://github.com/groq/groq-python), [Qdrant](https://github.com/qdrant/qdrant), [Marqo](https://github.com/marqo-ai/marqo), Pinecone, [Snowflake](https://github.com/snowflakedb/snowpark-python) [Weaviate](https://github.com/weaviate/weaviate),
76-
or [Milvus](https://github.com/milvus-io/milvus) retrieval integration(s), include the extra(s) below:
75+
For the optional (alphabetically sorted) [Chromadb](https://github.com/chroma-core/chroma), [Groq](https://github.com/groq/groq-python), [Marqo](https://github.com/marqo-ai/marqo), [Milvus](https://github.com/milvus-io/milvus), [MongoDB](https://www.mongodb.com), [MyScaleDB](https://github.com/myscale/myscaledb), Pinecone, [Qdrant](https://github.com/qdrant/qdrant), [Snowflake](https://github.com/snowflakedb/snowpark-python), or [Weaviate](https://github.com/weaviate/weaviate) retrieval integration(s), include the extra(s) below:
7776

7877
```
79-
pip install dspy-ai[chromadb] # or [groq] or [qdrant] or [marqo] or [mongodb] or [pinecone] or [snowflake] or [weaviate] or [milvus]
78+
pip install dspy-ai[chromadb] # or [groq] or [marqo] or [milvus] or [mongodb] or [myscale] or [pinecone] or [qdrant] or [snowflake] or [weaviate]
8079
```
8180

8281
## 2) Documentation
@@ -106,6 +105,7 @@ The DSPy documentation is divided into **tutorials** (step-by-step illustration
106105
- **Tracing in DSPy** with Arize Phoenix: [Tutorial for tracing your prompts and the steps of your DSPy programs](https://colab.research.google.com/github/Arize-ai/phoenix/blob/main/tutorials/tracing/dspy_tracing_tutorial.ipynb)
107106
- [DSPy: Not Your Average Prompt Engineering](https://jina.ai/news/dspy-not-your-average-prompt-engineering), why it's crucial for future prompt engineering, and yet why it is challenging for prompt engineers to learn.
108107
- **Tracing & Optimization Tracking in DSPy** with Parea AI: [Tutorial on tracing & evaluating a DSPy RAG program](https://docs.parea.ai/tutorials/dspy-rag-trace-evaluate/tutorial)
108+
- [DSPy: Not Your Average Prompt Engineering](https://jina.ai/news/dspy-not-your-average-prompt-engineering), why it's crucial for future prompt engineering, and yet why it is challenging for prompt engineers to learn.
109109

110110
### B) Guides
111111

docs/api/MyScaleRM.md

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
---
2+
sidebar_position: 8
3+
---
4+
5+
# retrieve.MyScaleRM
6+
## Constructor
7+
8+
Initializes an instance of the `MyScaleRM` class, which is designed to use MyScaleDB (a ClickHouse fork optimized for vector similarity and full-text search) to retrieve documents based on query embeddings. This class supports embedding generation using either local models or OpenAI's API and manages database interactions efficiently.
9+
10+
### Syntax
11+
```python
12+
MyScaleRM(
13+
client: clickhouse_connect.driver.client.Client,
14+
table: str,
15+
database: str = 'default',
16+
metadata_columns: List[str] = ['text'],
17+
vector_column: str = 'vector',
18+
k: int = 3,
19+
openai_api_key: Optional[str] = None,
20+
openai_model: Optional[str] = None,
21+
local_embed_model: Optional[str] = None
22+
)
23+
```
24+
## Parameters for `MyScaleRM` Constructor
25+
- `client` (_clickhouse_connect.driver.client.Client_): A client connection to the MyScaleDB database, used to execute queries and manage interactions with the database.
26+
- `table` (_str_): Specifies the table within MyScaleDB from which data will be retrieved. This table should be equipped with a vector column for conducting similarity searches.
27+
- `database` (_str_, optional): The name of the database where the table is located, defaulting to `"default"`.
28+
- `metadata_columns` (_List[str], optional_): Columns to include as metadata in the output, defaulting to `["text"]`.
29+
- `vector_column` (_str, optional_): The column that contains vector data, used for similarity searches, defaulting to `"vector"`.
30+
- `k` (_int, optional_): The number of closest matches to return for each query, defaulting to 3.
31+
- `openai_api_key` (_str, optional_): API key for accessing OpenAI services, necessary if using OpenAI for embedding generation.
32+
- `openai_model` (_str, optional_): The specific OpenAI model to use for embeddings, required if an OpenAI API key is provided.
33+
- `local_embed_model` (_str, optional_): Specifies a local model for embedding generation, chosen if local computation is preferred.
34+
35+
## Methods
36+
### `forward`
37+
Executes a retrieval operation based on a user's query and returns the top `k` relevant results using the embeddings generated by the specified method.
38+
39+
### Syntax
40+
```python
41+
def forward(self, user_query: str, k: Optional[int] = None) -> dspy.Prediction
42+
```
43+
44+
## Parameters
45+
- `user_query` (_str_): The query or list of queries for which to retrieve matching passages.
46+
- `k` (_Optional[int], optional_): The number of top matches to retrieve. If not provided, it defaults to the `k` value set during class initialization.
47+
48+
## Returns
49+
- `dspy.Prediction`: Contains the retrieved passages, formatted as a list of `dotdict` objects. Each entry includes:
50+
- **long_text (str)**: The text content of the retrieved passage.
51+
52+
## Description
53+
54+
The `forward` method leverages the MyScaleDB's vector search capabilities to find the top `k` passages that best match the provided query. This method is integral for utilizing the MyScaleRM class to access and retrieve data efficiently based on semantic similarity, facilitated by the chosen embedding generation technique (either via a local model or the OpenAI API).
55+
56+
## Quickstart
57+
58+
This section provides practical examples of how to instantiate and use the `MyScaleRM` class to retrieve data from MyScaleDB efficiently using text embeddings.
59+
60+
```python
61+
from dspy.retrieve.myscaledb_rm import MyScaleRM
62+
63+
MyScale_model = MyScaleRM(client=client,
64+
table="table_name",
65+
openai_api_key="sk-***",
66+
openai_model="embeddings_model",
67+
vector_column="vector_column_name",
68+
metadata_columns=["add_your_columns_here"],
69+
k=6)
70+
71+
MyScale_model("Please suggest me some funny movies")
72+
73+
passages = results.passages
74+
75+
# Loop through each passage and print the 'long_text'
76+
for passage in passages:
77+
print(passage['long_text'], "\n")
78+
79+
```

dspy/retrieve/MyScaleRM.py

Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
import functools
2+
import os
3+
from typing import List, Optional
4+
5+
import openai
6+
7+
import dspy
8+
from dsp.modules.cache_utils import CacheMemory, NotebookCacheMemory, cache_turn_on
9+
from dsp.utils import dotdict
10+
11+
# Check for necessary libraries and suggest installation if not found.
12+
try:
13+
import clickhouse_connect
14+
except ImportError:
15+
raise ImportError(
16+
"The 'myscale' extra is required to use MyScaleRM. Install it with `pip install dspy-ai[myscale]`",
17+
)
18+
19+
# Verify the compatibility of the OpenAI library version installed.
20+
try:
21+
major, minor, _ = map(int, openai.__version__.split('.'))
22+
OPENAI_VERSION_COMPATIBLE = major >= 1 and minor >= 16
23+
except Exception:
24+
OPENAI_VERSION_COMPATIBLE = False
25+
26+
if not OPENAI_VERSION_COMPATIBLE:
27+
raise ImportError(
28+
"An incompatible OpenAI library version is installed. Ensure you have version 1.16.1 or later.",
29+
)
30+
31+
# Attempt to handle specific OpenAI errors; fallback to general ones if necessary.
32+
try:
33+
import openai.error
34+
ERRORS = (openai.error.RateLimitError, openai.error.ServiceUnavailableError, openai.error.APIError)
35+
except Exception:
36+
ERRORS = (openai.RateLimitError, openai.APIError)
37+
38+
39+
class MyScaleRM(dspy.Retrieve):
40+
"""
41+
A retrieval module that uses MyScaleDB to return the top passages for a given query.
42+
43+
MyScaleDB is a fork of ClickHouse that focuses on vector similarity search and full
44+
text search. MyScaleRM is designed to facilitate easy retrieval of information from
45+
MyScaleDB using embeddings. It supports embedding generation through either a local
46+
model or the OpenAI API. This class abstracts away the complexities of connecting to
47+
MyScaleDB, managing API keys, and processing queries to return semantically
48+
relevant results.
49+
50+
Assumes that a table named `database.table` exists in MyScaleDB, and that the
51+
table has column named `vector_column` that stores vector data and a vector index has
52+
been created on this column. Other metadata are stored in `metadata_columns`.
53+
54+
Args:
55+
client (clickhouse_connect.driver.client.Client): A client connection to the MyScaleDB.
56+
table (str): Name of the table within the database to perform queries against.
57+
database (str, optional): Name of the database to query within MyScaleDB.
58+
metadata_columns(List[str], optional): A list of columns to include in the results.
59+
vector_column (str, optional): The name of the column in the table that stores vector data.
60+
k (int, optional): The number of closest matches to retrieve for a given query.
61+
openai_api_key (str, optional): The API key for accessing OpenAI's services.
62+
model (str, optional): Specifies the particular OpenAI model to use for embedding generation.
63+
use_local_model (bool): Flag indicating whether a local model is used for embeddings.
64+
65+
"""
66+
67+
def __init__(self,
68+
client: clickhouse_connect.driver.client.Client,
69+
table: str,
70+
database: str = "default",
71+
metadata_columns: List[str] = ["text"],
72+
vector_column: str = "vector",
73+
k: int = 3,
74+
openai_api_key: Optional[str] = None,
75+
openai_model: Optional[str] = None,
76+
local_embed_model: Optional[str] = None):
77+
self.client = client
78+
self.database = database
79+
self.table = table
80+
if not metadata_columns:
81+
raise ValueError("metadata_columns is required")
82+
self.metadata_columns = metadata_columns
83+
self.vector_column = vector_column
84+
self.k = k
85+
self.openai_api_key = openai_api_key
86+
self.model = openai_model
87+
self.use_local_model = False
88+
89+
if local_embed_model:
90+
self.setup_local_model(local_embed_model)
91+
elif openai_api_key:
92+
os.environ['OPENAI_API_KEY'] = self.openai_api_key
93+
94+
def setup_local_model(self, model_name: str):
95+
"""
96+
Configures a local model for embedding generation, including model and tokenizer loading.
97+
98+
Args:
99+
model_name: The name or path to the pre-trained model to load.
100+
101+
Raises:
102+
ModuleNotFoundError: If necessary libraries (torch or transformers) are not installed.
103+
"""
104+
try:
105+
import torch
106+
from transformers import AutoModel, AutoTokenizer
107+
except ImportError as exc:
108+
raise ModuleNotFoundError(
109+
"""You need to install PyTorch and Hugging Face's transformers library to use a local embedding model.
110+
Install the pytorch using `pip install torch` and transformers using `pip install transformers` """,
111+
) from exc
112+
113+
try:
114+
self._local_embed_model = AutoModel.from_pretrained(model_name)
115+
self._local_tokenizer = AutoTokenizer.from_pretrained(model_name)
116+
self.use_local_model = True
117+
except Exception as e:
118+
raise ValueError(f"Failed to load model or tokenizer. Error: {str(e)}")
119+
120+
if torch.cuda.is_available():
121+
self.device = torch.device('cuda:0')
122+
elif torch.backends.mps.is_available():
123+
self.device = torch.device('mps')
124+
else:
125+
self.device = torch.device('cpu')
126+
127+
self._local_embed_model.to(self.device)
128+
129+
@functools.lru_cache(maxsize=None if cache_turn_on else 0)
130+
@NotebookCacheMemory.cache
131+
def get_embeddings(self, queries: List[str]) -> List[List[float]]:
132+
"""
133+
Determines the appropriate source (OpenAI or local model) for embedding generation based on class configuration,
134+
and retrieves embeddings for the provided queries.
135+
136+
Args:
137+
queries: A list of text queries to generate embeddings for.
138+
139+
Returns:
140+
A list of embeddings, each corresponding to a query in the input list.
141+
142+
Raises:
143+
ValueError: If neither an OpenAI API key nor a local model has been configured.
144+
"""
145+
if self.openai_api_key and self.model:
146+
return self._get_embeddings_from_openai(queries)
147+
elif self.use_local_model:
148+
return self._get_embedding_from_local_model(queries)
149+
else:
150+
raise ValueError("No valid method for obtaining embeddings is configured.")
151+
152+
#TO DO Add this method as Util method outside MyScaleRM
153+
@CacheMemory.cache
154+
def _get_embeddings_from_openai(self, queries: List[str]) -> List[List[float]]:
155+
"""
156+
Uses the OpenAI API to generate embeddings for a list of queries.
157+
158+
Args:
159+
queries: A list of strings for which to generate embeddings.
160+
161+
Returns:
162+
A list of lists, where each inner list contains the embedding of a query.
163+
"""
164+
165+
response = openai.embeddings.create(
166+
model=self.model,
167+
input=queries)
168+
return response.data[0].embedding
169+
170+
#TO DO Add this method as Util method outside MyScaleRM
171+
@CacheMemory.cache
172+
def _get_embedding_from_local_model(self, query: str) -> List[float]:
173+
"""
174+
Generates embeddings for a single query using the configured local model.
175+
176+
Args:
177+
query: The text query to generate an embedding for.
178+
179+
Returns:
180+
A list of floats representing the query's embedding.
181+
"""
182+
import torch
183+
self._local_embed_model.eval() # Ensure the model is in evaluation mode
184+
185+
inputs = self._local_tokenizer(query, return_tensors="pt", padding=True, truncation=True).to(self.device)
186+
with torch.no_grad():
187+
output = self._local_embed_model(**inputs)
188+
embedding = output.last_hidden_state.mean(dim=1).cpu().numpy().tolist()[0]
189+
190+
return embedding
191+
192+
def forward(self, user_query: str, k: Optional[int] = None) -> dspy.Prediction:
193+
"""
194+
Executes a retrieval operation based on a user's query and returns the top k relevant results.
195+
196+
Args:
197+
user_query: The query text to search for.
198+
k: Optional; The number of top matches to return. Defaults to the class's configured k value.
199+
200+
Returns:
201+
A dspy.Prediction object containing the formatted retrieval results.
202+
203+
Raises:
204+
ValueError: If the user_query is None.
205+
"""
206+
if user_query is None:
207+
raise ValueError("Query is required")
208+
k = k if k is not None else self.k
209+
embeddings = self.get_embeddings([user_query])
210+
columns_string = ', '.join(self.metadata_columns)
211+
result = self.client.query(f"""
212+
SELECT {columns_string},
213+
distance({self.vector_column}, {embeddings}) as dist FROM {self.database}.{self.table} ORDER BY dist LIMIT {k}
214+
""")
215+
216+
# We convert the metadata into strings to pass to dspy.Prediction
217+
results = []
218+
for row in result.named_results():
219+
if len(self.metadata_columns) == 1:
220+
results.append(row[self.metadata_columns[0]])
221+
else:
222+
row_strings = [f"{column}: {row[column]}" for column in self.metadata_columns] # Format row data
223+
row_string = "\n".join(row_strings) # Combine formatted data
224+
results.append(row_string) # Append to results
225+
226+
return dspy.Prediction(passages=[dotdict({"long_text": passage}) for passage in results]) # Return results as Prediction

setup.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,11 @@
3232
"faiss-cpu": ["sentence_transformers", "faiss-cpu"],
3333
"milvus": ["pymilvus~=2.3.7"],
3434
"google-vertex-ai": ["google-cloud-aiplatform==1.43.0"],
35+
"myscale":["clickhouse-connect"],
3536
"snowflake": ["snowflake-snowpark-python"],
3637
"fastembed": ["fastembed"],
38+
"google-vertex-ai": ["google-cloud-aiplatform==1.43.0"],
39+
"myscale":["clickhouse-connect"],
3740
"groq": ["groq~=0.8.0"],
3841
},
3942
classifiers=[

0 commit comments

Comments
 (0)