-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Serve] add type hints for controller and backend_worker #10288
Changes from 23 commits
fac557b
1737b46
c9a17fd
cadc974
898db81
8d48a39
0994641
4a07eb3
f54d9ea
6ec1e7e
080dab4
1f0d1dd
82d99d0
d81261b
5ea3642
a3fc043
eda50b8
3b069b1
cc0f07e
2b6e3d0
720996a
2851bf4
d197a7b
de01cf3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,7 @@ | |
from collections import defaultdict | ||
from itertools import groupby | ||
from operator import attrgetter | ||
from typing import Union | ||
from typing import Union, List, Any, Callable, Type | ||
import time | ||
|
||
import ray | ||
|
@@ -20,32 +20,33 @@ | |
from ray.experimental import metrics | ||
from ray.serve.config import BackendConfig | ||
from ray.serve.router import Query | ||
from ray.exceptions import RayTaskError | ||
|
||
logger = _get_logger() | ||
|
||
|
||
class BatchQueue: | ||
def __init__(self, max_batch_size, timeout_s): | ||
def __init__(self, max_batch_size: int, timeout_s: float) -> None: | ||
self.queue = asyncio.Queue() | ||
self.full_batch_event = asyncio.Event() | ||
self.max_batch_size = max_batch_size | ||
self.timeout_s = timeout_s | ||
|
||
def set_config(self, max_batch_size, timeout_s): | ||
def set_config(self, max_batch_size: int, timeout_s: float) -> None: | ||
self.max_batch_size = max_batch_size | ||
self.timeout_s = timeout_s | ||
|
||
def put(self, request): | ||
def put(self, request: Query) -> None: | ||
self.queue.put_nowait(request) | ||
# Signal when the full batch is ready. The event will be reset | ||
# in wait_for_batch. | ||
if self.queue.qsize() == self.max_batch_size: | ||
self.full_batch_event.set() | ||
|
||
def qsize(self): | ||
def qsize(self) -> int: | ||
return self.queue.qsize() | ||
|
||
async def wait_for_batch(self): | ||
async def wait_for_batch(self) -> List[Query]: | ||
"""Wait for batch respecting self.max_batch_size and self.timeout_s. | ||
|
||
Returns a batch of up to self.max_batch_size items, waiting for up | ||
|
@@ -89,7 +90,7 @@ async def wait_for_batch(self): | |
return batch | ||
|
||
|
||
def create_backend_worker(func_or_class): | ||
def create_backend_worker(func_or_class: Union[Callable, Type[Callable]]): | ||
"""Creates a worker class wrapping the provided function or class.""" | ||
|
||
if inspect.isfunction(func_or_class): | ||
|
@@ -99,6 +100,7 @@ def create_backend_worker(func_or_class): | |
else: | ||
assert False, "func_or_class must be function or class." | ||
|
||
# TODO(architkulkarni): Add type hints after upgrading cloudpickle | ||
class RayServeWrappedWorker(object): | ||
def __init__(self, | ||
backend_tag, | ||
|
@@ -129,7 +131,7 @@ def ready(self): | |
return RayServeWrappedWorker | ||
|
||
|
||
def wrap_to_ray_error(exception): | ||
def wrap_to_ray_error(exception: Exception) -> RayTaskError: | ||
"""Utility method to wrap exceptions in user code.""" | ||
|
||
try: | ||
|
@@ -140,7 +142,7 @@ def wrap_to_ray_error(exception): | |
return ray.exceptions.RayTaskError(str(e), traceback_str, e.__class__) | ||
|
||
|
||
def ensure_async(func): | ||
def ensure_async(func: Callable) -> Callable: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. -> AsyncCallable There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Couldn't find AsyncCallable in the project or on google, I made it Coroutine instead There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oops. Sorry I confused types in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No problem, done |
||
if inspect.iscoroutinefunction(func): | ||
return func | ||
else: | ||
|
@@ -150,8 +152,8 @@ def ensure_async(func): | |
class RayServeWorker: | ||
"""Handles requests with the provided callable.""" | ||
|
||
def __init__(self, backend_tag, replica_tag, _callable, | ||
backend_config: BackendConfig, is_function): | ||
def __init__(self, backend_tag: str, replica_tag: str, _callable: Callable, | ||
backend_config: BackendConfig, is_function: bool) -> None: | ||
self.backend_tag = backend_tag | ||
self.replica_tag = replica_tag | ||
self.callable = _callable | ||
|
@@ -182,7 +184,7 @@ def __init__(self, backend_tag, replica_tag, _callable, | |
|
||
asyncio.get_event_loop().create_task(self.main_loop()) | ||
|
||
def get_runner_method(self, request_item): | ||
def get_runner_method(self, request_item: Query) -> Callable: | ||
method_name = request_item.call_method | ||
if not hasattr(self.callable, method_name): | ||
raise RayServeException("Backend doesn't have method {} " | ||
|
@@ -193,7 +195,7 @@ def get_runner_method(self, request_item): | |
return self.callable | ||
return getattr(self.callable, method_name) | ||
|
||
def has_positional_args(self, f): | ||
def has_positional_args(self, f: Callable) -> bool: | ||
# NOTE: | ||
# In the case of simple functions, not actors, the f will be | ||
# function.__call__, but we need to inspect the function itself. | ||
|
@@ -207,13 +209,13 @@ def has_positional_args(self, f): | |
return True | ||
return False | ||
|
||
def _reset_context(self): | ||
def _reset_context(self) -> None: | ||
# NOTE(simon): context management won't work in async mode because | ||
# many concurrent queries might be running at the same time. | ||
serve_context.web = None | ||
serve_context.batch_size = None | ||
|
||
async def invoke_single(self, request_item): | ||
async def invoke_single(self, request_item: Query) -> Any: | ||
args, kwargs, is_web_context = parse_request_item(request_item) | ||
serve_context.web = is_web_context | ||
|
||
|
@@ -231,7 +233,7 @@ async def invoke_single(self, request_item): | |
|
||
return result | ||
|
||
async def invoke_batch(self, request_item_list): | ||
async def invoke_batch(self, request_item_list: List[Query]) -> List[Any]: | ||
arg_list = [] | ||
kwargs_list = defaultdict(list) | ||
context_flags = set() | ||
|
@@ -308,7 +310,7 @@ async def invoke_batch(self, request_item_list): | |
self._reset_context() | ||
return [wrapped_exception for _ in range(batch_size)] | ||
|
||
async def main_loop(self): | ||
async def main_loop(self) -> None: | ||
while True: | ||
# NOTE(simon): There's an issue when user updated batch size and | ||
# batch wait timeout during the execution, these values will not be | ||
|
@@ -338,12 +340,13 @@ async def main_loop(self): | |
# it will not be raised. | ||
await asyncio.wait(all_evaluated_futures) | ||
|
||
def update_config(self, new_config: BackendConfig): | ||
def update_config(self, new_config: BackendConfig) -> None: | ||
self.config = new_config | ||
self.batch_queue.set_config(self.config.max_batch_size or 1, | ||
self.config.batch_wait_timeout) | ||
|
||
async def handle_request(self, request: Union[Query, bytes]): | ||
async def handle_request(self, | ||
request: Union[Query, bytes]) -> asyncio.Future: | ||
if isinstance(request, bytes): | ||
request = Query.ray_deserialize(request) | ||
logger.debug("Worker {} got request {}".format(self.replica_tag, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is fine. I think cloudpickle only breaks for
Optional
in generated class likeRayServeWrappedWorker
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, that makes sense. I forgot that the tests for the type hints for API were passing earlier. I'll revert this change