Skip to content

Commit

Permalink
Move tools to agent in agentchat; refactored logging to support too…
Browse files Browse the repository at this point in the history
…l events (#3665)

* Move tool to agent; refactor logging in agentchat

* Update notebook
  • Loading branch information
ekzhu authored Oct 7, 2024
1 parent be5c0b5 commit 54eaa2b
Show file tree
Hide file tree
Showing 11 changed files with 236 additions and 124 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from ._base_chat_agent import (
BaseChatAgent,
BaseToolUseChatAgent,
ChatMessage,
MultiModalMessage,
StopMessage,
Expand All @@ -13,6 +14,7 @@

__all__ = [
"BaseChatAgent",
"BaseToolUseChatAgent",
"ChatMessage",
"TextMessage",
"MultiModalMessage",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from autogen_core.base import CancellationToken
from autogen_core.components import FunctionCall, Image
from autogen_core.components.models import FunctionExecutionResult
from autogen_core.components.tools import Tool
from pydantic import BaseModel


Expand Down Expand Up @@ -49,7 +50,7 @@ class StopMessage(BaseMessage):
"""The content for the stop message."""


ChatMessage = TextMessage | MultiModalMessage | ToolCallMessage | ToolCallResultMessage | StopMessage
ChatMessage = TextMessage | MultiModalMessage | StopMessage | ToolCallMessage | ToolCallResultMessage
"""A message used by agents in a team."""


Expand Down Expand Up @@ -79,3 +80,21 @@ def description(self) -> str:
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
"""Handle incoming messages and return a response message."""
...


class BaseToolUseChatAgent(BaseChatAgent):
"""Base class for a chat agent that can use tools.
Subclass this base class to create an agent class that uses tools by returning
ToolCallMessage message from the :meth:`on_messages` method and receiving
ToolCallResultMessage message from the input to the :meth:`on_messages` method.
"""

def __init__(self, name: str, description: str, registered_tools: List[Tool]) -> None:
super().__init__(name, description)
self._registered_tools = registered_tools

@property
def registered_tools(self) -> List[Tool]:
"""The list of tools that the agent can use."""
return self._registered_tools
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
SystemMessage,
UserMessage,
)
from autogen_core.components.tools import ToolSchema
from autogen_core.components.tools import Tool

from ._base_chat_agent import (
BaseChatAgent,
BaseToolUseChatAgent,
ChatMessage,
MultiModalMessage,
StopMessage,
Expand All @@ -23,7 +23,7 @@
)


class ToolUseAssistantAgent(BaseChatAgent):
class ToolUseAssistantAgent(BaseToolUseChatAgent):
"""An agent that provides assistance with tool use.
It responds with a StopMessage when 'terminate' is detected in the response.
Expand All @@ -33,15 +33,15 @@ def __init__(
self,
name: str,
model_client: ChatCompletionClient,
tool_schema: List[ToolSchema],
registered_tools: List[Tool],
*,
description: str = "An agent that provides assistance with ability to use tools.",
system_message: str = "You are a helpful AI assistant. Solve tasks using your tools. Reply 'TERMINATE' in the end when the task is completed.",
):
super().__init__(name=name, description=description)
super().__init__(name=name, description=description, registered_tools=registered_tools)
self._model_client = model_client
self._system_messages = [SystemMessage(content=system_message)]
self._tool_schema = tool_schema
self._tool_schema = [tool.schema for tool in registered_tools]
self._model_context: List[LLMMessage] = []

async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from typing import Optional
from autogen_core.base import AgentId
from pydantic import BaseModel, ConfigDict

from pydantic import BaseModel

from ..agents import ChatMessage
from ..agents import MultiModalMessage, StopMessage, TextMessage, ToolCallMessage, ToolCallResultMessage


class ContentPublishEvent(BaseModel):
Expand All @@ -11,9 +10,13 @@ class ContentPublishEvent(BaseModel):
content of the event.
"""

agent_message: ChatMessage
agent_message: TextMessage | MultiModalMessage | StopMessage
"""The message published by the agent."""
source: Optional[str] = None

source: AgentId | None = None
"""The agent ID that published the message."""

model_config = ConfigDict(arbitrary_types_allowed=True)


class ContentRequestEvent(BaseModel):
Expand All @@ -22,3 +25,27 @@ class ContentRequestEvent(BaseModel):
"""

...


class ToolCallEvent(BaseModel):
"""An event produced when requesting a tool call."""

agent_message: ToolCallMessage
"""The tool call message."""

source: AgentId
"""The sender of the tool call message."""

model_config = ConfigDict(arbitrary_types_allowed=True)


class ToolCallResultEvent(BaseModel):
"""An event produced when a tool call is completed."""

agent_message: ToolCallResultMessage
"""The tool call result message."""

source: AgentId
"""The sender of the tool call result message."""

model_config = ConfigDict(arbitrary_types_allowed=True)
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,23 @@
import sys
from dataclasses import asdict, is_dataclass
from datetime import datetime
from typing import Any, Dict, List, Sequence, Union
from typing import Any, Dict, List, Union

from autogen_core.base import AgentId
from autogen_core.components import FunctionCall, Image
from autogen_core.components.models import FunctionExecutionResult

from ..agents import ChatMessage, MultiModalMessage, StopMessage, TextMessage, ToolCallMessage, ToolCallResultMessage
from ._events import ContentPublishEvent
from ._events import ContentPublishEvent, ToolCallEvent, ToolCallResultEvent

EVENT_LOGGER_NAME = "autogen_agentchat.events"
ContentType = Union[str, List[Union[str, Image]], List[FunctionCall], List[FunctionExecutionResult]]


class BaseLogHandler(logging.Handler):
def serialize_content(
self, content: Union[ContentType, Sequence[ChatMessage], ChatMessage]
self,
content: Union[ContentType, ChatMessage],
) -> Union[List[Any], Dict[str, Any], str]:
if isinstance(content, (str, list)):
return content
Expand All @@ -41,19 +43,35 @@ def json_serializer(obj: Any) -> Any:


class ConsoleLogHandler(BaseLogHandler):
def _format_message(
self,
*,
source_agent_id: AgentId | None,
message: ChatMessage,
timestamp: str,
) -> str:
body = f"{self.serialize_content(message.content)}\nFrom: {message.source}"
if source_agent_id is None:
console_message = f"\n{'-'*75} \n" f"\033[91m[{timestamp}]:\033[0m\n" f"\n{body}"
else:
# Display the source agent type rather than agent ID for better readability.
# Also in AgentChat the agent type is unique for each agent.
console_message = f"\n{'-'*75} \n" f"\033[91m[{timestamp}], {source_agent_id.type}:\033[0m\n" f"\n{body}"
return console_message

def emit(self, record: logging.LogRecord) -> None:
try:
ts = datetime.fromtimestamp(record.created).isoformat()
if isinstance(record.msg, ContentPublishEvent):
console_message = (
f"\n{'-'*75} \n"
f"\033[91m[{ts}], {record.msg.agent_message.source}:\033[0m\n"
f"\n{self.serialize_content(record.msg.agent_message.content)}"
ts = datetime.fromtimestamp(record.created).isoformat()
if isinstance(record.msg, ContentPublishEvent | ToolCallEvent | ToolCallResultEvent):
sys.stdout.write(
self._format_message(
source_agent_id=record.msg.source,
message=record.msg.agent_message,
timestamp=ts,
)
sys.stdout.write(console_message)
sys.stdout.flush()
except Exception:
self.handleError(record)
)
sys.stdout.flush()
else:
raise ValueError(f"Unexpected log record: {record.msg}")


class FileLogHandler(BaseLogHandler):
Expand All @@ -62,32 +80,37 @@ def __init__(self, filename: str) -> None:
self.filename = filename
self.file_handler = logging.FileHandler(filename)

def emit(self, record: logging.LogRecord) -> None:
try:
ts = datetime.fromtimestamp(record.created).isoformat()
if isinstance(record.msg, ContentPublishEvent):
log_entry = json.dumps(
{
"timestamp": ts,
"source": record.msg.agent_message.source,
"message": self.serialize_content(record.msg.agent_message.content),
"type": "OrchestrationEvent",
},
default=self.json_serializer,
)
def _format_entry(self, *, source: AgentId | None, message: ChatMessage, timestamp: str) -> Dict[str, Any]:
return {
"timestamp": timestamp,
"source": source,
"message": self.serialize_content(message),
"type": "OrchestrationEvent",
}

file_record = logging.LogRecord(
name=record.name,
level=record.levelno,
pathname=record.pathname,
lineno=record.lineno,
msg=log_entry,
args=(),
exc_info=record.exc_info,
)
self.file_handler.emit(file_record)
except Exception:
self.handleError(record)
def emit(self, record: logging.LogRecord) -> None:
ts = datetime.fromtimestamp(record.created).isoformat()
if isinstance(record.msg, ContentPublishEvent | ToolCallEvent | ToolCallResultEvent):
log_entry = json.dumps(
self._format_entry(
source=record.msg.source,
message=record.msg.agent_message,
timestamp=ts,
),
default=self.json_serializer,
)
else:
raise ValueError(f"Unexpected log record: {record.msg}")
file_record = logging.LogRecord(
name=record.name,
level=record.levelno,
pathname=record.pathname,
lineno=record.lineno,
msg=log_entry,
args=(),
exc_info=record.exc_info,
)
self.file_handler.emit(file_record)

def close(self) -> None:
self.file_handler.close()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
from autogen_core.components.tool_agent import ToolException

from ...agents import BaseChatAgent, MultiModalMessage, StopMessage, TextMessage, ToolCallMessage, ToolCallResultMessage
from .._events import ContentPublishEvent, ContentRequestEvent
from .._events import ContentPublishEvent, ContentRequestEvent, ToolCallEvent, ToolCallResultEvent
from .._logging import EVENT_LOGGER_NAME
from ._sequential_routed_agent import SequentialRoutedAgent

event_logger = logging.getLogger(EVENT_LOGGER_NAME)


class BaseChatAgentContainer(SequentialRoutedAgent):
"""A core agent class that delegates message handling to an
Expand All @@ -21,16 +23,15 @@ class BaseChatAgentContainer(SequentialRoutedAgent):
Args:
parent_topic_type (str): The topic type of the parent orchestrator.
agent (BaseChatAgent): The agent to delegate message handling to.
tool_agent_type (AgentType): The agent type of the tool agent to use for tool calls.
tool_agent_type (AgentType, optional): The agent type of the tool agent. Defaults to None.
"""

def __init__(self, parent_topic_type: str, agent: BaseChatAgent, tool_agent_type: AgentType) -> None:
def __init__(self, parent_topic_type: str, agent: BaseChatAgent, tool_agent_type: AgentType | None = None) -> None:
super().__init__(description=agent.description)
self._parent_topic_type = parent_topic_type
self._agent = agent
self._message_buffer: List[TextMessage | MultiModalMessage | StopMessage] = []
self._tool_agent_id = AgentId(type=tool_agent_type, key=self.id.key)
self._logger = self.logger = logging.getLogger(EVENT_LOGGER_NAME)
self._tool_agent_id = AgentId(type=tool_agent_type, key=self.id.key) if tool_agent_type else None

@event
async def handle_content_publish(self, message: ContentPublishEvent, ctx: MessageContext) -> None:
Expand All @@ -48,38 +49,43 @@ async def handle_content_request(self, message: ContentRequestEvent, ctx: Messag
to the delegate agent and publish the response."""
response = await self._agent.on_messages(self._message_buffer, ctx.cancellation_token)

# Handle tool calls.
while isinstance(response, ToolCallMessage):
self._logger.info(ContentPublishEvent(agent_message=response))
if self._tool_agent_id is not None:
# Handle tool calls.
while isinstance(response, ToolCallMessage):
# Log the tool call.
event_logger.info(ToolCallEvent(agent_message=response, source=self.id))

results: List[FunctionExecutionResult | BaseException] = await asyncio.gather(
*[
self.send_message(
message=call,
recipient=self._tool_agent_id,
cancellation_token=ctx.cancellation_token,
)
for call in response.content
]
)
# Combine the results in to a single response and handle exceptions.
function_results: List[FunctionExecutionResult] = []
for result in results:
if isinstance(result, FunctionExecutionResult):
function_results.append(result)
elif isinstance(result, ToolException):
function_results.append(FunctionExecutionResult(content=f"Error: {result}", call_id=result.call_id))
elif isinstance(result, BaseException):
raise result # Unexpected exception.
# Create a new tool call result message.
feedback = ToolCallResultMessage(content=function_results, source=self._tool_agent_id.type)
# TODO: use logging instead of print
self._logger.info(ContentPublishEvent(agent_message=feedback, source=self._tool_agent_id.type))
response = await self._agent.on_messages([feedback], ctx.cancellation_token)
results: List[FunctionExecutionResult | BaseException] = await asyncio.gather(
*[
self.send_message(
message=call,
recipient=self._tool_agent_id,
cancellation_token=ctx.cancellation_token,
)
for call in response.content
]
)
# Combine the results in to a single response and handle exceptions.
function_results: List[FunctionExecutionResult] = []
for result in results:
if isinstance(result, FunctionExecutionResult):
function_results.append(result)
elif isinstance(result, ToolException):
function_results.append(
FunctionExecutionResult(content=f"Error: {result}", call_id=result.call_id)
)
elif isinstance(result, BaseException):
raise result # Unexpected exception.
# Create a new tool call result message.
feedback = ToolCallResultMessage(content=function_results, source=self._tool_agent_id.type)
# Log the feedback.
event_logger.info(ToolCallResultEvent(agent_message=feedback, source=self._tool_agent_id))
response = await self._agent.on_messages([feedback], ctx.cancellation_token)

# Publish the response.
assert isinstance(response, TextMessage | MultiModalMessage | StopMessage)
self._message_buffer.clear()
await self.publish_message(
ContentPublishEvent(agent_message=response), topic_id=DefaultTopicId(type=self._parent_topic_type)
ContentPublishEvent(agent_message=response, source=self.id),
topic_id=DefaultTopicId(type=self._parent_topic_type),
)
Loading

0 comments on commit 54eaa2b

Please sign in to comment.