Skip to content

Commit

Permalink
add a way to provide extra grpc options (#3667)
Browse files Browse the repository at this point in the history
* rebase and address PR comments

* address PR feedback
  • Loading branch information
MohMaz authored Oct 8, 2024
1 parent e400567 commit 29c23d5
Show file tree
Hide file tree
Showing 8 changed files with 171 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,22 @@ This allows your agent to work in a distributed environment a well as a local on
An {py:class}`autogen_core.base.AgentId` is composed of a `type` and a `key`. The type corresponds to the factory that created the agent, and the key is a runtime, data dependent key for this instance.

The key can correspond to a user id, a session id, or could just be "default" if you don't need to differentiate between instances. Each unique key will create a new instance of the agent, based on the factory provided. This allows the system to automatically scale to different instances of the same agent, and to manage the lifecycle of each instance independently based on how you choose to handle keys in your application.

## How do I increase the GRPC message size?

If you need to provide custom gRPC options, such as overriding the `max_send_message_length` and `max_receive_message_length`, you can define an `extra_grpc_config` variable and pass it to both the `WorkerAgentRuntimeHost` and `WorkerAgentRuntime` instances.

```python
# Define custom gRPC options
extra_grpc_config = [
("grpc.max_send_message_length", new_max_size),
("grpc.max_receive_message_length", new_max_size),
]

# Create instances of WorkerAgentRuntimeHost and WorkerAgentRuntime with the custom gRPC options

host = WorkerAgentRuntimeHost(address=host_address, extra_grpc_config=extra_grpc_config)
worker1 = WorkerAgentRuntime(host_address=host_address, extra_grpc_config=extra_grpc_config)
```

**Note**: When `WorkerAgentRuntime` creates a host connection for the clients, it uses `DEFAULT_GRPC_CONFIG` from `HostConnection` class as default set of values which will can be overriden if you pass parameters with the same name using `extra_grpc_config`.
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,4 @@
from ._worker_runtime import WorkerAgentRuntime
from ._worker_runtime_host import WorkerAgentRuntimeHost

__all__ = [
"SingleThreadedAgentRuntime",
"WorkerAgentRuntime",
"WorkerAgentRuntimeHost",
]
__all__ = ["SingleThreadedAgentRuntime", "WorkerAgentRuntime", "WorkerAgentRuntimeHost"]
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

from autogen_core.base import JSON_DATA_CONTENT_TYPE
from autogen_core.base._serialization import MessageSerializer, SerializationRegistry
from autogen_core.base._type_helpers import ChannelArgumentType

from ..base import (
Agent,
Expand Down Expand Up @@ -63,6 +64,7 @@
P = ParamSpec("P")
T = TypeVar("T", bound=Agent)


type_func_alias = type


Expand All @@ -78,20 +80,27 @@ def __aiter__(self) -> AsyncIterator[Any]:


class HostConnection:
DEFAULT_GRPC_CONFIG: ClassVar[Mapping[str, Any]] = {
"methodConfig": [
{
"name": [{}],
"retryPolicy": {
"maxAttempts": 3,
"initialBackoff": "0.01s",
"maxBackoff": "5s",
"backoffMultiplier": 2,
"retryableStatusCodes": ["UNAVAILABLE"],
},
}
],
}
DEFAULT_GRPC_CONFIG: ClassVar[ChannelArgumentType] = [
(
"grpc.service_config",
json.dumps(
{
"methodConfig": [
{
"name": [{}],
"retryPolicy": {
"maxAttempts": 3,
"initialBackoff": "0.01s",
"maxBackoff": "5s",
"backoffMultiplier": 2,
"retryableStatusCodes": ["UNAVAILABLE"],
},
}
],
}
),
)
]

def __init__(self, channel: grpc.aio.Channel) -> None: # type: ignore
self._channel = channel
Expand All @@ -100,9 +109,17 @@ def __init__(self, channel: grpc.aio.Channel) -> None: # type: ignore
self._connection_task: Task[None] | None = None

@classmethod
def from_host_address(cls, host_address: str, grpc_config: Mapping[str, Any] = DEFAULT_GRPC_CONFIG) -> Self:
def from_host_address(cls, host_address: str, extra_grpc_config: ChannelArgumentType = DEFAULT_GRPC_CONFIG) -> Self:
logger.info("Connecting to %s", host_address)
channel = grpc.aio.insecure_channel(host_address, options=[("grpc.service_config", json.dumps(grpc_config))])
# Always use DEFAULT_GRPC_CONFIG and override it with provided grpc_config
merged_options = [
(k, v) for k, v in {**dict(HostConnection.DEFAULT_GRPC_CONFIG), **dict(extra_grpc_config)}.items()
]

channel = grpc.aio.insecure_channel(
host_address,
options=merged_options,
)
instance = cls(channel)
instance._connection_task = asyncio.create_task(
instance._connect(channel, instance._send_queue, instance._recv_queue)
Expand Down Expand Up @@ -150,7 +167,12 @@ async def recv(self) -> agent_worker_pb2.Message:


class WorkerAgentRuntime(AgentRuntime):
def __init__(self, host_address: str, tracer_provider: TracerProvider | None = None) -> None:
def __init__(
self,
host_address: str,
tracer_provider: TracerProvider | None = None,
extra_grpc_config: ChannelArgumentType | None = None,
) -> None:
self._host_address = host_address
self._trace_helper = TraceHelper(tracer_provider, MessageRuntimeTracingConfig("Worker Runtime"))
self._per_type_subscribers: DefaultDict[tuple[str, str], Set[AgentId]] = defaultdict(set)
Expand All @@ -168,13 +190,16 @@ def __init__(self, host_address: str, tracer_provider: TracerProvider | None = N
self._background_tasks: Set[Task[Any]] = set()
self._subscription_manager = SubscriptionManager()
self._serialization_registry = SerializationRegistry()
self._extra_grpc_config = extra_grpc_config or []

def start(self) -> None:
"""Start the runtime in a background task."""
if self._running:
raise ValueError("Runtime is already running.")
logger.info(f"Connecting to host: {self._host_address}")
self._host_connection = HostConnection.from_host_address(self._host_address)
self._host_connection = HostConnection.from_host_address(
self._host_address, extra_grpc_config=self._extra_grpc_config
)
logger.info("Connection established")
if self._read_task is None:
self._read_task = asyncio.create_task(self._run_read_loop())
Expand Down
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
import asyncio
import logging
import signal
from typing import Sequence
from typing import Optional, Sequence

import grpc

from autogen_core.base._type_helpers import ChannelArgumentType

from ._worker_runtime_host_servicer import WorkerAgentRuntimeHostServicer
from .protos import agent_worker_pb2_grpc

logger = logging.getLogger("autogen_core")


class WorkerAgentRuntimeHost:
def __init__(self, address: str) -> None:
self._server = grpc.aio.server()
def __init__(self, address: str, extra_grpc_config: Optional[ChannelArgumentType] = None) -> None:
self._server = grpc.aio.server(options=extra_grpc_config)
self._servicer = WorkerAgentRuntimeHostServicer()
agent_worker_pb2_grpc.add_AgentRpcServicer_to_server(self._servicer, self._server)
self._server.add_insecure_port(address)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from collections.abc import Sequence
from types import NoneType, UnionType
from typing import Any, Optional, Type, Union, get_args, get_origin
from typing import Any, Optional, Tuple, Type, Union, get_args, get_origin

# Had to redefine this from grpc.aio._typing as using that one was causing mypy errors
ChannelArgumentType = Sequence[Tuple[str, Any]]


def is_union(t: object) -> bool:
Expand Down
16 changes: 8 additions & 8 deletions python/packages/autogen-core/tests/test_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,17 @@
from autogen_core.components import (
DefaultTopicId,
TypeSubscription,
default_subscription,
type_subscription,
)
from opentelemetry.sdk.trace import TracerProvider
from test_utils import CascadingAgent, CascadingMessageType, LoopbackAgent, MessageType, NoopAgent
from test_utils import (
CascadingAgent,
CascadingMessageType,
LoopbackAgent,
LoopbackAgentWithDefaultSubscription,
MessageType,
NoopAgent,
)
from test_utils.telemetry_test_utils import TestExporter, get_test_tracer_provider

test_exporter = TestExporter()
Expand Down Expand Up @@ -218,9 +224,6 @@ async def test_default_subscription() -> None:
runtime = SingleThreadedAgentRuntime()
runtime.start()

@default_subscription
class LoopbackAgentWithDefaultSubscription(LoopbackAgent): ...

await LoopbackAgentWithDefaultSubscription.register(runtime, "name", LoopbackAgentWithDefaultSubscription)

agent_id = AgentId("name", key="default")
Expand Down Expand Up @@ -267,9 +270,6 @@ async def test_default_subscription_publish_to_other_source() -> None:
runtime = SingleThreadedAgentRuntime()
runtime.start()

@default_subscription
class LoopbackAgentWithDefaultSubscription(LoopbackAgent): ...

await LoopbackAgentWithDefaultSubscription.register(runtime, "name", LoopbackAgentWithDefaultSubscription)

agent_id = AgentId("name", key="default")
Expand Down
34 changes: 12 additions & 22 deletions python/packages/autogen-core/tests/test_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,28 @@ class CascadingMessageType:
round: int


@dataclass
class ContentMessage:
content: str


class LoopbackAgent(RoutedAgent):
def __init__(self) -> None:
super().__init__("A loop back agent.")
self.num_calls = 0

@message_handler
async def on_new_message(self, message: MessageType, ctx: MessageContext) -> MessageType:
async def on_new_message(
self, message: MessageType | ContentMessage, ctx: MessageContext
) -> MessageType | ContentMessage:
self.num_calls += 1
return message


@default_subscription
class LoopbackAgentWithDefaultSubscription(LoopbackAgent): ...


@default_subscription
class CascadingAgent(RoutedAgent):
def __init__(self, max_rounds: int) -> None:
Expand All @@ -46,24 +57,3 @@ def __init__(self) -> None:

async def on_message(self, message: Any, ctx: MessageContext) -> Any:
raise NotImplementedError


@dataclass
class MyMessage:
content: str


@default_subscription
class MyAgent(RoutedAgent):
def __init__(self, name: str) -> None:
super().__init__("My agent")
self._name = name
self._counter = 0

@message_handler
async def my_message_handler(self, message: MyMessage, ctx: MessageContext) -> None:
self._counter += 1
if self._counter > 5:
return
content = f"{self._name}: Hello x {self._counter}"
await self.publish_message(MyMessage(content=content), DefaultTopicId())
Loading

0 comments on commit 29c23d5

Please sign in to comment.