@@ -14,21 +14,25 @@ def __init__(self, config: dict):
14
14
"""
15
15
openai .api_key = config ['api_key' ]
16
16
self .config = config
17
- self .initial_history = [{"role" : "system" , "content" : config ['assistant_prompt' ]}]
18
- self .history = self .initial_history
17
+ self .sessions : dict [int : list ] = dict () # {chat_id: history}
19
18
20
- def get_response (self , query ) -> str :
19
+
20
+ def get_response (self , chat_id : int , query : str ) -> str :
21
21
"""
22
22
Gets a response from the GPT-3 model.
23
+ :param chat_id: The chat ID
23
24
:param query: The query to send to the model
24
25
:return: The answer from the model
25
26
"""
26
27
try :
27
- self .history .append ({"role" : "user" , "content" : query })
28
+ if chat_id not in self .sessions :
29
+ self .reset_history (chat_id )
30
+
31
+ self .__add_to_history (chat_id , role = "user" , content = query )
28
32
29
33
response = openai .ChatCompletion .create (
30
34
model = self .config ['model' ],
31
- messages = self .history ,
35
+ messages = self .sessions [ chat_id ] ,
32
36
temperature = self .config ['temperature' ],
33
37
n = self .config ['n_choices' ],
34
38
max_tokens = self .config ['max_tokens' ],
@@ -42,13 +46,13 @@ def get_response(self, query) -> str:
42
46
if len (response .choices ) > 1 and self .config ['n_choices' ] > 1 :
43
47
for index , choice in enumerate (response .choices ):
44
48
if index == 0 :
45
- self .history . append ({ " role" : " assistant" , " content" : choice ['message' ]['content' ]} )
49
+ self .__add_to_history ( chat_id , role = " assistant" , content = choice ['message' ]['content' ])
46
50
answer += f'{ index + 1 } \u20e3 \n '
47
51
answer += choice ['message' ]['content' ]
48
52
answer += '\n \n '
49
53
else :
50
54
answer = response .choices [0 ]['message' ]['content' ]
51
- self .history . append ({ " role" : " assistant" , " content" : answer } )
55
+ self .__add_to_history ( chat_id , role = " assistant" , content = answer )
52
56
53
57
if self .config ['show_usage' ]:
54
58
answer += "\n \n ---\n " \
@@ -63,7 +67,7 @@ def get_response(self, query) -> str:
63
67
64
68
except openai .error .RateLimitError as e :
65
69
logging .exception (e )
66
- return "⚠️ _OpenAI RateLimit exceeded_ ⚠️\n Please try again in a while. "
70
+ return f "⚠️ _OpenAI Rate Limit exceeded_ ⚠️\n { str ( e ) } "
67
71
68
72
except openai .error .InvalidRequestError as e :
69
73
logging .exception (e )
@@ -73,8 +77,19 @@ def get_response(self, query) -> str:
73
77
logging .exception (e )
74
78
return f"⚠️ _An error has occurred_ ⚠️\n { str (e )} "
75
79
76
- def reset_history (self ):
80
+
81
+ def reset_history (self , chat_id ):
77
82
"""
78
83
Resets the conversation history.
79
84
"""
80
- self .history = self .initial_history
85
+ self .sessions [chat_id ] = [{"role" : "system" , "content" : self .config ['assistant_prompt' ]}]
86
+
87
+
88
+ def __add_to_history (self , chat_id , role , content ):
89
+ """
90
+ Adds a message to the conversation history.
91
+ :param chat_id: The chat ID
92
+ :param role: The role of the message sender
93
+ :param content: The message content
94
+ """
95
+ self .sessions [chat_id ].append ({"role" : role , "content" : content })
0 commit comments