diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/task/__init__.py b/python/packages/autogen-agentchat/src/autogen_agentchat/task/__init__.py index dd7b6265ad4..e1e6766338d 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/task/__init__.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/task/__init__.py @@ -7,6 +7,7 @@ TextMentionTermination, TimeoutTermination, TokenUsageTermination, + SourceMatchTermination, ) __all__ = [ @@ -17,5 +18,6 @@ "HandoffTermination", "TimeoutTermination", "ExternalTermination", + "SourceMatchTermination", "Console", ] diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/task/_terminations.py b/python/packages/autogen-agentchat/src/autogen_agentchat/task/_terminations.py index f8d79cef285..81cb5cca7d6 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/task/_terminations.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/task/_terminations.py @@ -1,5 +1,5 @@ import time -from typing import Sequence +from typing import Sequence, List from ..base import TerminatedException, TerminationCondition from ..messages import AgentMessage, HandoffMessage, MultiModalMessage, StopMessage, TextMessage @@ -251,3 +251,36 @@ async def __call__(self, messages: Sequence[AgentMessage]) -> StopMessage | None async def reset(self) -> None: self._terminated = False self._setted = False + + +class SourceMatchTermination(TerminationCondition): + """Terminate the conversation after a specific source responds. + + Args: + sources (List[str]): List of source names to terminate the conversation. + + Raises: + TerminatedException: If the termination condition has already been reached. + """ + + def __init__(self, sources: List[str]) -> None: + self._sources = sources + self._terminated = False + + @property + def terminated(self) -> bool: + return self._terminated + + async def __call__(self, messages: Sequence[AgentMessage]) -> StopMessage | None: + if self._terminated: + raise TerminatedException("Termination condition has already been reached") + if not messages: + return None + for message in messages: + if message.source in self._sources: + self._terminated = True + return StopMessage(content=f"'{message.source}' answered", source="SourceMatchTermination") + return None + + async def reset(self) -> None: + self._terminated = False diff --git a/python/packages/autogen-agentchat/tests/test_termination_condition.py b/python/packages/autogen-agentchat/tests/test_termination_condition.py index c09e0e1c14a..f4aa5d2a720 100644 --- a/python/packages/autogen-agentchat/tests/test_termination_condition.py +++ b/python/packages/autogen-agentchat/tests/test_termination_condition.py @@ -1,6 +1,7 @@ import asyncio import pytest +from autogen_agentchat.base import TerminatedException from autogen_agentchat.messages import HandoffMessage, StopMessage, TextMessage from autogen_agentchat.task import ( ExternalTermination, @@ -10,6 +11,7 @@ TextMentionTermination, TimeoutTermination, TokenUsageTermination, + SourceMatchTermination, ) from autogen_core.components.models import RequestUsage @@ -242,3 +244,26 @@ async def test_external_termination() -> None: await termination.reset() assert await termination([]) is None + + +@pytest.mark.asyncio +async def test_source_match_termination() -> None: + termination = SourceMatchTermination(sources=["Assistant"]) + assert await termination([]) is None + + continue_messages = [TextMessage(content="Hello", source="agent"), TextMessage(content="Hello", source="user")] + assert await termination(continue_messages) is None + + terminate_messages = [ + TextMessage(content="Hello", source="agent"), + TextMessage(content="Hello", source="Assistant"), + TextMessage(content="Hello", source="user"), + ] + result = await termination(terminate_messages) + assert isinstance(result, StopMessage) + assert termination.terminated + + with pytest.raises(TerminatedException): + await termination([]) + await termination.reset() + assert not termination.terminated