Skip to content

Commit 804a974

Browse files
Merge pull request stanfordnlp#1113 from Preemo-Inc/main
fix(dspy): together client response parsing
2 parents 3ac7b16 + 285dc90 commit 804a974

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

dsp/modules/hf_client.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -307,11 +307,13 @@ def run_server(self, port, model_name=None, model_path=None, env_variable=None,
307307
docker_process.wait()
308308

309309
class Together(HFModel):
310-
def __init__(self, model, **kwargs):
310+
def __init__(self, model, api_base="https://api.together.xyz/v1", api_key=None, **kwargs):
311311
super().__init__(model=model, is_client=True)
312312
self.session = requests.Session()
313-
self.api_base = os.getenv("TOGETHER_API_BASE")
314-
self.token = os.getenv("TOGETHER_API_KEY")
313+
self.api_base = os.getenv("TOGETHER_API_BASE") or api_base
314+
assert not self.api_base.endswith("/"), "Together base URL shouldn't end with /"
315+
self.token = os.getenv("TOGETHER_API_KEY") or api_key
316+
315317
self.model = model
316318

317319
self.use_inst_template = False
@@ -338,8 +340,6 @@ def __init__(self, model, **kwargs):
338340
on_backoff=backoff_hdlr,
339341
)
340342
def _generate(self, prompt, use_chat_api=False, **kwargs):
341-
url = f"{self.api_base}"
342-
343343
kwargs = {**self.kwargs, **kwargs}
344344

345345
stop = kwargs.get("stop")
@@ -367,6 +367,7 @@ def _generate(self, prompt, use_chat_api=False, **kwargs):
367367
"stop": stop,
368368
}
369369
else:
370+
url = f"{self.api_base}/completions"
370371
body = {
371372
"model": self.model,
372373
"prompt": prompt,
@@ -384,9 +385,9 @@ def _generate(self, prompt, use_chat_api=False, **kwargs):
384385
with self.session.post(url, headers=headers, json=body) as resp:
385386
resp_json = resp.json()
386387
if use_chat_api:
387-
completions = [resp_json['output'].get('choices', [])[0].get('message', {}).get('content', "")]
388+
completions = [resp_json.get('choices', [])[0].get('message', {}).get('content', "")]
388389
else:
389-
completions = [resp_json['output'].get('choices', [])[0].get('text', "")]
390+
completions = [resp_json.get('choices', [])[0].get('text', "")]
390391
response = {"prompt": prompt, "choices": [{"text": c} for c in completions]}
391392
return response
392393
except Exception as e:

0 commit comments

Comments
 (0)