11"""Module for interacting with Google Vertex AI."""
2+
23from typing import Any , Dict
34
45import backoff
1617
1718def backoff_hdlr (details ):
1819 """Handler from https://pypi.org/project/backoff/"""
19- print (f"Backing off { details ['wait' ]:0.1f} seconds after { details ['tries' ]} tries "
20- f"calling function { details ['target' ]} with kwargs "
21- f"{ details ['kwargs' ]} " )
20+ print (
21+ f"Backing off { details ['wait' ]:0.1f} seconds after { details ['tries' ]} tries "
22+ f"calling function { details ['target' ]} with kwargs "
23+ f"{ details ['kwargs' ]} " ,
24+ )
2225
2326
2427def giveup_hdlr (details ):
@@ -27,14 +30,17 @@ def giveup_hdlr(details):
2730 return False
2831 return True
2932
33+
3034class GoogleVertexAI (LM ):
3135 """Wrapper around GoogleVertexAI's API.
3236
3337 Currently supported models include `gemini-pro-1.0`.
3438 """
3539
3640 def __init__ (
37- self , model : str = "text-bison@002" , ** kwargs ,
41+ self ,
42+ model : str = "text-bison@002" ,
43+ ** kwargs ,
3844 ):
3945 """
4046 Parameters
@@ -54,40 +60,42 @@ def __init__(
5460 if "code" in model :
5561 model_cls = CodeGenerationModel
5662 self .available_args = {
57- ' suffix' ,
58- ' max_output_tokens' ,
59- ' temperature' ,
60- ' stop_sequences' ,
61- ' candidate_count' ,
63+ " suffix" ,
64+ " max_output_tokens" ,
65+ " temperature" ,
66+ " stop_sequences" ,
67+ " candidate_count" ,
6268 }
6369 elif "gemini" in model :
6470 model_cls = GenerativeModel
6571 self .available_args = {
66- ' max_output_tokens' ,
67- ' temperature' ,
68- ' top_k' ,
69- ' top_p' ,
70- ' stop_sequences' ,
71- ' candidate_count' ,
72+ " max_output_tokens" ,
73+ " temperature" ,
74+ " top_k" ,
75+ " top_p" ,
76+ " stop_sequences" ,
77+ " candidate_count" ,
7278 }
73- elif ' text' in model :
79+ elif " text" in model :
7480 model_cls = TextGenerationModel
7581 self .available_args = {
76- ' max_output_tokens' ,
77- ' temperature' ,
78- ' top_k' ,
79- ' top_p' ,
80- ' stop_sequences' ,
81- ' candidate_count' ,
82+ " max_output_tokens" ,
83+ " temperature" ,
84+ " top_k" ,
85+ " top_p" ,
86+ " stop_sequences" ,
87+ " candidate_count" ,
8288 }
8389 else :
8490 raise PydanticCustomError (
85- ' model' ,
91+ " model" ,
8692 'model name is not valid, got "{model_name}"' ,
8793 dict (wrong_value = model ),
8894 )
8995 if self ._is_gemini :
90- self .client = model_cls (model_name = model , safety_settings = kwargs .get ('safety_settings' )) # pylint: disable=unexpected-keyword-arg,no-value-for-parameter
96+ self .client = model_cls (
97+ model_name = model , safety_settings = kwargs .get ("safety_settings" ),
98+ ) # pylint: disable=unexpected-keyword-arg,no-value-for-parameter
9199 else :
92100 self .client = model_cls .from_pretrained (model )
93101 self .provider = "googlevertexai"
@@ -113,10 +121,18 @@ def _prepare_params(
113121 self ,
114122 parameters : Any ,
115123 ) -> dict :
116- stop_sequences = parameters .get (' stop' )
117- params_mapping = {"n" : "candidate_count" , ' max_tokens' : ' max_output_tokens' }
124+ stop_sequences = parameters .get (" stop" )
125+ params_mapping = {"n" : "candidate_count" , " max_tokens" : " max_output_tokens" }
118126 params = {params_mapping .get (k , k ): v for k , v in parameters .items ()}
119127 params = {** self .kwargs , "stop_sequences" : stop_sequences , ** params }
128+
129+ if self ._is_gemini :
130+ if "candidate_count" in params and params ["candidate_count" ] != 1 :
131+ print (
132+ f"As of now, Gemini only supports `candidate_count == 1` (see also https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/gemini#parameters). The current value for candidate_count of { params ['candidate_count' ]} will be overridden." ,
133+ )
134+ params ["candidate_count" ] = 1
135+
120136 return {k : params [k ] for k in set (params .keys ()) & self .available_args }
121137
122138 def basic_request (self , prompt : str , ** kwargs ):
@@ -131,11 +147,15 @@ def basic_request(self, prompt: str, **kwargs):
131147 "prompt" : prompt ,
132148 "response" : {
133149 "prompt" : prompt ,
134- "choices" : [{
135- "text" : '\n ' .join (v .text for v in c .content .parts ),
136- 'safetyAttributes' : {v .category : v .probability for v in c .safety_ratings },
150+ "choices" : [
151+ {
152+ "text" : "\n " .join (v .text for v in c .content .parts ),
153+ "safetyAttributes" : {
154+ v .category : v .probability for v in c .safety_ratings
155+ },
137156 }
138- for c in response .candidates ],
157+ for c in response .candidates
158+ ],
139159 },
140160 "kwargs" : kwargs ,
141161 "raw_kwargs" : raw_kwargs ,
@@ -146,15 +166,20 @@ def basic_request(self, prompt: str, **kwargs):
146166 "prompt" : prompt ,
147167 "response" : {
148168 "prompt" : prompt ,
149- "choices" : [{"text" : c ["content" ], 'safetyAttributes' : c ['safetyAttributes' ]}
150- for c in response .predictions ],
169+ "choices" : [
170+ {
171+ "text" : c ["content" ],
172+ "safetyAttributes" : c ["safetyAttributes" ],
173+ }
174+ for c in response .predictions
175+ ],
151176 },
152177 "kwargs" : kwargs ,
153178 "raw_kwargs" : raw_kwargs ,
154179 }
155180 self .history .append (history )
156181
157- return [i [' text' ] for i in history [' response' ][ ' choices' ]]
182+ return [i [" text" ] for i in history [" response" ][ " choices" ]]
158183
159184 @backoff .on_exception (
160185 backoff .expo ,
0 commit comments