Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions dsp/modules/hf_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,13 +130,15 @@ def __init__(self, model, port, model_type: Literal['chat', 'text'] = 'text', ur
else:
raise ValueError(f"The url provided to `HFClientVLLM` is neither a string nor a list of strings. It is of type {type(url)}.")

self.urls_const = tuple(self.urls)
self.port = port
self.model_type = model_type
self.headers = {"Content-Type": "application/json"}
self.kwargs |= kwargs
# kwargs needs to have model, port and url for the lm.copy() to work properly
self.kwargs.update({
'port': port,
'url': url,
'url': self.urls_const,
})


Expand All @@ -157,8 +159,10 @@ def _generate(self, prompt, **kwargs):
"messages": messages,
**kwargs,
}
response = send_hfvllm_chat_request_v00(
response = send_hfvllm_request_v01_wrapped(
f"{url}/v1/chat/completions",
url=self.urls_const,
port=self.port,
json=payload,
headers=self.headers,
)
Expand All @@ -181,9 +185,11 @@ def _generate(self, prompt, **kwargs):
"prompt": prompt,
**kwargs,
}
response = send_hfvllm_request_v00(

response = send_hfvllm_request_v01_wrapped(
f"{url}/v1/completions",
url=self.urls_const,
port=self.port,
json=payload,
headers=self.headers,
)
Expand All @@ -201,13 +207,20 @@ def _generate(self, prompt, **kwargs):
print("Failed to parse JSON response:", response.text)
raise Exception("Received invalid JSON response from server")


@CacheMemory.cache(ignore=['arg'])
def send_hfvllm_request_v00(arg, **kwargs):
def send_hfvllm_request_v01(arg, url, port, **kwargs):
return requests.post(arg, **kwargs)

# @functools.lru_cache(maxsize=None if cache_turn_on else 0)
@NotebookCacheMemory.cache(ignore=['arg'])
def send_hfvllm_request_v01_wrapped(arg, url, port, **kwargs):
return send_hftgi_request_v01(arg, url, port, **kwargs)

@CacheMemory.cache(ignore=['arg'])
@CacheMemory.cache
def send_hfvllm_request_v00(arg, **kwargs):
return requests.post(arg, **kwargs)

@CacheMemory.cache
def send_hfvllm_chat_request_v00(arg, **kwargs):
return requests.post(arg, **kwargs)

Expand Down