Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix][Frontend] Fix Issues Under High Load With zeromq Frontend #7394

Merged
merged 88 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
88 commits
Select commit Hold shift + click to select a range
b2e29a5
added proxy to limit use of uniz sockets
robertgshaw2-redhat Aug 10, 2024
8d31115
Merge branch 'main' into fix-zmq-max-sockets
robertgshaw2-redhat Aug 10, 2024
6d2b3df
comment
robertgshaw2-redhat Aug 10, 2024
c73e943
use random inproc path
robertgshaw2-redhat Aug 10, 2024
f1768fb
format
robertgshaw2-redhat Aug 10, 2024
601a461
foamt
robertgshaw2-redhat Aug 10, 2024
1a47d94
format
robertgshaw2-redhat Aug 10, 2024
eeecb09
Update vllm/entrypoints/openai/rpc/client.py
robertgshaw2-redhat Aug 10, 2024
2770e40
cleaning
robertgshaw2-redhat Aug 14, 2024
5a85618
Merge branch 'main' into fix-zmq-max-sockets
robertgshaw2-redhat Aug 18, 2024
938db1d
Merge branch 'fix-zmq-max-sockets' of https://github.com/neuralmagic/…
robertgshaw2-redhat Aug 18, 2024
ea2f03e
remove logging
robertgshaw2-redhat Aug 18, 2024
5cebc65
add info message re: concurrency
robertgshaw2-redhat Aug 18, 2024
2c12436
update comment
robertgshaw2-redhat Aug 18, 2024
9afd6ba
update
robertgshaw2-redhat Aug 18, 2024
c262088
format
robertgshaw2-redhat Aug 18, 2024
3e580d5
reorder
robertgshaw2-redhat Aug 18, 2024
d9e10e0
reverT
robertgshaw2-redhat Aug 18, 2024
4e3a63a
fix
robertgshaw2-redhat Aug 18, 2024
e54bf8a
fix
robertgshaw2-redhat Aug 18, 2024
6544f3a
fix abort logic
robertgshaw2-redhat Aug 18, 2024
81f4da8
reduce LOC change
robertgshaw2-redhat Aug 18, 2024
b3374bc
cleanup
robertgshaw2-redhat Aug 18, 2024
dd1817a
cleanup
robertgshaw2-redhat Aug 18, 2024
5b56365
format
robertgshaw2-redhat Aug 18, 2024
05ff816
fix client
robertgshaw2-redhat Aug 18, 2024
e551d30
revert unneccessary change
robertgshaw2-redhat Aug 18, 2024
3d7f65f
revert startup probe changes to separate PR
robertgshaw2-redhat Aug 18, 2024
e7e6f1e
stash
robertgshaw2-redhat Aug 18, 2024
eaaebcc
Merge branch 'main' into fix-zmq-max-sockets
robertgshaw2-redhat Aug 18, 2024
21b5239
stash draining
robertgshaw2-redhat Aug 19, 2024
7e15b00
update
robertgshaw2-redhat Aug 19, 2024
74c4166
stash
robertgshaw2-redhat Aug 19, 2024
450e949
convert RPCServer to use DEALER
robertgshaw2-redhat Aug 19, 2024
8348f1f
stash
robertgshaw2-redhat Aug 19, 2024
545956e
fix
robertgshaw2-redhat Aug 19, 2024
7a34611
cleaning
robertgshaw2-redhat Aug 19, 2024
50abb94
stash
robertgshaw2-redhat Aug 19, 2024
1723687
remove awk
robertgshaw2-redhat Aug 19, 2024
3dfc9ef
nits
robertgshaw2-redhat Aug 20, 2024
8d40f2d
format
robertgshaw2-redhat Aug 20, 2024
3397460
format
robertgshaw2-redhat Aug 20, 2024
ef132dc
nit
robertgshaw2-redhat Aug 20, 2024
10ef204
change
robertgshaw2-redhat Aug 20, 2024
b67718f
clean
robertgshaw2-redhat Aug 20, 2024
c3c1dbe
Update vllm/entrypoints/openai/rpc/server.py
robertgshaw2-redhat Aug 20, 2024
ee6efcf
format
robertgshaw2-redhat Aug 20, 2024
3fdc2fe
cleanup abort logic
robertgshaw2-redhat Aug 20, 2024
4cacb56
nit
robertgshaw2-redhat Aug 20, 2024
724eb31
added load test
robertgshaw2-redhat Aug 21, 2024
4d5e6b7
update load test
robertgshaw2-redhat Aug 21, 2024
b9e4168
updated
robertgshaw2-redhat Aug 21, 2024
8f9bc23
format
robertgshaw2-redhat Aug 21, 2024
9a2be3f
updated
robertgshaw2-redhat Aug 21, 2024
dee38f0
revert suurious change
robertgshaw2-redhat Aug 21, 2024
e78f443
convert to even smaller model
robertgshaw2-redhat Aug 21, 2024
cc2d7db
20k requests
robertgshaw2-redhat Aug 21, 2024
b40e269
convert to 10k requests
robertgshaw2-redhat Aug 21, 2024
03eed9c
clean up closing logic
robertgshaw2-redhat Aug 21, 2024
f697226
use constant
robertgshaw2-redhat Aug 21, 2024
fd642ab
fix bad cleanup
robertgshaw2-redhat Aug 21, 2024
762c2ed
remove useless argument
robertgshaw2-redhat Aug 21, 2024
c805ed2
up to 20k requests
robertgshaw2-redhat Aug 21, 2024
2e1652e
revert to 10k requests
robertgshaw2-redhat Aug 21, 2024
3e1ede4
revert suprious argument
robertgshaw2-redhat Aug 21, 2024
b3bf7ef
revert to 20k
robertgshaw2-redhat Aug 21, 2024
708bd34
format
robertgshaw2-redhat Aug 21, 2024
10a88ec
[BugFix] Raise all exception variations in async generator
njhill Aug 20, 2024
db8aebc
Fix possible premature generator completion; add tests
njhill Aug 21, 2024
b16c64b
format
robertgshaw2-redhat Aug 21, 2024
a9ecaa9
added test accuracy
robertgshaw2-redhat Aug 21, 2024
6f8d5e8
format
robertgshaw2-redhat Aug 21, 2024
bab177f
updated test pipeline
robertgshaw2-redhat Aug 21, 2024
7b58281
fix lm eval
robertgshaw2-redhat Aug 21, 2024
adf45d1
cleanup
robertgshaw2-redhat Aug 21, 2024
9e827b0
updated
robertgshaw2-redhat Aug 21, 2024
47dca36
Merge branch 'main' into fix-zmq-max-sockets
robertgshaw2-redhat Aug 21, 2024
f84c341
added sleep time
robertgshaw2-redhat Aug 21, 2024
0ce78f8
actually sleep
robertgshaw2-redhat Aug 21, 2024
8054348
formatting
robertgshaw2-redhat Aug 21, 2024
5ddbdab
format
robertgshaw2-redhat Aug 21, 2024
1ebbe9e
mypy
robertgshaw2-redhat Aug 21, 2024
53d639b
mypy
robertgshaw2-redhat Aug 21, 2024
a36b381
format
robertgshaw2-redhat Aug 21, 2024
415ee39
remove test load
robertgshaw2-redhat Aug 21, 2024
26440e6
stash
robertgshaw2-redhat Aug 21, 2024
2442a9d
Merge branch 'fix-zmq-max-sockets' of https://github.com/neuralmagic/…
robertgshaw2-redhat Aug 21, 2024
b72f84f
Merge branch 'fix-raise-cancelled' into fix-zmq-max-sockets
robertgshaw2-redhat Aug 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/tracing/test_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,5 +114,5 @@ def test_traces(trace_service):
SpanAttributes.LLM_LATENCY_TIME_TO_FIRST_TOKEN) == ttft
e2e_time = metrics.finished_time - metrics.arrival_time
assert attributes.get(SpanAttributes.LLM_LATENCY_E2E) == e2e_time
assert attributes.get(SpanAttributes.LLM_LATENCY_TIME_IN_SCHEDULER
) == metrics.scheduler_time
assert attributes.get(
Copy link
Collaborator Author

@robertgshaw2-redhat robertgshaw2-redhat Aug 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make ./format happy

SpanAttributes.LLM_LATENCY_TIME_IN_SCHEDULER) == metrics.scheduler_time
2 changes: 2 additions & 0 deletions vllm/entrypoints/openai/rpc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

