Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add vllm Predictor #20

Merged
merged 44 commits into from
Jan 18, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
0d9e1d5
add vllm_predictor
xwu99 Jan 3, 2024
f5f0360
add tests skeleton
xwu99 Jan 3, 2024
69f1612
add tests skeleton
xwu99 Jan 3, 2024
a62ec75
add pytest.ini
xwu99 Jan 3, 2024
a8e6b6d
wip
xwu99 Jan 4, 2024
565c616
complete, debug wip
xwu99 Jan 4, 2024
a554202
nit
xwu99 Jan 5, 2024
c0c1661
nit
xwu99 Jan 5, 2024
cb263ea
nit
xwu99 Jan 5, 2024
c0f5cea
complete generate supporting str and List[str]
xwu99 Jan 8, 2024
af76998
add model
xwu99 Jan 8, 2024
064246a
add streaming
xwu99 Jan 8, 2024
b572b19
remove tests
xwu99 Jan 10, 2024
2deaa47
Add install-vllm-cpu script
xwu99 Jan 10, 2024
7573a7b
nit
xwu99 Jan 10, 2024
c6f0dc9
nit
xwu99 Jan 10, 2024
7cd95fe
merge upstream
xwu99 Jan 12, 2024
57aa0d1
nit
xwu99 Jan 12, 2024
17c1206
fix package inference
xwu99 Jan 14, 2024
9c83f0c
update install script and add doc
xwu99 Jan 15, 2024
a5875ab
nit
xwu99 Jan 15, 2024
9db8d49
nit
xwu99 Jan 15, 2024
71688f5
nit
xwu99 Jan 15, 2024
bde72f3
add dtype support
xwu99 Jan 16, 2024
dfda735
nit
xwu99 Jan 16, 2024
da88b78
nit
xwu99 Jan 16, 2024
176e766
nit
xwu99 Jan 16, 2024
c890594
Merge remote-tracking branch 'upstream/main' into vllm-predictor
xwu99 Jan 16, 2024
9ec0311
add ci
xwu99 Jan 16, 2024
226f3d2
nit
xwu99 Jan 16, 2024
3b09a6e
nit
xwu99 Jan 16, 2024
0cdb27b
add libpthread-stubs0-dev
xwu99 Jan 17, 2024
fd0fb29
fix install-vllm-cpu
xwu99 Jan 17, 2024
3f9ba57
fix
xwu99 Jan 17, 2024
f89008d
revert inference.inference_config
xwu99 Jan 17, 2024
7d85569
debug ci
xwu99 Jan 17, 2024
1cadcd9
debug ci
xwu99 Jan 17, 2024
8759cb5
debug ci
xwu99 Jan 17, 2024
1cfc3b7
debug ci
xwu99 Jan 17, 2024
964d69a
debug ci
xwu99 Jan 17, 2024
c6f7686
debug ci
xwu99 Jan 17, 2024
e959bb2
debug ci
xwu99 Jan 17, 2024
a42f300
debug ci
xwu99 Jan 17, 2024
f647769
update
xwu99 Jan 18, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions dev/scripts/install-vllm-cpu.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#!/usr/bin/env bash

# The script will install vllm-cpu into current conda environment
# Use the following command to create a new conda env if necessary
# $ conda create -n vllm-cpu python=3.10
# $ conda activate vllm-cpu

# Install g++ 12.3 for building
conda install -y -c conda-forge gxx=12.3 gxx_linux-64=12.3

# Install from source
# TODO: need to verify if conda env needed to reactivate to setup g++ envs
MAX_JOBS=8 pip install -v git+https://github.com/bigPYJ1151/vllm@PR_Branch
4 changes: 2 additions & 2 deletions inference/api_server_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@

import os
from ray import serve
from inference.api_openai_backend.query_client import RouterQueryClient
from inference.api_openai_backend.router_app import Router, router_app
from api_openai_backend.query_client import RouterQueryClient
from api_openai_backend.router_app import Router, router_app


