Skip to content

Commit c6219ca

Browse files
Merge pull request stanfordnlp#1188 from marshmellow77/fix-vertexai-gemini-candidate-count
Ensuring that candidate count for Gemini is always 1
2 parents 39cda49 + e3ba07b commit c6219ca

File tree

1 file changed

+58
-33
lines changed

1 file changed

+58
-33
lines changed

dsp/modules/googlevertexai.py

Lines changed: 58 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Module for interacting with Google Vertex AI."""
2+
23
from typing import Any, Dict
34

45
import backoff
@@ -16,9 +17,11 @@
1617

1718
def backoff_hdlr(details):
1819
"""Handler from https://pypi.org/project/backoff/"""
19-
print(f"Backing off {details['wait']:0.1f} seconds after {details['tries']} tries "
20-
f"calling function {details['target']} with kwargs "
21-
f"{details['kwargs']}")
20+
print(
21+
f"Backing off {details['wait']:0.1f} seconds after {details['tries']} tries "
22+
f"calling function {details['target']} with kwargs "
23+
f"{details['kwargs']}",
24+
)
2225

2326

2427
def giveup_hdlr(details):
@@ -27,14 +30,17 @@ def giveup_hdlr(details):
2730
return False
2831
return True
2932

33+
3034
class GoogleVertexAI(LM):
3135
"""Wrapper around GoogleVertexAI's API.
3236
3337
Currently supported models include `gemini-pro-1.0`.
3438
"""
3539

3640
def __init__(
37-
self, model: str = "text-bison@002", **kwargs,
41+
self,
42+
model: str = "text-bison@002",
43+
**kwargs,
3844
):
3945
"""
4046
Parameters
@@ -54,40 +60,42 @@ def __init__(
5460
if "code" in model:
5561
model_cls = CodeGenerationModel
5662
self.available_args = {
57-
'suffix',
58-
'max_output_tokens',
59-
'temperature',
60-
'stop_sequences',
61-
'candidate_count',
63+
"suffix",
64+
"max_output_tokens",
65+
"temperature",
66+
"stop_sequences",
67+
"candidate_count",
6268
}
6369
elif "gemini" in model:
6470
model_cls = GenerativeModel
6571
self.available_args = {
66-
'max_output_tokens',
67-
'temperature',
68-
'top_k',
69-
'top_p',
70-
'stop_sequences',
71-
'candidate_count',
72+
"max_output_tokens",
73+
"temperature",
74+
"top_k",
75+
"top_p",
76+
"stop_sequences",
77+
"candidate_count",
7278
}
73-
elif 'text' in model:
79+
elif "text" in model:
7480
model_cls = TextGenerationModel
7581
self.available_args = {
76-
'max_output_tokens',
77-
'temperature',
78-
'top_k',
79-
'top_p',
80-
'stop_sequences',
81-
'candidate_count',
82+
"max_output_tokens",
83+
"temperature",
84+
"top_k",
85+
"top_p",
86+
"stop_sequences",
87+
"candidate_count",
8288
}
8389
else:
8490
raise PydanticCustomError(
85-
'model',
91+
"model",
8692
'model name is not valid, got "{model_name}"',
8793
dict(wrong_value=model),
8894
)
8995
if self._is_gemini:
90-
self.client = model_cls(model_name=model, safety_settings=kwargs.get('safety_settings')) # pylint: disable=unexpected-keyword-arg,no-value-for-parameter
96+
self.client = model_cls(
97+
model_name=model, safety_settings=kwargs.get("safety_settings"),
98+
) # pylint: disable=unexpected-keyword-arg,no-value-for-parameter
9199
else:
92100
self.client = model_cls.from_pretrained(model)
93101
self.provider = "googlevertexai"
@@ -113,10 +121,18 @@ def _prepare_params(
113121
self,
114122
parameters: Any,
115123
) -> dict:
116-
stop_sequences = parameters.get('stop')
117-
params_mapping = {"n": "candidate_count", 'max_tokens':'max_output_tokens'}
124+
stop_sequences = parameters.get("stop")
125+
params_mapping = {"n": "candidate_count", "max_tokens": "max_output_tokens"}
118126
params = {params_mapping.get(k, k): v for k, v in parameters.items()}
119127
params = {**self.kwargs, "stop_sequences": stop_sequences, **params}
128+
129+
if self._is_gemini:
130+
if "candidate_count" in params and params["candidate_count"] != 1:
131+
print(
132+
f"As of now, Gemini only supports `candidate_count == 1` (see also https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/gemini#parameters). The current value for candidate_count of {params['candidate_count']} will be overridden.",
133+
)
134+
params["candidate_count"] = 1
135+
120136
return {k: params[k] for k in set(params.keys()) & self.available_args}
121137

122138
def basic_request(self, prompt: str, **kwargs):
@@ -131,11 +147,15 @@ def basic_request(self, prompt: str, **kwargs):
131147
"prompt": prompt,
132148
"response": {
133149
"prompt": prompt,
134-
"choices": [{
135-
"text": '\n'.join(v.text for v in c.content.parts),
136-
'safetyAttributes': {v.category: v.probability for v in c.safety_ratings},
150+
"choices": [
151+
{
152+
"text": "\n".join(v.text for v in c.content.parts),
153+
"safetyAttributes": {
154+
v.category: v.probability for v in c.safety_ratings
155+
},
137156
}
138-
for c in response.candidates],
157+
for c in response.candidates
158+
],
139159
},
140160
"kwargs": kwargs,
141161
"raw_kwargs": raw_kwargs,
@@ -146,15 +166,20 @@ def basic_request(self, prompt: str, **kwargs):
146166
"prompt": prompt,
147167
"response": {
148168
"prompt": prompt,
149-
"choices": [{"text": c["content"], 'safetyAttributes': c['safetyAttributes']}
150-
for c in response.predictions],
169+
"choices": [
170+
{
171+
"text": c["content"],
172+
"safetyAttributes": c["safetyAttributes"],
173+
}
174+
for c in response.predictions
175+
],
151176
},
152177
"kwargs": kwargs,
153178
"raw_kwargs": raw_kwargs,
154179
}
155180
self.history.append(history)
156181

157-
return [i['text'] for i in history['response']['choices']]
182+
return [i["text"] for i in history["response"]["choices"]]
158183

159184
@backoff.on_exception(
160185
backoff.expo,

0 commit comments

Comments
 (0)