Skip to content

Commit c8405aa

Browse files
authored
Merge pull request stanfordnlp#398 from CShorten/main
Create Google LM
2 parents a124260 + 6c208cb commit c8405aa

File tree

1 file changed

+104
-0
lines changed

1 file changed

+104
-0
lines changed

dsp/modules/google.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import math
2+
from typing import Any, Optional
3+
import backoff
4+
5+
from dsp.modules.lm import LM
6+
7+
try:
8+
import google.generativeai as genai
9+
except ImportError:
10+
google_api_error = Exception
11+
print("Not loading Google because it is not installed.")
12+
13+
def backoff_hdlr(details):
14+
"""Handler from https://pypi.org/project/backoff/"""
15+
print(
16+
"Backing off {wait:0.1f} seconds after {tries} tries "
17+
"calling function {target} with kwargs "
18+
"{kwargs}".format(**details)
19+
)
20+
21+
22+
def giveup_hdlr(details):
23+
"""wrapper function that decides when to give up on retry"""
24+
if "rate limits" in details.message:
25+
return False
26+
return True
27+
28+
29+
class Google(LM):
30+
"""Wrapper around Google's API.
31+
32+
Currently supported models include `gemini-pro-1.0`.
33+
"""
34+
35+
def __init__(
36+
self,
37+
model: str = "gemini-pro-1.0",
38+
api_key: Optional[str] = None,
39+
**kwargs
40+
):
41+
"""
42+
Parameters
43+
----------
44+
model : str
45+
Which pre-trained model from Google to use?
46+
Choices are [`gemini-pro-1.0`]
47+
api_key : str
48+
The API key for Google.
49+
It can be obtained from https://cloud.google.com/generative-ai-studio
50+
**kwargs: dict
51+
Additional arguments to pass to the API provider.
52+
"""
53+
super().__init__(model)
54+
self.google = genai.configure(api_key=self.api_key)
55+
self.provider = "google"
56+
self.kwargs = {
57+
"model_name": model,
58+
"temperature": 0.0 if "temperature" not in kwargs else kwargs["temperature"],
59+
"max_output_tokens": 2048,
60+
"top_p": 1,
61+
"top_k": 1,
62+
**kwargs
63+
}
64+
65+
self.history: list[dict[str, Any]] = []
66+
67+
def basic_request(self, prompt: str, **kwargs):
68+
raw_kwargs = kwargs
69+
kwargs = {
70+
**self.kwargs,
71+
"prompt": prompt,
72+
**kwargs,
73+
}
74+
response = self.co.generate(**kwargs)
75+
76+
history = {
77+
"prompt": prompt,
78+
"response": response,
79+
"kwargs": kwargs,
80+
"raw_kwargs": raw_kwargs,
81+
}
82+
self.history.append(history)
83+
84+
return response
85+
86+
@backoff.on_exception(
87+
backoff.expo,
88+
(google_api_error),
89+
max_time=1000,
90+
on_backoff=backoff_hdlr,
91+
giveup=giveup_hdlr,
92+
)
93+
def request(self, prompt: str, **kwargs):
94+
"""Handles retrieval of completions from Google whilst handling API errors"""
95+
return self.basic_request(prompt, **kwargs)
96+
97+
def __call__(
98+
self,
99+
prompt: str,
100+
only_completed: bool = True,
101+
return_sorted: bool = False,
102+
**kwargs
103+
):
104+
return self.request(prompt, **kwargs)

0 commit comments

Comments
 (0)