From 293af39f7baed6126a70303064ea23d9d098dcf2 Mon Sep 17 00:00:00 2001 From: Cindy Zhang Date: Fri, 8 Nov 2024 12:50:02 -0800 Subject: [PATCH] [serve] improve router resolve request arg func (#48658) ## Why are these changes needed? Refactor resolve request arg func. Separate into: 1. actual type checking + argument resolution (injected into the Router as a function) 2. the logic for executing that resolution and replacing args/kwargs (kept as part of Router) --------- Signed-off-by: Cindy Zhang Signed-off-by: mohitjain2504 --- python/ray/serve/_private/default_impl.py | 4 +- python/ray/serve/_private/router.py | 47 +++++++++++++++++--- python/ray/serve/_private/utils.py | 53 ++++------------------- 3 files changed, 52 insertions(+), 52 deletions(-) diff --git a/python/ray/serve/_private/default_impl.py b/python/ray/serve/_private/default_impl.py index 3b2abf6e0829..45314de071a8 100644 --- a/python/ray/serve/_private/default_impl.py +++ b/python/ray/serve/_private/default_impl.py @@ -27,7 +27,7 @@ get_current_actor_id, get_head_node_id, inside_ray_client_context, - resolve_request_args, + resolve_deployment_response, ) # NOTE: Please read carefully before changing! @@ -124,7 +124,7 @@ def create_router( not is_inside_ray_client_context and RAY_SERVE_ENABLE_STRICT_MAX_ONGOING_REQUESTS ), - resolve_request_args_func=resolve_request_args, + resolve_request_arg_func=resolve_deployment_response, ) diff --git a/python/ray/serve/_private/router.py b/python/ray/serve/_private/router.py index 2b83ff2d86d7..9cd8c10f5f82 100644 --- a/python/ray/serve/_private/router.py +++ b/python/ray/serve/_private/router.py @@ -8,7 +8,7 @@ from collections import defaultdict from contextlib import contextmanager from functools import partial -from typing import Any, Coroutine, DefaultDict, List, Optional, Tuple, Union +from typing import Any, Coroutine, DefaultDict, Dict, List, Optional, Tuple, Union import ray from ray.actor import ActorHandle @@ -31,7 +31,7 @@ from ray.serve._private.metrics_utils import InMemoryMetricsStore, MetricsPusher from ray.serve._private.replica_result import ReplicaResult from ray.serve._private.replica_scheduler import PendingRequest, ReplicaScheduler -from ray.serve._private.utils import resolve_request_args +from ray.serve._private.utils import resolve_deployment_response from ray.serve.config import AutoscalingConfig from ray.serve.exceptions import BackPressureError from ray.util import metrics @@ -342,7 +342,7 @@ def __init__( event_loop: asyncio.BaseEventLoop, replica_scheduler: Optional[ReplicaScheduler], enable_strict_max_ongoing_requests: bool, - resolve_request_args_func: Coroutine = resolve_request_args, + resolve_request_arg_func: Coroutine = resolve_deployment_response, ): """Used to assign requests to downstream replicas for a deployment. @@ -355,7 +355,7 @@ def __init__( self._enable_strict_max_ongoing_requests = enable_strict_max_ongoing_requests self._replica_scheduler: ReplicaScheduler = replica_scheduler - self._resolve_request_args = resolve_request_args_func + self._resolve_request_arg_func = resolve_request_arg_func # Flipped to `True` once the router has received a non-empty # replica set at least once. @@ -429,6 +429,43 @@ def update_deployment_config(self, deployment_config: DeploymentConfig): curr_num_replicas=len(self._replica_scheduler.curr_replicas), ) + async def _resolve_request_arguments( + self, request_args: Tuple[Any], request_kwargs: Dict[str, Any] + ) -> Tuple[Tuple[Any], Dict[str, Any]]: + """Asynchronously resolve and replace top-level request args and kwargs.""" + new_args = list(request_args) + new_kwargs = request_kwargs.copy() + + # Map from index -> task for resolving positional arg + resolve_arg_tasks = {} + for i, obj in enumerate(request_args): + task = await self._resolve_request_arg_func(obj) + if task is not None: + resolve_arg_tasks[i] = task + + # Map from key -> task for resolving key-word arg + resolve_kwarg_tasks = {} + for k, obj in request_kwargs.items(): + task = await self._resolve_request_arg_func(obj) + if task is not None: + resolve_kwarg_tasks[k] = task + + # Gather all argument resolution tasks concurrently. + if resolve_arg_tasks or resolve_kwarg_tasks: + all_tasks = list(resolve_arg_tasks.values()) + list( + resolve_kwarg_tasks.values() + ) + await asyncio.wait(all_tasks) + + # Update new args and new kwargs with resolved arguments + for index, task in resolve_arg_tasks.items(): + new_args[index] = task.result() + for key, task in resolve_kwarg_tasks.items(): + new_kwargs[key] = task.result() + + # Return new args and new kwargs + return new_args, new_kwargs + def _process_finished_request( self, replica_id: ReplicaID, @@ -548,7 +585,7 @@ async def assign_request( replica_result = None try: - request_args, request_kwargs = await self._resolve_request_args( + request_args, request_kwargs = await self._resolve_request_arguments( request_args, request_kwargs ) replica_result, replica_id = await self.schedule_and_send_request( diff --git a/python/ray/serve/_private/utils.py b/python/ray/serve/_private/utils.py index b57b8518b22f..1193f7722b63 100644 --- a/python/ray/serve/_private/utils.py +++ b/python/ray/serve/_private/utils.py @@ -12,7 +12,7 @@ from decimal import ROUND_HALF_UP, Decimal from enum import Enum from functools import wraps -from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union +from typing import Any, Callable, Dict, List, Optional, TypeVar, Union import requests @@ -594,52 +594,15 @@ def validate_route_prefix(route_prefix: Union[DEFAULT, None, str]): ) -async def resolve_request_args( - request_args: Tuple[Any], request_kwargs: Dict[str, Any] -) -> Tuple[Tuple[Any], Dict[str, Any]]: - """Replaces top-level `DeploymentResponse` objects with resolved object refs. +async def resolve_deployment_response(obj: Any): + """Resolve `DeploymentResponse` objects to underlying object references. This enables composition without explicitly calling `_to_object_ref`. """ from ray.serve.handle import DeploymentResponse, DeploymentResponseGenerator - new_args = [None for _ in range(len(request_args))] - new_kwargs = {} - - arg_tasks = [] - response_indices = [] - for i, obj in enumerate(request_args): - if isinstance(obj, DeploymentResponseGenerator): - raise GENERATOR_COMPOSITION_NOT_SUPPORTED_ERROR - elif isinstance(obj, DeploymentResponse): - # Launch async task to convert DeploymentResponse to an object ref, and - # keep track of the argument index in the original `request_args` - response_indices.append(i) - arg_tasks.append(asyncio.create_task(obj._to_object_ref())) - else: - new_args[i] = obj - - kwarg_tasks = [] - response_keys = [] - for k, obj in request_kwargs.items(): - if isinstance(obj, DeploymentResponseGenerator): - raise GENERATOR_COMPOSITION_NOT_SUPPORTED_ERROR - elif isinstance(obj, DeploymentResponse): - # Launch async task to convert DeploymentResponse to an object ref, and - # keep track of the corresponding key in the original `request_kwargs` - response_keys.append(k) - kwarg_tasks.append(asyncio.create_task(obj._to_object_ref())) - else: - new_kwargs[k] = obj - - # Gather `DeploymentResponse` object refs concurrently. - arg_obj_refs = await asyncio.gather(*arg_tasks) - kwarg_obj_refs = await asyncio.gather(*kwarg_tasks) - - # Update new args and new kwargs with resolved object refs - for index, obj_ref in zip(response_indices, arg_obj_refs): - new_args[index] = obj_ref - new_kwargs.update((zip(response_keys, kwarg_obj_refs))) - - # Return new args and new kwargs - return new_args, new_kwargs + if isinstance(obj, DeploymentResponseGenerator): + raise GENERATOR_COMPOSITION_NOT_SUPPORTED_ERROR + elif isinstance(obj, DeploymentResponse): + # Launch async task to convert DeploymentResponse to an object ref + return asyncio.create_task(obj._to_object_ref())