Skip to content

Commit 362350b

Browse files
authored
Merge pull request stanfordnlp#881 from stanfordnlp/vllm-cache
Fixed VLLM Cache
2 parents 0643a14 + 8ba9a24 commit 362350b

File tree

1 file changed

+20
-7
lines changed

1 file changed

+20
-7
lines changed

dsp/modules/hf_client.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -130,13 +130,15 @@ def __init__(self, model, port, model_type: Literal['chat', 'text'] = 'text', ur
130130
else:
131131
raise ValueError(f"The url provided to `HFClientVLLM` is neither a string nor a list of strings. It is of type {type(url)}.")
132132

133+
self.urls_const = tuple(self.urls)
134+
self.port = port
133135
self.model_type = model_type
134136
self.headers = {"Content-Type": "application/json"}
135137
self.kwargs |= kwargs
136138
# kwargs needs to have model, port and url for the lm.copy() to work properly
137139
self.kwargs.update({
138140
'port': port,
139-
'url': url,
141+
'url': self.urls_const,
140142
})
141143

142144

@@ -157,8 +159,10 @@ def _generate(self, prompt, **kwargs):
157159
"messages": messages,
158160
**kwargs,
159161
}
160-
response = send_hfvllm_chat_request_v00(
162+
response = send_hfvllm_request_v01_wrapped(
161163
f"{url}/v1/chat/completions",
164+
url=self.urls_const,
165+
port=self.port,
162166
json=payload,
163167
headers=self.headers,
164168
)
@@ -181,9 +185,11 @@ def _generate(self, prompt, **kwargs):
181185
"prompt": prompt,
182186
**kwargs,
183187
}
184-
185-
response = send_hfvllm_request_v00(
188+
189+
response = send_hfvllm_request_v01_wrapped(
186190
f"{url}/v1/completions",
191+
url=self.urls_const,
192+
port=self.port,
187193
json=payload,
188194
headers=self.headers,
189195
)
@@ -201,13 +207,20 @@ def _generate(self, prompt, **kwargs):
201207
print("Failed to parse JSON response:", response.text)
202208
raise Exception("Received invalid JSON response from server")
203209

204-
205210
@CacheMemory.cache(ignore=['arg'])
206-
def send_hfvllm_request_v00(arg, **kwargs):
211+
def send_hfvllm_request_v01(arg, url, port, **kwargs):
207212
return requests.post(arg, **kwargs)
208213

214+
# @functools.lru_cache(maxsize=None if cache_turn_on else 0)
215+
@NotebookCacheMemory.cache(ignore=['arg'])
216+
def send_hfvllm_request_v01_wrapped(arg, url, port, **kwargs):
217+
return send_hftgi_request_v01(arg, url, port, **kwargs)
209218

210-
@CacheMemory.cache(ignore=['arg'])
219+
@CacheMemory.cache
220+
def send_hfvllm_request_v00(arg, **kwargs):
221+
return requests.post(arg, **kwargs)
222+
223+
@CacheMemory.cache
211224
def send_hfvllm_chat_request_v00(arg, **kwargs):
212225
return requests.post(arg, **kwargs)
213226

0 commit comments

Comments
 (0)