Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ Do Not Merge ] pyzmq based openai server prototypes (w/ protobuf) #6880

Draft
wants to merge 41 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
bed649a
:alembic: add backend proto file
joerunde Jul 25, 2024
7de9d49
:recycle: move proto to grpc/pb
joerunde Jul 25, 2024
9394a62
:sparkles: add proto compilation
joerunde Jul 25, 2024
dd8bf96
updated
robertgshaw2-neuralmagic Jul 25, 2024
5c7fbff
kinda working
robertgshaw2-neuralmagic Jul 25, 2024
952e8ef
:construction: more wip
joerunde Jul 25, 2024
e8eac95
fixed
robertgshaw2-neuralmagic Jul 25, 2024
938a843
:bug: fixup race condition
joerunde Jul 25, 2024
2b8d7cd
:bug: remove timeout
joerunde Jul 25, 2024
ea02d39
format
robertgshaw2-neuralmagic Jul 26, 2024
4a2dc46
streaming
robertgshaw2-neuralmagic Jul 26, 2024
30f2bc9
removed breaks
robertgshaw2-neuralmagic Jul 26, 2024
c718b68
pushing current state
robertgshaw2-neuralmagic Jul 26, 2024
b3d25c6
:alembic: try unix sockets
joerunde Jul 26, 2024
2765b17
:zap: no background loop
joerunde Jul 26, 2024
b219778
spurious change
robertgshaw2-neuralmagic Jul 26, 2024
932ea23
remove spurious change
robertgshaw2-neuralmagic Jul 26, 2024
f029114
spurious changes
robertgshaw2-neuralmagic Jul 26, 2024
6854758
spurioous change
robertgshaw2-neuralmagic Jul 26, 2024
3b5ff66
:bug: whoops
joerunde Jul 26, 2024
79247c3
:memo: log stuff
joerunde Jul 26, 2024
a39ebc0
stash
robertgshaw2-neuralmagic Jul 26, 2024
ef257f1
pushing up
robertgshaw2-neuralmagic Jul 26, 2024
a6c9bc5
stash
robertgshaw2-neuralmagic Jul 28, 2024
d7490bc
actually working
robertgshaw2-neuralmagic Jul 28, 2024
f68fd60
cleanup
robertgshaw2-neuralmagic Jul 28, 2024
38b5b9c
more cleanup
robertgshaw2-neuralmagic Jul 28, 2024
bc54311
cleanup
robertgshaw2-neuralmagic Jul 28, 2024
3cccebb
stash
robertgshaw2-neuralmagic Jul 28, 2024
4b78e29
more cleanup
robertgshaw2-neuralmagic Jul 28, 2024
345bfdd
setup
robertgshaw2-neuralmagic Jul 28, 2024
cfbb001
cleanup
robertgshaw2-neuralmagic Jul 28, 2024
d811b42
format
robertgshaw2-neuralmagic Jul 28, 2024
852534e
cleaning up
robertgshaw2-neuralmagic Jul 28, 2024
e42be96
zlib
robertgshaw2-neuralmagic Jul 28, 2024
5202a59
Revert "zlib"
robertgshaw2-neuralmagic Jul 28, 2024
e1cfb85
stash
robertgshaw2-neuralmagic Jul 29, 2024
50987b9
working
robertgshaw2-neuralmagic Jul 29, 2024
bc072ec
add missing files
robertgshaw2-neuralmagic Jul 29, 2024
48f1b81
add build proto.sh
robertgshaw2-neuralmagic Jul 29, 2024
ab4db75
make it work with script
robertgshaw2-neuralmagic Jul 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions benchmarks/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,8 @@ async def benchmark(
)

print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='='))
print("{:<40} {:<10}".format("TOKENS PER REQUESTS:",
metrics.total_output // metrics.completed))
print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
print("{:<40} {:<10.2f}".format("Benchmark duration (s):",
benchmark_duration))
Expand Down
1 change: 1 addition & 0 deletions build_proto.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python -m grpc_tools.protoc --proto_path=. --python_out=. --grpc_python_out=. vllm/grpc/pb/generate.proto
7 changes: 3 additions & 4 deletions examples/openai_completion_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,13 @@
model = models.data[0].id

# Completion API
stream = False
stream = True
completion = client.completions.create(
model=model,
prompt="A robot may not injure a human being",
echo=False,
n=2,
stream=stream,
logprobs=3)
n=1,
stream=stream)

print("Completion results:")
if stream:
Expand Down
85 changes: 45 additions & 40 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ async def _force_log():
await asyncio.sleep(10)
await engine.do_log_stats()

if not engine_args.disable_log_stats:
task = asyncio.create_task(_force_log())
_running_tasks.add(task)
task.add_done_callback(_running_tasks.remove)
# if not engine_args.disable_log_stats:
# task = asyncio.create_task(_force_log())
# _running_tasks.add(task)
# task.add_done_callback(_running_tasks.remove)

yield

Expand Down Expand Up @@ -221,19 +221,24 @@ async def build_server(
) -> uvicorn.Server:
app = build_app(args)

if args.served_model_name is not None:
served_model_names = args.served_model_name
else:
served_model_names = [args.model]
# if args.served_model_name is not None:
# served_model_names = args.served_model_name
# else:
# served_model_names = [args.model]

served_model_names = "meta-llama/Meta-Llama-3-8B-Instruct"

from vllm.grpc.client import RPCClient
engine = RPCClient()

global engine, engine_args
# global engine, engine_args

engine_args = AsyncEngineArgs.from_cli_args(args)
engine = (llm_engine
if llm_engine is not None else AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.OPENAI_API_SERVER))
# engine_args = AsyncEngineArgs.from_cli_args(args)
# engine = (llm_engine
# if llm_engine is not None else AsyncLLMEngine.from_engine_args(
# engine_args, usage_context=UsageContext.OPENAI_API_SERVER))

