Skip to content

Introduce streaming tool and support streaming for AgentTool and TeamTool. #6712

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
ModelFamily,
SystemMessage,
)
from autogen_core.tools import BaseTool, FunctionTool, StaticWorkbench, Workbench
from autogen_core.tools import BaseTool, FunctionTool, StaticStreamWorkbench, ToolResult, Workbench
from pydantic import BaseModel
from typing_extensions import Self

Expand Down Expand Up @@ -754,7 +754,7 @@
else:
self._workbench = [workbench]
else:
self._workbench = [StaticWorkbench(self._tools)]
self._workbench = [StaticStreamWorkbench(self._tools)]

if model_context is not None:
self._model_context = model_context
Expand Down Expand Up @@ -1051,18 +1051,44 @@
yield tool_call_msg

# STEP 4B: Execute tool calls
executed_calls_and_results = await asyncio.gather(
*[
cls._execute_tool_call(
tool_call=call,
workbench=workbench,
handoff_tools=handoff_tools,
agent_name=agent_name,
cancellation_token=cancellation_token,
)
for call in model_result.content
]
)
# Use a queue to handle streaming results from tool calls.
stream = asyncio.Queue[BaseAgentEvent | BaseChatMessage | None]()

async def _execute_tool_calls(
function_calls: List[FunctionCall],
) -> List[Tuple[FunctionCall, FunctionExecutionResult]]:
results = await asyncio.gather(
*[
cls._execute_tool_call(
tool_call=call,
workbench=workbench,
handoff_tools=handoff_tools,
agent_name=agent_name,
cancellation_token=cancellation_token,
stream=stream,
)
for call in function_calls
]
)
# Signal the end of streaming by putting None in the queue.
stream.put_nowait(None)
return results

task = asyncio.create_task(_execute_tool_calls(model_result.content))

while True:
event = await stream.get()
if event is None:
# End of streaming, break the loop.
break
if isinstance(event, BaseAgentEvent) or isinstance(event, BaseChatMessage):
yield event
inner_messages.append(event)
else:
raise RuntimeError(f"Unexpected event type: {type(event)}")

Check warning on line 1088 in python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py#L1088

Added line #L1088 was not covered by tests

# Wait for all tool calls to complete.
executed_calls_and_results = await task
exec_results = [result for _, result in executed_calls_and_results]

