From 54eaa2bb4ec63ac4485d42a2bb7a35410192434d Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Mon, 7 Oct 2024 09:38:24 -0700 Subject: [PATCH] Move tools to agent in `agentchat`; refactored logging to support tool events (#3665) * Move tool to agent; refactor logging in agentchat * Update notebook --- .../src/autogen_agentchat/agents/__init__.py | 2 + .../agents/_base_chat_agent.py | 21 +++- .../agents/_tool_use_assistant_agent.py | 12 +-- .../src/autogen_agentchat/teams/_events.py | 39 +++++-- .../src/autogen_agentchat/teams/_logging.py | 101 +++++++++++------- .../group_chat/_base_chat_agent_container.py | 72 +++++++------ .../group_chat/_base_group_chat_manager.py | 7 +- .../group_chat/_round_robin_group_chat.py | 45 +++++--- .../tests/test_group_chat.py | 4 +- .../tests/test_sequential_routed_agent.py | 6 +- .../guides/tool_use.ipynb | 51 +++++---- 11 files changed, 236 insertions(+), 124 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/__init__.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/__init__.py index 3c521ae7ac3..fda137f5f6b 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/__init__.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/__init__.py @@ -1,5 +1,6 @@ from ._base_chat_agent import ( BaseChatAgent, + BaseToolUseChatAgent, ChatMessage, MultiModalMessage, StopMessage, @@ -13,6 +14,7 @@ __all__ = [ "BaseChatAgent", + "BaseToolUseChatAgent", "ChatMessage", "TextMessage", "MultiModalMessage", diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py index 6f90185b3c3..b5274596874 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py @@ -4,6 +4,7 @@ from autogen_core.base import CancellationToken from autogen_core.components import FunctionCall, Image from autogen_core.components.models import FunctionExecutionResult +from autogen_core.components.tools import Tool from pydantic import BaseModel @@ -49,7 +50,7 @@ class StopMessage(BaseMessage): """The content for the stop message.""" -ChatMessage = TextMessage | MultiModalMessage | ToolCallMessage | ToolCallResultMessage | StopMessage +ChatMessage = TextMessage | MultiModalMessage | StopMessage | ToolCallMessage | ToolCallResultMessage """A message used by agents in a team.""" @@ -79,3 +80,21 @@ def description(self) -> str: async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage: """Handle incoming messages and return a response message.""" ... + + +class BaseToolUseChatAgent(BaseChatAgent): + """Base class for a chat agent that can use tools. + + Subclass this base class to create an agent class that uses tools by returning + ToolCallMessage message from the :meth:`on_messages` method and receiving + ToolCallResultMessage message from the input to the :meth:`on_messages` method. + """ + + def __init__(self, name: str, description: str, registered_tools: List[Tool]) -> None: + super().__init__(name, description) + self._registered_tools = registered_tools + + @property + def registered_tools(self) -> List[Tool]: + """The list of tools that the agent can use.""" + return self._registered_tools diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_tool_use_assistant_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_tool_use_assistant_agent.py index 7154e663f2a..7e287b75bd5 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_tool_use_assistant_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_tool_use_assistant_agent.py @@ -10,10 +10,10 @@ SystemMessage, UserMessage, ) -from autogen_core.components.tools import ToolSchema +from autogen_core.components.tools import Tool from ._base_chat_agent import ( - BaseChatAgent, + BaseToolUseChatAgent, ChatMessage, MultiModalMessage, StopMessage, @@ -23,7 +23,7 @@ ) -class ToolUseAssistantAgent(BaseChatAgent): +class ToolUseAssistantAgent(BaseToolUseChatAgent): """An agent that provides assistance with tool use. It responds with a StopMessage when 'terminate' is detected in the response. @@ -33,15 +33,15 @@ def __init__( self, name: str, model_client: ChatCompletionClient, - tool_schema: List[ToolSchema], + registered_tools: List[Tool], *, description: str = "An agent that provides assistance with ability to use tools.", system_message: str = "You are a helpful AI assistant. Solve tasks using your tools. Reply 'TERMINATE' in the end when the task is completed.", ): - super().__init__(name=name, description=description) + super().__init__(name=name, description=description, registered_tools=registered_tools) self._model_client = model_client self._system_messages = [SystemMessage(content=system_message)] - self._tool_schema = tool_schema + self._tool_schema = [tool.schema for tool in registered_tools] self._model_context: List[LLMMessage] = [] async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage: diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_events.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_events.py index 01554fb25d1..ba15fd21a3f 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_events.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_events.py @@ -1,8 +1,7 @@ -from typing import Optional +from autogen_core.base import AgentId +from pydantic import BaseModel, ConfigDict -from pydantic import BaseModel - -from ..agents import ChatMessage +from ..agents import MultiModalMessage, StopMessage, TextMessage, ToolCallMessage, ToolCallResultMessage class ContentPublishEvent(BaseModel): @@ -11,9 +10,13 @@ class ContentPublishEvent(BaseModel): content of the event. """ - agent_message: ChatMessage + agent_message: TextMessage | MultiModalMessage | StopMessage """The message published by the agent.""" - source: Optional[str] = None + + source: AgentId | None = None + """The agent ID that published the message.""" + + model_config = ConfigDict(arbitrary_types_allowed=True) class ContentRequestEvent(BaseModel): @@ -22,3 +25,27 @@ class ContentRequestEvent(BaseModel): """ ... + + +class ToolCallEvent(BaseModel): + """An event produced when requesting a tool call.""" + + agent_message: ToolCallMessage + """The tool call message.""" + + source: AgentId + """The sender of the tool call message.""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + +class ToolCallResultEvent(BaseModel): + """An event produced when a tool call is completed.""" + + agent_message: ToolCallResultMessage + """The tool call result message.""" + + source: AgentId + """The sender of the tool call result message.""" + + model_config = ConfigDict(arbitrary_types_allowed=True) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_logging.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_logging.py index 8e8ff89d2e9..41fba1948f6 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_logging.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_logging.py @@ -3,13 +3,14 @@ import sys from dataclasses import asdict, is_dataclass from datetime import datetime -from typing import Any, Dict, List, Sequence, Union +from typing import Any, Dict, List, Union +from autogen_core.base import AgentId from autogen_core.components import FunctionCall, Image from autogen_core.components.models import FunctionExecutionResult from ..agents import ChatMessage, MultiModalMessage, StopMessage, TextMessage, ToolCallMessage, ToolCallResultMessage -from ._events import ContentPublishEvent +from ._events import ContentPublishEvent, ToolCallEvent, ToolCallResultEvent EVENT_LOGGER_NAME = "autogen_agentchat.events" ContentType = Union[str, List[Union[str, Image]], List[FunctionCall], List[FunctionExecutionResult]] @@ -17,7 +18,8 @@ class BaseLogHandler(logging.Handler): def serialize_content( - self, content: Union[ContentType, Sequence[ChatMessage], ChatMessage] + self, + content: Union[ContentType, ChatMessage], ) -> Union[List[Any], Dict[str, Any], str]: if isinstance(content, (str, list)): return content @@ -41,19 +43,35 @@ def json_serializer(obj: Any) -> Any: class ConsoleLogHandler(BaseLogHandler): + def _format_message( + self, + *, + source_agent_id: AgentId | None, + message: ChatMessage, + timestamp: str, + ) -> str: + body = f"{self.serialize_content(message.content)}\nFrom: {message.source}" + if source_agent_id is None: + console_message = f"\n{'-'*75} \n" f"\033[91m[{timestamp}]:\033[0m\n" f"\n{body}" + else: + # Display the source agent type rather than agent ID for better readability. + # Also in AgentChat the agent type is unique for each agent. + console_message = f"\n{'-'*75} \n" f"\033[91m[{timestamp}], {source_agent_id.type}:\033[0m\n" f"\n{body}" + return console_message + def emit(self, record: logging.LogRecord) -> None: - try: - ts = datetime.fromtimestamp(record.created).isoformat() - if isinstance(record.msg, ContentPublishEvent): - console_message = ( - f"\n{'-'*75} \n" - f"\033[91m[{ts}], {record.msg.agent_message.source}:\033[0m\n" - f"\n{self.serialize_content(record.msg.agent_message.content)}" + ts = datetime.fromtimestamp(record.created).isoformat() + if isinstance(record.msg, ContentPublishEvent | ToolCallEvent | ToolCallResultEvent): + sys.stdout.write( + self._format_message( + source_agent_id=record.msg.source, + message=record.msg.agent_message, + timestamp=ts, ) - sys.stdout.write(console_message) - sys.stdout.flush() - except Exception: - self.handleError(record) + ) + sys.stdout.flush() + else: + raise ValueError(f"Unexpected log record: {record.msg}") class FileLogHandler(BaseLogHandler): @@ -62,32 +80,37 @@ def __init__(self, filename: str) -> None: self.filename = filename self.file_handler = logging.FileHandler(filename) - def emit(self, record: logging.LogRecord) -> None: - try: - ts = datetime.fromtimestamp(record.created).isoformat() - if isinstance(record.msg, ContentPublishEvent): - log_entry = json.dumps( - { - "timestamp": ts, - "source": record.msg.agent_message.source, - "message": self.serialize_content(record.msg.agent_message.content), - "type": "OrchestrationEvent", - }, - default=self.json_serializer, - ) + def _format_entry(self, *, source: AgentId | None, message: ChatMessage, timestamp: str) -> Dict[str, Any]: + return { + "timestamp": timestamp, + "source": source, + "message": self.serialize_content(message), + "type": "OrchestrationEvent", + } - file_record = logging.LogRecord( - name=record.name, - level=record.levelno, - pathname=record.pathname, - lineno=record.lineno, - msg=log_entry, - args=(), - exc_info=record.exc_info, - ) - self.file_handler.emit(file_record) - except Exception: - self.handleError(record) + def emit(self, record: logging.LogRecord) -> None: + ts = datetime.fromtimestamp(record.created).isoformat() + if isinstance(record.msg, ContentPublishEvent | ToolCallEvent | ToolCallResultEvent): + log_entry = json.dumps( + self._format_entry( + source=record.msg.source, + message=record.msg.agent_message, + timestamp=ts, + ), + default=self.json_serializer, + ) + else: + raise ValueError(f"Unexpected log record: {record.msg}") + file_record = logging.LogRecord( + name=record.name, + level=record.levelno, + pathname=record.pathname, + lineno=record.lineno, + msg=log_entry, + args=(), + exc_info=record.exc_info, + ) + self.file_handler.emit(file_record) def close(self) -> None: self.file_handler.close() diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/group_chat/_base_chat_agent_container.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/group_chat/_base_chat_agent_container.py index 409d0e0c9f1..23f6fdc623a 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/group_chat/_base_chat_agent_container.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/group_chat/_base_chat_agent_container.py @@ -8,10 +8,12 @@ from autogen_core.components.tool_agent import ToolException from ...agents import BaseChatAgent, MultiModalMessage, StopMessage, TextMessage, ToolCallMessage, ToolCallResultMessage -from .._events import ContentPublishEvent, ContentRequestEvent +from .._events import ContentPublishEvent, ContentRequestEvent, ToolCallEvent, ToolCallResultEvent from .._logging import EVENT_LOGGER_NAME from ._sequential_routed_agent import SequentialRoutedAgent +event_logger = logging.getLogger(EVENT_LOGGER_NAME) + class BaseChatAgentContainer(SequentialRoutedAgent): """A core agent class that delegates message handling to an @@ -21,16 +23,15 @@ class BaseChatAgentContainer(SequentialRoutedAgent): Args: parent_topic_type (str): The topic type of the parent orchestrator. agent (BaseChatAgent): The agent to delegate message handling to. - tool_agent_type (AgentType): The agent type of the tool agent to use for tool calls. + tool_agent_type (AgentType, optional): The agent type of the tool agent. Defaults to None. """ - def __init__(self, parent_topic_type: str, agent: BaseChatAgent, tool_agent_type: AgentType) -> None: + def __init__(self, parent_topic_type: str, agent: BaseChatAgent, tool_agent_type: AgentType | None = None) -> None: super().__init__(description=agent.description) self._parent_topic_type = parent_topic_type self._agent = agent self._message_buffer: List[TextMessage | MultiModalMessage | StopMessage] = [] - self._tool_agent_id = AgentId(type=tool_agent_type, key=self.id.key) - self._logger = self.logger = logging.getLogger(EVENT_LOGGER_NAME) + self._tool_agent_id = AgentId(type=tool_agent_type, key=self.id.key) if tool_agent_type else None @event async def handle_content_publish(self, message: ContentPublishEvent, ctx: MessageContext) -> None: @@ -48,38 +49,43 @@ async def handle_content_request(self, message: ContentRequestEvent, ctx: Messag to the delegate agent and publish the response.""" response = await self._agent.on_messages(self._message_buffer, ctx.cancellation_token) - # Handle tool calls. - while isinstance(response, ToolCallMessage): - self._logger.info(ContentPublishEvent(agent_message=response)) + if self._tool_agent_id is not None: + # Handle tool calls. + while isinstance(response, ToolCallMessage): + # Log the tool call. + event_logger.info(ToolCallEvent(agent_message=response, source=self.id)) - results: List[FunctionExecutionResult | BaseException] = await asyncio.gather( - *[ - self.send_message( - message=call, - recipient=self._tool_agent_id, - cancellation_token=ctx.cancellation_token, - ) - for call in response.content - ] - ) - # Combine the results in to a single response and handle exceptions. - function_results: List[FunctionExecutionResult] = [] - for result in results: - if isinstance(result, FunctionExecutionResult): - function_results.append(result) - elif isinstance(result, ToolException): - function_results.append(FunctionExecutionResult(content=f"Error: {result}", call_id=result.call_id)) - elif isinstance(result, BaseException): - raise result # Unexpected exception. - # Create a new tool call result message. - feedback = ToolCallResultMessage(content=function_results, source=self._tool_agent_id.type) - # TODO: use logging instead of print - self._logger.info(ContentPublishEvent(agent_message=feedback, source=self._tool_agent_id.type)) - response = await self._agent.on_messages([feedback], ctx.cancellation_token) + results: List[FunctionExecutionResult | BaseException] = await asyncio.gather( + *[ + self.send_message( + message=call, + recipient=self._tool_agent_id, + cancellation_token=ctx.cancellation_token, + ) + for call in response.content + ] + ) + # Combine the results in to a single response and handle exceptions. + function_results: List[FunctionExecutionResult] = [] + for result in results: + if isinstance(result, FunctionExecutionResult): + function_results.append(result) + elif isinstance(result, ToolException): + function_results.append( + FunctionExecutionResult(content=f"Error: {result}", call_id=result.call_id) + ) + elif isinstance(result, BaseException): + raise result # Unexpected exception. + # Create a new tool call result message. + feedback = ToolCallResultMessage(content=function_results, source=self._tool_agent_id.type) + # Log the feedback. + event_logger.info(ToolCallResultEvent(agent_message=feedback, source=self._tool_agent_id)) + response = await self._agent.on_messages([feedback], ctx.cancellation_token) # Publish the response. assert isinstance(response, TextMessage | MultiModalMessage | StopMessage) self._message_buffer.clear() await self.publish_message( - ContentPublishEvent(agent_message=response), topic_id=DefaultTopicId(type=self._parent_topic_type) + ContentPublishEvent(agent_message=response, source=self.id), + topic_id=DefaultTopicId(type=self._parent_topic_type), ) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/group_chat/_base_group_chat_manager.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/group_chat/_base_group_chat_manager.py index 04a61820938..7059e11fdf5 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/group_chat/_base_group_chat_manager.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/group_chat/_base_group_chat_manager.py @@ -9,6 +9,8 @@ from .._logging import EVENT_LOGGER_NAME from ._sequential_routed_agent import SequentialRoutedAgent +event_logger = logging.getLogger(EVENT_LOGGER_NAME) + class BaseGroupChatManager(SequentialRoutedAgent): """Base class for a group chat manager that manages a group chat with multiple participants. @@ -50,7 +52,6 @@ def __init__( self._participant_topic_types = participant_topic_types self._participant_descriptions = participant_descriptions self._message_thread: List[ChatMessage] = [] - self._logger = self.logger = logging.getLogger(EVENT_LOGGER_NAME + ".agentchatchat") @event async def handle_content_publish(self, message: ContentPublishEvent, ctx: MessageContext) -> None: @@ -63,9 +64,7 @@ async def handle_content_publish(self, message: ContentPublishEvent, ctx: Messag assert ctx.topic_id is not None group_chat_topic_id = TopicId(type=self._group_topic_type, source=ctx.topic_id.source) - # TODO: use something else other than print. - - self._logger.info(ContentPublishEvent(agent_message=message.agent_message)) + event_logger.info(message) # Process event from parent. if ctx.topic_id.type == self._parent_topic_type: diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/group_chat/_round_robin_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/group_chat/_round_robin_group_chat.py index 86cb3ed14d2..b58cf28f085 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/group_chat/_round_robin_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/group_chat/_round_robin_group_chat.py @@ -9,7 +9,7 @@ from autogen_agentchat.agents._base_chat_agent import ChatMessage -from ...agents import BaseChatAgent, TextMessage +from ...agents import BaseChatAgent, BaseToolUseChatAgent, TextMessage from .._base_team import BaseTeam, TeamRunResult from .._events import ContentPublishEvent, ContentRequestEvent from ._base_chat_agent_container import BaseChatAgentContainer @@ -37,8 +37,8 @@ class RoundRobinGroupChat(BaseTeam): from autogen_agentchat.agents import ToolUseAssistantAgent from autogen_agentchat.teams import RoundRobinGroupChat - assistant = ToolUseAssistantAgent("Assistant", model_client=..., tool_schema=[...]) - team = RoundRobinGroupChat([assistant], tools=[...]) + assistant = ToolUseAssistantAgent("Assistant", model_client=..., registered_tools=...) + team = RoundRobinGroupChat([assistant]) await team.run("What's the weather in New York?") A team with multiple participants: @@ -55,17 +55,21 @@ class RoundRobinGroupChat(BaseTeam): """ - def __init__(self, participants: List[BaseChatAgent], *, tools: List[Tool] | None = None): + def __init__(self, participants: List[BaseChatAgent]): if len(participants) == 0: raise ValueError("At least one participant is required.") if len(participants) != len(set(participant.name for participant in participants)): raise ValueError("The participant names must be unique.") + for participant in participants: + if isinstance(participant, BaseToolUseChatAgent) and not participant.registered_tools: + raise ValueError( + f"Participant '{participant.name}' is a tool use agent so it must have registered tools." + ) self._participants = participants self._team_id = str(uuid.uuid4()) - self._tools = tools or [] def _create_factory( - self, parent_topic_type: str, agent: BaseChatAgent, tool_agent_type: AgentType + self, parent_topic_type: str, agent: BaseChatAgent, tool_agent_type: AgentType | None ) -> Callable[[], BaseChatAgentContainer]: def _factory() -> BaseChatAgentContainer: id = AgentInstantiationContext.current_agent_id() @@ -76,6 +80,16 @@ def _factory() -> BaseChatAgentContainer: return _factory + def _create_tool_agent_factory( + self, + caller_name: str, + tools: List[Tool], + ) -> Callable[[], ToolAgent]: + def _factory() -> ToolAgent: + return ToolAgent(f"Tool agent for {caller_name}", tools) + + return _factory + async def run(self, task: str) -> TeamRunResult: """Run the team and return the result.""" # Create the runtime. @@ -87,16 +101,23 @@ async def run(self, task: str) -> TeamRunResult: group_topic_type = "round_robin_group_topic" team_topic_type = "team_topic" - # Register the tool agent. - tool_agent_type = await ToolAgent.register( - runtime, "tool_agent", lambda: ToolAgent("Tool agent for round-robin group chat", self._tools) - ) - # No subscriptions are needed for the tool agent, which will be called via direct messages. - # Register participants. participant_topic_types: List[str] = [] participant_descriptions: List[str] = [] for participant in self._participants: + if isinstance(participant, BaseToolUseChatAgent): + assert participant.registered_tools is not None and len(participant.registered_tools) > 0 + # Register the tool agent. + tool_agent_type = await ToolAgent.register( + runtime, + f"tool_agent_for_{participant.name}", + self._create_tool_agent_factory(participant.name, participant.registered_tools), + ) + # No subscriptions are needed for the tool agent, which will be called via direct messages. + else: + # No tool agent is needed. + tool_agent_type = None + # Use the participant name as the agent type and topic type. agent_type = participant.name topic_type = participant.name diff --git a/python/packages/autogen-agentchat/tests/test_group_chat.py b/python/packages/autogen-agentchat/tests/test_group_chat.py index f728d5736c2..7c2800c2b16 100644 --- a/python/packages/autogen-agentchat/tests/test_group_chat.py +++ b/python/packages/autogen-agentchat/tests/test_group_chat.py @@ -171,10 +171,10 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch tool_use_agent = ToolUseAssistantAgent( "tool_use_agent", model_client=OpenAIChatCompletionClient(model=model, api_key=""), - tool_schema=[tool.schema], + registered_tools=[tool], ) echo_agent = _EchoAgent("echo_agent", description="echo agent") - team = RoundRobinGroupChat(participants=[tool_use_agent, echo_agent], tools=[tool]) + team = RoundRobinGroupChat(participants=[tool_use_agent, echo_agent]) await team.run("Write a program that prints 'Hello, world!'") context = tool_use_agent._model_context # pyright: ignore assert context[0].content == "Write a program that prints 'Hello, world!'" diff --git a/python/packages/autogen-agentchat/tests/test_sequential_routed_agent.py b/python/packages/autogen-agentchat/tests/test_sequential_routed_agent.py index 8e8957e273c..912043168a1 100644 --- a/python/packages/autogen-agentchat/tests/test_sequential_routed_agent.py +++ b/python/packages/autogen-agentchat/tests/test_sequential_routed_agent.py @@ -16,7 +16,7 @@ class Message: @default_subscription -class TestAgent(SequentialRoutedAgent): +class _TestAgent(SequentialRoutedAgent): def __init__(self, description: str) -> None: super().__init__(description=description) self.messages: List[Message] = [] @@ -32,11 +32,11 @@ async def handle_content_publish(self, message: Message, ctx: MessageContext) -> async def test_sequential_routed_agent() -> None: runtime = SingleThreadedAgentRuntime() runtime.start() - await TestAgent.register(runtime, type="test_agent", factory=lambda: TestAgent(description="Test Agent")) + await _TestAgent.register(runtime, type="test_agent", factory=lambda: _TestAgent(description="Test Agent")) test_agent_id = AgentId(type="test_agent", key="default") for i in range(100): await runtime.publish_message(Message(content=f"{i}"), topic_id=DefaultTopicId()) await runtime.stop_when_idle() - test_agent = await runtime.try_get_underlying_agent_instance(test_agent_id, TestAgent) + test_agent = await runtime.try_get_underlying_agent_instance(test_agent_id, _TestAgent) for i in range(100): assert test_agent.messages[i].content == f"{i}" diff --git a/python/packages/autogen-core/docs/src/agentchat-user-guide/guides/tool_use.ipynb b/python/packages/autogen-core/docs/src/agentchat-user-guide/guides/tool_use.ipynb index a77b8b6c78e..1c254305282 100644 --- a/python/packages/autogen-core/docs/src/agentchat-user-guide/guides/tool_use.ipynb +++ b/python/packages/autogen-core/docs/src/agentchat-user-guide/guides/tool_use.ipynb @@ -17,7 +17,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -29,7 +29,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -42,28 +42,43 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "--------------------------------------------------------------------------------\n", - "user:\n", + "\n", + "--------------------------------------------------------------------------- \n", + "\u001b[91m[2024-10-04T17:59:55.737430]:\u001b[0m\n", + "\n", "What's the weather in New York?\n", - "--------------------------------------------------------------------------------\n", - "Weather_Assistant:\n", - "[FunctionCall(id='call_I8mFF4D73eoC3hhO81ldmIG3', arguments='{\"city\":\"New York\"}', name='get_weather')]\n", - "--------------------------------------------------------------------------------\n", - "tool_agent:\n", - "[FunctionExecutionResult(content='Sunny', call_id='call_I8mFF4D73eoC3hhO81ldmIG3')]\n", - "--------------------------------------------------------------------------------\n", - "Weather_Assistant:\n", + "From: user" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "--------------------------------------------------------------------------- \n", + "\u001b[91m[2024-10-04T17:59:56.310787], Weather_Assistant:\u001b[0m\n", + "\n", + "[FunctionCall(id='call_zxmdHPEQ1QMd2NwvYUSgxxDV', arguments='{\"city\":\"New York\"}', name='get_weather')]\n", + "From: Weather_Assistant\n", + "--------------------------------------------------------------------------- \n", + "\u001b[91m[2024-10-04T17:59:56.312084], tool_agent_for_Weather_Assistant:\u001b[0m\n", + "\n", + "[FunctionExecutionResult(content='Sunny', call_id='call_zxmdHPEQ1QMd2NwvYUSgxxDV')]\n", + "From: tool_agent_for_Weather_Assistant\n", + "--------------------------------------------------------------------------- \n", + "\u001b[91m[2024-10-04T17:59:56.767874], Weather_Assistant:\u001b[0m\n", + "\n", "The weather in New York is sunny. \n", "\n", "TERMINATE\n", - "TeamRunResult(result='The weather in New York is sunny. \\n\\nTERMINATE')\n" + "From: Weather_Assistant" ] } ], @@ -71,11 +86,11 @@ "assistant = ToolUseAssistantAgent(\n", " \"Weather_Assistant\",\n", " model_client=OpenAIChatCompletionClient(model=\"gpt-4o-mini\"),\n", - " tool_schema=[get_weather_tool.schema],\n", + " registered_tools=[get_weather_tool],\n", ")\n", - "team = RoundRobinGroupChat([assistant], tools=[get_weather_tool])\n", + "team = RoundRobinGroupChat([assistant])\n", "result = await team.run(\"What's the weather in New York?\")\n", - "print(result)" + "# print(result)" ] } ], @@ -95,7 +110,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.9" + "version": "3.12.6" } }, "nbformat": 4,