Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MagenticOneGroupChat to AGS #4595

Merged
merged 8 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,17 @@

from .... import TRACE_LOGGER_NAME
from ....base import Response, TerminationCondition
from ....messages import AgentMessage, ChatMessage, MultiModalMessage, StopMessage, TextMessage
from ....messages import (
AgentMessage,
ChatMessage,
HandoffMessage,
MultiModalMessage,
StopMessage,
TextMessage,
ToolCallMessage,
ToolCallResultMessage,
)

from ....state import MagenticOneOrchestratorState
from .._base_group_chat_manager import BaseGroupChatManager
from .._events import (
Expand Down Expand Up @@ -418,7 +428,12 @@ def _thread_to_context(self) -> List[LLMMessage]:
"""Convert the message thread to a context for the model."""
context: List[LLMMessage] = []
for m in self._message_thread:
if m.source == self._name:
if isinstance(m, ToolCallMessage | ToolCallResultMessage):
# Ignore tool call messages.
continue
elif isinstance(m, StopMessage | HandoffMessage):
context.append(UserMessage(content=m.content, source=m.source))
elif m.source == self._name:
assert isinstance(m, TextMessage)
context.append(AssistantMessage(content=m.content, source=m.source))
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
import yaml
from autogen_agentchat.agents import AssistantAgent, UserProxyAgent
from autogen_agentchat.conditions import MaxMessageTermination, StopMessageTermination, TextMentionTermination
from autogen_agentchat.teams import RoundRobinGroupChat, SelectorGroupChat
from autogen_agentchat.teams import RoundRobinGroupChat, SelectorGroupChat, MagenticOneGroupChat
from autogen_core.components.tools import FunctionTool
from autogen_ext.agents.web_surfer import MultimodalWebSurfer
from autogen_ext.agents.file_surfer import FileSurfer
from autogen_ext.agents.magentic_one import MagenticOneCoderAgent
from autogen_ext.models import OpenAIChatCompletionClient

from ..datamodel.types import (
Expand All @@ -32,8 +34,8 @@

logger = logging.getLogger(__name__)

TeamComponent = Union[RoundRobinGroupChat, SelectorGroupChat]
AgentComponent = Union[AssistantAgent, MultimodalWebSurfer]
TeamComponent = Union[RoundRobinGroupChat, SelectorGroupChat, MagenticOneGroupChat]
AgentComponent = Union[AssistantAgent, MultimodalWebSurfer, UserProxyAgent, FileSurfer, MagenticOneCoderAgent]
ModelComponent = Union[OpenAIChatCompletionClient]
ToolComponent = Union[FunctionTool] # Will grow with more tool types
TerminationComponent = Union[MaxMessageTermination, StopMessageTermination, TextMentionTermination]
Expand Down Expand Up @@ -243,6 +245,15 @@ async def load_team(self, config: TeamConfig, input_func: Optional[Callable] = N
termination_condition=termination,
selector_prompt=selector_prompt,
)
elif config.team_type == TeamTypes.MAGENTIC_ONE:
if not model_client:
raise ValueError("MagenticOneGroupChat requires a model_client")
return MagenticOneGroupChat(
participants=participants,
model_client=model_client,
termination_condition=termination if termination is not None else None,
max_turns=config.max_turns if config.max_turns is not None else 20,
)
else:
raise ValueError(f"Unsupported team type: {config.team_type}")

Expand Down Expand Up @@ -292,7 +303,16 @@ async def load_agent(self, config: AgentConfig, input_func: Optional[Callable] =
use_ocr=config.use_ocr if config.use_ocr is not None else False,
animate_actions=config.animate_actions if config.animate_actions is not None else False,
)

elif config.agent_type == AgentTypes.FILE_SURFER:
return FileSurfer(
name=config.name,
model_client=model_client,
)
elif config.agent_type == AgentTypes.MAGENTIC_ONE_CODER:
return MagenticOneCoderAgent(
name=config.name,
model_client=model_client,
)
else:
raise ValueError(f"Unsupported agent type: {config.agent_type}")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,14 @@ class AgentTypes(str, Enum):
ASSISTANT = "AssistantAgent"
USERPROXY = "UserProxyAgent"
MULTIMODAL_WEBSURFER = "MultimodalWebSurfer"
FILE_SURFER = "FileSurfer"
MAGENTIC_ONE_CODER = "MagenticOneCoderAgent"


class TeamTypes(str, Enum):
ROUND_ROBIN = "RoundRobinGroupChat"
SELECTOR = "SelectorGroupChat"
MAGENTIC_ONE = "MagenticOneGroupChat"


class TerminationTypes(str, Enum):
Expand Down Expand Up @@ -103,6 +106,7 @@ class TeamConfig(BaseConfig):
selector_prompt: Optional[str] = None
termination_condition: Optional[TerminationConfig] = None
component_type: ComponentTypes = ComponentTypes.TEAM
max_turns: Optional[int] = None


class TeamResult(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,15 @@ async def start_stream(self, run_id: UUID, task: str, team_config: dict) -> None
await self._send_message(run_id, formatted_message)

# Save message if it's a content message
if isinstance(message, (AgentMessage, ChatMessage)):
if isinstance(message, TextMessage):
await self._save_message(run_id, message)
elif isinstance(message, MultiModalMessage):
await self._save_message(run_id, message)
# Capture final result if it's a TeamResult
elif isinstance(message, TeamResult):
final_result = message.model_dump()

elif isinstance(message, (AgentMessage, ChatMessage)):
await self._save_message(run_id, message)
if not cancellation_token.is_cancelled() and run_id not in self._closed_connections:
if final_result:
await self._update_run(run_id, RunStatus.COMPLETE, team_result=final_result)
Expand Down Expand Up @@ -285,6 +288,7 @@ def _format_message(self, message: Any) -> Optional[dict]:
Returns:
Optional[dict]: Formatted message or None if formatting fails
"""

try:
if isinstance(message, MultiModalMessage):
message_dump = message.model_dump()
Expand All @@ -296,7 +300,8 @@ def _format_message(self, message: Any) -> Optional[dict]:
},
]
return {"type": "message", "data": message_dump}
elif isinstance(message, (AgentMessage, ChatMessage)):

elif isinstance(message, TextMessage):
return {"type": "message", "data": message.model_dump()}

elif isinstance(message, TeamResult):
Expand All @@ -305,6 +310,9 @@ def _format_message(self, message: Any) -> Optional[dict]:
"data": message.model_dump(),
"status": "complete",
}
elif isinstance(message, (AgentMessage, ChatMessage)):
return {"type": "message", "data": message.model_dump()}

return None
except Exception as e:
logger.error(f"Message formatting error: {e}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,11 @@ export type ModelTypes = "OpenAIChatCompletionClient";
export type AgentTypes =
| "AssistantAgent"
| "CodingAssistantAgent"
| "MultimodalWebSurfer";
| "MultimodalWebSurfer"
| "FileSurfer"
| "MagenticOneCoderAgent";

export type TeamTypes = "RoundRobinGroupChat" | "SelectorGroupChat";
export type TeamTypes = "RoundRobinGroupChat" | "SelectorGroupChat" | "MagenticOneGroupChat";

// class ComponentType(str, Enum):
// TEAM = "team"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ export const TeamEditor: React.FC<TeamEditorProps> = ({
throw new Error("Participants must be an array");
}
if (
!["RoundRobinGroupChat", "SelectorGroupChat"].includes(parsed.team_type)
!["RoundRobinGroupChat", "SelectorGroupChat", "MagenticOneGroupChat"].includes(parsed.team_type)
) {
throw new Error("Invalid team_type");
}
Expand Down Expand Up @@ -169,7 +169,7 @@ export const TeamEditor: React.FC<TeamEditorProps> = ({
>
<div className="mb-2 text-xs text-gray-500">
Required fields: name (string), team_type ("RoundRobinGroupChat" |
"SelectorGroupChat"), participants (array)
"SelectorGroupChat" | "MagenticOneGroupChat"), participants (array)
</div>

<div className="h-[500px] mb-4">
Expand Down
Loading
Loading