Skip to content

Commit a77218c

Browse files
authored
Fix mypy errors in step functions of EmbodiedAgent, CriticAgent and Human (camel-ai#192)
1 parent aea5b86 commit a77218c

File tree

13 files changed

+77
-59
lines changed

13 files changed

+77
-59
lines changed

camel/agents/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# limitations under the License.
1313
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
1414
from .base import BaseAgent
15-
from .chat_agent import ChatAgent
15+
from .chat_agent import ChatAgent, ChatAgentResponse
1616
from .task_agent import TaskPlannerAgent, TaskSpecifyAgent
1717
from .critic_agent import CriticAgent
1818
from .tool_agents.base import BaseToolAgent
@@ -22,6 +22,7 @@
2222
__all__ = [
2323
'BaseAgent',
2424
'ChatAgent',
25+
'ChatAgentResponse',
2526
'TaskSpecifyAgent',
2627
'TaskPlannerAgent',
2728
'CriticAgent',

camel/agents/base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,18 @@
1212
# limitations under the License.
1313
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
1414
from abc import ABC, abstractmethod
15+
from typing import Any
1516

1617

1718
class BaseAgent(ABC):
1819
r"""An abstract base class for all CAMEL agents."""
1920

2021
@abstractmethod
21-
def reset(self) -> None:
22+
def reset(self, *args: Any, **kwargs: Any) -> Any:
2223
r"""Resets the agent to its initial state."""
2324
pass
2425

2526
@abstractmethod
26-
def step(self) -> None:
27+
def step(self, *args: Any, **kwargs: Any) -> Any:
2728
r"""Performs a single step of the agent."""
2829
pass

camel/agents/chat_agent.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def __init__(
107107
self.role_name: str = system_message.role_name
108108
self.role_type: RoleType = system_message.role_type
109109
self.output_language: Optional[str] = output_language
110-
if output_language is not None:
110+
if self.output_language is not None:
111111
self.set_output_language(self.output_language)
112112

113113
self.model: ModelType = (model if model is not None else
@@ -320,32 +320,32 @@ def handle_stream_response(
320320
tuple: A tuple of list of output `ChatMessage`, list of
321321
finish reasons, usage dictionary, and response id.
322322
"""
323-
content_dict = defaultdict(lambda: "")
324-
finish_reasons = defaultdict(lambda: "")
323+
content_dict: defaultdict = defaultdict(lambda: "")
324+
finish_reasons_dict: defaultdict = defaultdict(lambda: "")
325325
output_messages: List[BaseMessage] = []
326326
response_id: str = ""
327327
# All choices in one response share one role
328328
role: str = ""
329329
for chunk in response:
330330
response_id = chunk["id"]
331331
for choice in chunk["choices"]:
332-
index = choice["index"]
333-
delta = choice["delta"]
332+
index: int = choice["index"]
333+
delta: Dict = choice["delta"]
334334
if len(delta) != 0:
335335
# When response has not been stopped
336336
# Notice that only the first chunk has the "role"
337337
role = delta.get("role", role)
338338
delta_content = delta.get("content", "")
339339
content_dict[index] += delta_content
340340
else:
341-
finish_reasons[index] = choice["finish_reason"]
341+
finish_reasons_dict[index] = choice["finish_reason"]
342342
chat_message = BaseMessage(role_name=self.role_name,
343343
role_type=self.role_type,
344344
meta_dict=dict(),
345345
content=content_dict[index])
346346
output_messages.append(chat_message)
347347
finish_reasons = [
348-
finish_reasons[i] for i in range(len(finish_reasons))
348+
finish_reasons_dict[i] for i in range(len(finish_reasons_dict))
349349
]
350350
usage_dict = self.get_usage_dict(output_messages, prompt_tokens)
351351
return output_messages, finish_reasons, usage_dict, response_id

camel/agents/critic_agent.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,13 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
14-
import copy
1514
import random
1615
import warnings
1716
from typing import Any, Dict, Optional, Sequence
1817

1918
from colorama import Fore
2019

21-
from camel.agents import ChatAgent
20+
from camel.agents import ChatAgent, ChatAgentResponse
2221
from camel.messages import BaseMessage
2322
from camel.typing import ModelType
2423
from camel.utils import get_first_int, print_text_animated
@@ -141,33 +140,35 @@ def parse_critic(self, critic_msg: BaseMessage) -> Optional[str]:
141140
choice = str(get_first_int(critic_msg.content))
142141
return choice
143142

144-
def step(self, messages: Sequence[BaseMessage]) -> BaseMessage:
143+
def reduce_step(
144+
self,
145+
input_messages: Sequence[BaseMessage],
146+
) -> ChatAgentResponse:
145147
r"""Performs one step of the conversation by flattening options to the
146148
critic, getting the option, and parsing the choice.
147149
148150
Args:
149151
messages (Sequence[BaseMessage]): A list of BaseMessage objects.
150152
151153
Returns:
152-
BaseMessage: A `BaseMessage` object representing the critic's
153-
choice.
154+
ChatAgentResponse: A `ChatAgentResponse` object includes the
155+
critic's choice.
154156
"""
155157
meta_chat_message = BaseMessage(
156-
role_name=messages[0].role_name,
157-
role_type=messages[0].role_type,
158-
meta_dict=messages[0].meta_dict,
158+
role_name=input_messages[0].role_name,
159+
role_type=input_messages[0].role_type,
160+
meta_dict=input_messages[0].meta_dict,
159161
content="",
160162
)
161163

162-
flatten_options = self.flatten_options(messages)
164+
flatten_options = self.flatten_options(input_messages)
163165
if self.verbose:
164166
print_text_animated(self.logger_color +
165167
f"\x1b[3m{flatten_options}\x1b[0m\n")
166-
input_msg = copy.deepcopy(meta_chat_message)
167-
input_msg.content = flatten_options
168+
input_msg = meta_chat_message.create_new_instance(flatten_options)
168169

169170
option = self.get_option(input_msg)
170-
output_msg = copy.deepcopy(meta_chat_message)
171-
output_msg.content = option
171+
output_msg = meta_chat_message.create_new_instance(option)
172172

173-
return output_msg
173+
# TODO: The return `info` can be improved.
174+
return ChatAgentResponse([output_msg], terminated=False, info={})

camel/agents/embodied_agent.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,16 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
14-
from typing import Any, Dict, List, Optional, Tuple
14+
from typing import Any, List, Optional
1515

1616
from colorama import Fore
1717

18-
from camel.agents import BaseToolAgent, ChatAgent, HuggingFaceToolAgent
18+
from camel.agents import (
19+
BaseToolAgent,
20+
ChatAgent,
21+
ChatAgentResponse,
22+
HuggingFaceToolAgent,
23+
)
1924
from camel.messages import BaseMessage
2025
from camel.typing import ModelType
2126
from camel.utils import print_text_animated
@@ -80,16 +85,16 @@ def get_action_space_prompt(self) -> str:
8085
def step(
8186
self,
8287
input_message: BaseMessage,
83-
) -> Tuple[BaseMessage, bool, Dict[str, Any]]:
88+
) -> ChatAgentResponse:
8489
r"""Performs a step in the conversation.
8590
8691
Args:
8792
input_message (BaseMessage): The input message.
8893
8994
Returns:
90-
Tuple[BaseMessage, bool, Dict[str, Any]]: A tuple
91-
containing the output messages, termination status, and
92-
additional information.
95+
ChatAgentResponse: A struct containing the output messages,
96+
a boolean indicating whether the chat session has terminated,
97+
and information about the chat session.
9398
"""
9499
response = super().step(input_message)
95100

@@ -128,4 +133,4 @@ def step(
128133
f"\n> Embodied Actions:\n{content}")
129134
message = BaseMessage(input_message.role_name, input_message.role_type,
130135
input_message.meta_dict, content)
131-
return message, response.terminated, response.info
136+
return ChatAgentResponse([message], response.terminated, response.info)

camel/agents/task_agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def __init__(
5151
model_config: Optional[Any] = None,
5252
task_specify_prompt: Optional[Union[str, TextPrompt]] = None,
5353
word_limit: int = DEFAULT_WORD_LIMIT,
54-
output_language: str = None,
54+
output_language: Optional[str] = None,
5555
) -> None:
5656

5757
if task_specify_prompt is None:
@@ -135,7 +135,7 @@ def __init__(
135135
self,
136136
model: Optional[ModelType] = None,
137137
model_config: Any = None,
138-
output_language: str = None,
138+
output_language: Optional[str] = None,
139139
) -> None:
140140

141141
self.task_planner_prompt = TextPrompt(

camel/human.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from colorama import Fore
1717

18+
from camel.agents import ChatAgentResponse
1819
from camel.messages import BaseMessage
1920
from camel.utils import print_text_animated
2021

@@ -86,36 +87,35 @@ def get_input(self) -> str:
8687

8788
return human_input
8889

89-
def parse_input(self, human_input: str,
90-
meta_chat_message: BaseMessage) -> BaseMessage:
90+
def parse_input(self, human_input: str) -> str:
9191
r"""Parses the user's input and returns a `BaseMessage` object.
9292
9393
Args:
9494
human_input (str): The user's input.
95-
meta_chat_message (BaseMessage): A `BaseMessage` object.
9695
9796
Returns:
98-
BaseMessage: A `BaseMessage` object.
97+
content: A `str` object representing the user's input.
9998
"""
10099
if self.options_dict[human_input] == self.input_button:
101-
meta_chat_message.content = input(self.logger_color +
102-
"Please enter your message: ")
103-
return meta_chat_message
100+
content = input(self.logger_color + "Please enter your message: ")
104101
elif self.options_dict[human_input] == self.kill_button:
105102
exit(self.logger_color + f"Killed by {self.name}.")
106103
else:
107-
meta_chat_message.content = self.options_dict[human_input]
108-
return meta_chat_message
104+
content = self.options_dict[human_input]
109105

110-
def step(self, messages: Sequence[BaseMessage]) -> BaseMessage:
106+
return content
107+
108+
def reduce_step(self,
109+
messages: Sequence[BaseMessage]) -> ChatAgentResponse:
111110
r"""Performs one step of the conversation by displaying options to the
112111
user, getting their input, and parsing their choice.
113112
114113
Args:
115114
messages (Sequence[BaseMessage]): A list of BaseMessage objects.
116115
117116
Returns:
118-
BaseMessage: A `BaseMessage` object representing the user's choice.
117+
ChatAgentResponse: A `ChatAgentResponse` object representing the
118+
user's choice.
119119
"""
120120
meta_chat_message = BaseMessage(
121121
role_name=messages[0].role_name,
@@ -125,4 +125,6 @@ def step(self, messages: Sequence[BaseMessage]) -> BaseMessage:
125125
)
126126
self.display_options(messages)
127127
human_input = self.get_input()
128-
return self.parse_input(human_input, meta_chat_message)
128+
content = self.parse_input(human_input)
129+
message = meta_chat_message.create_new_instance(content)
130+
return ChatAgentResponse([message], terminated=False, info={})

camel/societies/role_playing.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def __init__(
8888
sys_msg_generator_kwargs: Optional[Dict] = None,
8989
extend_sys_msg_meta_dicts: Optional[List[Dict]] = None,
9090
extend_task_specify_meta_dict: Optional[Dict] = None,
91-
output_language: str = None,
91+
output_language: Optional[str] = None,
9292
) -> None:
9393
self.with_task_specify = with_task_specify
9494
self.with_task_planner = with_task_planner
@@ -250,7 +250,8 @@ def reduce_message_options(
250250
raise ValueError("Got than one message to process. "
251251
f"Num of messages: {len(messages)}.")
252252
elif self.with_critic_in_the_loop and self.critic is not None:
253-
processed_msg = self.critic.step(messages)
253+
critic_response = self.critic.reduce_step(messages)
254+
processed_msg = critic_response.msg
254255
else:
255256
processed_msg = messages[0]
256257

examples/embodiment/hugging_face_tool.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ def main():
4646
"caption the image content, "
4747
"save the images by species name."),
4848
)
49-
output_message, _, _ = embodied_agent.step(user_msg)
50-
print(output_message.content)
49+
response = embodied_agent.step(user_msg)
50+
print(response.msg.content)
5151

5252

5353
if __name__ == "__main__":

pyproject.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,9 @@ markers = [
6969

7070
[tool.coverage.report]
7171
include_namespace_packages = true
72+
73+
[[tool.mypy.overrides]]
74+
module = [
75+
"transformers.tools",
76+
]
77+
ignore_missing_imports = true

0 commit comments

Comments
 (0)