diff --git a/python/packages/autogen-core/src/autogen_core/application/_worker_runtime_host_servicer.py b/python/packages/autogen-core/src/autogen_core/application/_worker_runtime_host_servicer.py index 847d6125caf..9308edbcd8d 100644 --- a/python/packages/autogen-core/src/autogen_core/application/_worker_runtime_host_servicer.py +++ b/python/packages/autogen-core/src/autogen_core/application/_worker_runtime_host_servicer.py @@ -27,6 +27,7 @@ def __init__(self) -> None: self._pending_responses: Dict[int, Dict[str, Future[Any]]] = {} self._background_tasks: Set[Task[Any]] = set() self._subscription_manager = SubscriptionManager() + self._client_id_to_subscription_id_mapping: Dict[int, set[str]] = {} async def OpenChannel( # type: ignore self, @@ -68,13 +69,18 @@ async def OpenChannel( # type: ignore for future in self._pending_responses.pop(client_id, {}).values(): future.cancel() # Remove the client id from the agent type to client id mapping. - async with self._agent_type_to_client_id_lock: - agent_types = [ - agent_type for agent_type, id_ in self._agent_type_to_client_id.items() if id_ == client_id - ] - for agent_type in agent_types: - del self._agent_type_to_client_id[agent_type] - logger.info(f"Client {client_id} disconnected.") + await self._on_client_disconnect(client_id) + + async def _on_client_disconnect(self, client_id: int) -> None: + async with self._agent_type_to_client_id_lock: + agent_types = [agent_type for agent_type, id_ in self._agent_type_to_client_id.items() if id_ == client_id] + for agent_type in agent_types: + logger.info(f"Removing agent type {agent_type} from agent type to client id mapping") + del self._agent_type_to_client_id[agent_type] + for sub_id in self._client_id_to_subscription_id_mapping.get(client_id, []): + logger.info(f"Client id {client_id} disconnected. Removing corresponding subscription with id {id}") + await self._subscription_manager.remove_subscription(sub_id) + logger.info(f"Client {client_id} disconnected successfully") def _raise_on_exception(self, task: Task[Any]) -> None: exception = task.exception() @@ -220,6 +226,8 @@ async def _process_add_subscription_request( ) try: await self._subscription_manager.add_subscription(type_subscription) + subscription_ids = self._client_id_to_subscription_id_mapping.setdefault(client_id, set()) + subscription_ids.add(type_subscription.id) success = True error = None except ValueError as e: diff --git a/python/packages/autogen-core/tests/test_utils/__init__.py b/python/packages/autogen-core/tests/test_utils/__init__.py index 4c377b2185d..eb64be1d8b2 100644 --- a/python/packages/autogen-core/tests/test_utils/__init__.py +++ b/python/packages/autogen-core/tests/test_utils/__init__.py @@ -46,3 +46,24 @@ def __init__(self) -> None: async def on_message(self, message: Any, ctx: MessageContext) -> Any: raise NotImplementedError + + +@dataclass +class MyMessage: + content: str + + +@default_subscription +class MyAgent(RoutedAgent): + def __init__(self, name: str) -> None: + super().__init__("My agent") + self._name = name + self._counter = 0 + + @message_handler + async def my_message_handler(self, message: MyMessage, ctx: MessageContext) -> None: + self._counter += 1 + if self._counter > 5: + return + content = f"{self._name}: Hello x {self._counter}" + await self.publish_message(MyMessage(content=content), DefaultTopicId()) diff --git a/python/packages/autogen-core/tests/test_worker_runtime.py b/python/packages/autogen-core/tests/test_worker_runtime.py index edc5455bae0..1f9698c04fb 100644 --- a/python/packages/autogen-core/tests/test_worker_runtime.py +++ b/python/packages/autogen-core/tests/test_worker_runtime.py @@ -1,5 +1,6 @@ import asyncio import logging +import os from typing import List import pytest @@ -10,13 +11,14 @@ TopicId, try_get_known_serializers_for_type, ) +from autogen_core.base._subscription import Subscription from autogen_core.components import ( DefaultTopicId, TypeSubscription, default_subscription, type_subscription, ) -from test_utils import CascadingAgent, CascadingMessageType, LoopbackAgent, MessageType, NoopAgent +from test_utils import CascadingAgent, CascadingMessageType, LoopbackAgent, MessageType, MyAgent, MyMessage, NoopAgent @pytest.mark.asyncio @@ -300,3 +302,102 @@ class LoopbackAgentWithSubscription(LoopbackAgent): ... await worker.stop() await publisher.stop() await host.stop() + + +@pytest.mark.asyncio +async def test_duplicate_subscription() -> None: + host_address = "localhost:50059" + host = WorkerAgentRuntimeHost(address=host_address) + worker1 = WorkerAgentRuntime(host_address=host_address) + worker1_2 = WorkerAgentRuntime(host_address=host_address) + host.start() + try: + worker1.start() + await MyAgent.register(worker1, "worker1", lambda: MyAgent("worker1")) + + worker1_2.start() + + # Note: This passes because worker1 is still running + with pytest.raises(RuntimeError, match="Agent type worker1 already registered"): + await MyAgent.register(worker1_2, "worker1", lambda: MyAgent("worker1_2")) + + # This is somehow covered in test_disconnected_agent as well as a stop will also disconnect the agent. + # Will keep them both for now as we might replace the way we simulate a disconnect + await worker1.stop() + + with pytest.raises(ValueError): + await MyAgent.register(worker1_2, "worker1", lambda: MyAgent("worker1_2")) + + except Exception as ex: + raise ex + finally: + await worker1_2.stop() + await host.stop() + + +@pytest.mark.asyncio +async def test_disconnected_agent() -> None: + host_address = "localhost:50059" + host = WorkerAgentRuntimeHost(address=host_address) + host.start() + worker1 = WorkerAgentRuntime(host_address=host_address) + worker1_2 = WorkerAgentRuntime(host_address=host_address) + + # TODO: Implementing `get_current_subscriptions` and `get_subscribed_recipients` requires access + # to some private properties. This needs to be updated once they are available publicly + + def get_current_subscriptions() -> List[Subscription]: + return host._servicer._subscription_manager._subscriptions # type: ignore[reportPrivateUsage] + + async def get_subscribed_recipients() -> List[AgentId]: + return await host._servicer._subscription_manager.get_subscribed_recipients(DefaultTopicId()) # type: ignore[reportPrivateUsage] + + try: + worker1.start() + await MyAgent.register(worker1, "worker1", lambda: MyAgent("worker1")) + + subscriptions1 = get_current_subscriptions() + assert len(subscriptions1) == 1 + recipients1 = await get_subscribed_recipients() + assert AgentId(type="worker1", key="default") in recipients1 + + first_subscription_id = subscriptions1[0].id + + await worker1.publish_message(MyMessage(content="Hello!"), DefaultTopicId()) + # This is a simple simulation of worker disconnct + if worker1._host_connection is not None: # type: ignore[reportPrivateUsage] + try: + await worker1._host_connection.close() # type: ignore[reportPrivateUsage] + except asyncio.CancelledError: + pass + + await asyncio.sleep(1) + + subscriptions2 = get_current_subscriptions() + assert len(subscriptions2) == 0 + recipients2 = await get_subscribed_recipients() + assert len(recipients2) == 0 + await asyncio.sleep(1) + + worker1_2.start() + await MyAgent.register(worker1_2, "worker1", lambda: MyAgent("worker1")) + + subscriptions3 = get_current_subscriptions() + assert len(subscriptions3) == 1 + assert first_subscription_id not in [x.id for x in subscriptions3] + + recipients3 = await get_subscribed_recipients() + assert len(set(recipients2)) == len(recipients2) # Make sure there are no duplicates + assert AgentId(type="worker1", key="default") in recipients3 + except Exception as ex: + raise ex + finally: + await worker1.stop() + await worker1_2.stop() + await host.stop() + + +if __name__ == "__main__": + os.environ["GRPC_VERBOSITY"] = "DEBUG" + os.environ["GRPC_TRACE"] = "all" + asyncio.run(test_disconnected_agent())