Skip to content

Commit 32a08de

Browse files
authored
Merge pull request stanfordnlp#371 from paxcema/fix/anyscale/n_kwarg
fix: manually iterate kwargs['n'] in anyscale
2 parents bb5d0de + 2c11aa0 commit 32a08de

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

dsp/modules/hf_client.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,8 @@ def _generate(self, prompt, use_chat_api=False, **kwargs):
293293
print(f"resp_json:{resp_json}")
294294
print(f"Failed to parse JSON response: {e}")
295295
raise Exception("Received invalid JSON response from server")
296+
297+
296298
class Anyscale(HFModel):
297299
def __init__(self, model, **kwargs):
298300
super().__init__(model=model, is_client=True)
@@ -337,14 +339,16 @@ def _generate(self, prompt, use_chat_api=False, **kwargs):
337339
headers = {"Authorization": f"Bearer {self.token}"}
338340

339341
try:
340-
with self.session.post(url, headers=headers, json=body) as resp:
341-
resp_json = resp.json()
342-
if use_chat_api:
343-
completions = [resp_json.get('choices', [])[0].get('message', {}).get('content', "")]
344-
else:
345-
completions = [resp_json.get('choices', [])[0].get('text', "")]
346-
response = {"prompt": prompt, "choices": [{"text": c} for c in completions]}
347-
return response
342+
completions = []
343+
for i in range(kwargs.get('n', 1)):
344+
with self.session.post(url, headers=headers, json=body) as resp:
345+
resp_json = resp.json()
346+
if use_chat_api:
347+
completions.extend([resp_json.get('choices', [])[0].get('message', {}).get('content', "")])
348+
else:
349+
completions.extend([resp_json.get('choices', [])[0].get('text', "")])
350+
response = {"prompt": prompt, "choices": [{"text": c} for c in completions]}
351+
return response
348352
except Exception as e:
349353
print(f"Failed to parse JSON response: {e}")
350354
raise Exception("Received invalid JSON response from server")

0 commit comments

Comments
 (0)