Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move tools to agent in agentchat; refactored logging to support tool events #3665

Merged
merged 3 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from ._base_chat_agent import (
BaseChatAgent,
BaseToolUseChatAgent,
ChatMessage,
MultiModalMessage,
StopMessage,
Expand All @@ -13,6 +14,7 @@

__all__ = [
"BaseChatAgent",
"BaseToolUseChatAgent",
"ChatMessage",
"TextMessage",
"MultiModalMessage",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from autogen_core.base import CancellationToken
from autogen_core.components import FunctionCall, Image
from autogen_core.components.models import FunctionExecutionResult
from autogen_core.components.tools import Tool
from pydantic import BaseModel


Expand Down Expand Up @@ -49,7 +50,7 @@ class StopMessage(BaseMessage):
"""The content for the stop message."""


ChatMessage = TextMessage | MultiModalMessage | ToolCallMessage | ToolCallResultMessage | StopMessage
ChatMessage = TextMessage | MultiModalMessage | StopMessage | ToolCallMessage | ToolCallResultMessage
"""A message used by agents in a team."""


Expand Down Expand Up @@ -79,3 +80,21 @@ def description(self) -> str:
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
"""Handle incoming messages and return a response message."""
...


class BaseToolUseChatAgent(BaseChatAgent):
"""Base class for a chat agent that can use tools.

Subclass this base class to create an agent class that uses tools by returning
ToolCallMessage message from the :meth:`on_messages` method and receiving
ToolCallResultMessage message from the input to the :meth:`on_messages` method.
"""

def __init__(self, name: str, description: str, registered_tools: List[Tool]) -> None:
super().__init__(name, description)
self._registered_tools = registered_tools

@property
def registered_tools(self) -> List[Tool]:
"""The list of tools that the agent can use."""
return self._registered_tools
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
SystemMessage,
UserMessage,
)
from autogen_core.components.tools import ToolSchema
from autogen_core.components.tools import Tool

from ._base_chat_agent import (
BaseChatAgent,
BaseToolUseChatAgent,
ChatMessage,
MultiModalMessage,
StopMessage,
Expand All @@ -23,7 +23,7 @@
)


class ToolUseAssistantAgent(BaseChatAgent):
class ToolUseAssistantAgent(BaseToolUseChatAgent):
"""An agent that provides assistance with tool use.

It responds with a StopMessage when 'terminate' is detected in the response.
Expand All @@ -33,15 +33,15 @@ def __init__(
self,
name: str,
model_client: ChatCompletionClient,
tool_schema: List[ToolSchema],
registered_tools: List[Tool],
*,
description: str = "An agent that provides assistance with ability to use tools.",
system_message: str = "You are a helpful AI assistant. Solve tasks using your tools. Reply 'TERMINATE' in the end when the task is completed.",
):
super().__init__(name=name, description=description)
super().__init__(name=name, description=description, registered_tools=registered_tools)
self._model_client = model_client
self._system_messages = [SystemMessage(content=system_message)]
self._tool_schema = tool_schema
self._tool_schema = [tool.schema for tool in registered_tools]
self._model_context: List[LLMMessage] = []

async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from typing import Optional
from autogen_core.base import AgentId
from pydantic import BaseModel, ConfigDict

from pydantic import BaseModel

from ..agents import ChatMessage
from ..agents import MultiModalMessage, StopMessage, TextMessage, ToolCallMessage, ToolCallResultMessage


class ContentPublishEvent(BaseModel):
Expand All @@ -11,9 +10,13 @@ class ContentPublishEvent(BaseModel):
content of the event.
"""

agent_message: ChatMessage
agent_message: TextMessage | MultiModalMessage | StopMessage
"""The message published by the agent."""
source: Optional[str] = None

source: AgentId | None = None
"""The agent ID that published the message."""

model_config = ConfigDict(arbitrary_types_allowed=True)


class ContentRequestEvent(BaseModel):
Expand All @@ -22,3 +25,27 @@ class ContentRequestEvent(BaseModel):
"""

...


class ToolCallEvent(BaseModel):
"""An event produced when requesting a tool call."""

agent_message: ToolCallMessage
"""The tool call message."""

source: AgentId
"""The sender of the tool call message."""

model_config = ConfigDict(arbitrary_types_allowed=True)


class ToolCallResultEvent(BaseModel):
"""An event produced when a tool call is completed."""

agent_message: ToolCallResultMessage
"""The tool call result message."""

source: AgentId
"""The sender of the tool call result message."""

model_config = ConfigDict(arbitrary_types_allowed=True)
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,23 @@
import sys
from dataclasses import asdict, is_dataclass
from datetime import datetime
from typing import Any, Dict, List, Sequence, Union
from typing import Any, Dict, List, Union

from autogen_core.base import AgentId
from autogen_core.components import FunctionCall, Image
from autogen_core.components.models import FunctionExecutionResult

from ..agents import ChatMessage, MultiModalMessage, StopMessage, TextMessage, ToolCallMessage, ToolCallResultMessage
from ._events import ContentPublishEvent
from ._events import ContentPublishEvent, ToolCallEvent, ToolCallResultEvent

EVENT_LOGGER_NAME = "autogen_agentchat.events"
ContentType = Union[str, List[Union[str, Image]], List[FunctionCall], List[FunctionExecutionResult]]


class BaseLogHandler(logging.Handler):
def serialize_content(
self, content: Union[ContentType, Sequence[ChatMessage], ChatMessage]
self,
content: Union[ContentType, ChatMessage],
) -> Union[List[Any], Dict[str, Any], str]:
if isinstance(content, (str, list)):
return content
Expand All @@ -41,19 +43,35 @@ def json_serializer(obj: Any) -> Any:


class ConsoleLogHandler(BaseLogHandler):
def _format_message(
self,
*,
source_agent_id: AgentId | None,
message: ChatMessage,
timestamp: str,
) -> str:
body = f"{self.serialize_content(message.content)}\nFrom: {message.source}"
if source_agent_id is None:
console_message = f"\n{'-'*75} \n" f"\033[91m[{timestamp}]:\033[0m\n" f"\n{body}"
else:
# Display the source agent type rather than agent ID for better readability.
# Also in AgentChat the agent type is unique for each agent.
console_message = f"\n{'-'*75} \n" f"\033[91m[{timestamp}], {source_agent_id.type}:\033[0m\n" f"\n{body}"
return console_message

def emit(self, record: logging.LogRecord) -> None:
try:
ts = datetime.fromtimestamp(record.created).isoformat()
if isinstance(record.msg, ContentPublishEvent):
console_message = (
f"\n{'-'*75} \n"
f"\033[91m[{ts}], {record.msg.agent_message.source}:\033[0m\n"
f"\n{self.serialize_content(record.msg.agent_message.content)}"
ts = datetime.fromtimestamp(record.created).isoformat()
if isinstance(record.msg, ContentPublishEvent | ToolCallEvent | ToolCallResultEvent):
sys.stdout.write(
self._format_message(
source_agent_id=record.msg.source,
message=record.msg.agent_message,
timestamp=ts,
)
sys.stdout.write(console_message)
sys.stdout.flush()
except Exception:
self.handleError(record)
)
sys.stdout.flush()
else:
raise ValueError(f"Unexpected log record: {record.msg}")


class FileLogHandler(BaseLogHandler):
Expand All @@ -62,32 +80,37 @@ def __init__(self, filename: str) -> None:
self.filename = filename
self.file_handler = logging.FileHandler(filename)

def emit(self, record: logging.LogRecord) -> None:
try:
ts = datetime.fromtimestamp(record.created).isoformat()
if isinstance(record.msg, ContentPublishEvent):
log_entry = json.dumps(
{
"timestamp": ts,
"source": record.msg.agent_message.source,
"message": self.serialize_content(record.msg.agent_message.content),
"type": "OrchestrationEvent",
},
default=self.json_serializer,
)
def _format_entry(self, *, source: AgentId | None, message: ChatMessage, timestamp: str) -> Dict[str, Any]:
return {
"timestamp": timestamp,
"source": source,
"message": self.serialize_content(message),
"type": "OrchestrationEvent",
}

file_record = logging.LogRecord(
name=record.name,
level=record.levelno,
pathname=record.pathname,
lineno=record.lineno,
msg=log_entry,
args=(),
exc_info=record.exc_info,
)
self.file_handler.emit(file_record)
except Exception:
self.handleError(record)
def emit(self, record: logging.LogRecord) -> None:
ts = datetime.fromtimestamp(record.created).isoformat()
if isinstance(record.msg, ContentPublishEvent | ToolCallEvent | ToolCallResultEvent):
log_entry = json.dumps(
self._format_entry(
source=record.msg.source,
message=record.msg.agent_message,
timestamp=ts,
),
default=self.json_serializer,
)
else:
raise ValueError(f"Unexpected log record: {record.msg}")
file_record = logging.LogRecord(
name=record.name,
level=record.levelno,
pathname=record.pathname,
lineno=record.lineno,
msg=log_entry,
args=(),
exc_info=record.exc_info,
)
self.file_handler.emit(file_record)

def close(self) -> None:
self.file_handler.close()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
from autogen_core.components.tool_agent import ToolException

from ...agents import BaseChatAgent, MultiModalMessage, StopMessage, TextMessage, ToolCallMessage, ToolCallResultMessage
from .._events import ContentPublishEvent, ContentRequestEvent
from .._events import ContentPublishEvent, ContentRequestEvent, ToolCallEvent, ToolCallResultEvent
from .._logging import EVENT_LOGGER_NAME
from ._sequential_routed_agent import SequentialRoutedAgent

event_logger = logging.getLogger(EVENT_LOGGER_NAME)


class BaseChatAgentContainer(SequentialRoutedAgent):
"""A core agent class that delegates message handling to an
Expand All @@ -21,16 +23,15 @@ class BaseChatAgentContainer(SequentialRoutedAgent):
Args:
parent_topic_type (str): The topic type of the parent orchestrator.
agent (BaseChatAgent): The agent to delegate message handling to.
tool_agent_type (AgentType): The agent type of the tool agent to use for tool calls.
tool_agent_type (AgentType, optional): The agent type of the tool agent. Defaults to None.
"""

def __init__(self, parent_topic_type: str, agent: BaseChatAgent, tool_agent_type: AgentType) -> None:
def __init__(self, parent_topic_type: str, agent: BaseChatAgent, tool_agent_type: AgentType | None = None) -> None:
super().__init__(description=agent.description)
self._parent_topic_type = parent_topic_type
self._agent = agent
self._message_buffer: List[TextMessage | MultiModalMessage | StopMessage] = []
self._tool_agent_id = AgentId(type=tool_agent_type, key=self.id.key)
self._logger = self.logger = logging.getLogger(EVENT_LOGGER_NAME)
self._tool_agent_id = AgentId(type=tool_agent_type, key=self.id.key) if tool_agent_type else None

@event
async def handle_content_publish(self, message: ContentPublishEvent, ctx: MessageContext) -> None:
Expand All @@ -48,38 +49,43 @@ async def handle_content_request(self, message: ContentRequestEvent, ctx: Messag
to the delegate agent and publish the response."""
response = await self._agent.on_messages(self._message_buffer, ctx.cancellation_token)

# Handle tool calls.
while isinstance(response, ToolCallMessage):
self._logger.info(ContentPublishEvent(agent_message=response))
if self._tool_agent_id is not None:
# Handle tool calls.
while isinstance(response, ToolCallMessage):
# Log the tool call.
event_logger.info(ToolCallEvent(agent_message=response, source=self.id))

results: List[FunctionExecutionResult | BaseException] = await asyncio.gather(
*[
self.send_message(
message=call,
recipient=self._tool_agent_id,
cancellation_token=ctx.cancellation_token,
)
for call in response.content
]
)
# Combine the results in to a single response and handle exceptions.
function_results: List[FunctionExecutionResult] = []
for result in results:
if isinstance(result, FunctionExecutionResult):
function_results.append(result)
elif isinstance(result, ToolException):
function_results.append(FunctionExecutionResult(content=f"Error: {result}", call_id=result.call_id))
elif isinstance(result, BaseException):
raise result # Unexpected exception.
# Create a new tool call result message.
feedback = ToolCallResultMessage(content=function_results, source=self._tool_agent_id.type)
# TODO: use logging instead of print
self._logger.info(ContentPublishEvent(agent_message=feedback, source=self._tool_agent_id.type))
response = await self._agent.on_messages([feedback], ctx.cancellation_token)
results: List[FunctionExecutionResult | BaseException] = await asyncio.gather(
*[
self.send_message(
message=call,
recipient=self._tool_agent_id,
cancellation_token=ctx.cancellation_token,
)
for call in response.content
]
)
# Combine the results in to a single response and handle exceptions.
function_results: List[FunctionExecutionResult] = []
for result in results:
if isinstance(result, FunctionExecutionResult):
function_results.append(result)
elif isinstance(result, ToolException):
function_results.append(
FunctionExecutionResult(content=f"Error: {result}", call_id=result.call_id)
)
elif isinstance(result, BaseException):
raise result # Unexpected exception.
# Create a new tool call result message.
feedback = ToolCallResultMessage(content=function_results, source=self._tool_agent_id.type)
# Log the feedback.
event_logger.info(ToolCallResultEvent(agent_message=feedback, source=self._tool_agent_id))
response = await self._agent.on_messages([feedback], ctx.cancellation_token)

# Publish the response.
assert isinstance(response, TextMessage | MultiModalMessage | StopMessage)
self._message_buffer.clear()
await self.publish_message(
ContentPublishEvent(agent_message=response), topic_id=DefaultTopicId(type=self._parent_topic_type)
ContentPublishEvent(agent_message=response, source=self.id),
topic_id=DefaultTopicId(type=self._parent_topic_type),
)
Loading
Loading