Skip to content

Commit

Permalink
Agent name termination (microsoft#4123)
Browse files Browse the repository at this point in the history
  • Loading branch information
thainduy authored Nov 23, 2024
1 parent 8f4d8c8 commit 0b5eaf1
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
TextMentionTermination,
TimeoutTermination,
TokenUsageTermination,
SourceMatchTermination,
)

__all__ = [
Expand All @@ -17,5 +18,6 @@
"HandoffTermination",
"TimeoutTermination",
"ExternalTermination",
"SourceMatchTermination",
"Console",
]
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import time
from typing import Sequence
from typing import Sequence, List

from ..base import TerminatedException, TerminationCondition
from ..messages import AgentMessage, HandoffMessage, MultiModalMessage, StopMessage, TextMessage
Expand Down Expand Up @@ -251,3 +251,36 @@ async def __call__(self, messages: Sequence[AgentMessage]) -> StopMessage | None
async def reset(self) -> None:
self._terminated = False
self._setted = False


class SourceMatchTermination(TerminationCondition):
"""Terminate the conversation after a specific source responds.
Args:
sources (List[str]): List of source names to terminate the conversation.
Raises:
TerminatedException: If the termination condition has already been reached.
"""

def __init__(self, sources: List[str]) -> None:
self._sources = sources
self._terminated = False

@property
def terminated(self) -> bool:
return self._terminated

async def __call__(self, messages: Sequence[AgentMessage]) -> StopMessage | None:
if self._terminated:
raise TerminatedException("Termination condition has already been reached")
if not messages:
return None
for message in messages:
if message.source in self._sources:
self._terminated = True
return StopMessage(content=f"'{message.source}' answered", source="SourceMatchTermination")
return None

async def reset(self) -> None:
self._terminated = False
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio

import pytest
from autogen_agentchat.base import TerminatedException
from autogen_agentchat.messages import HandoffMessage, StopMessage, TextMessage
from autogen_agentchat.task import (
ExternalTermination,
Expand All @@ -10,6 +11,7 @@
TextMentionTermination,
TimeoutTermination,
TokenUsageTermination,
SourceMatchTermination,
)
from autogen_core.components.models import RequestUsage

Expand Down Expand Up @@ -242,3 +244,26 @@ async def test_external_termination() -> None:

await termination.reset()
assert await termination([]) is None


@pytest.mark.asyncio
async def test_source_match_termination() -> None:
termination = SourceMatchTermination(sources=["Assistant"])
assert await termination([]) is None

continue_messages = [TextMessage(content="Hello", source="agent"), TextMessage(content="Hello", source="user")]
assert await termination(continue_messages) is None

terminate_messages = [
TextMessage(content="Hello", source="agent"),
TextMessage(content="Hello", source="Assistant"),
TextMessage(content="Hello", source="user"),
]
result = await termination(terminate_messages)
assert isinstance(result, StopMessage)
assert termination.terminated

with pytest.raises(TerminatedException):
await termination([])
await termination.reset()
assert not termination.terminated

0 comments on commit 0b5eaf1

Please sign in to comment.