Skip to content

Commit 081b637

Browse files
Merge pull request stanfordnlp#1168 from efenocchi/main
fix(dspy): fixed bug in deeplake_rm retriever part
2 parents 05a4923 + 386aa53 commit 081b637

File tree

1 file changed

+15
-19
lines changed

1 file changed

+15
-19
lines changed

dspy/retrieve/deeplake_rm.py

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,14 @@
1111
from dsp.utils import dotdict
1212

1313
try:
14-
import openai.error
14+
import openai
1515

1616
ERRORS = (
17-
openai.error.RateLimitError,
18-
openai.error.ServiceUnavailableError,
19-
openai.error.APIError,
17+
openai.RateLimitError,
18+
openai.APIError,
2019
)
2120
except Exception:
22-
ERRORS = (openai.error.RateLimitError, openai.error.APIError)
21+
ERRORS = (openai.RateLimitError, openai.APIError)
2322

2423

2524
class DeeplakeRM(dspy.Retrieve):
@@ -58,13 +57,15 @@ def __init__(
5857
k: int = 3,
5958
):
6059
try:
61-
from deeplake import VectorStore
60+
from deeplake import VectorStore
6261
except ImportError:
63-
raise ImportError(
64-
"The 'deeplake' extra is required to use DeepLakeRM. Install it with `pip install dspy-ai[deeplake]`",
65-
)
62+
raise ImportError("The 'deeplake' extra is required to use DeepLakeRM. Install it with `pip install dspy-ai[deeplake]`",)
63+
6664
self._deeplake_vectorstore_name = deeplake_vectorstore_name
67-
self._deeplake_client = deeplake_client
65+
self._deeplake_client = deeplake_client(
66+
path=self._deeplake_vectorstore_name,
67+
embedding_function=self.embedding_function,
68+
)
6869

6970
super().__init__(k=k)
7071

@@ -73,11 +74,9 @@ def embedding_function(self, texts, model="text-embedding-ada-002"):
7374
texts = [texts]
7475

7576
texts = [t.replace("\n", " ") for t in texts]
76-
return [
77-
data["embedding"]
78-
for data in openai.Embedding.create(input=texts, model=model)["data"]
79-
]
80-
77+
78+
return [data.embedding for data in openai.embeddings.create(input = texts, model=model).data]
79+
8180
def forward(
8281
self, query_or_queries: Union[str, List[str]], k: Optional[int],**kwargs,
8382
) -> dspy.Prediction:
@@ -103,10 +102,7 @@ def forward(
103102
passages = defaultdict(float)
104103
#deeplake doesn't support batch querying, manually querying each query and storing them
105104
for query in queries:
106-
results = self._deeplake_client(
107-
path=self._deeplake_vectorstore_name,
108-
embedding_function=self.embedding_function,
109-
).search(query, k=k,**kwargs)
105+
results = self._deeplake_client.search(query, k=k, **kwargs)
110106

111107
for score,text in zip(results.get('score',0.0),results.get('text',"")):
112108
passages[text] += score

0 commit comments

Comments
 (0)