Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[serve] improve router resolve request arg func #48658

Merged
merged 4 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/ray/serve/_private/default_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
get_current_actor_id,
get_head_node_id,
inside_ray_client_context,
resolve_request_args,
resolve_deployment_response,
)

# NOTE: Please read carefully before changing!
Expand Down Expand Up @@ -124,7 +124,7 @@ def create_router(
not is_inside_ray_client_context
and RAY_SERVE_ENABLE_STRICT_MAX_ONGOING_REQUESTS
),
resolve_request_args_func=resolve_request_args,
resolve_request_arg_func=resolve_deployment_response,
)


Expand Down
47 changes: 42 additions & 5 deletions python/ray/serve/_private/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from collections import defaultdict
from contextlib import contextmanager
from functools import partial
from typing import Any, Coroutine, DefaultDict, List, Optional, Tuple, Union
from typing import Any, Coroutine, DefaultDict, Dict, List, Optional, Tuple, Union

import ray
from ray.actor import ActorHandle
Expand All @@ -31,7 +31,7 @@
from ray.serve._private.metrics_utils import InMemoryMetricsStore, MetricsPusher
from ray.serve._private.replica_result import ReplicaResult
from ray.serve._private.replica_scheduler import PendingRequest, ReplicaScheduler
from ray.serve._private.utils import resolve_request_args
from ray.serve._private.utils import resolve_deployment_response
from ray.serve.config import AutoscalingConfig
from ray.serve.exceptions import BackPressureError
from ray.util import metrics
Expand Down Expand Up @@ -342,7 +342,7 @@ def __init__(
event_loop: asyncio.BaseEventLoop,
replica_scheduler: Optional[ReplicaScheduler],
enable_strict_max_ongoing_requests: bool,
resolve_request_args_func: Coroutine = resolve_request_args,
resolve_request_arg_func: Coroutine = resolve_deployment_response,
):
"""Used to assign requests to downstream replicas for a deployment.

Expand All @@ -355,7 +355,7 @@ def __init__(
self._enable_strict_max_ongoing_requests = enable_strict_max_ongoing_requests

self._replica_scheduler: ReplicaScheduler = replica_scheduler
self._resolve_request_args = resolve_request_args_func
self._resolve_request_arg_func = resolve_request_arg_func

# Flipped to `True` once the router has received a non-empty
# replica set at least once.
Expand Down Expand Up @@ -429,6 +429,43 @@ def update_deployment_config(self, deployment_config: DeploymentConfig):
curr_num_replicas=len(self._replica_scheduler.curr_replicas),
)

async def _resolve_request_arguments(
self, request_args: Tuple[Any], request_kwargs: Dict[str, Any]
) -> Tuple[Tuple[Any], Dict[str, Any]]:
"""Asynchronously resolve and replace top-level request args and kwargs."""
new_args = list(request_args)
new_kwargs = request_kwargs.copy()

# Map from index -> task for resolving positional arg
resolve_arg_tasks = {}
for i, obj in enumerate(request_args):
task = await self._resolve_request_arg_func(obj)
if task is not None:
resolve_arg_tasks[i] = task

# Map from key -> task for resolving key-word arg
resolve_kwarg_tasks = {}
for k, obj in request_kwargs.items():
task = await self._resolve_request_arg_func(obj)
if task is not None:
resolve_kwarg_tasks[k] = task

# Gather all argument resolution tasks concurrently.
if resolve_arg_tasks or resolve_kwarg_tasks:
all_tasks = list(resolve_arg_tasks.values()) + list(
resolve_kwarg_tasks.values()
)
await asyncio.wait(all_tasks)

# Update new args and new kwargs with resolved arguments
for index, task in resolve_arg_tasks.items():
new_args[index] = task.result()
for key, task in resolve_kwarg_tasks.items():
new_kwargs[key] = task.result()

# Return new args and new kwargs
return new_args, new_kwargs

def _process_finished_request(
self,
replica_id: ReplicaID,
Expand Down Expand Up @@ -548,7 +585,7 @@ async def assign_request(

replica_result = None
try:
request_args, request_kwargs = await self._resolve_request_args(
request_args, request_kwargs = await self._resolve_request_arguments(
request_args, request_kwargs
)
replica_result, replica_id = await self.schedule_and_send_request(
Expand Down
53 changes: 8 additions & 45 deletions python/ray/serve/_private/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from decimal import ROUND_HALF_UP, Decimal
from enum import Enum
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union
from typing import Any, Callable, Dict, List, Optional, TypeVar, Union

import requests

Expand Down Expand Up @@ -594,52 +594,15 @@ def validate_route_prefix(route_prefix: Union[DEFAULT, None, str]):
)


async def resolve_request_args(
request_args: Tuple[Any], request_kwargs: Dict[str, Any]
) -> Tuple[Tuple[Any], Dict[str, Any]]:
"""Replaces top-level `DeploymentResponse` objects with resolved object refs.
async def resolve_deployment_response(obj: Any):
"""Resolve `DeploymentResponse` objects to underlying object references.

This enables composition without explicitly calling `_to_object_ref`.
"""
from ray.serve.handle import DeploymentResponse, DeploymentResponseGenerator

new_args = [None for _ in range(len(request_args))]
new_kwargs = {}

arg_tasks = []
response_indices = []
for i, obj in enumerate(request_args):
if isinstance(obj, DeploymentResponseGenerator):
raise GENERATOR_COMPOSITION_NOT_SUPPORTED_ERROR
elif isinstance(obj, DeploymentResponse):
# Launch async task to convert DeploymentResponse to an object ref, and
# keep track of the argument index in the original `request_args`
response_indices.append(i)
arg_tasks.append(asyncio.create_task(obj._to_object_ref()))
else:
new_args[i] = obj

kwarg_tasks = []
response_keys = []
for k, obj in request_kwargs.items():
if isinstance(obj, DeploymentResponseGenerator):
raise GENERATOR_COMPOSITION_NOT_SUPPORTED_ERROR
elif isinstance(obj, DeploymentResponse):
# Launch async task to convert DeploymentResponse to an object ref, and
# keep track of the corresponding key in the original `request_kwargs`
response_keys.append(k)
kwarg_tasks.append(asyncio.create_task(obj._to_object_ref()))
else:
new_kwargs[k] = obj

# Gather `DeploymentResponse` object refs concurrently.
arg_obj_refs = await asyncio.gather(*arg_tasks)
kwarg_obj_refs = await asyncio.gather(*kwarg_tasks)

# Update new args and new kwargs with resolved object refs
for index, obj_ref in zip(response_indices, arg_obj_refs):
new_args[index] = obj_ref
new_kwargs.update((zip(response_keys, kwarg_obj_refs)))

# Return new args and new kwargs
return new_args, new_kwargs
if isinstance(obj, DeploymentResponseGenerator):
raise GENERATOR_COMPOSITION_NOT_SUPPORTED_ERROR
elif isinstance(obj, DeploymentResponse):
# Launch async task to convert DeploymentResponse to an object ref
return asyncio.create_task(obj._to_object_ref())