VLLM_RPC_SUCCESS_STR = "SUCCESS"
VLLM_RPC_HEALTHY_STR = "HEALTHY"
# TODO: figure out if this can be set to inf.
VLLM_RPC_ZMQ_MAX_SOCKETS = 1000000


@dataclass
Expand Down
49 changes: 40 additions & 9 deletions vllm/entrypoints/openai/rpc/client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import asyncio
from contextlib import contextmanager
from typing import Any, AsyncGenerator, Mapping, Optional
from uuid import uuid4

import cloudpickle
import zmq
Expand All @@ -9,8 +11,10 @@
ParallelConfig, SchedulerConfig)
from vllm.entrypoints.openai.rpc import (RPC_REQUEST_TYPE,
VLLM_RPC_HEALTHY_STR,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCGenerateRequest, RPCUtilityRequest)
VLLM_RPC_SUCCESS_STR,
VLLM_RPC_ZMQ_MAX_SOCKETS,
RPCAbortRequest, RPCGenerateRequest,
RPCUtilityRequest)
from vllm.inputs import PromptInputs
from vllm.lora.request import LoRARequest
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
Expand All @@ -21,12 +25,40 @@
# Time to wait before checking it the server process is alive.
SERVER_START_TIMEOUT_MS = 1000

# Inprocess path
INPROC_PATH = f"inproc://{uuid4()}"


