Skip to content

Commit

Permalink
[serve] improve router resolve request arg func (ray-project#48658)
Browse files Browse the repository at this point in the history
## Why are these changes needed?

Refactor resolve request arg func.

Separate into:
1. actual type checking + argument resolution (injected into the Router
as a function)
2. the logic for executing that resolution and replacing args/kwargs
(kept as part of Router)

---------

Signed-off-by: Cindy Zhang <[email protected]>
Signed-off-by: mohitjain2504 <[email protected]>
  • Loading branch information
zcin authored and mohitjain2504 committed Nov 15, 2024
1 parent 129da13 commit 293af39
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 52 deletions.
4 changes: 2 additions & 2 deletions python/ray/serve/_private/default_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
get_current_actor_id,
get_head_node_id,
inside_ray_client_context,
resolve_request_args,
resolve_deployment_response,
)

# NOTE: Please read carefully before changing!
Expand Down Expand Up @@ -124,7 +124,7 @@ def create_router(
not is_inside_ray_client_context
and RAY_SERVE_ENABLE_STRICT_MAX_ONGOING_REQUESTS
),
resolve_request_args_func=resolve_request_args,
resolve_request_arg_func=resolve_deployment_response,
)


Expand Down
47 changes: 42 additions & 5 deletions python/ray/serve/_private/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from collections import defaultdict
from contextlib import contextmanager
from functools import partial
from typing import Any, Coroutine, DefaultDict, List, Optional, Tuple, Union
from typing import Any, Coroutine, DefaultDict, Dict, List, Optional, Tuple, Union

import ray
from ray.actor import ActorHandle
Expand All @@ -31,7 +31,7 @@
from ray.serve._private.metrics_utils import InMemoryMetricsStore, MetricsPusher
from ray.serve._private.replica_result import ReplicaResult
from ray.serve._private.replica_scheduler import PendingRequest, ReplicaScheduler
from ray.serve._private.utils import resolve_request_args
from ray.serve._private.utils import resolve_deployment_response
from ray.serve.config import AutoscalingConfig
from ray.serve.exceptions import BackPressureError
from ray.util import metrics
Expand Down Expand Up @@ -342,7 +342,7 @@ def __init__(
event_loop: asyncio.BaseEventLoop,
replica_scheduler: Optional[ReplicaScheduler],
enable_strict_max_ongoing_requests: bool,
resolve_request_args_func: Coroutine = resolve_request_args,
resolve_request_arg_func: Coroutine = resolve_deployment_response,
):
"""Used to assign requests to downstream replicas for a deployment.
Expand All @@ -355,7 +355,7 @@ def __init__(
self._enable_strict_max_ongoing_requests = enable_strict_max_ongoing_requests

self._replica_scheduler: ReplicaScheduler = replica_scheduler
self._resolve_request_args = resolve_request_args_func
self._resolve_request_arg_func = resolve_request_arg_func

# Flipped to `True` once the router has received a non-empty
# replica set at least once.
Expand Down Expand Up @@ -429,6 +429,43 @@ def update_deployment_config(self, deployment_config: DeploymentConfig):
curr_num_replicas=len(self._replica_scheduler.curr_replicas),
)

async def _resolve_request_arguments(
self, request_args: Tuple[Any], request_kwargs: Dict[str, Any]
) -> Tuple[Tuple[Any], Dict[str, Any]]:
"""Asynchronously resolve and replace top-level request args and kwargs."""
new_args = list(request_args)
new_kwargs = request_kwargs.copy()

# Map from index -> task for resolving positional arg
resolve_arg_tasks = {}
for i, obj in enumerate(request_args):
task = await self._resolve_request_arg_func(obj)
if task is not None:
resolve_arg_tasks[i] = task

# Map from key -> task for resolving key-word arg
resolve_kwarg_tasks = {}
for k, obj in request_kwargs.items():
task = await self._resolve_request_arg_func(obj)
if task is not None:
resolve_kwarg_tasks[k] = task

# Gather all argument resolution tasks concurrently.
if resolve_arg_tasks or resolve_kwarg_tasks:
all_tasks = list(resolve_arg_tasks.values()) + list(
resolve_kwarg_tasks.values()
)
await asyncio.wait(all_tasks)

# Update new args and new kwargs with resolved arguments
for index, task in resolve_arg_tasks.items():
new_args[index] = task.result()
for key, task in resolve_kwarg_tasks.items():
new_kwargs[key] = task.result()

# Return new args and new kwargs
return new_args, new_kwargs

def _process_finished_request(
self,
replica_id: ReplicaID,
Expand Down Expand Up @@ -548,7 +585,7 @@ async def assign_request(

replica_result = None
try:
request_args, request_kwargs = await self._resolve_request_args(
request_args, request_kwargs = await self._resolve_request_arguments(
request_args, request_kwargs
)
replica_result, replica_id = await self.schedule_and_send_request(
Expand Down
53 changes: 8 additions & 45 deletions python/ray/serve/_private/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from decimal import ROUND_HALF_UP, Decimal
from enum import Enum
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union
from typing import Any, Callable, Dict, List, Optional, TypeVar, Union

import requests

Expand Down Expand Up @@ -594,52 +594,15 @@ def validate_route_prefix(route_prefix: Union[DEFAULT, None, str]):
)


async def resolve_request_args(
request_args: Tuple[Any], request_kwargs: Dict[str, Any]
) -> Tuple[Tuple[Any], Dict[str, Any]]:
"""Replaces top-level `DeploymentResponse` objects with resolved object refs.
async def resolve_deployment_response(obj: Any):
"""Resolve `DeploymentResponse` objects to underlying object references.
This enables composition without explicitly calling `_to_object_ref`.
"""
from ray.serve.handle import DeploymentResponse, DeploymentResponseGenerator