# Yield ToolCallExecutionEvent
Expand Down Expand Up @@ -1302,6 +1328,7 @@
handoff_tools: List[BaseTool[Any, Any]],
agent_name: str,
cancellation_token: CancellationToken,
stream: asyncio.Queue[BaseAgentEvent | BaseChatMessage | None],
) -> Tuple[FunctionCall, FunctionExecutionResult]:
"""Execute a single tool call and return the result."""
# Load the arguments from the tool call.
Expand Down Expand Up @@ -1339,18 +1366,38 @@
for wb in workbench:
tools = await wb.list_tools()
if any(t["name"] == tool_call.name for t in tools):
result = await wb.call_tool(
name=tool_call.name,
arguments=arguments,
cancellation_token=cancellation_token,
call_id=tool_call.id,
)
if isinstance(wb, StaticStreamWorkbench):
tool_result: ToolResult | None = None
async for event in wb.call_tool_stream(
name=tool_call.name,
arguments=arguments,
cancellation_token=cancellation_token,
call_id=tool_call.id,
):
if isinstance(event, ToolResult):
tool_result = event
elif isinstance(event, BaseAgentEvent) or isinstance(event, BaseChatMessage):
await stream.put(event)
else:
warnings.warn(

Check warning on line 1382 in python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py#L1382

Added line #L1382 was not covered by tests
f"Unexpected event type: {type(event)} in tool call streaming.",
UserWarning,
stacklevel=2,
)
assert isinstance(tool_result, ToolResult), "Tool result should not be None in streaming mode."
else:
tool_result = await wb.call_tool(
name=tool_call.name,
arguments=arguments,
cancellation_token=cancellation_token,
call_id=tool_call.id,
)
return (
tool_call,
FunctionExecutionResult(
content=result.to_text(),
content=tool_result.to_text(),
call_id=tool_call.id,
is_error=result.is_error,
is_error=tool_result.is_error,
name=tool_call.name,
),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ class AgentToolConfig(BaseModel):
"""Configuration for the AgentTool."""

agent: ComponentModel
"""The agent to be used for running the task."""

return_value_as_last_message: bool = False
"""Whether to return the value as the last message of the task result."""


class AgentTool(TaskRunnerTool, Component[AgentToolConfig]):
Expand All @@ -20,6 +24,11 @@ class AgentTool(TaskRunnerTool, Component[AgentToolConfig]):

Args:
agent (BaseChatAgent): The agent to be used for running the task.
return_value_as_last_message (bool): Whether to use the last message content of the task result
as the return value of the tool in :meth:`~autogen_agentchat.tools.TaskRunnerTool.return_value_as_string`.
If set to True, the last message content will be returned as a string.
If set to False, the tool will return all messages in the task result as a string concatenated together,
with each message prefixed by its source (e.g., "writer: ...", "assistant: ...").

Example:

Expand Down Expand Up @@ -57,15 +66,18 @@ async def main() -> None:
component_config_schema = AgentToolConfig
component_provider_override = "autogen_agentchat.tools.AgentTool"

def __init__(self, agent: BaseChatAgent) -> None:
def __init__(self, agent: BaseChatAgent, return_value_as_last_message: bool = False) -> None:
self._agent = agent
super().__init__(agent, agent.name, agent.description)
super().__init__(
agent, agent.name, agent.description, return_value_as_last_message=return_value_as_last_message
)

def _to_config(self) -> AgentToolConfig:
return AgentToolConfig(
agent=self._agent.dump_component(),
return_value_as_last_message=self._return_value_as_last_message,
)

@classmethod
def _from_config(cls, config: AgentToolConfig) -> Self:
return cls(BaseChatAgent.load_component(config.agent))
return cls(BaseChatAgent.load_component(config.agent), config.return_value_as_last_message)
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from abc import ABC
from typing import Annotated, Any, List, Mapping
from typing import Annotated, Any, AsyncGenerator, List, Mapping

from autogen_core import CancellationToken
from autogen_core.tools import BaseTool
from autogen_core.tools import BaseStreamTool
from pydantic import BaseModel

from ..agents import BaseChatAgent
from ..base import TaskResult
from ..messages import BaseChatMessage
from ..messages import BaseAgentEvent, BaseChatMessage
from ..teams import BaseGroupChat


Expand All @@ -17,13 +17,20 @@
task: Annotated[str, "The task to be executed."]


class TaskRunnerTool(BaseTool[TaskRunnerToolArgs, TaskResult], ABC):
class TaskRunnerTool(BaseStreamTool[TaskRunnerToolArgs, BaseAgentEvent | BaseChatMessage, TaskResult], ABC):
"""An base class for tool that can be used to run a task using a team or an agent."""

component_type = "tool"

def __init__(self, task_runner: BaseGroupChat | BaseChatAgent, name: str, description: str) -> None:
def __init__(
self,
task_runner: BaseGroupChat | BaseChatAgent,
name: str,
description: str,
return_value_as_last_message: bool,
) -> None:
self._task_runner = task_runner
self._return_value_as_last_message = return_value_as_last_message
super().__init__(
args_type=TaskRunnerToolArgs,
return_type=TaskResult,
Expand All @@ -32,10 +39,23 @@
)

async def run(self, args: TaskRunnerToolArgs, cancellation_token: CancellationToken) -> TaskResult:
"""Run the task and return the result."""
return await self._task_runner.run(task=args.task, cancellation_token=cancellation_token)

async def run_stream(
self, args: TaskRunnerToolArgs, cancellation_token: CancellationToken
) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | TaskResult, None]:
"""Run the task and yield events or messages as they are produced, the final :class:`TaskResult`
will be yielded at the end."""
async for event in self._task_runner.run_stream(task=args.task, cancellation_token=cancellation_token):
yield event