class AsyncEngineRPCClient:

def __init__(self, rpc_path: str):
self.context = zmq.asyncio.Context()
self.rpc_path = rpc_path
self.context.set(zmq.constants.MAX_SOCKETS, VLLM_RPC_ZMQ_MAX_SOCKETS)

# PROXY
self.from_client = self.context.socket(zmq.constants.ROUTER)
self.from_client.bind(INPROC_PATH)

# Connection to RPC Server.
self.to_server = self.context.socket(zmq.constants.DEALER)
self.to_server.connect(rpc_path)

self.proxy_task = asyncio.create_task(
self.run_proxy(self.from_client, self.to_server))

async def run_proxy(self, socket_from, socket_to):
poller = zmq.asyncio.Poller()
poller.register(socket_from, zmq.constants.POLLIN)
poller.register(socket_to, zmq.constants.POLLIN)
while True:
events = await poller.poll()
events = dict(events)
if socket_from in events:
msg = await socket_from.recv_multipart()
await socket_to.send_multipart(msg)
elif socket_to in events:
robertgshaw2-redhat marked this conversation as resolved.
Show resolved Hide resolved
msg = await socket_to.recv_multipart()
await socket_from.send_multipart(msg)

async def setup(self):
"""Setup the client before it starts sending server requests."""
Expand Down Expand Up @@ -62,7 +94,7 @@ def socket(self):
# to enable streaming.
socket = self.context.socket(zmq.constants.DEALER)
try:
socket.connect(self.rpc_path)
socket.connect(INPROC_PATH)
yield socket
finally:
# linger == 0 means discard unsent messages
Expand All @@ -82,9 +114,8 @@ async def _send_get_data_rpc_request(self, request: RPCUtilityRequest,
"""Send an RPC request that is expecting data back."""

with self.socket() as socket:

# Ping RPCServer with a request.
await socket.send(cloudpickle.dumps(request))
await socket.send_multipart([cloudpickle.dumps(request)])

# Await the data from the Server.
data = cloudpickle.loads(await socket.recv())
Expand All @@ -105,7 +136,7 @@ async def _send_one_way_rpc_request(self,
"""Send one-way RPC request to trigger an action."""
with self.socket() as socket:
# Ping RPC Server with request.
await socket.send(cloudpickle.dumps(request))
await socket.send_multipart([cloudpickle.dumps(request)])

# Await acknowledgement from RPCServer.
if timeout is not None and await socket.poll(timeout=timeout) == 0:
Expand Down Expand Up @@ -269,8 +300,8 @@ async def check_health(self) -> None:
with self.socket() as socket:

# Ping RPCServer with CHECK_HEALTH request.
await socket.send(cloudpickle.dumps(RPCUtilityRequest.CHECK_HEALTH)
)
await socket.send_multipart(
[cloudpickle.dumps(RPCUtilityRequest.CHECK_HEALTH)])

# Await the reply from the server.
# TODO: do we need an internal timeout here?
Expand Down
126 changes: 62 additions & 64 deletions vllm/entrypoints/openai/rpc/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@

from vllm import AsyncEngineArgs, AsyncLLMEngine
from vllm.entrypoints.openai.rpc import (VLLM_RPC_HEALTHY_STR,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCGenerateRequest, RPCUtilityRequest)
VLLM_RPC_SUCCESS_STR,
VLLM_RPC_ZMQ_MAX_SOCKETS,
RPCAbortRequest, RPCGenerateRequest,
RPCUtilityRequest)
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext

Expand All @@ -27,6 +29,7 @@ def __init__(self, async_engine_args: AsyncEngineArgs,

# Initialize context.
self.context = zmq.asyncio.Context()
self.context.set(zmq.constants.MAX_SOCKETS, VLLM_RPC_ZMQ_MAX_SOCKETS)

# Init socket for readiness state.
self.socket = self.context.socket(zmq.constants.ROUTER)
Expand All @@ -37,64 +40,55 @@ def cleanup(self):
self.socket.close()
self.context.destroy()

async def get_model_config(self, identity):
"""Send the ModelConfig"""
model_config = await self.engine.get_model_config()

await self.socket.send_multipart(
[identity, cloudpickle.dumps(model_config)])

async def get_decoding_config(self, identity):
"""Send the DecodingConfig"""
decoding_config = await self.engine.get_decoding_config()

await self.socket.send_multipart(
[identity, cloudpickle.dumps(decoding_config)])

async def get_lora_config(self, identity):
lora_config = await self.engine.get_lora_config()

await self.socket.send_multipart(
[identity, cloudpickle.dumps(lora_config)])

async def get_scheduler_config(self, identity):
"""Send the SchedulerConfig"""
parallel_config = await self.engine.get_scheduler_config()

await self.socket.send_multipart(
[identity, cloudpickle.dumps(parallel_config)])
async def get_config(self, identity, part2, request):
try:
if request == RPCUtilityRequest.GET_MODEL_CONFIG:
config = await self.engine.get_model_config()
elif request == RPCUtilityRequest.GET_DECODING_CONFIG:
config = await self.engine.get_decoding_config()
elif request == RPCUtilityRequest.GET_LORA_CONFIG:
config = await self.engine.get_lora_config()
elif request == RPCUtilityRequest.GET_SCHEDULER_CONFIG:
config = await self.engine.get_scheduler_config()
elif request == RPCUtilityRequest.GET_PARALLEL_CONFIG:
config = await self.engine.get_parallel_config()
else:
raise ValueError("Unknown Config Request: %s", request)

async def get_parallel_config(self, identity):
"""Send the ParallelConfig"""
parallel_config = await self.engine.get_parallel_config()
await self.socket.send_multipart(
[identity, part2, cloudpickle.dumps(config)])

await self.socket.send_multipart(
[identity, cloudpickle.dumps(parallel_config)])
except Exception as e:
### Notify client of all failures
await self.socket.send_multipart(
[identity, part2, cloudpickle.dumps(e)])

async def is_tracing_enabled(self, identity):
async def is_tracing_enabled(self, identity, part2):
"""Send the is_tracing_enabled flag"""
tracing_flag = await self.engine.is_tracing_enabled()

await self.socket.send_multipart(
[identity, cloudpickle.dumps(tracing_flag)])
[identity, part2, cloudpickle.dumps(tracing_flag)])

async def do_log_stats(self, identity):
async def do_log_stats(self, identity, part2):
"""Log stats and confirm success."""
await self.engine.do_log_stats()

await self.socket.send_multipart([
identity,
part2,
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
])

async def is_server_ready(self, identity):
async def is_server_ready(self, identity, part2):
"""Notify the client that we are ready."""
await self.socket.send_multipart([
identity,
part2,
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
])

async def abort(self, identity, request: RPCAbortRequest):
async def abort(self, identity, part2, request: RPCAbortRequest):
"""Abort request and notify the client of success."""
try:
# Abort the request in the llm engine.
Expand All @@ -105,10 +99,12 @@ async def abort(self, identity, request: RPCAbortRequest):
# Send confirmation to the client.
await self.socket.send_multipart([
identity,
part2,
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
])

async def generate(self, identity, generate_request: RPCGenerateRequest):
async def generate(self, identity, part2,
generate_request: RPCGenerateRequest):
try:
results_generator = self.engine.generate(
generate_request.inputs,
Expand All @@ -120,51 +116,53 @@ async def generate(self, identity, generate_request: RPCGenerateRequest):

async for request_output in results_generator:
await self.socket.send_multipart(
[identity, cloudpickle.dumps(request_output)])
[identity, part2,
cloudpickle.dumps(request_output)])

except Exception as e:
### Notify client of all failures
await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
await self.socket.send_multipart(
[identity, part2, cloudpickle.dumps(e)])

async def check_health(self, identity):
async def check_health(self, identity, part2):
try:
await self.engine.check_health()
await self.socket.send_multipart(
[identity, cloudpickle.dumps(VLLM_RPC_HEALTHY_STR)])
[identity, part2,
cloudpickle.dumps(VLLM_RPC_HEALTHY_STR)])