new_args = [None for _ in range(len(request_args))]
new_kwargs = {}

arg_tasks = []
response_indices = []
for i, obj in enumerate(request_args):
if isinstance(obj, DeploymentResponseGenerator):
raise GENERATOR_COMPOSITION_NOT_SUPPORTED_ERROR
elif isinstance(obj, DeploymentResponse):
# Launch async task to convert DeploymentResponse to an object ref, and
# keep track of the argument index in the original `request_args`
response_indices.append(i)
arg_tasks.append(asyncio.create_task(obj._to_object_ref()))
else:
new_args[i] = obj

kwarg_tasks = []
response_keys = []
for k, obj in request_kwargs.items():
if isinstance(obj, DeploymentResponseGenerator):
raise GENERATOR_COMPOSITION_NOT_SUPPORTED_ERROR
elif isinstance(obj, DeploymentResponse):
# Launch async task to convert DeploymentResponse to an object ref, and
# keep track of the corresponding key in the original `request_kwargs`
response_keys.append(k)
kwarg_tasks.append(asyncio.create_task(obj._to_object_ref()))
else:
new_kwargs[k] = obj

# Gather `DeploymentResponse` object refs concurrently.
arg_obj_refs = await asyncio.gather(*arg_tasks)
kwarg_obj_refs = await asyncio.gather(*kwarg_tasks)

# Update new args and new kwargs with resolved object refs
for index, obj_ref in zip(response_indices, arg_obj_refs):
new_args[index] = obj_ref
new_kwargs.update((zip(response_keys, kwarg_obj_refs)))

# Return new args and new kwargs
return new_args, new_kwargs
if isinstance(obj, DeploymentResponseGenerator):
raise GENERATOR_COMPOSITION_NOT_SUPPORTED_ERROR
elif isinstance(obj, DeploymentResponse):
# Launch async task to convert DeploymentResponse to an object ref
return asyncio.create_task(obj._to_object_ref())

0 comments on commit 293af39

Please sign in to comment.