def return_value_as_string(self, value: TaskResult) -> str:
"""Convert the task result to a string."""
if self._return_value_as_last_message:
if value.messages and isinstance(value.messages[-1], BaseChatMessage):
return value.messages[-1].to_model_text()
raise ValueError("The last message is not a BaseChatMessage.")

Check warning on line 58 in python/packages/autogen-agentchat/src/autogen_agentchat/tools/_task_runner_tool.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-agentchat/src/autogen_agentchat/tools/_task_runner_tool.py#L58

Added line #L58 was not covered by tests
parts: List[str] = []
for message in value.messages:
if isinstance(message, BaseChatMessage):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,13 @@ class TeamToolConfig(BaseModel):
"""Configuration for the TeamTool."""

name: str
"""The name of the tool."""
description: str
"""The name and description of the tool."""
team: ComponentModel
"""The team to be used for running the task."""
return_value_as_last_message: bool = False
"""Whether to return the value as the last message of the task result."""


class TeamTool(TaskRunnerTool, Component[TeamToolConfig]):
Expand All @@ -24,22 +29,92 @@ class TeamTool(TaskRunnerTool, Component[TeamToolConfig]):
team (BaseGroupChat): The team to be used for running the task.
name (str): The name of the tool.
description (str): The description of the tool.
return_value_as_last_message (bool): Whether to use the last message content of the task result
as the return value of the tool in :meth:`~autogen_agentchat.tools.TaskRunnerTool.return_value_as_string`.
If set to True, the last message content will be returned as a string.
If set to False, the tool will return all messages in the task result as a string concatenated together,
with each message prefixed by its source (e.g., "writer: ...", "assistant: ...").

Example:

.. code-block:: python

from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.conditions import SourceMatchTermination
from autogen_agentchat.teams import RoundRobinGroupChat
from autogen_agentchat.tools import TeamTool
from autogen_agentchat.ui import Console
from autogen_ext.models.ollama import OllamaChatCompletionClient


async def main() -> None:
model_client = OllamaChatCompletionClient(model="llama3.2")

writer = AssistantAgent(name="writer", model_client=model_client, system_message="You are a helpful assistant.")
reviewer = AssistantAgent(
name="reviewer", model_client=model_client, system_message="You are a critical reviewer."
)
summarizer = AssistantAgent(
name="summarizer",
model_client=model_client,
system_message="You combine the review and produce a revised response.",
)
team = RoundRobinGroupChat(
[writer, reviewer, summarizer], termination_condition=SourceMatchTermination(sources=["summarizer"])
)

# Create a TeamTool that uses the team to run tasks, returning the last message as the result.
tool = TeamTool(
team=team, name="writing_team", description="A tool for writing tasks.", return_value_as_last_message=True
)

main_agent = AssistantAgent(
name="main_agent",
model_client=model_client,
system_message="You are a helpful assistant that can use the writing tool.",
tools=[tool],
)
# For handling each events manually.
# async for message in main_agent.run_stream(
# task="Write a short story about a robot learning to love.",
# ):
# print(message)
# Use Console to display the messages in a more readable format.
await Console(
main_agent.run_stream(
task="Write a short story about a robot learning to love.",
)
)


if __name__ == "__main__":
import asyncio

asyncio.run(main())
"""

component_config_schema = TeamToolConfig
component_provider_override = "autogen_agentchat.tools.TeamTool"

def __init__(self, team: BaseGroupChat, name: str, description: str) -> None:
def __init__(
self, team: BaseGroupChat, name: str, description: str, return_value_as_last_message: bool = False
) -> None:
self._team = team
super().__init__(team, name, description)
super().__init__(team, name, description, return_value_as_last_message=return_value_as_last_message)

def _to_config(self) -> TeamToolConfig:
return TeamToolConfig(
name=self._name,
description=self._description,
team=self._team.dump_component(),
return_value_as_last_message=self._return_value_as_last_message,
)

@classmethod
def _from_config(cls, config: TeamToolConfig) -> Self:
return cls(BaseGroupChat.load_component(config.team), config.name, config.description)
return cls(
BaseGroupChat.load_component(config.team),
config.name,
config.description,
config.return_value_as_last_message,
)
Loading
Loading