Skip to content

Commit 16c1ac3

Browse files
Removed all derived classes of BaseMessage and its role field (camel-ai#177)
1 parent cde10f8 commit 16c1ac3

33 files changed

+272
-526
lines changed

apps/agents/agents.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
from apps.agents.text_utils import split_markdown_code
3131
from camel.agents import TaskSpecifyAgent
32-
from camel.messages import AssistantChatMessage
32+
from camel.messages import BaseMessage
3333
from camel.societies import RolePlaying
3434

3535
REPO_ROOT = os.path.realpath(
@@ -43,17 +43,16 @@ class State:
4343
session: Optional[RolePlaying]
4444
max_messages: int
4545
chat: ChatBotHistory
46-
saved_assistant_msg: Optional[AssistantChatMessage]
46+
saved_assistant_msg: Optional[BaseMessage]
4747

4848
@classmethod
4949
def empty(cls) -> 'State':
5050
return cls(None, 0, [], None)
5151

5252
@staticmethod
53-
def construct_inplace(
54-
state: 'State', session: Optional[RolePlaying], max_messages: int,
55-
chat: ChatBotHistory,
56-
saved_assistant_msg: Optional[AssistantChatMessage]) -> None:
53+
def construct_inplace(state: 'State', session: Optional[RolePlaying],
54+
max_messages: int, chat: ChatBotHistory,
55+
saved_assistant_msg: Optional[BaseMessage]) -> None:
5756
state.session = session
5857
state.max_messages = max_messages
5958
state.chat = chat
@@ -216,7 +215,7 @@ def role_playing_chat_init(state) -> \
216215

217216
try:
218217
init_assistant_msg, _ = session.init_chat()
219-
init_assistant_msg: AssistantChatMessage
218+
init_assistant_msg: BaseMessage
220219
except (openai.error.RateLimitError, tenacity.RetryError,
221220
RuntimeError) as ex:
222221
print("OpenAI API exception 1 " + str(ex))

camel/agents/chat_agent.py

Lines changed: 60 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from camel.agents import BaseAgent
2222
from camel.configs import ChatGPTConfig
23-
from camel.messages import ChatMessage, MessageType, SystemMessage
23+
from camel.messages import BaseMessage
2424
from camel.models import BaseModelBackend, ModelFactory
2525
from camel.typing import ModelType, RoleType
2626
from camel.utils import num_tokens_from_messages, openai_api_key_required
@@ -31,15 +31,15 @@ class ChatAgentResponse:
3131
r"""Response of a ChatAgent.
3232
3333
Attributes:
34-
msgs (List[ChatMessage]): A list of zero, one or several messages.
34+
msgs (List[BaseMessage]): A list of zero, one or several messages.
3535
If the list is empty, there is some error in message generation.
3636
If the list has one message, this is normal mode.
3737
If the list has several messages, this is the critic mode.
3838
terminated (bool): A boolean indicating whether the agent decided
3939
to terminate the chat session.
4040
info (Dict[str, Any]): Extra information about the chat message.
4141
"""
42-
msgs: List[ChatMessage]
42+
msgs: List[BaseMessage]
4343
terminated: bool
4444
info: Dict[str, Any]
4545

@@ -51,11 +51,32 @@ def msg(self):
5151
return self.msgs[0]
5252

5353

54+
@dataclass(frozen=True)
55+
class ChatRecord:
56+
r"""Historical records of who made what message.
57+
58+
Attributes:
59+
role_at_backend (str): Role of the message that mirrors OpenAI
60+
message role that may be `system` or `user` or `assistant`.
61+
message (BaseMessage): Message payload.
62+
"""
63+
role_at_backend: str
64+
message: BaseMessage
65+
66+
def to_openai_message(self):
67+
r"""Converts the payload message to OpenAI-compatible format.
68+
69+
Returns:
70+
OpenAIMessage: OpenAI-compatible message
71+
"""
72+
return self.message.to_openai_message(self.role_at_backend)
73+
74+
5475
class ChatAgent(BaseAgent):
5576
r"""Class for managing conversations of CAMEL Chat Agents.
5677
5778
Args:
58-
system_message (SystemMessage): The system message for the chat agent.
79+
system_message (BaseMessage): The system message for the chat agent.
5980
model (ModelType, optional): The LLM model to use for generating
6081
responses. (default :obj:`ModelType.GPT_3_5_TURBO`)
6182
model_config (Any, optional): Configuration options for the LLM model.
@@ -69,14 +90,14 @@ class ChatAgent(BaseAgent):
6990

7091
def __init__(
7192
self,
72-
system_message: SystemMessage,
93+
system_message: BaseMessage,
7394
model: Optional[ModelType] = None,
7495
model_config: Optional[Any] = None,
7596
message_window_size: Optional[int] = None,
7697
output_language: Optional[str] = None,
7798
) -> None:
7899

79-
self.system_message: SystemMessage = system_message
100+
self.system_message: BaseMessage = system_message
80101
self.role_name: str = system_message.role_name
81102
self.role_type: RoleType = system_message.role_type
82103
self.output_language: Optional[str] = output_language
@@ -93,20 +114,21 @@ def __init__(
93114
self.model_token_limit: int = self.model_backend.token_limit
94115

95116
self.terminated: bool = False
117+
self.stored_messages: List[ChatRecord]
96118
self.init_messages()
97119

98-
def reset(self) -> List[MessageType]:
120+
def reset(self) -> List[ChatRecord]:
99121
r"""Resets the :obj:`ChatAgent` to its initial state and returns the
100122
stored messages.
101123
102124
Returns:
103-
List[MessageType]: The stored messages.
125+
List[BaseMessage]: The stored messages.
104126
"""
105127
self.terminated = False
106128
self.init_messages()
107129
return self.stored_messages
108130

109-
def set_output_language(self, output_language: str) -> SystemMessage:
131+
def set_output_language(self, output_language: str) -> BaseMessage:
110132
r"""Sets the output language for the system message. This method
111133
updates the output language for the system message. The output
112134
language determines the language in which the output text should be
@@ -116,7 +138,7 @@ def set_output_language(self, output_language: str) -> SystemMessage:
116138
output_language (str): The desired output language.
117139
118140
Returns:
119-
SystemMessage: The updated system message object.
141+
BaseMessage: The updated system message object.
120142
"""
121143
self.output_language = output_language
122144
content = (self.system_message.content +
@@ -156,32 +178,46 @@ def init_messages(self) -> None:
156178
r"""Initializes the stored messages list with the initial system
157179
message.
158180
"""
159-
self.stored_messages: List[MessageType] = [self.system_message]
181+
self.stored_messages = [ChatRecord('system', self.system_message)]
160182

161-
def update_messages(self, message: ChatMessage) -> List[MessageType]:
183+
def update_messages(self, role: str,
184+
message: BaseMessage) -> List[ChatRecord]:
162185
r"""Updates the stored messages list with a new message.
163186
164187
Args:
165-
message (ChatMessage): The new message to add to the stored
188+
message (BaseMessage): The new message to add to the stored
166189
messages.
167190
168191
Returns:
169-
List[ChatMessage]: The updated stored messages.
192+
List[BaseMessage]: The updated stored messages.
170193
"""
171-
self.stored_messages.append(message)
194+
if role not in {'system', 'user', 'assistant'}:
195+
raise ValueError(f"Unsupported role {role}")
196+
self.stored_messages.append(ChatRecord(role, message))
172197
return self.stored_messages
173198

199+
def submit_message(self, message: BaseMessage) -> None:
200+
r"""Submits the externaly provided message as if it were an answer of
201+
the chat LLM from the backend. Currently the choise of the critic is
202+
submitted with this method.
203+
204+
Args:
205+
message (BaseMessage): An external message to be added as an
206+
assistant response.
207+
"""
208+
self.stored_messages.append(ChatRecord('assistant', message))
209+
174210
@retry(wait=wait_exponential(min=5, max=60), stop=stop_after_attempt(5))
175211
@openai_api_key_required
176212
def step(
177213
self,
178-
input_message: ChatMessage,
214+
input_message: BaseMessage,
179215
) -> ChatAgentResponse:
180216
r"""Performs a single step in the chat session by generating a response
181217
to the input message.
182218
183219
Args:
184-
input_message (ChatMessage): The input message to the agent.
220+
input_message (BaseMessage): The input message to the agent.
185221
Its `role` field that specifies the role at backen may be either
186222
`user` or `assistant` but it will be set to `user` anyway since
187223
for the self agent any incoming message is external.
@@ -192,25 +228,25 @@ def step(
192228
the chat session has terminated, and information about the chat
193229
session.
194230
"""
195-
msg_user_at_backend = input_message.set_user_role_at_backend()
196-
messages = self.update_messages(msg_user_at_backend)
231+
messages = self.update_messages('user', input_message)
197232
if self.message_window_size is not None and len(
198233
messages) > self.message_window_size:
199-
messages = [self.system_message
234+
messages = [ChatRecord('system', self.system_message)
200235
] + messages[-self.message_window_size:]
201-
openai_messages = [message.to_openai_message() for message in messages]
236+
openai_messages = [record.to_openai_message() for record in messages]
202237
num_tokens = num_tokens_from_messages(openai_messages, self.model)
203238

204-
output_messages: Optional[List[ChatMessage]]
239+
output_messages: Optional[List[BaseMessage]]
205240
info: Dict[str, Any]
206241

207242
if num_tokens < self.model_token_limit:
208243
response = self.model_backend.run(openai_messages)
209244
if not isinstance(response, dict):
210245
raise RuntimeError("OpenAI returned unexpected struct")
211246
output_messages = [
212-
ChatMessage(role_name=self.role_name, role_type=self.role_type,
213-
meta_dict=dict(), **dict(choice["message"]))
247+
BaseMessage(role_name=self.role_name, role_type=self.role_type,
248+
meta_dict=dict(),
249+
content=choice["message"]['content'])
214250
for choice in response["choices"]
215251
]
216252
info = self.get_info(

camel/agents/critic_agent.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from colorama import Fore
2020

2121
from camel.agents import ChatAgent
22-
from camel.messages import ChatMessage, SystemMessage
22+
from camel.messages import BaseMessage
2323
from camel.typing import ModelType
2424
from camel.utils import get_first_int, print_text_animated
2525

@@ -28,7 +28,7 @@ class CriticAgent(ChatAgent):
2828
r"""A class for the critic agent that assists in selecting an option.
2929
3030
Args:
31-
system_message (SystemMessage): The system message for the critic
31+
system_message (BaseMessage): The system message for the critic
3232
agent.
3333
model (ModelType, optional): The LLM model to use for generating
3434
responses. (default :obj:`ModelType.GPT_3_5_TURBO`)
@@ -46,7 +46,7 @@ class CriticAgent(ChatAgent):
4646

4747
def __init__(
4848
self,
49-
system_message: SystemMessage,
49+
system_message: BaseMessage,
5050
model: ModelType = ModelType.GPT_3_5_TURBO,
5151
model_config: Optional[Any] = None,
5252
message_window_size: int = 6,
@@ -61,11 +61,11 @@ def __init__(
6161
self.verbose = verbose
6262
self.logger_color = logger_color
6363

64-
def flatten_options(self, messages: Sequence[ChatMessage]) -> str:
64+
def flatten_options(self, messages: Sequence[BaseMessage]) -> str:
6565
r"""Flattens the options to the critic.
6666
6767
Args:
68-
messages (Sequence[ChatMessage]): A list of `ChatMessage` objects.
68+
messages (Sequence[BaseMessage]): A list of `BaseMessage` objects.
6969
7070
Returns:
7171
str: A string containing the flattened options to the critic.
@@ -83,11 +83,11 @@ def flatten_options(self, messages: Sequence[ChatMessage]) -> str:
8383
"and then your explanation and comparison: ")
8484
return flatten_options + format
8585

86-
def get_option(self, input_message: ChatMessage) -> str:
86+
def get_option(self, input_message: BaseMessage) -> str:
8787
r"""Gets the option selected by the critic.
8888
8989
Args:
90-
input_message (ChatMessage): A `ChatMessage` object representing
90+
input_message (BaseMessage): A `BaseMessage` object representing
9191
the input message.
9292
9393
Returns:
@@ -104,8 +104,8 @@ def get_option(self, input_message: ChatMessage) -> str:
104104
if critic_response.terminated:
105105
raise RuntimeError("Critic step failed.")
106106

107-
critic_msg = critic_response.msgs[0]
108-
self.update_messages(critic_msg)
107+
critic_msg = critic_response.msg
108+
self.update_messages('assistant', critic_msg)
109109
if self.verbose:
110110
print_text_animated(self.logger_color + "\n> Critic response: "
111111
f"\x1b[3m{critic_msg.content}\x1b[0m\n")
@@ -114,11 +114,10 @@ def get_option(self, input_message: ChatMessage) -> str:
114114
if choice in self.options_dict:
115115
return self.options_dict[choice]
116116
else:
117-
input_message = ChatMessage(
117+
input_message = BaseMessage(
118118
role_name=input_message.role_name,
119119
role_type=input_message.role_type,
120120
meta_dict=input_message.meta_dict,
121-
role=input_message.role,
122121
content="> Invalid choice. Please choose again.\n" +
123122
msg_content,
124123
)
@@ -128,11 +127,11 @@ def get_option(self, input_message: ChatMessage) -> str:
128127
"Returning a random option.")
129128
return random.choice(list(self.options_dict.values()))
130129

131-
def parse_critic(self, critic_msg: ChatMessage) -> Optional[str]:
130+
def parse_critic(self, critic_msg: BaseMessage) -> Optional[str]:
132131
r"""Parses the critic's message and extracts the choice.
133132
134133
Args:
135-
critic_msg (ChatMessage): A `ChatMessage` object representing the
134+
critic_msg (BaseMessage): A `BaseMessage` object representing the
136135
critic's response.
137136
138137
Returns:
@@ -142,22 +141,21 @@ def parse_critic(self, critic_msg: ChatMessage) -> Optional[str]:
142141
choice = str(get_first_int(critic_msg.content))
143142
return choice
144143

145-
def step(self, messages: Sequence[ChatMessage]) -> ChatMessage:
144+
def step(self, messages: Sequence[BaseMessage]) -> BaseMessage:
146145
r"""Performs one step of the conversation by flattening options to the
147146
critic, getting the option, and parsing the choice.
148147
149148
Args:
150-
messages (Sequence[ChatMessage]): A list of ChatMessage objects.
149+
messages (Sequence[BaseMessage]): A list of BaseMessage objects.
151150
152151
Returns:
153-
ChatMessage: A `ChatMessage` object representing the critic's
152+
BaseMessage: A `BaseMessage` object representing the critic's
154153
choice.
155154
"""
156-
meta_chat_message = ChatMessage(
155+
meta_chat_message = BaseMessage(
157156
role_name=messages[0].role_name,
158157
role_type=messages[0].role_type,
159158
meta_dict=messages[0].meta_dict,
160-
role=messages[0].role,
161159
content="",
162160
)
163161

0 commit comments

Comments
 (0)