Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add vLLM surpport for gradio demo, inference script and openai api demo #35

Merged
merged 10 commits into from
Aug 2, 2023
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
- 🚀 针对Llama-2模型扩充了**新版中文词表**,开源了中文LLaMA-2和Alpaca-2大模型
- 🚀 开源了预训练脚本、指令精调脚本,用户可根据需要进一步训练模型
- 🚀 使用个人电脑的CPU/GPU快速在本地进行大模型量化和部署体验
- 🚀 支持[🤗transformers](https://github.com/huggingface/transformers), [llama.cpp](https://github.com/ggerganov/llama.cpp), [text-generation-webui](https://github.com/oobabooga/text-generation-webui), [LangChain](https://github.com/hwchase17/langchain)等LLaMA生态
- 🚀 支持[🤗transformers](https://github.com/huggingface/transformers), [llama.cpp](https://github.com/ggerganov/llama.cpp), [text-generation-webui](https://github.com/oobabooga/text-generation-webui), [LangChain](https://github.com/hwchase17/langchain), [vLLM](https://github.com/vllm-project/vllm)等LLaMA生态
- 目前已开源的模型:Chinese-LLaMA-2-7B, Chinese-Alpaca-2-7B

----
Expand Down Expand Up @@ -129,11 +129,11 @@

本项目中的相关模型主要支持以下量化、推理和部署方式。

| 工具 | 特点 | CPU | GPU | 量化 | GUI | API | 教程 |
| :----------------------------------------------------------- | ---------------------------- | :--: | :--: | :--: | :--: | :--: | :----------------------------------------------------------: |
| [**llama.cpp**](https://github.com/ggerganov/llama.cpp) | 丰富的量化选项和高效本地推理 | ✅ | ✅ | ✅ | ❌ | ✅ | [link](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/llamacpp_zh) |
| [**🤗Transformers**](https://github.com/huggingface/transformers) | 原生transformers推理接口 | ✅ | ✅ | ✅ | ✅ | ❌ | [link](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/inference_with_transformers_zh) |
| [**仿OpenAI API调用**](https://platform.openai.com/docs/api-reference) | 仿OpenAI API接口的服务器Demo | ✅ | ✅ | ✅ | ❌ | ✅ | [link](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/api_calls_zh) |
| 工具 | 特点 | CPU | GPU | 量化 | GUI | API | vLLM | 教程 |
| :----------------------------------------------------------- | ---------------------------- | :--: | :--: | :--: | :--: | :--: | :--: | :----------------------------------------------------------: |
| [**llama.cpp**](https://github.com/ggerganov/llama.cpp) | 丰富的量化选项和高效本地推理 | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | [link](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/llamacpp_zh) |
| [**🤗Transformers**](https://github.com/huggingface/transformers) | 原生transformers推理接口 | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | [link](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/inference_with_transformers_zh) |
| [**仿OpenAI API调用**](https://platform.openai.com/docs/api-reference) | 仿OpenAI API接口的服务器Demo | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | [link](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/api_calls_zh) |

⚠️ 一代模型相关推理与部署支持将陆续迁移到本项目,届时将同步更新相关教程。

Expand Down
12 changes: 6 additions & 6 deletions README_EN.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ This project is based on the Llama-2, released by Meta, and it is the second gen
- 🚀 New extended Chinese vocabulary beyond Llama-2, open-sourcing the Chinese LLaMA-2 and Alpaca-2 LLMs.
- 🚀 Open-sourced the pre-training and instruction finetuning (SFT) scripts for further tuning on user's data
- 🚀 Quickly deploy and experience the quantized LLMs on CPU/GPU of personal PC
- 🚀 Support for LLaMA ecosystems like [🤗transformers](https://github.com/huggingface/transformers), [llama.cpp](https://github.com/ggerganov/llama.cpp), [text-generation-webui](https://github.com/oobabooga/text-generation-webui), [LangChain](https://github.com/hwchase17/langchain) etc.
- 🚀 Support for LLaMA ecosystems like [🤗transformers](https://github.com/huggingface/transformers), [llama.cpp](https://github.com/ggerganov/llama.cpp), [text-generation-webui](https://github.com/oobabooga/text-generation-webui), [LangChain](https://github.com/hwchase17/langchain), [vLLM](https://github.com/vllm-project/vllm) etc.
- The currently open-source models are Chinese-LLaMA-2-7B and Chinese-Alpaca-2-7B.

----
Expand Down Expand Up @@ -123,11 +123,11 @@ Below are the sizes of the full models in FP16 precision and 4-bit quantization.

The models in this project mainly support the following quantization, inference, and deployment methods.

| Tool | Features | CPU | GPU | Quant | GUI | API | Tutorial |
| :----------------------------------------------------------- | ------------------------------------------------------- | :--: | :--: | :---: | :--: | :--: | :----------------------------------------------------------: |
| [**llama.cpp**](https://github.com/ggerganov/llama.cpp) | Rich quantization options and efficient local inference | ✅ | ✅ | ✅ | ❌ | ✅ | [link](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/llamacpp_en) |
| [**🤗Transformers**](https://github.com/huggingface/transformers) | Native transformers inference interface | ✅ | ✅ | ✅ | ✅ | ❌ | [link](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/inference_with_transformers_en) |
| [**OpenAI API Calls**](https://platform.openai.com/docs/api-reference) | A server that implements OpenAI API | ✅ | ✅ | ✅ | ❌ | ✅ | [link](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/api_calls_en) |
| Tool | Features | CPU | GPU | Quant | GUI | API | vLLM | Tutorial |
| :----------------------------------------------------------- | ------------------------------------------------------- | :--: | :--: | :---: | :--: | :--: | :--: | :----------------------------------------------------------: |
| [**llama.cpp**](https://github.com/ggerganov/llama.cpp) | Rich quantization options and efficient local inference | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | [link](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/llamacpp_en) |
| [**🤗Transformers**](https://github.com/huggingface/transformers) | Native transformers inference interface | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | [link](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/inference_with_transformers_en) |
| [**OpenAI API Calls**](https://platform.openai.com/docs/api-reference) | A server that implements OpenAI API | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | [link](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/api_calls_en) |

⚠️ Inference and deployment support related to the first-generation model will be gradually migrated to this project, and relevant tutorials will be updated later.

Expand Down
274 changes: 196 additions & 78 deletions scripts/inference/gradio_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
from threading import Thread
import traceback
import gc
import json
import requests
from typing import Iterable, List
import subprocess

DEFAULT_SYSTEM_PROMPT = """You are a helpful assistant. 你是一个乐于助人的助手。"""

Expand Down Expand Up @@ -69,6 +73,20 @@
default=DEFAULT_SYSTEM_PROMPT,
help="The system prompt of the prompt template."
)
parser.add_argument(
"--use_vllm",
action='store_true',
help="Use vLLM as back-end LLM service.")
parser.add_argument(
"--post_host",
type=str,
default="localhost",
help="Host of vLLM service.")
parser.add_argument(
"--post_port",
type=int,
default=8000,
help="Port of vLLM service.")
args = parser.parse_args()
if args.only_cpu is True:
args.gpus = ""
Expand All @@ -92,51 +110,79 @@

def setup():
global tokenizer, model, device, share, port, max_memory
max_memory = args.max_memory
port = args.port
share = args.share
load_in_8bit = args.load_in_8bit
load_type = torch.float16
if torch.cuda.is_available():
device = torch.device(0)
else:
device = torch.device('cpu')
if args.tokenizer_path is None:
args.tokenizer_path = args.lora_model
if args.lora_model is None:
if args.use_vllm:
# global share, port, max_memory
max_memory = args.max_memory
port = args.port
share = args.share

if args.lora_model is not None:
raise ValueError("vLLM currently does not support LoRA, please merge the LoRA weights to the base model.")
if args.load_in_8bit:
raise ValueError("vLLM currently does not support quantization, please use fp16 (default) or unuse --use_vllm.")
if args.only_cpu:
raise ValueError("vLLM requires GPUs with compute capability not less than 7.0. If you want to run only on CPU, please unuse --use_vllm.")

if args.tokenizer_path is None:
args.tokenizer_path = args.base_model
tokenizer = LlamaTokenizer.from_pretrained(args.tokenizer_path, legacy=True)

base_model = LlamaForCausalLM.from_pretrained(
args.base_model,
load_in_8bit=load_in_8bit,
torch_dtype=load_type,
low_cpu_mem_usage=True,
device_map='auto',
)

model_vocab_size = base_model.get_input_embeddings().weight.size(0)
tokenzier_vocab_size = len(tokenizer)
print(f"Vocab of the base model: {model_vocab_size}")
print(f"Vocab of the tokenizer: {tokenzier_vocab_size}")
if model_vocab_size!=tokenzier_vocab_size:
print("Resize model embeddings to fit tokenizer")
base_model.resize_token_embeddings(tokenzier_vocab_size)
if args.lora_model is not None:
print("loading peft model")
model = PeftModel.from_pretrained(
base_model,
args.lora_model,
tokenizer = LlamaTokenizer.from_pretrained(args.tokenizer_path, legacy=True)

print("Start launch vllm server.")
cmd = [
"python -m vllm.entrypoints.api_server",
f"--model={args.base_model}",
f"--tokenizer={args.tokenizer_path}",
"--tokenizer-mode=slow",
f"--tensor-parallel-size={len(args.gpus.split(','))}",
"&",
]
subprocess.check_call(cmd)

Check warning on line 139 in scripts/inference/gradio_demo.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

scripts/inference/gradio_demo.py#L139

subprocess call - check for execution of untrusted input.
else:
max_memory = args.max_memory
port = args.port
share = args.share
load_in_8bit = args.load_in_8bit
load_type = torch.float16
if torch.cuda.is_available():
device = torch.device(0)
else:
device = torch.device('cpu')
if args.tokenizer_path is None:
args.tokenizer_path = args.lora_model
if args.lora_model is None:
args.tokenizer_path = args.base_model
tokenizer = LlamaTokenizer.from_pretrained(args.tokenizer_path, legacy=True)

base_model = LlamaForCausalLM.from_pretrained(
args.base_model,
load_in_8bit=load_in_8bit,
torch_dtype=load_type,
low_cpu_mem_usage=True,
device_map='auto',
)
else:
model = base_model

if device == torch.device('cpu'):
model.float()
model_vocab_size = base_model.get_input_embeddings().weight.size(0)
tokenizer_vocab_size = len(tokenizer)
print(f"Vocab of the base model: {model_vocab_size}")
print(f"Vocab of the tokenizer: {tokenizer_vocab_size}")
if model_vocab_size != tokenizer_vocab_size:
print("Resize model embeddings to fit tokenizer")
base_model.resize_token_embeddings(tokenizer_vocab_size)
if args.lora_model is not None:
print("loading peft model")
model = PeftModel.from_pretrained(
base_model,
args.lora_model,
torch_dtype=load_type,
device_map='auto',
)
else:
model = base_model

if device == torch.device('cpu'):
model.float()

model.eval()
model.eval()


# Reset the user input
Expand Down Expand Up @@ -239,6 +285,45 @@
torch.cuda.empty_cache()


def post_http_request(prompt: str,
api_url: str,
n: int = 1,
top_p: float = 0.9,
top_k: int = 40,
temperature: float = 0.7,
max_tokens: int = 512,
presence_penalty: float = 1.0,
use_beam_search: bool = False,
stream: bool = False) -> requests.Response:
headers = {"User-Agent": "Test Client"}
pload = {
"prompt": prompt,
"n": n,
"top_p": 1 if use_beam_search else top_p,
"top_k": -1 if use_beam_search else top_k,
"temperature": 0 if use_beam_search else temperature,
"max_tokens": max_tokens,
"use_beam_search": use_beam_search,
"best_of": 5 if use_beam_search else n,
"presence_penalty": presence_penalty,
"stream": stream,
}
print(pload)

response = requests.post(api_url, headers=headers, json=pload, stream=True)
return response


def get_streaming_response(response: requests.Response) -> Iterable[List[str]]:
for chunk in response.iter_lines(chunk_size=8192,
decode_unicode=False,
delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode("utf-8"))
output = data["text"]
yield output


# Perform prediction based on the user input and history
@torch.no_grad()
def predict(
Expand All @@ -248,7 +333,8 @@
temperature=0.1,
top_k=40,
do_sample=True,
repetition_penalty=1.0
repetition_penalty=1.0,
presence_penalty=0.0,
):
while True:
print("len(history):", len(history))
Expand Down Expand Up @@ -277,46 +363,68 @@
else:
break

inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(device)

generate_params = {
'input_ids': input_ids,
'max_new_tokens': max_new_tokens,
'top_p': top_p,
'temperature': temperature,
'top_k': top_k,
'do_sample': do_sample,
'repetition_penalty': repetition_penalty,
}
if args.use_vllm:
generate_params = {
'max_tokens': max_new_tokens,
'top_p': top_p,
'temperature': temperature,
'top_k': top_k,
"use_beam_search": not do_sample,
'presence_penalty': presence_penalty,
}

api_url = f"http://{args.post_host}:{args.post_port}/generate"

def generate_with_callback(callback=None, **kwargs):
if 'stopping_criteria' in kwargs:
kwargs['stopping_criteria'].append(Stream(callback_func=callback))
else:
kwargs['stopping_criteria'] = [Stream(callback_func=callback)]
clear_torch_cache()
with torch.no_grad():
model.generate(**kwargs)

def generate_with_streaming(**kwargs):
return Iteratorize(generate_with_callback, kwargs, callback=None)
response = post_http_request(prompt, api_url, **generate_params, stream=True)

with generate_with_streaming(**generate_params) as generator:
for output in generator:
next_token_ids = output[len(input_ids[0]):]
if next_token_ids[0] == tokenizer.eos_token_id:
break
new_tokens = tokenizer.decode(
next_token_ids, skip_special_tokens=True)
if isinstance(tokenizer, LlamaTokenizer) and len(next_token_ids) > 0:
if tokenizer.convert_ids_to_tokens(int(next_token_ids[0])).startswith('▁'):
new_tokens = ' ' + new_tokens
for h in get_streaming_response(response):
for line in h:
line = line.replace(prompt, '')
history[-1][1] = line
yield history

history[-1][1] = new_tokens
yield history
if len(next_token_ids) >= max_new_tokens:
break
else:
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(device)

generate_params = {
'input_ids': input_ids,
'max_new_tokens': max_new_tokens,
'top_p': top_p,
'temperature': temperature,
'top_k': top_k,
'do_sample': do_sample,
'repetition_penalty': repetition_penalty,
}

def generate_with_callback(callback=None, **kwargs):
if 'stopping_criteria' in kwargs:
kwargs['stopping_criteria'].append(Stream(callback_func=callback))
else:
kwargs['stopping_criteria'] = [Stream(callback_func=callback)]
clear_torch_cache()
with torch.no_grad():
model.generate(**kwargs)

def generate_with_streaming(**kwargs):
return Iteratorize(generate_with_callback, kwargs, callback=None)

with generate_with_streaming(**generate_params) as generator:
for output in generator:
next_token_ids = output[len(input_ids[0]):]
if next_token_ids[0] == tokenizer.eos_token_id:
break
new_tokens = tokenizer.decode(
next_token_ids, skip_special_tokens=True)
if isinstance(tokenizer, LlamaTokenizer) and len(next_token_ids) > 0:
if tokenizer.convert_ids_to_tokens(int(next_token_ids[0])).startswith('▁'):
new_tokens = ' ' + new_tokens

history[-1][1] = new_tokens
yield history
if len(next_token_ids) >= max_new_tokens:
break


# Call the setup function to initialize the components
Expand Down Expand Up @@ -370,7 +478,16 @@
value=1.1,
step=0.1,
label="Repetition Penalty",
interactive=True)
interactive=True,
visible=False if args.use_vllm else True)
presence_penalty = gr.Slider(
-2.0,
2.0,
value=1.0,
step=0.1,
label="Presence Penalty",
interactive=True,
visible=True if args.use_vllm else False)

params = [user_input, chatbot]
predict_params = [
Expand All @@ -380,7 +497,8 @@
temperature,
top_k,
do_sample,
repetition_penalty]
repetition_penalty,
presence_penalty]

submitBtn.click(
user,
Expand Down
Loading