-
-
Notifications
You must be signed in to change notification settings - Fork 5.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bugfix][Frontend] Fix Issues Under High Load With zeromq
Frontend
#7394
Changes from 6 commits
b2e29a5
8d31115
6d2b3df
c73e943
f1768fb
601a461
1a47d94
eeecb09
2770e40
5a85618
938db1d
ea2f03e
5cebc65
2c12436
9afd6ba
c262088
3e580d5
d9e10e0
4e3a63a
e54bf8a
6544f3a
81f4da8
b3374bc
dd1817a
5b56365
05ff816
e551d30
3d7f65f
e7e6f1e
eaaebcc
21b5239
7e15b00
74c4166
450e949
8348f1f
545956e
7a34611
50abb94
1723687
3dfc9ef
8d40f2d
3397460
ef132dc
10ef204
b67718f
c3c1dbe
ee6efcf
3fdc2fe
4cacb56
724eb31
4d5e6b7
b9e4168
8f9bc23
9a2be3f
dee38f0
e78f443
cc2d7db
b40e269
03eed9c
f697226
fd642ab
762c2ed
c805ed2
2e1652e
3e1ede4
b3bf7ef
708bd34
10a88ec
db8aebc
b16c64b
a9ecaa9
6f8d5e8
bab177f
7b58281
adf45d1
9e827b0
47dca36
f84c341
0ce78f8
8054348
5ddbdab
1ebbe9e
53d639b
a36b381
415ee39
26440e6
2442a9d
b72f84f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,8 +9,10 @@ | |
|
||
from vllm import AsyncEngineArgs, AsyncLLMEngine | ||
from vllm.entrypoints.openai.rpc import (VLLM_RPC_HEALTHY_STR, | ||
VLLM_RPC_SUCCESS_STR, RPCAbortRequest, | ||
RPCGenerateRequest, RPCUtilityRequest) | ||
VLLM_RPC_SUCCESS_STR, | ||
VLLM_RPC_ZMQ_MAX_SOCKETS, | ||
RPCAbortRequest, RPCGenerateRequest, | ||
RPCUtilityRequest) | ||
from vllm.logger import init_logger | ||
from vllm.usage.usage_lib import UsageContext | ||
|
||
|
@@ -27,6 +29,7 @@ def __init__(self, async_engine_args: AsyncEngineArgs, | |
|
||
# Initialize context. | ||
self.context = zmq.asyncio.Context() | ||
self.context.set(zmq.constants.MAX_SOCKETS, VLLM_RPC_ZMQ_MAX_SOCKETS) | ||
|
||
# Init socket for readiness state. | ||
self.socket = self.context.socket(zmq.constants.ROUTER) | ||
|
@@ -37,64 +40,55 @@ def cleanup(self): | |
self.socket.close() | ||
self.context.destroy() | ||
|
||
async def get_model_config(self, identity): | ||
"""Send the ModelConfig""" | ||
model_config = await self.engine.get_model_config() | ||
|
||
await self.socket.send_multipart( | ||
[identity, cloudpickle.dumps(model_config)]) | ||
|
||
async def get_decoding_config(self, identity): | ||
"""Send the DecodingConfig""" | ||
decoding_config = await self.engine.get_decoding_config() | ||
|
||
await self.socket.send_multipart( | ||
[identity, cloudpickle.dumps(decoding_config)]) | ||
|
||
async def get_lora_config(self, identity): | ||
lora_config = await self.engine.get_lora_config() | ||
|
||
await self.socket.send_multipart( | ||
[identity, cloudpickle.dumps(lora_config)]) | ||
|
||
async def get_scheduler_config(self, identity): | ||
"""Send the SchedulerConfig""" | ||
parallel_config = await self.engine.get_scheduler_config() | ||
|
||
await self.socket.send_multipart( | ||
[identity, cloudpickle.dumps(parallel_config)]) | ||
async def get_config(self, identity, part2, request): | ||
try: | ||
if request == RPCUtilityRequest.GET_MODEL_CONFIG: | ||
config = await self.engine.get_model_config() | ||
elif request == RPCUtilityRequest.GET_DECODING_CONFIG: | ||
config = await self.engine.get_decoding_config() | ||
elif request == RPCUtilityRequest.GET_LORA_CONFIG: | ||
config = await self.engine.get_lora_config() | ||
elif request == RPCUtilityRequest.GET_SCHEDULER_CONFIG: | ||
config = await self.engine.get_scheduler_config() | ||
elif request == RPCUtilityRequest.GET_PARALLEL_CONFIG: | ||
config = await self.engine.get_parallel_config() | ||
else: | ||
raise ValueError("Unknown Config Request: %s", request) | ||
|
||
async def get_parallel_config(self, identity): | ||
"""Send the ParallelConfig""" | ||
parallel_config = await self.engine.get_parallel_config() | ||
await self.socket.send_multipart( | ||
[identity, part2, cloudpickle.dumps(config)]) | ||
|
||
await self.socket.send_multipart( | ||
[identity, cloudpickle.dumps(parallel_config)]) | ||
except Exception as e: | ||
### Notify client of all failures | ||
await self.socket.send_multipart( | ||
[identity, part2, cloudpickle.dumps(e)]) | ||
|
||
async def is_tracing_enabled(self, identity): | ||
async def is_tracing_enabled(self, identity, part2): | ||
"""Send the is_tracing_enabled flag""" | ||
tracing_flag = await self.engine.is_tracing_enabled() | ||
|
||
await self.socket.send_multipart( | ||
[identity, cloudpickle.dumps(tracing_flag)]) | ||
[identity, part2, cloudpickle.dumps(tracing_flag)]) | ||
|
||
async def do_log_stats(self, identity): | ||
async def do_log_stats(self, identity, part2): | ||
"""Log stats and confirm success.""" | ||
await self.engine.do_log_stats() | ||
|
||
await self.socket.send_multipart([ | ||
identity, | ||
part2, | ||
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR), | ||
]) | ||
|
||
async def is_server_ready(self, identity): | ||
async def is_server_ready(self, identity, part2): | ||
"""Notify the client that we are ready.""" | ||
await self.socket.send_multipart([ | ||
identity, | ||
part2, | ||
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR), | ||
]) | ||
|
||
async def abort(self, identity, request: RPCAbortRequest): | ||
async def abort(self, identity, part2, request: RPCAbortRequest): | ||
"""Abort request and notify the client of success.""" | ||
try: | ||
# Abort the request in the llm engine. | ||
|
@@ -105,10 +99,12 @@ async def abort(self, identity, request: RPCAbortRequest): | |
# Send confirmation to the client. | ||
await self.socket.send_multipart([ | ||
identity, | ||
part2, | ||
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR), | ||
]) | ||
|
||
async def generate(self, identity, generate_request: RPCGenerateRequest): | ||
async def generate(self, identity, part2, | ||
generate_request: RPCGenerateRequest): | ||
try: | ||
results_generator = self.engine.generate( | ||
generate_request.inputs, | ||
|
@@ -120,51 +116,53 @@ async def generate(self, identity, generate_request: RPCGenerateRequest): | |
|
||
async for request_output in results_generator: | ||
await self.socket.send_multipart( | ||
[identity, cloudpickle.dumps(request_output)]) | ||
[identity, part2, | ||
cloudpickle.dumps(request_output)]) | ||
|
||
except Exception as e: | ||
### Notify client of all failures | ||
await self.socket.send_multipart([identity, cloudpickle.dumps(e)]) | ||
await self.socket.send_multipart( | ||
[identity, part2, cloudpickle.dumps(e)]) | ||
|
||
async def check_health(self, identity): | ||
async def check_health(self, identity, part2): | ||
try: | ||
await self.engine.check_health() | ||
await self.socket.send_multipart( | ||
[identity, cloudpickle.dumps(VLLM_RPC_HEALTHY_STR)]) | ||
[identity, part2, | ||
cloudpickle.dumps(VLLM_RPC_HEALTHY_STR)]) | ||
|
||
except Exception as e: | ||
await self.socket.send_multipart([identity, cloudpickle.dumps(e)]) | ||
await self.socket.send_multipart( | ||
[identity, part2, cloudpickle.dumps(e)]) | ||
|
||
def _make_handler_coro(self, identity, | ||
def _make_handler_coro(self, identity, part2, | ||
message) -> Coroutine[Any, Any, Never]: | ||
"""Route the zmq message to the handler coroutine.""" | ||
|
||
request = cloudpickle.loads(message) | ||
|
||
if isinstance(request, RPCGenerateRequest): | ||
return self.generate(identity, request) | ||
return self.generate(identity, part2, request) | ||
|
||
elif isinstance(request, RPCAbortRequest): | ||
return self.abort(identity, request) | ||
return self.abort(identity, part2, request) | ||
|
||
elif isinstance(request, RPCUtilityRequest): | ||
if request == RPCUtilityRequest.GET_MODEL_CONFIG: | ||
return self.get_model_config(identity) | ||
elif request == RPCUtilityRequest.GET_PARALLEL_CONFIG: | ||
return self.get_parallel_config(identity) | ||
elif request == RPCUtilityRequest.GET_DECODING_CONFIG: | ||
return self.get_decoding_config(identity) | ||
elif request == RPCUtilityRequest.GET_SCHEDULER_CONFIG: | ||
return self.get_scheduler_config(identity) | ||
elif request == RPCUtilityRequest.GET_LORA_CONFIG: | ||
return self.get_lora_config(identity) | ||
if request in [ | ||
RPCUtilityRequest.GET_MODEL_CONFIG, | ||
RPCUtilityRequest.GET_PARALLEL_CONFIG, | ||
RPCUtilityRequest.GET_DECODING_CONFIG, | ||
RPCUtilityRequest.GET_SCHEDULER_CONFIG, | ||
RPCUtilityRequest.GET_LORA_CONFIG | ||
]: | ||
return self.get_config(identity, part2, request) | ||
elif request == RPCUtilityRequest.DO_LOG_STATS: | ||
return self.do_log_stats(identity) | ||
return self.do_log_stats(identity, part2) | ||
elif request == RPCUtilityRequest.IS_SERVER_READY: | ||
return self.is_server_ready(identity) | ||
return self.is_server_ready(identity, part2) | ||
elif request == RPCUtilityRequest.CHECK_HEALTH: | ||
return self.check_health(identity) | ||
return self.check_health(identity, part2) | ||
elif request == RPCUtilityRequest.IS_TRACING_ENABLED: | ||
return self.is_tracing_enabled(identity) | ||
return self.is_tracing_enabled(identity, part2) | ||
else: | ||
raise ValueError(f"Unknown RPCUtilityRequest type: {request}") | ||
|
||
|
@@ -177,11 +175,11 @@ async def run_server_loop(self): | |
running_tasks = set() | ||
while True: | ||
# Wait for a request. | ||
identity, message = await self.socket.recv_multipart() | ||
identity, part2, message = await self.socket.recv_multipart() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For future readers it'd be nice to add a link to some zmq docs here or give this a descriptive name to say what There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Its related to the use of |
||
|
||
# Process the request async. | ||
task = asyncio.create_task( | ||
self._make_handler_coro(identity, message)) | ||
self._make_handler_coro(identity, part2, message)) | ||
|
||
# We need to keep around a strong reference to the task, | ||
# to avoid the task disappearing mid-execution as running tasks | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make ./format happy