2020
2121from camel .agents import BaseAgent
2222from camel .configs import ChatGPTConfig
23- from camel .messages import ChatMessage , MessageType , SystemMessage
23+ from camel .messages import BaseMessage
2424from camel .models import BaseModelBackend , ModelFactory
2525from camel .typing import ModelType , RoleType
2626from camel .utils import num_tokens_from_messages , openai_api_key_required
@@ -31,15 +31,15 @@ class ChatAgentResponse:
3131 r"""Response of a ChatAgent.
3232
3333 Attributes:
34- msgs (List[ChatMessage ]): A list of zero, one or several messages.
34+ msgs (List[BaseMessage ]): A list of zero, one or several messages.
3535 If the list is empty, there is some error in message generation.
3636 If the list has one message, this is normal mode.
3737 If the list has several messages, this is the critic mode.
3838 terminated (bool): A boolean indicating whether the agent decided
3939 to terminate the chat session.
4040 info (Dict[str, Any]): Extra information about the chat message.
4141 """
42- msgs : List [ChatMessage ]
42+ msgs : List [BaseMessage ]
4343 terminated : bool
4444 info : Dict [str , Any ]
4545
@@ -51,11 +51,32 @@ def msg(self):
5151 return self .msgs [0 ]
5252
5353
54+ @dataclass (frozen = True )
55+ class ChatRecord :
56+ r"""Historical records of who made what message.
57+
58+ Attributes:
59+ role_at_backend (str): Role of the message that mirrors OpenAI
60+ message role that may be `system` or `user` or `assistant`.
61+ message (BaseMessage): Message payload.
62+ """
63+ role_at_backend : str
64+ message : BaseMessage
65+
66+ def to_openai_message (self ):
67+ r"""Converts the payload message to OpenAI-compatible format.
68+
69+ Returns:
70+ OpenAIMessage: OpenAI-compatible message
71+ """
72+ return self .message .to_openai_message (self .role_at_backend )
73+
74+
5475class ChatAgent (BaseAgent ):
5576 r"""Class for managing conversations of CAMEL Chat Agents.
5677
5778 Args:
58- system_message (SystemMessage ): The system message for the chat agent.
79+ system_message (BaseMessage ): The system message for the chat agent.
5980 model (ModelType, optional): The LLM model to use for generating
6081 responses. (default :obj:`ModelType.GPT_3_5_TURBO`)
6182 model_config (Any, optional): Configuration options for the LLM model.
@@ -69,14 +90,14 @@ class ChatAgent(BaseAgent):
6990
7091 def __init__ (
7192 self ,
72- system_message : SystemMessage ,
93+ system_message : BaseMessage ,
7394 model : Optional [ModelType ] = None ,
7495 model_config : Optional [Any ] = None ,
7596 message_window_size : Optional [int ] = None ,
7697 output_language : Optional [str ] = None ,
7798 ) -> None :
7899
79- self .system_message : SystemMessage = system_message
100+ self .system_message : BaseMessage = system_message
80101 self .role_name : str = system_message .role_name
81102 self .role_type : RoleType = system_message .role_type
82103 self .output_language : Optional [str ] = output_language
@@ -93,20 +114,21 @@ def __init__(
93114 self .model_token_limit : int = self .model_backend .token_limit
94115
95116 self .terminated : bool = False
117+ self .stored_messages : List [ChatRecord ]
96118 self .init_messages ()
97119
98- def reset (self ) -> List [MessageType ]:
120+ def reset (self ) -> List [ChatRecord ]:
99121 r"""Resets the :obj:`ChatAgent` to its initial state and returns the
100122 stored messages.
101123
102124 Returns:
103- List[MessageType ]: The stored messages.
125+ List[BaseMessage ]: The stored messages.
104126 """
105127 self .terminated = False
106128 self .init_messages ()
107129 return self .stored_messages
108130
109- def set_output_language (self , output_language : str ) -> SystemMessage :
131+ def set_output_language (self , output_language : str ) -> BaseMessage :
110132 r"""Sets the output language for the system message. This method
111133 updates the output language for the system message. The output
112134 language determines the language in which the output text should be
@@ -116,7 +138,7 @@ def set_output_language(self, output_language: str) -> SystemMessage:
116138 output_language (str): The desired output language.
117139
118140 Returns:
119- SystemMessage : The updated system message object.
141+ BaseMessage : The updated system message object.
120142 """
121143 self .output_language = output_language
122144 content = (self .system_message .content +
@@ -156,32 +178,46 @@ def init_messages(self) -> None:
156178 r"""Initializes the stored messages list with the initial system
157179 message.
158180 """
159- self .stored_messages : List [ MessageType ] = [self .system_message ]
181+ self .stored_messages = [ChatRecord ( 'system' , self .system_message ) ]
160182
161- def update_messages (self , message : ChatMessage ) -> List [MessageType ]:
183+ def update_messages (self , role : str ,
184+ message : BaseMessage ) -> List [ChatRecord ]:
162185 r"""Updates the stored messages list with a new message.
163186
164187 Args:
165- message (ChatMessage ): The new message to add to the stored
188+ message (BaseMessage ): The new message to add to the stored
166189 messages.
167190
168191 Returns:
169- List[ChatMessage ]: The updated stored messages.
192+ List[BaseMessage ]: The updated stored messages.
170193 """
171- self .stored_messages .append (message )
194+ if role not in {'system' , 'user' , 'assistant' }:
195+ raise ValueError (f"Unsupported role { role } " )
196+ self .stored_messages .append (ChatRecord (role , message ))
172197 return self .stored_messages
173198
199+ def submit_message (self , message : BaseMessage ) -> None :
200+ r"""Submits the externaly provided message as if it were an answer of
201+ the chat LLM from the backend. Currently the choise of the critic is
202+ submitted with this method.
203+
204+ Args:
205+ message (BaseMessage): An external message to be added as an
206+ assistant response.
207+ """
208+ self .stored_messages .append (ChatRecord ('assistant' , message ))
209+
174210 @retry (wait = wait_exponential (min = 5 , max = 60 ), stop = stop_after_attempt (5 ))
175211 @openai_api_key_required
176212 def step (
177213 self ,
178- input_message : ChatMessage ,
214+ input_message : BaseMessage ,
179215 ) -> ChatAgentResponse :
180216 r"""Performs a single step in the chat session by generating a response
181217 to the input message.
182218
183219 Args:
184- input_message (ChatMessage ): The input message to the agent.
220+ input_message (BaseMessage ): The input message to the agent.
185221 Its `role` field that specifies the role at backen may be either
186222 `user` or `assistant` but it will be set to `user` anyway since
187223 for the self agent any incoming message is external.
@@ -192,25 +228,25 @@ def step(
192228 the chat session has terminated, and information about the chat
193229 session.
194230 """
195- msg_user_at_backend = input_message .set_user_role_at_backend ()
196- messages = self .update_messages (msg_user_at_backend )
231+ messages = self .update_messages ('user' , input_message )
197232 if self .message_window_size is not None and len (
198233 messages ) > self .message_window_size :
199- messages = [self .system_message
234+ messages = [ChatRecord ( 'system' , self .system_message )
200235 ] + messages [- self .message_window_size :]
201- openai_messages = [message .to_openai_message () for message in messages ]
236+ openai_messages = [record .to_openai_message () for record in messages ]
202237 num_tokens = num_tokens_from_messages (openai_messages , self .model )
203238
204- output_messages : Optional [List [ChatMessage ]]
239+ output_messages : Optional [List [BaseMessage ]]
205240 info : Dict [str , Any ]
206241
207242 if num_tokens < self .model_token_limit :
208243 response = self .model_backend .run (openai_messages )
209244 if not isinstance (response , dict ):
210245 raise RuntimeError ("OpenAI returned unexpected struct" )
211246 output_messages = [
212- ChatMessage (role_name = self .role_name , role_type = self .role_type ,
213- meta_dict = dict (), ** dict (choice ["message" ]))
247+ BaseMessage (role_name = self .role_name , role_type = self .role_type ,
248+ meta_dict = dict (),
249+ content = choice ["message" ]['content' ])
214250 for choice in response ["choices" ]
215251 ]
216252 info = self .get_info (
0 commit comments