model_config = await engine.get_model_config()
# model_config = await engine.get_model_config()

if args.disable_log_requests:
request_logger = None
Expand All @@ -245,40 +250,40 @@ async def build_server(
global openai_serving_embedding
global openai_serving_tokenization

openai_serving_chat = OpenAIServingChat(
engine,
model_config,
served_model_names,
args.response_role,
lora_modules=args.lora_modules,
prompt_adapters=args.prompt_adapters,
request_logger=request_logger,
chat_template=args.chat_template,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
)
# openai_serving_chat = OpenAIServingChat(
# engine,
# model_config,
# served_model_names,
# args.response_role,
# lora_modules=args.lora_modules,
# prompt_adapters=args.prompt_adapters,
# request_logger=request_logger,
# chat_template=args.chat_template,
# return_tokens_as_token_ids=args.return_tokens_as_token_ids,
# )
openai_serving_completion = OpenAIServingCompletion(
engine,
model_config,
# model_config,
served_model_names,
lora_modules=args.lora_modules,
prompt_adapters=args.prompt_adapters,
request_logger=request_logger,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
)
openai_serving_embedding = OpenAIServingEmbedding(
engine,
model_config,
served_model_names,
request_logger=request_logger,
)
openai_serving_tokenization = OpenAIServingTokenization(
engine,
model_config,
served_model_names,
lora_modules=args.lora_modules,
request_logger=request_logger,
chat_template=args.chat_template,
)
# openai_serving_embedding = OpenAIServingEmbedding(
# engine,
# model_config,
# served_model_names,
# request_logger=request_logger,
# )
# openai_serving_tokenization = OpenAIServingTokenization(
# engine,
# model_config,
# served_model_names,
# lora_modules=args.lora_modules,
# request_logger=request_logger,
# chat_template=args.chat_template,
# )
app.root_path = args.root_path

logger.info("Available routes are:")
Expand Down
45 changes: 23 additions & 22 deletions vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class OpenAIServingCompletion(OpenAIServing):
def __init__(
self,
engine: AsyncLLMEngine,
model_config: ModelConfig,
# model_config: ModelConfig,
served_model_names: List[str],
*,
lora_modules: Optional[List[LoRAModulePath]],
Expand All @@ -54,7 +54,7 @@ def __init__(
return_tokens_as_token_ids: bool = False,
):
super().__init__(engine=engine,
model_config=model_config,
# model_config=model_config,
served_model_names=served_model_names,
lora_modules=lora_modules,
prompt_adapters=prompt_adapters,
Expand Down Expand Up @@ -96,18 +96,18 @@ async def create_completion(self, request: CompletionRequest,
tokenizer = await self.engine.get_tokenizer(lora_request)

sampling_params = request.to_sampling_params()
decoding_config = await self.engine.get_decoding_config()
guided_decoding_backend = request.guided_decoding_backend \
or decoding_config.guided_decoding_backend
guided_decode_logit_processor = (
await
get_guided_decoding_logits_processor(guided_decoding_backend,
request, tokenizer))
if guided_decode_logit_processor is not None:
if sampling_params.logits_processors is None:
sampling_params.logits_processors = []
sampling_params.logits_processors.append(
guided_decode_logit_processor)
# decoding_config = await self.engine.get_decoding_config()
# guided_decoding_backend = request.guided_decoding_backend \
# or decoding_config.guided_decoding_backend
# guided_decode_logit_processor = (
# await
# get_guided_decoding_logits_processor(guided_decoding_backend,
# request, tokenizer))
# if guided_decode_logit_processor is not None:
# if sampling_params.logits_processors is None:
# sampling_params.logits_processors = []
# sampling_params.logits_processors.append(
# guided_decode_logit_processor)

prompts = list(
self._tokenize_prompt_input_or_inputs(
Expand All @@ -128,21 +128,21 @@ async def create_completion(self, request: CompletionRequest,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)

is_tracing_enabled = await self.engine.is_tracing_enabled()
trace_headers = None
if is_tracing_enabled:
trace_headers = extract_trace_headers(raw_request.headers)
if not is_tracing_enabled and contains_trace_headers(
raw_request.headers):
log_tracing_disabled_warning()
# is_tracing_enabled = await self.engine.is_tracing_enabled()
# trace_headers = None
# if is_tracing_enabled:
# trace_headers = extract_trace_headers(raw_request.headers)
# if not is_tracing_enabled and contains_trace_headers(
# raw_request.headers):
# log_tracing_disabled_warning()

generator = self.engine.generate(
{"prompt_token_ids": prompt_inputs["prompt_token_ids"]},
sampling_params,
request_id_item,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers,
# trace_headers=trace_headers,
)

generators.append(generator)
Expand Down Expand Up @@ -286,6 +286,7 @@ async def completion_stream_generator(

previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)
# finish_reason = None if output.finish_reason == "" else output.finish_reason
finish_reason = output.finish_reason
stop_reason = output.stop_reason

Expand Down
7 changes: 4 additions & 3 deletions vllm/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class OpenAIServing:
def __init__(
self,
engine: AsyncLLMEngine,
model_config: ModelConfig,
# model_config: ModelConfig,
served_model_names: List[str],
*,
lora_modules: Optional[List[LoRAModulePath]],
Expand All @@ -73,8 +73,9 @@ def __init__(
super().__init__()

self.engine = engine
self.model_config = model_config
self.max_model_len = model_config.max_model_len
# self.model_config = model_config
# self.max_model_len = model_config.max_model_len
self.max_model_len = 4096

self.served_model_names = served_model_names

Expand Down
Empty file added vllm/grpc/__init__.py
Empty file.
126 changes: 126 additions & 0 deletions vllm/grpc/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from vllm import AsyncLLMEngine
from vllm.grpc.pb import generate_pb2
from typing import AsyncIterator, List, Optional, Mapping

from vllm.inputs import PromptInputs
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.outputs import CompletionOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from transformers import AutoTokenizer
from dataclasses import dataclass

import time
import zmq
import zmq.asyncio
import pickle

MODEL = "meta-llama/Meta-Llama-3-8B-Instruct"

@dataclass
class RCPRequest:
inputs: PromptInputs
sampling_params: SamplingParams
request_id: str


class RPCClient(AsyncLLMEngine):
def __init__(self):
self.engine_use_ray = False
self.worker_use_ray = False
self.log_requests = False
self.engine = None

self.tokenizer = AutoTokenizer.from_pretrained(MODEL)

self.context = zmq.asyncio.Context()


@property
def is_running(self) -> bool:
return True

@property
def is_stopped(self) -> bool:
return False

@property
def errored(self) -> bool:
return False

async def get_tokenizer(
self,
lora_request: Optional[LoRARequest] = None,
) -> "PreTrainedTokenizer":
# TODO: what to return :/
return self.tokenizer

def start_background_loop(self):
# TODO something lol
pass

async def generate(
self,
inputs: PromptInputs,
sampling_params: SamplingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> AsyncIterator[RequestOutput]:
socket = self.context.socket(zmq.DEALER)
socket.connect('tcp://localhost:5570')

# socket.send_multipart([
# pickle.dumps(
# RCPRequest(
# inputs=inputs,
# sampling_params=sampling_params,
# request_id=request_id
# ), pickle.HIGHEST_PROTOCOL
# )
# ])
prompt: str = inputs.get('prompt', "")
prompt_token_ids: List[int] = inputs.get('prompt_token_ids', [])
proto = generate_pb2.GenerateRequest(
prompt_inputs=generate_pb2.PromptInputs(
prompt=prompt,
prompt_token_ids=prompt_token_ids),
request_id=request_id,
)
await socket.send_multipart([proto.SerializeToString()])

while True:
message = await socket.recv()
# request_output = pickle.loads(message)
generate_response = generate_pb2.GenerateResponse()
generate_response.ParseFromString(message)

completion_outputs = [
CompletionOutput(
index=output.index,
text=output.text,
token_ids=output.token_ids,
cumulative_logprob=0.0,
logprobs=None,
finish_reason=(None if output.finish_reason == "" else output.finish_reason),
) for output in generate_response.outputs
]

request_output = RequestOutput(
request_id=request_id,
prompt_token_ids=[],
outputs=completion_outputs,
finished=(completion_outputs[0].finish_reason is not None),
prompt_logprobs=None,
prompt=prompt,
)

if request_output.finished:
break

yield request_output

socket.close()
yield request_output
Empty file added vllm/grpc/pb/__init__.py
Empty file.
Loading
Loading