Skip to content

Commit fc58b9c

Browse files
Merge pull request stanfordnlp#1068 from sky-2002/epsilla-retriever
feat(dspy): Add epsilla retriever
2 parents 833ded7 + 41c399f commit fc58b9c

File tree

3 files changed

+117
-1
lines changed

3 files changed

+117
-1
lines changed

dspy/retrieve/epsilla_rm.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from collections import defaultdict # noqa: F401
2+
from typing import Dict, List, Union # noqa: UP035
3+
4+
import dspy
5+
from dsp.utils import dotdict
6+
7+
try:
8+
from pyepsilla import vectordb
9+
except ImportError:
10+
raise ImportError( # noqa: B904
11+
"The 'pyepsilla' extra is required to use EpsillaRM. Install it with `pip install dspy-ai[epsilla]`",
12+
)
13+
14+
15+
class EpsillaRM(dspy.Retrieve):
16+
def __init__(
17+
self,
18+
epsilla_client: vectordb.Client,
19+
db_name: str,
20+
db_path: str,
21+
table_name: str,
22+
k: int = 3,
23+
page_content: str = "document",
24+
):
25+
self._epsilla_client = epsilla_client
26+
self._epsilla_client.load_db(db_name=db_name, db_path=db_path)
27+
self._epsilla_client.use_db(db_name=db_name)
28+
self.page_content = page_content
29+
self.table_name = table_name
30+
31+
super().__init__(k=k)
32+
33+
def forward(self, query_or_queries: Union[str, List[str]], k: Union[int, None] = None, **kwargs) -> dspy.Prediction: # noqa: ARG002
34+
queries = [query_or_queries] if isinstance(query_or_queries, str) else query_or_queries
35+
queries = [q for q in queries if q]
36+
limit = k if k else self.k
37+
all_query_results: list = []
38+
39+
passages: Dict = defaultdict(float)
40+
41+
for result_dict in all_query_results:
42+
for result in result_dict:
43+
passages[result[self.page_content]] += result["@distance"]
44+
sorted_passages = sorted(passages.items(), key=lambda x: x[1], reverse=False)[:limit]
45+
return dspy.Prediction(passages=[dotdict({"long_text": passage}) for passage, _ in sorted_passages])

poetry.lock

Lines changed: 69 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ anthropic = ["anthropic~=0.18.0"]
4545
chromadb = ["chromadb~=0.4.14"]
4646
qdrant = ["qdrant-client>=1.6.2", "fastembed>=0.2.0"]
4747
marqo = ["marqo"]
48+
epsilla = ["pyepsilla~=0.3.7"]
4849
pinecone = ["pinecone-client~=2.2.4"]
4950
weaviate = ["weaviate-client~=4.5.4"]
5051
milvus = ["pymilvus~=2.3.7"]
@@ -99,6 +100,7 @@ anthropic = { version = "^0.18.0", optional = true }
99100
chromadb = { version = "^0.4.14", optional = true }
100101
fastembed = { version = ">=0.2.0", optional = true }
101102
marqo = { version = "*", optional = true }
103+
pyepsilla = {version = "^0.3.7", optional = true}
102104
qdrant-client = { version = "^1.6.2", optional = true }
103105
pinecone-client = { version = "^2.2.4", optional = true }
104106
weaviate-client = { version = "^4.5.4", optional = true }
@@ -140,6 +142,7 @@ ipykernel = "^6.29.4"
140142
chromadb = ["chromadb"]
141143
qdrant = ["qdrant-client", "fastembed"]
142144
marqo = ["marqo"]
145+
epsilla = ["pyepsilla"]
143146
pinecone = ["pinecone-client"]
144147
weaviate = ["weaviate-client"]
145148
milvus = ["pymilvus"]

0 commit comments

Comments
 (0)