Skip to content

Commit

Permalink
Merge branch 'staging' into ekzhu-tools
Browse files Browse the repository at this point in the history
  • Loading branch information
ekzhu authored Oct 7, 2024
2 parents 4e06442 + be5c0b5 commit db9226c
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(self) -> None:
self._pending_responses: Dict[int, Dict[str, Future[Any]]] = {}
self._background_tasks: Set[Task[Any]] = set()
self._subscription_manager = SubscriptionManager()
self._client_id_to_subscription_id_mapping: Dict[int, set[str]] = {}

async def OpenChannel( # type: ignore
self,
Expand Down Expand Up @@ -68,13 +69,18 @@ async def OpenChannel( # type: ignore
for future in self._pending_responses.pop(client_id, {}).values():
future.cancel()
# Remove the client id from the agent type to client id mapping.
async with self._agent_type_to_client_id_lock:
agent_types = [
agent_type for agent_type, id_ in self._agent_type_to_client_id.items() if id_ == client_id
]
for agent_type in agent_types:
del self._agent_type_to_client_id[agent_type]
logger.info(f"Client {client_id} disconnected.")
await self._on_client_disconnect(client_id)

async def _on_client_disconnect(self, client_id: int) -> None:
async with self._agent_type_to_client_id_lock:
agent_types = [agent_type for agent_type, id_ in self._agent_type_to_client_id.items() if id_ == client_id]
for agent_type in agent_types:
logger.info(f"Removing agent type {agent_type} from agent type to client id mapping")
del self._agent_type_to_client_id[agent_type]
for sub_id in self._client_id_to_subscription_id_mapping.get(client_id, []):
logger.info(f"Client id {client_id} disconnected. Removing corresponding subscription with id {id}")
await self._subscription_manager.remove_subscription(sub_id)
logger.info(f"Client {client_id} disconnected successfully")

def _raise_on_exception(self, task: Task[Any]) -> None:
exception = task.exception()
Expand Down Expand Up @@ -220,6 +226,8 @@ async def _process_add_subscription_request(
)
try:
await self._subscription_manager.add_subscription(type_subscription)
subscription_ids = self._client_id_to_subscription_id_mapping.setdefault(client_id, set())
subscription_ids.add(type_subscription.id)
success = True
error = None
except ValueError as e:
Expand Down
21 changes: 21 additions & 0 deletions python/packages/autogen-core/tests/test_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,24 @@ def __init__(self) -> None:

async def on_message(self, message: Any, ctx: MessageContext) -> Any:
raise NotImplementedError


@dataclass
class MyMessage:
content: str


@default_subscription
class MyAgent(RoutedAgent):
def __init__(self, name: str) -> None:
super().__init__("My agent")
self._name = name
self._counter = 0

@message_handler
async def my_message_handler(self, message: MyMessage, ctx: MessageContext) -> None:
self._counter += 1
if self._counter > 5:
return
content = f"{self._name}: Hello x {self._counter}"
await self.publish_message(MyMessage(content=content), DefaultTopicId())
103 changes: 102 additions & 1 deletion python/packages/autogen-core/tests/test_worker_runtime.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import logging
import os
from typing import List

import pytest
Expand All @@ -10,13 +11,14 @@
TopicId,
try_get_known_serializers_for_type,
)
from autogen_core.base._subscription import Subscription
from autogen_core.components import (
DefaultTopicId,
TypeSubscription,
default_subscription,
type_subscription,
)
from test_utils import CascadingAgent, CascadingMessageType, LoopbackAgent, MessageType, NoopAgent
from test_utils import CascadingAgent, CascadingMessageType, LoopbackAgent, MessageType, MyAgent, MyMessage, NoopAgent


@pytest.mark.asyncio
Expand Down Expand Up @@ -300,3 +302,102 @@ class LoopbackAgentWithSubscription(LoopbackAgent): ...
await worker.stop()
await publisher.stop()
await host.stop()


@pytest.mark.asyncio
async def test_duplicate_subscription() -> None:
host_address = "localhost:50059"
host = WorkerAgentRuntimeHost(address=host_address)
worker1 = WorkerAgentRuntime(host_address=host_address)
worker1_2 = WorkerAgentRuntime(host_address=host_address)
host.start()
try:
worker1.start()
await MyAgent.register(worker1, "worker1", lambda: MyAgent("worker1"))

worker1_2.start()

# Note: This passes because worker1 is still running
with pytest.raises(RuntimeError, match="Agent type worker1 already registered"):
await MyAgent.register(worker1_2, "worker1", lambda: MyAgent("worker1_2"))

# This is somehow covered in test_disconnected_agent as well as a stop will also disconnect the agent.
# Will keep them both for now as we might replace the way we simulate a disconnect
await worker1.stop()

with pytest.raises(ValueError):
await MyAgent.register(worker1_2, "worker1", lambda: MyAgent("worker1_2"))

except Exception as ex:
raise ex
finally:
await worker1_2.stop()
await host.stop()


@pytest.mark.asyncio
async def test_disconnected_agent() -> None:
host_address = "localhost:50059"
host = WorkerAgentRuntimeHost(address=host_address)
host.start()
worker1 = WorkerAgentRuntime(host_address=host_address)
worker1_2 = WorkerAgentRuntime(host_address=host_address)

# TODO: Implementing `get_current_subscriptions` and `get_subscribed_recipients` requires access
# to some private properties. This needs to be updated once they are available publicly

def get_current_subscriptions() -> List[Subscription]:
return host._servicer._subscription_manager._subscriptions # type: ignore[reportPrivateUsage]

async def get_subscribed_recipients() -> List[AgentId]:
return await host._servicer._subscription_manager.get_subscribed_recipients(DefaultTopicId()) # type: ignore[reportPrivateUsage]

try:
worker1.start()
await MyAgent.register(worker1, "worker1", lambda: MyAgent("worker1"))

subscriptions1 = get_current_subscriptions()
assert len(subscriptions1) == 1
recipients1 = await get_subscribed_recipients()
assert AgentId(type="worker1", key="default") in recipients1

first_subscription_id = subscriptions1[0].id

await worker1.publish_message(MyMessage(content="Hello!"), DefaultTopicId())
# This is a simple simulation of worker disconnct
if worker1._host_connection is not None: # type: ignore[reportPrivateUsage]
try:
await worker1._host_connection.close() # type: ignore[reportPrivateUsage]
except asyncio.CancelledError:
pass

await asyncio.sleep(1)

subscriptions2 = get_current_subscriptions()
assert len(subscriptions2) == 0
recipients2 = await get_subscribed_recipients()
assert len(recipients2) == 0
await asyncio.sleep(1)

worker1_2.start()
await MyAgent.register(worker1_2, "worker1", lambda: MyAgent("worker1"))

subscriptions3 = get_current_subscriptions()
assert len(subscriptions3) == 1
assert first_subscription_id not in [x.id for x in subscriptions3]

recipients3 = await get_subscribed_recipients()
assert len(set(recipients2)) == len(recipients2) # Make sure there are no duplicates
assert AgentId(type="worker1", key="default") in recipients3
except Exception as ex:
raise ex
finally:
await worker1.stop()
await worker1_2.stop()
await host.stop()


if __name__ == "__main__":
os.environ["GRPC_VERBOSITY"] = "DEBUG"
os.environ["GRPC_TRACE"] = "all"
asyncio.run(test_disconnected_agent())

0 comments on commit db9226c

Please sign in to comment.