def router_application(deployments):
Expand Down
2 changes: 1 addition & 1 deletion inference/deepspeed_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import os
from predictor import Predictor
from utils import get_torch_dtype
from inference.inference_config import (
from inference_config import (
InferenceConfig,
DEVICE_CPU,
DEVICE_XPU,
Expand Down
1 change: 1 addition & 0 deletions inference/inference_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ class InferenceConfig(BaseModel):
gpus_per_worker: int = 0
hpus_per_worker: int = 0
deepspeed: bool = False
vllm: bool = False
workers_per_group: int = 2
device: str = DEVICE_CPU
ipex: Ipex = Ipex()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import yaml
import os
from inference.inference_config import InferenceConfig
from inference_config import InferenceConfig

ic = InferenceConfig()

Expand Down
25 changes: 25 additions & 0 deletions inference/models/vllm/llama-2-7b-chat-hf-vllm.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
port: 8000
name: llama-2-7b-chat-hf
route_prefix: /llama-2-7b-chat-hf
cpus_per_worker: 24
gpus_per_worker: 0
deepspeed: false
vllm: true
workers_per_group: 2
device: "cpu"
ipex:
enabled: true
precision: bf16
model_description:
model_id_or_path: meta-llama/Llama-2-7b-chat-hf
tokenizer_name_or_path: meta-llama/Llama-2-7b-chat-hf
chat_processor: ChatModelLLama
prompt:
intro: ''
human_id: '[INST] {msg} [/INST]

'
bot_id: ''
stop_words: []
config:
use_auth_token: ''
16 changes: 13 additions & 3 deletions inference/predictor.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import re
import torch
from transformers import AutoTokenizer, StoppingCriteriaList
from inference.inference_config import InferenceConfig
from inference_config import InferenceConfig
from utils import StoppingCriteriaSub
from typing import List, AsyncGenerator, Union


class Predictor:
Expand Down Expand Up @@ -72,11 +73,20 @@ def configure_tokenizer(self, model_name):
tokenizer.pad_token = tokenizer.eos_token
model.generation_config.pad_token_id = model.generation_config.eos_token_id

def generate(self, prompt, **config):
def generate(self, prompts: Union[str, List[str]], **config) -> Union[str, List[str]]:
pass

def streaming_generate(self, prompt, streamer, **config):
async def generate_async(
self, prompts: Union[str, List[str]], **config
) -> Union[str, List[str]]:
pass

# output is streamed into streamer
def streaming_generate(self, prompt: str, streamer, **config) -> None:
pass

def get_streamer(self):
pass

async def stream_results(self, results_generator) -> AsyncGenerator[str, None]:
pass
32 changes: 26 additions & 6 deletions inference/predictor_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
from queue import Empty
import torch
from transformers import TextIteratorStreamer
from inference.inference_config import InferenceConfig
from inference_config import InferenceConfig
from typing import Union, Dict, Any
from starlette.responses import StreamingResponse
from inference.api_openai_backend.openai_protocol import ModelResponse
from api_openai_backend.openai_protocol import ModelResponse


@serve.deployment
Expand All @@ -53,11 +53,17 @@ def __init__(self, infer_conf: InferenceConfig):
self.process_tool = chat_processor(**prompt.dict())

self.use_deepspeed = infer_conf.deepspeed
self.use_vllm = infer_conf.vllm

if self.use_deepspeed:
from deepspeed_predictor import DeepSpeedPredictor

self.predictor = DeepSpeedPredictor(infer_conf)
self.streamer = self.predictor.get_streamer()
elif self.use_vllm:
from vllm_predictor import VllmPredictor

self.predictor = VllmPredictor(infer_conf)
else:
from transformer_predictor import TransformerPredictor

Expand Down Expand Up @@ -94,23 +100,37 @@ async def __call__(self, http_request: Request) -> Union[StreamingResponse, str]
prompts.extend(text)
else:
prompts.append(text)

if not streaming_response:
return self.predictor.generate(prompts, **config)
if self.use_vllm:
return await self.predictor.generate_async(prompts, **config)
else:
return self.predictor.generate(prompts, **config)

if self.use_deepspeed:
self.predictor.streaming_generate(prompts, self.streamer, **config)
return StreamingResponse(
self.consume_streamer(), status_code=200, media_type="text/plain"
)
elif self.use_vllm:
# TODO: streaming only support single prompt
# It's a wordaround for current situation, need another PR to address this
xwu99 marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(prompts, list):
prompt = prompts[0]
xwu99 marked this conversation as resolved.
Show resolved Hide resolved
results_generator = await self.predictor.streaming_generate_async(prompt, **config)
return StreamingResponse(
self.predictor.stream_results(results_generator),
status_code=200,
media_type="text/plain",
)
else:
streamer = self.predictor.get_streamer()
self.loop.run_in_executor(
None,
functools.partial(self.predictor.streaming_generate, prompts, streamer, **config),
)
return StreamingResponse(
self.consume_streamer_async(streamer),
status_code=200,
media_type="text/plain",
self.consume_streamer_async(streamer), status_code=200, media_type="text/plain"
)

async def stream_response(self, prompt, config):
Expand Down
2 changes: 1 addition & 1 deletion inference/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from api_server_simple import serve_run
from api_server_openai import openai_serve_run
from predictor_deployment import PredictorDeployment
from inference.inference_config import ModelDescription, InferenceConfig, all_models
from inference_config import ModelDescription, InferenceConfig, all_models


def get_deployed_models(args):
Expand Down
3 changes: 1 addition & 2 deletions inference/transformer_predictor.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import torch
from transformers import AutoModelForCausalLM, AutoConfig
from transformers import TextIteratorStreamer
from inference.inference_config import InferenceConfig, IPEX_PRECISION_BF16
from inference_config import InferenceConfig, IPEX_PRECISION_BF16
from predictor import Predictor
from utils import get_torch_dtype


class TransformerPredictor(Predictor):
def __init__(self, infer_conf: InferenceConfig):
super().__init__(infer_conf)

model_desc = infer_conf.model_description
model_config = model_desc.config
hf_config = AutoConfig.from_pretrained(
Expand Down
2 changes: 1 addition & 1 deletion inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from transformers import StoppingCriteria
import torch
from inference.inference_config import InferenceConfig, DEVICE_CPU
from inference_config import InferenceConfig, DEVICE_CPU
from typing import Dict, Any


Expand Down
67 changes: 67 additions & 0 deletions inference/vllm_predictor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from typing import AsyncGenerator, List, Union
from predictor import Predictor
from inference_config import InferenceConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid
xwu99 marked this conversation as resolved.
Show resolved Hide resolved
import asyncio


class VllmPredictor(Predictor):
def __init__(self, infer_conf: InferenceConfig):
super().__init__(infer_conf)

model_desc = infer_conf.model_description
model_config = model_desc.config

args = AsyncEngineArgs(
model=model_desc.model_id_or_path,
trust_remote_code=model_config.trust_remote_code,
device=infer_conf.device,
)

self.engine = AsyncLLMEngine.from_engine_args(args)

async def _get_generator_output(self, results_generator):
async for request_output in results_generator:
if request_output.finished:
return request_output.outputs[0].text
return None

async def generate_async(
self, prompts: Union[str, List[str]], **config
) -> Union[str, List[str]]:
sampling_params = SamplingParams(**config)
if isinstance(prompts, str):
request_id = random_uuid()
results_generator = self.engine.generate(prompts, sampling_params, request_id)
async for request_output in results_generator:
if request_output.finished:
return request_output.outputs[0].text
else:
results_generators = [
self.engine.generate(prompt, sampling_params, random_uuid()) for prompt in prompts
]
xwu99 marked this conversation as resolved.
Show resolved Hide resolved
results = [
self._get_generator_output(results_generator)
for results_generator in results_generators
]
return await asyncio.gather(*results)

return ""

async def streaming_generate_async(self, prompt, **config):
sampling_params = SamplingParams(**config)
request_id = random_uuid()
results_generator = self.engine.generate(prompt, sampling_params, request_id)
return results_generator

async def stream_results(self, results_generator) -> AsyncGenerator[str, None]:
num_returned = 0
async for request_output in results_generator:
text_outputs = [output.text for output in request_output.outputs]
assert len(text_outputs) == 1
text_output = text_outputs[0][num_returned:]
yield text_output
num_returned += len(text_output)
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ bigdl-cpu = [
"bigdl-llm[all]"
]

vllm = [
"vllm>=0.2.6"
]

xwu99 marked this conversation as resolved.
Show resolved Hide resolved
[tool.setuptools]
packages = ["finetune", "inference"]

Expand Down
10 changes: 5 additions & 5 deletions ui/start_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
import sys

sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from inference.inference_config import all_models, ModelDescription, Prompt
from inference.inference_config import InferenceConfig as FinetunedConfig
from inference.chat_process import ChatModelGptJ, ChatModelLLama # noqa: F401
from inference.predictor_deployment import PredictorDeployment
from inference_config import all_models, ModelDescription, Prompt
from inference_config import InferenceConfig as FinetunedConfig
from chat_process import ChatModelGptJ, ChatModelLLama # noqa: F401
from predictor_deployment import PredictorDeployment
from ray import serve
import ray
import gradio as gr
Expand Down Expand Up @@ -752,7 +752,7 @@ def _init_ui(self):
head_content = """
<div style="color: #fff;text-align: center;">
<div style="position:absolute; left:15px; top:15px; "><img src="/file=ui/images/logo.png" width="50" height="50"/></div>
<p style="color: #fff; font-size: 1.1rem;">Manage LLM Lifecycle</p>
<p style="color: #fff; font-size: 1.1rem;">Manage LLM Lifecycle</p>
<p style="color: #fff; font-size: 0.9rem;">Fine-Tune LLMs using workflow on Ray, Deploy and Inference</p>
</div>
"""
Expand Down
Loading