-
Notifications
You must be signed in to change notification settings - Fork 5.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add group chat pattern, create separate folder for patterns (#117)
* add tool use example; refactor example directory * update * add more examples * fix * fix * doc * move * add group chat example, create patterns folder
- Loading branch information
Showing
7 changed files
with
212 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
import asyncio | ||
from dataclasses import dataclass | ||
from typing import Any, List | ||
|
||
from agnext.application import SingleThreadedAgentRuntime | ||
from agnext.components import TypeRoutedAgent, message_handler | ||
from agnext.components.models import ( | ||
AssistantMessage, | ||
ChatCompletionClient, | ||
LLMMessage, | ||
OpenAI, | ||
SystemMessage, | ||
UserMessage, | ||
) | ||
from agnext.core import AgentId, CancellationToken | ||
from agnext.core.intervention import DefaultInterventionHandler | ||
|
||
|
||
@dataclass | ||
class Message: | ||
source: str | ||
content: str | ||
|
||
|
||
@dataclass | ||
class RequestToSpeak: | ||
pass | ||
|
||
|
||
@dataclass | ||
class Termination: | ||
pass | ||
|
||
|
||
class RoundRobinGroupChatManager(TypeRoutedAgent): | ||
def __init__( | ||
self, | ||
description: str, | ||
participants: List[AgentId], | ||
num_rounds: int, | ||
) -> None: | ||
super().__init__(description) | ||
self._participants = participants | ||
self._num_rounds = num_rounds | ||
self._round_count = 0 | ||
|
||
@message_handler | ||
async def handle_message(self, message: Message, cancellation_token: CancellationToken) -> None: | ||
# Select the next speaker in a round-robin fashion | ||
speaker = self._participants[self._round_count % len(self._participants)] | ||
self._round_count += 1 | ||
if self._round_count == self._num_rounds * len(self._participants): | ||
# End the conversation after the specified number of rounds. | ||
self.publish_message(Termination()) | ||
return | ||
# Send a request to speak message to the selected speaker. | ||
self.send_message(RequestToSpeak(), speaker) | ||
|
||
|
||
class GroupChatParticipant(TypeRoutedAgent): | ||
def __init__( | ||
self, | ||
description: str, | ||
system_messages: List[SystemMessage], | ||
model_client: ChatCompletionClient, | ||
) -> None: | ||
super().__init__(description) | ||
self._system_messages = system_messages | ||
self._model_client = model_client | ||
self._memory: List[Message] = [] | ||
|
||
@message_handler | ||
async def handle_message(self, message: Message, cancellation_token: CancellationToken) -> None: | ||
self._memory.append(message) | ||
|
||
@message_handler | ||
async def handle_request_to_speak(self, message: RequestToSpeak, cancellation_token: CancellationToken) -> None: | ||
# Generate a response to the last message in the memory | ||
if not self._memory: | ||
return | ||
llm_messages: List[LLMMessage] = [] | ||
for m in self._memory[-10:]: | ||
if m.source == self.metadata["name"]: | ||
llm_messages.append(AssistantMessage(content=m.content, source=self.metadata["name"])) | ||
else: | ||
llm_messages.append(UserMessage(content=m.content, source=m.source)) | ||
response = await self._model_client.create(self._system_messages + llm_messages) | ||
assert isinstance(response.content, str) | ||
speach = Message(content=response.content, source=self.metadata["name"]) | ||
self._memory.append(speach) | ||
self.publish_message(speach) | ||
|
||
|
||
class TerminationHandler(DefaultInterventionHandler): | ||
"""A handler that listens for termination messages.""" | ||
|
||
def __init__(self) -> None: | ||
self._terminated = False | ||
|
||
async def on_publish(self, message: Any, *, sender: AgentId | None) -> Any: | ||
if isinstance(message, Termination): | ||
self._terminated = True | ||
return message | ||
|
||
@property | ||
def terminated(self) -> bool: | ||
return self._terminated | ||
|
||
|
||
async def main() -> None: | ||
# Create the termination handler. | ||
termination_handler = TerminationHandler() | ||
|
||
# Create the runtime. | ||
runtime = SingleThreadedAgentRuntime(intervention_handler=termination_handler) | ||
|
||
# Register the participants. | ||
agent1 = runtime.register_and_get( | ||
"DataScientist", | ||
lambda: GroupChatParticipant( | ||
description="A data scientist", | ||
system_messages=[SystemMessage("You are a data scientist.")], | ||
model_client=OpenAI(model="gpt-3.5-turbo"), | ||
), | ||
) | ||
agent2 = runtime.register_and_get( | ||
"Engineer", | ||
lambda: GroupChatParticipant( | ||
description="An engineer", | ||
system_messages=[SystemMessage("You are an engineer.")], | ||
model_client=OpenAI(model="gpt-3.5-turbo"), | ||
), | ||
) | ||
agent3 = runtime.register_and_get( | ||
"Artist", | ||
lambda: GroupChatParticipant( | ||
description="An artist", | ||
system_messages=[SystemMessage("You are an artist.")], | ||
model_client=OpenAI(model="gpt-3.5-turbo"), | ||
), | ||
) | ||
|
||
# Register the group chat manager. | ||
runtime.register( | ||
"GroupChatManager", | ||
lambda: RoundRobinGroupChatManager( | ||
description="A group chat manager", | ||
participants=[agent1, agent2, agent3], | ||
num_rounds=3, | ||
), | ||
) | ||
|
||
# Start the conversation. | ||
runtime.publish_message(Message(content="Hello, everyone!", source="Moderator"), namespace="default") | ||
|
||
# Run the runtime until termination. | ||
while not termination_handler.terminated: | ||
await runtime.process_next() | ||
|
||
|
||
if __name__ == "__main__": | ||
import logging | ||
|
||
logging.basicConfig(level=logging.WARNING) | ||
logging.getLogger("agnext").setLevel(logging.DEBUG) | ||
asyncio.run(main()) |
File renamed without changes.
File renamed without changes.