Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor agentchat +implement base chat agent run method #3913

Merged
merged 6 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from ._base_chat_agent import BaseChatAgent, BaseToolUseChatAgent
from ._code_executor_agent import CodeExecutorAgent
from ._coding_assistant_agent import CodingAssistantAgent
from ._tool_use_assistant_agent import ToolUseAssistantAgent

__all__ = [
"BaseChatAgent",
"BaseToolUseChatAgent",
"CodeExecutorAgent",
"CodingAssistantAgent",
"ToolUseAssistantAgent",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
from autogen_core.base import CancellationToken
from autogen_core.components.tools import Tool

from ..base import ChatAgent, TaskResult, TerminationCondition, ToolUseChatAgent
from ..messages import ChatMessage
from ._base_task import TaskResult, TaskRunner
from ..teams import RoundRobinGroupChat


class BaseChatAgent(TaskRunner, ABC):
"""Base class for a chat agent that can participant in a team."""
class BaseChatAgent(ChatAgent, ABC):
"""Base class for a chat agent."""

def __init__(self, name: str, description: str) -> None:
self._name = name
Expand All @@ -36,13 +37,23 @@ async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token:
...

async def run(
self, task: str, *, source: str = "user", cancellation_token: CancellationToken | None = None
self,
task: str,
*,
cancellation_token: CancellationToken | None = None,
termination_condition: TerminationCondition | None = None,
) -> TaskResult:
# TODO: Implement this method.
raise NotImplementedError
"""Run the agent with the given task and return the result."""
group_chat = RoundRobinGroupChat(participants=[self])
ekzhu marked this conversation as resolved.
Show resolved Hide resolved
result = await group_chat.run(
task=task,
cancellation_token=cancellation_token,
termination_condition=termination_condition,
)
return result


class BaseToolUseChatAgent(BaseChatAgent):
class BaseToolUseChatAgent(BaseChatAgent, ToolUseChatAgent):
"""Base class for a chat agent that can use tools.

Subclass this base class to create an agent class that uses tools by returning
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from autogen_core.base import CancellationToken
from autogen_core.components.code_executor import CodeBlock, CodeExecutor, extract_markdown_code_blocks

from ..base import BaseChatAgent
from ..messages import ChatMessage, TextMessage
from ._base_chat_agent import BaseChatAgent


class CodeExecutorAgent(BaseChatAgent):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
UserMessage,
)

from ..base import BaseChatAgent
from ..messages import ChatMessage, MultiModalMessage, StopMessage, TextMessage
from ._base_chat_agent import BaseChatAgent


class CodingAssistantAgent(BaseChatAgent):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
)
from autogen_core.components.tools import FunctionTool, Tool

from ..base import BaseToolUseChatAgent
from ..messages import (
ChatMessage,
MultiModalMessage,
Expand All @@ -21,6 +20,7 @@
ToolCallMessage,
ToolCallResultMessage,
)
from ._base_chat_agent import BaseToolUseChatAgent


class ToolUseAssistantAgent(BaseToolUseChatAgent):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from ._base_chat_agent import BaseChatAgent, BaseToolUseChatAgent
from ._base_task import TaskResult, TaskRunner
from ._base_team import Team
from ._base_termination import TerminatedException, TerminationCondition
from ._chat_agent import ChatAgent, ToolUseChatAgent
from ._task import TaskResult, TaskRunner
from ._team import Team
from ._termination import TerminatedException, TerminationCondition

__all__ = [
"BaseChatAgent",
"BaseToolUseChatAgent",
"ChatAgent",
"ToolUseChatAgent",
"Team",
"TerminatedException",
"TerminationCondition",
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from typing import List, Protocol, Sequence, runtime_checkable

from autogen_core.base import CancellationToken
from autogen_core.components.tools import Tool

from ..messages import ChatMessage
from ._task import TaskResult, TaskRunner
from ._termination import TerminationCondition


@runtime_checkable
class ChatAgent(TaskRunner, Protocol):
"""Protocol for a chat agent."""

@property
def name(self) -> str:
"""The name of the agent. This is used by team to uniquely identify
the agent. It should be unique within the team."""
...

@property
def description(self) -> str:
"""The description of the agent. This is used by team to
make decisions about which agents to use. The description should
describe the agent's capabilities and how to interact with it."""
...

async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
"""Handle incoming messages and return a response message."""
...

async def run(
self,
task: str,
*,
cancellation_token: CancellationToken | None = None,
termination_condition: TerminationCondition | None = None,
) -> TaskResult:
"""Run the agent with the given task and return the result."""
...


@runtime_checkable
class ToolUseChatAgent(ChatAgent, Protocol):
"""Protocol for a chat agent that can use tools."""

@property
def registered_tools(self) -> List[Tool]:
"""The list of tools that the agent can use."""
...
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from dataclasses import dataclass
from typing import Protocol, Sequence

from autogen_core.base import CancellationToken

from ..messages import ChatMessage
from ._termination import TerminationCondition


@dataclass
Expand All @@ -15,6 +18,12 @@ class TaskResult:
class TaskRunner(Protocol):
"""A task runner."""

async def run(self, task: str) -> TaskResult:
async def run(
self,
task: str,
*,
cancellation_token: CancellationToken | None = None,
termination_condition: TerminationCondition | None = None,
) -> TaskResult:
"""Run the task."""
...
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from typing import Protocol

from autogen_core.base import CancellationToken

from ._task import TaskResult, TaskRunner
from ._termination import TerminationCondition


class Team(TaskRunner, Protocol):
async def run(
self,
task: str,
*,
cancellation_token: CancellationToken | None = None,
termination_condition: TerminationCondition | None = None,
) -> TaskResult:
"""Run the team on a given task until the termination condition is met."""
...
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from ._terminations import MaxMessageTermination, StopMessageTermination, TextMentionTermination

__all__ = [
"MaxMessageTermination",
"TextMentionTermination",
"StopMessageTermination",
]
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
from ._group_chat._round_robin_group_chat import RoundRobinGroupChat
from ._group_chat._selector_group_chat import SelectorGroupChat
from ._terminations import MaxMessageTermination, StopMessageTermination, TextMentionTermination

__all__ = [
"MaxMessageTermination",
"TextMentionTermination",
"StopMessageTermination",
"RoundRobinGroupChat",
"SelectorGroupChat",
]
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from autogen_core.components.tool_agent import ToolException

from ... import EVENT_LOGGER_NAME
from ...base import BaseChatAgent
from ...base import ChatAgent
from ...messages import MultiModalMessage, StopMessage, TextMessage, ToolCallMessage, ToolCallResultMessage
from .._events import ContentPublishEvent, ContentRequestEvent, ToolCallEvent, ToolCallResultEvent
from ._sequential_routed_agent import SequentialRoutedAgent
Expand All @@ -27,7 +27,7 @@ class BaseChatAgentContainer(SequentialRoutedAgent):
tool_agent_type (AgentType, optional): The agent type of the tool agent. Defaults to None.
"""

def __init__(self, parent_topic_type: str, agent: BaseChatAgent, tool_agent_type: AgentType | None = None) -> None:
def __init__(self, parent_topic_type: str, agent: ChatAgent, tool_agent_type: AgentType | None = None) -> None:
super().__init__(description=agent.description)
self._parent_topic_type = parent_topic_type
self._agent = agent
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,20 @@
from typing import Callable, List

from autogen_core.application import SingleThreadedAgentRuntime
from autogen_core.base import AgentId, AgentInstantiationContext, AgentRuntime, AgentType, MessageContext, TopicId
from autogen_core.base import (
AgentId,
AgentInstantiationContext,
AgentRuntime,
AgentType,
CancellationToken,
MessageContext,
TopicId,
)
from autogen_core.components import ClosureAgent, TypeSubscription
from autogen_core.components.tool_agent import ToolAgent
from autogen_core.components.tools import Tool

from ...base import BaseChatAgent, BaseToolUseChatAgent, TaskResult, Team, TerminationCondition
from ...base import ChatAgent, TaskResult, Team, TerminationCondition, ToolUseChatAgent
from ...messages import ChatMessage, TextMessage
from .._events import ContentPublishEvent, ContentRequestEvent
from ._base_chat_agent_container import BaseChatAgentContainer
Expand All @@ -22,13 +30,13 @@ class BaseGroupChat(Team, ABC):
create a subclass of :class:`BaseGroupChat` that uses the group chat manager.
"""

def __init__(self, participants: List[BaseChatAgent], group_chat_manager_class: type[BaseGroupChatManager]):
def __init__(self, participants: List[ChatAgent], group_chat_manager_class: type[BaseGroupChatManager]):
if len(participants) == 0:
raise ValueError("At least one participant is required.")
if len(participants) != len(set(participant.name for participant in participants)):
raise ValueError("The participant names must be unique.")
for participant in participants:
if isinstance(participant, BaseToolUseChatAgent) and not participant.registered_tools:
if isinstance(participant, ToolUseChatAgent) and not participant.registered_tools:
raise ValueError(
f"Participant '{participant.name}' is a tool use agent so it must have registered tools."
)
Expand All @@ -47,7 +55,7 @@ def _create_group_chat_manager_factory(
) -> Callable[[], BaseGroupChatManager]: ...

def _create_participant_factory(
self, parent_topic_type: str, agent: BaseChatAgent, tool_agent_type: AgentType | None
self, parent_topic_type: str, agent: ChatAgent, tool_agent_type: AgentType | None
) -> Callable[[], BaseChatAgentContainer]:
def _factory() -> BaseChatAgentContainer:
id = AgentInstantiationContext.current_agent_id()
Expand All @@ -68,7 +76,13 @@ def _factory() -> ToolAgent:

return _factory

async def run(self, task: str, *, termination_condition: TerminationCondition | None = None) -> TaskResult:
async def run(
self,
task: str,
*,
cancellation_token: CancellationToken | None = None,
termination_condition: TerminationCondition | None = None,
) -> TaskResult:
"""Run the team and return the result."""
# Create intervention handler for termination.

Expand All @@ -85,7 +99,7 @@ async def run(self, task: str, *, termination_condition: TerminationCondition |
participant_topic_types: List[str] = []
participant_descriptions: List[str] = []
for participant in self._participants:
if isinstance(participant, BaseToolUseChatAgent):
if isinstance(participant, ToolUseChatAgent):
assert participant.registered_tools is not None and len(participant.registered_tools) > 0
# Register the tool agent.
tool_agent_type = await ToolAgent.register(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Callable, List

from ...base import BaseChatAgent, TerminationCondition
from ...base import ChatAgent, TerminationCondition
from .._events import ContentPublishEvent
from ._base_group_chat import BaseGroupChat
from ._base_group_chat_manager import BaseGroupChatManager
Expand Down Expand Up @@ -73,7 +73,7 @@ class RoundRobinGroupChat(BaseGroupChat):

"""

def __init__(self, participants: List[BaseChatAgent]):
def __init__(self, participants: List[ChatAgent]):
super().__init__(participants, group_chat_manager_class=RoundRobinGroupChatManager)

def _create_group_chat_manager_factory(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from autogen_core.components.models import ChatCompletionClient, SystemMessage

from ... import EVENT_LOGGER_NAME, TRACE_LOGGER_NAME
from ...base import BaseChatAgent, TerminationCondition
from ...base import ChatAgent, TerminationCondition
from ...messages import MultiModalMessage, StopMessage, TextMessage
from .._events import ContentPublishEvent, SelectSpeakerEvent
from ._base_group_chat import BaseGroupChat
Expand Down Expand Up @@ -178,7 +178,7 @@ class SelectorGroupChat(BaseGroupChat):

def __init__(
self,
participants: List[BaseChatAgent],
participants: List[ChatAgent],
model_client: ChatCompletionClient,
*,
selector_prompt: str = """You are in a role play game. The following roles are available:
Expand Down
4 changes: 2 additions & 2 deletions python/packages/autogen-agentchat/tests/test_group_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,17 @@
import pytest
from autogen_agentchat import EVENT_LOGGER_NAME
from autogen_agentchat.agents import (
BaseChatAgent,
CodeExecutorAgent,
CodingAssistantAgent,
ToolUseAssistantAgent,
)
from autogen_agentchat.base import BaseChatAgent
from autogen_agentchat.logging import FileLogHandler
from autogen_agentchat.messages import ChatMessage, StopMessage, TextMessage
from autogen_agentchat.task import StopMessageTermination
from autogen_agentchat.teams import (
RoundRobinGroupChat,
SelectorGroupChat,
StopMessageTermination,
)
from autogen_core.base import CancellationToken
from autogen_core.components import FunctionCall
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
from autogen_agentchat.messages import StopMessage, TextMessage
from autogen_agentchat.teams import MaxMessageTermination, StopMessageTermination, TextMentionTermination
from autogen_agentchat.task import MaxMessageTermination, StopMessageTermination, TextMentionTermination


@pytest.mark.asyncio
Expand Down
Loading
Loading