except Exception as e:
await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
await self.socket.send_multipart(
[identity, part2, cloudpickle.dumps(e)])

def _make_handler_coro(self, identity,
def _make_handler_coro(self, identity, part2,
message) -> Coroutine[Any, Any, Never]:
"""Route the zmq message to the handler coroutine."""

request = cloudpickle.loads(message)

if isinstance(request, RPCGenerateRequest):
return self.generate(identity, request)
return self.generate(identity, part2, request)

elif isinstance(request, RPCAbortRequest):
return self.abort(identity, request)
return self.abort(identity, part2, request)

elif isinstance(request, RPCUtilityRequest):
if request == RPCUtilityRequest.GET_MODEL_CONFIG:
return self.get_model_config(identity)
elif request == RPCUtilityRequest.GET_PARALLEL_CONFIG:
return self.get_parallel_config(identity)
elif request == RPCUtilityRequest.GET_DECODING_CONFIG:
return self.get_decoding_config(identity)
elif request == RPCUtilityRequest.GET_SCHEDULER_CONFIG:
return self.get_scheduler_config(identity)
elif request == RPCUtilityRequest.GET_LORA_CONFIG:
return self.get_lora_config(identity)
if request in [
RPCUtilityRequest.GET_MODEL_CONFIG,
RPCUtilityRequest.GET_PARALLEL_CONFIG,
RPCUtilityRequest.GET_DECODING_CONFIG,
RPCUtilityRequest.GET_SCHEDULER_CONFIG,
RPCUtilityRequest.GET_LORA_CONFIG
]:
return self.get_config(identity, part2, request)
elif request == RPCUtilityRequest.DO_LOG_STATS:
return self.do_log_stats(identity)
return self.do_log_stats(identity, part2)
elif request == RPCUtilityRequest.IS_SERVER_READY:
return self.is_server_ready(identity)
return self.is_server_ready(identity, part2)
elif request == RPCUtilityRequest.CHECK_HEALTH:
return self.check_health(identity)
return self.check_health(identity, part2)
elif request == RPCUtilityRequest.IS_TRACING_ENABLED:
return self.is_tracing_enabled(identity)
return self.is_tracing_enabled(identity, part2)
else:
raise ValueError(f"Unknown RPCUtilityRequest type: {request}")

Expand All @@ -177,11 +175,11 @@ async def run_server_loop(self):
running_tasks = set()
while True:
# Wait for a request.
identity, message = await self.socket.recv_multipart()
identity, part2, message = await self.socket.recv_multipart()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For future readers it'd be nice to add a link to some zmq docs here or give this a descriptive name to say what part2 is. From context here I'm guessing this is routing information for the client-side proxy?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its related to the use of ROUTER, will do


# Process the request async.
task = asyncio.create_task(
self._make_handler_coro(identity, message))
self._make_handler_coro(identity, part2, message))

# We need to keep around a strong reference to the task,
# to avoid the task disappearing mid-execution as running tasks
Expand Down
Loading