diff --git a/README.md b/README.md index 58b2a56..9106736 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ - 🚀 针对Llama-2模型扩充了**新版中文词表**,开源了中文LLaMA-2和Alpaca-2大模型 - 🚀 开源了预训练脚本、指令精调脚本,用户可根据需要进一步训练模型 - 🚀 使用个人电脑的CPU/GPU快速在本地进行大模型量化和部署体验 -- 🚀 支持[🤗transformers](https://github.com/huggingface/transformers), [llama.cpp](https://github.com/ggerganov/llama.cpp), [text-generation-webui](https://github.com/oobabooga/text-generation-webui), [LangChain](https://github.com/hwchase17/langchain)等LLaMA生态 +- 🚀 支持[🤗transformers](https://github.com/huggingface/transformers), [llama.cpp](https://github.com/ggerganov/llama.cpp), [text-generation-webui](https://github.com/oobabooga/text-generation-webui), [LangChain](https://github.com/hwchase17/langchain), [vLLM](https://github.com/vllm-project/vllm)等LLaMA生态 - 目前已开源的模型:Chinese-LLaMA-2-7B, Chinese-Alpaca-2-7B ---- @@ -129,11 +129,11 @@ 本项目中的相关模型主要支持以下量化、推理和部署方式。 -| 工具 | 特点 | CPU | GPU | 量化 | GUI | API | 教程 | -| :----------------------------------------------------------- | ---------------------------- | :--: | :--: | :--: | :--: | :--: | :----------------------------------------------------------: | -| [**llama.cpp**](https://github.com/ggerganov/llama.cpp) | 丰富的量化选项和高效本地推理 | ✅ | ✅ | ✅ | ❌ | ✅ | [link](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/llamacpp_zh) | -| [**🤗Transformers**](https://github.com/huggingface/transformers) | 原生transformers推理接口 | ✅ | ✅ | ✅ | ✅ | ❌ | [link](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/inference_with_transformers_zh) | -| [**仿OpenAI API调用**](https://platform.openai.com/docs/api-reference) | 仿OpenAI API接口的服务器Demo | ✅ | ✅ | ✅ | ❌ | ✅ | [link](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/api_calls_zh) | +| 工具 | 特点 | CPU | GPU | 量化 | GUI | API | vLLM | 教程 | +| :----------------------------------------------------------- | ---------------------------- | :--: | :--: | :--: | :--: | :--: | :--: | :----------------------------------------------------------: | +| [**llama.cpp**](https://github.com/ggerganov/llama.cpp) | 丰富的量化选项和高效本地推理 | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | [link](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/llamacpp_zh) | +| [**🤗Transformers**](https://github.com/huggingface/transformers) | 原生transformers推理接口 | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | [link](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/inference_with_transformers_zh) | +| [**仿OpenAI API调用**](https://platform.openai.com/docs/api-reference) | 仿OpenAI API接口的服务器Demo | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | [link](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/api_calls_zh) | ⚠️ 一代模型相关推理与部署支持将陆续迁移到本项目,届时将同步更新相关教程。 diff --git a/README_EN.md b/README_EN.md index 7fdb458..472b99a 100644 --- a/README_EN.md +++ b/README_EN.md @@ -20,7 +20,7 @@ This project is based on the Llama-2, released by Meta, and it is the second gen - 🚀 New extended Chinese vocabulary beyond Llama-2, open-sourcing the Chinese LLaMA-2 and Alpaca-2 LLMs. - 🚀 Open-sourced the pre-training and instruction finetuning (SFT) scripts for further tuning on user's data - 🚀 Quickly deploy and experience the quantized LLMs on CPU/GPU of personal PC -- 🚀 Support for LLaMA ecosystems like [🤗transformers](https://github.com/huggingface/transformers), [llama.cpp](https://github.com/ggerganov/llama.cpp), [text-generation-webui](https://github.com/oobabooga/text-generation-webui), [LangChain](https://github.com/hwchase17/langchain) etc. +- 🚀 Support for LLaMA ecosystems like [🤗transformers](https://github.com/huggingface/transformers), [llama.cpp](https://github.com/ggerganov/llama.cpp), [text-generation-webui](https://github.com/oobabooga/text-generation-webui), [LangChain](https://github.com/hwchase17/langchain), [vLLM](https://github.com/vllm-project/vllm) etc. - The currently open-source models are Chinese-LLaMA-2-7B and Chinese-Alpaca-2-7B. ---- @@ -123,11 +123,11 @@ Below are the sizes of the full models in FP16 precision and 4-bit quantization. The models in this project mainly support the following quantization, inference, and deployment methods. -| Tool | Features | CPU | GPU | Quant | GUI | API | Tutorial | -| :----------------------------------------------------------- | ------------------------------------------------------- | :--: | :--: | :---: | :--: | :--: | :----------------------------------------------------------: | -| [**llama.cpp**](https://github.com/ggerganov/llama.cpp) | Rich quantization options and efficient local inference | ✅ | ✅ | ✅ | ❌ | ✅ | [link](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/llamacpp_en) | -| [**🤗Transformers**](https://github.com/huggingface/transformers) | Native transformers inference interface | ✅ | ✅ | ✅ | ✅ | ❌ | [link](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/inference_with_transformers_en) | -| [**OpenAI API Calls**](https://platform.openai.com/docs/api-reference) | A server that implements OpenAI API | ✅ | ✅ | ✅ | ❌ | ✅ | [link](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/api_calls_en) | +| Tool | Features | CPU | GPU | Quant | GUI | API | vLLM | Tutorial | +| :----------------------------------------------------------- | ------------------------------------------------------- | :--: | :--: | :---: | :--: | :--: | :--: | :----------------------------------------------------------: | +| [**llama.cpp**](https://github.com/ggerganov/llama.cpp) | Rich quantization options and efficient local inference | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | [link](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/llamacpp_en) | +| [**🤗Transformers**](https://github.com/huggingface/transformers) | Native transformers inference interface | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | [link](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/inference_with_transformers_en) | +| [**OpenAI API Calls**](https://platform.openai.com/docs/api-reference) | A server that implements OpenAI API | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | [link](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/api_calls_en) | ⚠️ Inference and deployment support related to the first-generation model will be gradually migrated to this project, and relevant tutorials will be updated later. diff --git a/scripts/inference/gradio_demo.py b/scripts/inference/gradio_demo.py index 049698d..eed0dd6 100644 --- a/scripts/inference/gradio_demo.py +++ b/scripts/inference/gradio_demo.py @@ -11,6 +11,10 @@ from threading import Thread import traceback import gc +import json +import requests +from typing import Iterable, List +import subprocess DEFAULT_SYSTEM_PROMPT = """You are a helpful assistant. 你是一个乐于助人的助手。""" @@ -69,6 +73,20 @@ default=DEFAULT_SYSTEM_PROMPT, help="The system prompt of the prompt template." ) +parser.add_argument( + "--use_vllm", + action='store_true', + help="Use vLLM as back-end LLM service.") +parser.add_argument( + "--post_host", + type=str, + default="localhost", + help="Host of vLLM service.") +parser.add_argument( + "--post_port", + type=int, + default=8000, + help="Port of vLLM service.") args = parser.parse_args() if args.only_cpu is True: args.gpus = "" @@ -92,51 +110,79 @@ def setup(): global tokenizer, model, device, share, port, max_memory - max_memory = args.max_memory - port = args.port - share = args.share - load_in_8bit = args.load_in_8bit - load_type = torch.float16 - if torch.cuda.is_available(): - device = torch.device(0) - else: - device = torch.device('cpu') - if args.tokenizer_path is None: - args.tokenizer_path = args.lora_model - if args.lora_model is None: + if args.use_vllm: + # global share, port, max_memory + max_memory = args.max_memory + port = args.port + share = args.share + + if args.lora_model is not None: + raise ValueError("vLLM currently does not support LoRA, please merge the LoRA weights to the base model.") + if args.load_in_8bit: + raise ValueError("vLLM currently does not support quantization, please use fp16 (default) or unuse --use_vllm.") + if args.only_cpu: + raise ValueError("vLLM requires GPUs with compute capability not less than 7.0. If you want to run only on CPU, please unuse --use_vllm.") + + if args.tokenizer_path is None: args.tokenizer_path = args.base_model - tokenizer = LlamaTokenizer.from_pretrained(args.tokenizer_path, legacy=True) - - base_model = LlamaForCausalLM.from_pretrained( - args.base_model, - load_in_8bit=load_in_8bit, - torch_dtype=load_type, - low_cpu_mem_usage=True, - device_map='auto', - ) - - model_vocab_size = base_model.get_input_embeddings().weight.size(0) - tokenzier_vocab_size = len(tokenizer) - print(f"Vocab of the base model: {model_vocab_size}") - print(f"Vocab of the tokenizer: {tokenzier_vocab_size}") - if model_vocab_size!=tokenzier_vocab_size: - print("Resize model embeddings to fit tokenizer") - base_model.resize_token_embeddings(tokenzier_vocab_size) - if args.lora_model is not None: - print("loading peft model") - model = PeftModel.from_pretrained( - base_model, - args.lora_model, + tokenizer = LlamaTokenizer.from_pretrained(args.tokenizer_path, legacy=True) + + print("Start launch vllm server.") + cmd = [ + "python -m vllm.entrypoints.api_server", + f"--model={args.base_model}", + f"--tokenizer={args.tokenizer_path}", + "--tokenizer-mode=slow", + f"--tensor-parallel-size={len(args.gpus.split(','))}", + "&", + ] + subprocess.check_call(cmd) + else: + max_memory = args.max_memory + port = args.port + share = args.share + load_in_8bit = args.load_in_8bit + load_type = torch.float16 + if torch.cuda.is_available(): + device = torch.device(0) + else: + device = torch.device('cpu') + if args.tokenizer_path is None: + args.tokenizer_path = args.lora_model + if args.lora_model is None: + args.tokenizer_path = args.base_model + tokenizer = LlamaTokenizer.from_pretrained(args.tokenizer_path, legacy=True) + + base_model = LlamaForCausalLM.from_pretrained( + args.base_model, + load_in_8bit=load_in_8bit, torch_dtype=load_type, + low_cpu_mem_usage=True, device_map='auto', ) - else: - model = base_model - if device == torch.device('cpu'): - model.float() + model_vocab_size = base_model.get_input_embeddings().weight.size(0) + tokenizer_vocab_size = len(tokenizer) + print(f"Vocab of the base model: {model_vocab_size}") + print(f"Vocab of the tokenizer: {tokenizer_vocab_size}") + if model_vocab_size != tokenizer_vocab_size: + print("Resize model embeddings to fit tokenizer") + base_model.resize_token_embeddings(tokenizer_vocab_size) + if args.lora_model is not None: + print("loading peft model") + model = PeftModel.from_pretrained( + base_model, + args.lora_model, + torch_dtype=load_type, + device_map='auto', + ) + else: + model = base_model + + if device == torch.device('cpu'): + model.float() - model.eval() + model.eval() # Reset the user input @@ -239,6 +285,45 @@ def clear_torch_cache(): torch.cuda.empty_cache() +def post_http_request(prompt: str, + api_url: str, + n: int = 1, + top_p: float = 0.9, + top_k: int = 40, + temperature: float = 0.7, + max_tokens: int = 512, + presence_penalty: float = 1.0, + use_beam_search: bool = False, + stream: bool = False) -> requests.Response: + headers = {"User-Agent": "Test Client"} + pload = { + "prompt": prompt, + "n": n, + "top_p": 1 if use_beam_search else top_p, + "top_k": -1 if use_beam_search else top_k, + "temperature": 0 if use_beam_search else temperature, + "max_tokens": max_tokens, + "use_beam_search": use_beam_search, + "best_of": 5 if use_beam_search else n, + "presence_penalty": presence_penalty, + "stream": stream, + } + print(pload) + + response = requests.post(api_url, headers=headers, json=pload, stream=True) + return response + + +def get_streaming_response(response: requests.Response) -> Iterable[List[str]]: + for chunk in response.iter_lines(chunk_size=8192, + decode_unicode=False, + delimiter=b"\0"): + if chunk: + data = json.loads(chunk.decode("utf-8")) + output = data["text"] + yield output + + # Perform prediction based on the user input and history @torch.no_grad() def predict( @@ -248,7 +333,8 @@ def predict( temperature=0.1, top_k=40, do_sample=True, - repetition_penalty=1.0 + repetition_penalty=1.0, + presence_penalty=0.0, ): while True: print("len(history):", len(history)) @@ -277,46 +363,68 @@ def predict( else: break - inputs = tokenizer(prompt, return_tensors="pt") - input_ids = inputs["input_ids"].to(device) - - generate_params = { - 'input_ids': input_ids, - 'max_new_tokens': max_new_tokens, - 'top_p': top_p, - 'temperature': temperature, - 'top_k': top_k, - 'do_sample': do_sample, - 'repetition_penalty': repetition_penalty, - } + if args.use_vllm: + generate_params = { + 'max_tokens': max_new_tokens, + 'top_p': top_p, + 'temperature': temperature, + 'top_k': top_k, + "use_beam_search": not do_sample, + 'presence_penalty': presence_penalty, + } + + api_url = f"http://{args.post_host}:{args.post_port}/generate" - def generate_with_callback(callback=None, **kwargs): - if 'stopping_criteria' in kwargs: - kwargs['stopping_criteria'].append(Stream(callback_func=callback)) - else: - kwargs['stopping_criteria'] = [Stream(callback_func=callback)] - clear_torch_cache() - with torch.no_grad(): - model.generate(**kwargs) - def generate_with_streaming(**kwargs): - return Iteratorize(generate_with_callback, kwargs, callback=None) + response = post_http_request(prompt, api_url, **generate_params, stream=True) - with generate_with_streaming(**generate_params) as generator: - for output in generator: - next_token_ids = output[len(input_ids[0]):] - if next_token_ids[0] == tokenizer.eos_token_id: - break - new_tokens = tokenizer.decode( - next_token_ids, skip_special_tokens=True) - if isinstance(tokenizer, LlamaTokenizer) and len(next_token_ids) > 0: - if tokenizer.convert_ids_to_tokens(int(next_token_ids[0])).startswith('▁'): - new_tokens = ' ' + new_tokens + for h in get_streaming_response(response): + for line in h: + line = line.replace(prompt, '') + history[-1][1] = line + yield history - history[-1][1] = new_tokens - yield history - if len(next_token_ids) >= max_new_tokens: - break + else: + inputs = tokenizer(prompt, return_tensors="pt") + input_ids = inputs["input_ids"].to(device) + + generate_params = { + 'input_ids': input_ids, + 'max_new_tokens': max_new_tokens, + 'top_p': top_p, + 'temperature': temperature, + 'top_k': top_k, + 'do_sample': do_sample, + 'repetition_penalty': repetition_penalty, + } + + def generate_with_callback(callback=None, **kwargs): + if 'stopping_criteria' in kwargs: + kwargs['stopping_criteria'].append(Stream(callback_func=callback)) + else: + kwargs['stopping_criteria'] = [Stream(callback_func=callback)] + clear_torch_cache() + with torch.no_grad(): + model.generate(**kwargs) + + def generate_with_streaming(**kwargs): + return Iteratorize(generate_with_callback, kwargs, callback=None) + + with generate_with_streaming(**generate_params) as generator: + for output in generator: + next_token_ids = output[len(input_ids[0]):] + if next_token_ids[0] == tokenizer.eos_token_id: + break + new_tokens = tokenizer.decode( + next_token_ids, skip_special_tokens=True) + if isinstance(tokenizer, LlamaTokenizer) and len(next_token_ids) > 0: + if tokenizer.convert_ids_to_tokens(int(next_token_ids[0])).startswith('▁'): + new_tokens = ' ' + new_tokens + + history[-1][1] = new_tokens + yield history + if len(next_token_ids) >= max_new_tokens: + break # Call the setup function to initialize the components @@ -370,7 +478,16 @@ def generate_with_streaming(**kwargs): value=1.1, step=0.1, label="Repetition Penalty", - interactive=True) + interactive=True, + visible=False if args.use_vllm else True) + presence_penalty = gr.Slider( + -2.0, + 2.0, + value=1.0, + step=0.1, + label="Presence Penalty", + interactive=True, + visible=True if args.use_vllm else False) params = [user_input, chatbot] predict_params = [ @@ -380,7 +497,8 @@ def generate_with_streaming(**kwargs): temperature, top_k, do_sample, - repetition_penalty] + repetition_penalty, + presence_penalty] submitBtn.click( user, diff --git a/scripts/inference/inference_hf.py b/scripts/inference/inference_hf.py index 16665d3..cf3b651 100644 --- a/scripts/inference/inference_hf.py +++ b/scripts/inference/inference_hf.py @@ -22,8 +22,16 @@ parser.add_argument('--only_cpu',action='store_true',help='only use CPU for inference') parser.add_argument('--alpha',type=str,default="1.0", help="The scaling factor of NTK method, can be a float or 'auto'. ") parser.add_argument('--load_in_8bit',action='store_true', help="Load the LLM in the 8bit mode") +parser.add_argument("--use_vllm", action='store_true', help="Use vLLM as back-end LLM service.") parser.add_argument('--system_prompt',type=str,default=DEFAULT_SYSTEM_PROMPT, help="The system prompt of the prompt template.") args = parser.parse_args() +if args.use_vllm: + if args.lora_model is not None: + raise ValueError("vLLM currently does not support LoRA, please merge the LoRA weights to the base model.") + if args.load_in_8bit: + raise ValueError("vLLM currently does not support quantization, please use fp16 (default) or unuse --use_vllm.") + if args.only_cpu: + raise ValueError("vLLM requires GPUs with compute capability not less than 7.0. If you want to run only on CPU, please unuse --use_vllm.") if args.only_cpu is True: args.gpus = "" os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus @@ -31,6 +39,8 @@ from transformers import LlamaForCausalLM, LlamaTokenizer from transformers import GenerationConfig from peft import PeftModel +if args.use_vllm: + from vllm import LLM, SamplingParams import sys parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) @@ -39,15 +49,24 @@ apply_attention_patch(use_memory_efficient_attention=True) apply_ntk_scaling_patch(args.alpha) -generation_config = GenerationConfig( - temperature=0.2, - top_k=40, - top_p=0.9, - do_sample=True, - num_beams=1, - repetition_penalty=1.1, - max_new_tokens=400 -) +if args.use_vllm: + generation_config = dict( + temperature=0.2, + top_k=40, + top_p=0.9, + max_tokens=400, + presence_penalty=1.0, + ) +else: + generation_config = GenerationConfig( + temperature=0.2, + top_k=40, + top_p=0.9, + do_sample=True, + num_beams=1, + repetition_penalty=1.1, + max_new_tokens=400 + ) sample_data = ["为什么要减少污染,保护环境?"] @@ -65,31 +84,41 @@ def generate_prompt(instruction): args.tokenizer_path = args.lora_model if args.lora_model is None: args.tokenizer_path = args.base_model - tokenizer = LlamaTokenizer.from_pretrained(args.tokenizer_path, legacy=True) - - base_model = LlamaForCausalLM.from_pretrained( - args.base_model, - load_in_8bit=args.load_in_8bit, - torch_dtype=load_type, - low_cpu_mem_usage=True, - device_map='auto', - ) - - model_vocab_size = base_model.get_input_embeddings().weight.size(0) - tokenzier_vocab_size = len(tokenizer) - print(f"Vocab of the base model: {model_vocab_size}") - print(f"Vocab of the tokenizer: {tokenzier_vocab_size}") - if model_vocab_size!=tokenzier_vocab_size: - print("Resize model embeddings to fit tokenizer") - base_model.resize_token_embeddings(tokenzier_vocab_size) - if args.lora_model is not None: - print("loading peft model") - model = PeftModel.from_pretrained(base_model, args.lora_model,torch_dtype=load_type,device_map='auto',) + + if args.use_vllm: + model = LLM(model=args.base_model, + tokenizer=args.tokenizer_path, + tokenizer_mode='slow', + tensor_parallel_size=len(args.gpus.split(','))) + tokenizer = LlamaTokenizer.from_pretrained(args.tokenizer_path, legacy=True) else: - model = base_model + tokenizer = LlamaTokenizer.from_pretrained(args.tokenizer_path, legacy=True) + + base_model = LlamaForCausalLM.from_pretrained( + args.base_model, + load_in_8bit=args.load_in_8bit, + torch_dtype=load_type, + low_cpu_mem_usage=True, + device_map='auto', + ) + + model_vocab_size = base_model.get_input_embeddings().weight.size(0) + tokenizer_vocab_size = len(tokenizer) + print(f"Vocab of the base model: {model_vocab_size}") + print(f"Vocab of the tokenizer: {tokenizer_vocab_size}") + if model_vocab_size!=tokenizer_vocab_size: + print("Resize model embeddings to fit tokenizer") + base_model.resize_token_embeddings(tokenizer_vocab_size) + if args.lora_model is not None: + print("loading peft model") + model = PeftModel.from_pretrained(base_model, args.lora_model,torch_dtype=load_type,device_map='auto',) + else: + model = base_model + + if device==torch.device('cpu'): + model.float() + model.eval() - if device==torch.device('cpu'): - model.float() # test data if args.data_file is None: examples = sample_data @@ -99,7 +128,6 @@ def generate_prompt(instruction): print("first 10 examples:") for example in examples[:10]: print(example) - model.eval() with torch.no_grad(): if args.interactive: @@ -121,52 +149,78 @@ def generate_prompt(instruction): input_text = generate_prompt(instruction=raw_input_text) else: input_text = raw_input_text - inputs = tokenizer(input_text,return_tensors="pt") #add_special_tokens=False ? - generation_output = model.generate( - input_ids = inputs["input_ids"].to(device), - attention_mask = inputs['attention_mask'].to(device), - eos_token_id=tokenizer.eos_token_id, - pad_token_id=tokenizer.pad_token_id, - generation_config = generation_config - ) - s = generation_output[0] - output = tokenizer.decode(s,skip_special_tokens=True) - if args.with_prompt: - response = output.split("[/INST]")[-1].strip() + + if args.use_vllm: + output = model.generate([input_text], SamplingParams(**generation_config), use_tqdm=False) + response = output[0].outputs[0].text else: - response = output + inputs = tokenizer(input_text,return_tensors="pt") #add_special_tokens=False ? + generation_output = model.generate( + input_ids = inputs["input_ids"].to(device), + attention_mask = inputs['attention_mask'].to(device), + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id, + generation_config = generation_config + ) + s = generation_output[0] + output = tokenizer.decode(s,skip_special_tokens=True) + if args.with_prompt: + response = output.split("[/INST]")[-1].strip() + else: + response = output print("Response: ",response) print("\n") else: print("Start inference.") results = [] - for index, example in enumerate(examples): + if args.use_vllm: if args.with_prompt is True: - input_text = generate_prompt(instruction=example) + inputs = [generate_prompt(example) for example in examples] else: - input_text = example - inputs = tokenizer(input_text,return_tensors="pt") #add_special_tokens=False ? - generation_output = model.generate( - input_ids = inputs["input_ids"].to(device), - attention_mask = inputs['attention_mask'].to(device), - eos_token_id=tokenizer.eos_token_id, - pad_token_id=tokenizer.pad_token_id, - generation_config = generation_config - ) - s = generation_output[0] - output = tokenizer.decode(s,skip_special_tokens=True) - if args.with_prompt: - response = output.split("[/INST]")[1].strip() - else: - response = output - print(f"======={index}=======") - print(f"Input: {example}\n") - print(f"Output: {response}\n") - - results.append({"Input":input_text,"Output":response}) + inputs = examples + outputs = model.generate(inputs, SamplingParams(**generation_config)) + + for index, (example, output) in enumerate(zip(examples, outputs)): + response = output.outputs[0].text + + print(f"======={index}=======") + print(f"Input: {example}\n") + print(f"Output: {response}\n") + + results.append({"Input":example,"Output":response}) + + else: + for index, example in enumerate(examples): + if args.with_prompt is True: + input_text = generate_prompt(instruction=example) + else: + input_text = example + inputs = tokenizer(input_text,return_tensors="pt") #add_special_tokens=False ? + generation_output = model.generate( + input_ids = inputs["input_ids"].to(device), + attention_mask = inputs['attention_mask'].to(device), + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id, + generation_config = generation_config + ) + s = generation_output[0] + output = tokenizer.decode(s,skip_special_tokens=True) + if args.with_prompt: + response = output.split("[/INST]")[1].strip() + else: + response = output + print(f"======={index}=======") + print(f"Input: {example}\n") + print(f"Output: {response}\n") + + results.append({"Input":input_text,"Output":response}) dirname = os.path.dirname(args.predictions_file) os.makedirs(dirname,exist_ok=True) with open(args.predictions_file,'w') as f: json.dump(results,f,ensure_ascii=False,indent=2) - generation_config.save_pretrained('./') + if args.use_vllm: + with open(dirname+'/generation_config.json','w') as f: + json.dump(generation_config,f,ensure_ascii=False,indent=2) + else: + generation_config.save_pretrained('./') diff --git a/scripts/openai_server_demo/README_vllm.md b/scripts/openai_server_demo/README_vllm.md new file mode 100644 index 0000000..b652dc0 --- /dev/null +++ b/scripts/openai_server_demo/README_vllm.md @@ -0,0 +1,229 @@ +# OPENAI API DEMO + +> 更加详细的OPENAI API信息: + +这是一个使用fastapi实现的简易的仿OPENAI API风格的服务器DEMO,您可以使用这个API DEMO来快速搭建基于中文大模型的个人网站以及其他有趣的WEB DEMO。 + +本实现基于vLLM部署LLM后端服务,暂不支持加载LoRA模型、仅CPU部署和使用8bit推理。 + +## 部署方式 + +安装依赖 +``` shell +pip install fastapi uvicorn shortuuid vllm +``` + +启动脚本 +``` shell +python scripts/openai_server_demo/openai_api_server_vllm.py --model /path/to/base_model --tokenizer-mode slow --served-model-name chinese-llama-alpaca-2 +``` + +### 参数说明 + +`--model {base_model}`: 存放HF格式的LLaMA-2模型权重和配置文件的目录,可以是合并后的中文Alpaca-2模型 + +`--tokenizer {tokenizer_path}`: 存放对应tokenizer的目录。若不提供此参数,则其默认值与`--base_model`相同 + +`--tokenizer-mode {tokenizer-mode}`: tokenizer的模式。使用基于LLaMA/LLaMa-2的模型时,固定为`slow` + +`--tensor_parallel_size {tensor_parallel_size}`: 使用的GPU数量。默认为1 + +`--served-model-name {served-model-name}`: API中使用的模型名。若使用中文Alpaca-2系列模型,模型名中务必包含`chinese-llama-alpaca-2` + +`--host {host_name}`: 部署服务的host name。默认值是`localhost` + +`--prot {port}`: 部署服务的端口号。默认值是`8000` + +## API文档 + +### 文字接龙(completion) + +> 有关completion的中文翻译,李宏毅教授将其翻译为文字接龙 + +最基础的API接口,输入prompt,输出语言大模型的文字接龙(completion)结果。 + +API DEMO内置有prompt模板,prompt将被套入instruction模板中,这里输入的prompt应更像指令而非对话。 + +#### 快速体验completion接口 + +请求command: + +``` shell +curl http://localhost:8000/v1/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "chinese-llama-alpaca-2", + "prompt": "告诉我中国的首都在哪里" + }' +``` + +json返回体: + +``` json +{ + "id": "cmpl-41234d71fa034ec3ae90bbf6b5be7", + "object": "text_completion", + "created": 1690870733, + "model": "chinese-llama-alpaca-2", + "choices": [ + { + "index": 0, + "text": "中国的首都是北京。" + } + ] +} +``` + +#### completion接口高级参数 + +请求command: + +``` shell +curl http://localhost:8000/v1/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "chinese-llama-alpaca-2", + "prompt": "告诉我中国和美国分别各有哪些优点缺点", + "max_tokens": 90, + "temperature": 0.7, + "num_beams": 4, + "top_k": 40 + }' +``` + +json返回体: + +``` json +{ + "id": "cmpl-ceca9906bf0a429989e850368cc3f893", + "object": "text_completion", + "created": 1690870952, + "model": "chinese-llama-alpaca-2", + "choices": [ + { + "index": 0, + "text": "中国的优点是拥有丰富的文化和历史,而美国的优点是拥有先进的科技和经济体系。" + } + ] +} +``` + +#### completion接口高级参数说明 + +> 有关Decoding策略,更加详细的细节可以参考 该文章详细讲述了三种LLaMA会用到的Decoding策略:Greedy Decoding、Random Sampling 和 Beam Search,Decoding策略是top_k、top_p、temperature等高级参数的基础。 + +`prompt`: 生成文字接龙(completion)的提示。 + +`max_tokens`: 新生成的句子的token长度。 + +`temperature`: 在0和2之间选择的采样温度。较高的值如0.8会使输出更加随机,而较低的值如0.2则会使其输出更具有确定性。temperature越高,使用随机采样最为decoding的概率越大。 + +`use_beam_search`: 使用束搜索(beam search)。默认为`False`,即启用随机采样策略(random sampling) + +`n`: 输出序列的数量,默认为1 + +`best_of`: 当搜索策略为束搜索(beam search)时,该参数为在束搜索(beam search)中所使用的束个数。默认和`n`相同 + +`top_k`: 在随机采样(random sampling)时,前top_k高概率的token将作为候选token被随机采样。 + +`top_p`: 在随机采样(random sampling)时,累积概率超过top_p的token将作为候选token被随机采样,越低随机性越大,举个例子,当top_p设定为0.6时,概率前5的token概率分别为{0.23, 0.20, 0.18, 0.11, 0.10}时,前三个token的累积概率为0.61,那么第4个token将被过滤掉,只有前三的token将作为候选token被随机采样。 + +`presence_penalty`: 重复惩罚,取值范围-2 ~ 2,默认值为0。值大于0表示鼓励模型使用新的token,反之鼓励重复。 + + +### 聊天(chat completion) + +聊天接口支持多轮对话 + +#### 快速体验聊天接口 + +请求command: + +``` shell +curl http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "chinese-llama-alpaca-2", + "messages": [ + {"role": "user","message": "给我讲一些有关杭州的故事吧"} + ] + }' +``` + +json返回体: + +``` json +{ + "id": "cmpl-8fc1b6356cf64681a41a8739445a8cf8", + "object": "chat.completion", + "created": 1690872695, + "model": "chinese-llama-alpaca-2", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "好的,请问您对杭州有什么特别的偏好吗?" + } + } + ] +} +``` + +#### 多轮对话 + +请求command: + +``` shell +curl http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "chinese-llama-alpaca-2", + "messages": [ + {"role": "user","message": "给我讲一些有关杭州的故事吧"}, + {"role": "assistant","message": "好的,请问您对杭州有什么特别的偏好吗?"}, + {"role": "user","message": "我比较喜欢和西湖,可以给我讲一下西湖吗"} + ], + "repetition_penalty": 1.0 + }' +``` + +json返回体: + +``` json +{ + "id": "cmpl-02bf36497d3543c980ca2ae8cc4feb63", + "object": "chat.completion", + "created": 1690872676, + "model": "chinese-llama-alpaca-2", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "是的,西湖是杭州最著名的景点之一,它被誉为“人间天堂”。 <\\s>" + } + } + ] +} +``` + +#### 聊天接口高级参数说明 + +`prompt`: 生成文字接龙(completion)的提示。 + +`max_tokens`: 新生成的句子的token长度。 + +`temperature`: 在0和2之间选择的采样温度。较高的值如0.8会使输出更加随机,而较低的值如0.2则会使其输出更具有确定性。temperature越高,使用随机采样最为decoding的概率越大。 + +`use_beam_search`: 使用束搜索(beam search)。默认为`False`,即启用随机采样策略(random sampling) + +`n`: 输出序列的数量,默认为1 + +`best_of`: 当搜索策略为束搜索(beam search)时,该参数为在束搜索(beam search)中所使用的束个数。默认和`n`相同 + +`top_k`: 在随机采样(random sampling)时,前top_k高概率的token将作为候选token被随机采样。 + +`top_p`: 在随机采样(random sampling)时,累积概率超过top_p的token将作为候选token被随机采样,越低随机性越大,举个例子,当top_p设定为0.6时,概率前5的token概率分别为{0.23, 0.20, 0.18, 0.11, 0.10}时,前三个token的累积概率为0.61,那么第4个token将被过滤掉,只有前三的token将作为候选token被随机采样。 + +`presence_penalty`: 重复惩罚,取值范围-2 ~ 2,默认值为0。值大于0表示鼓励模型使用新的token,反之鼓励重复。 diff --git a/scripts/openai_server_demo/openai_api_protocol_vllm.py b/scripts/openai_server_demo/openai_api_protocol_vllm.py new file mode 100644 index 0000000..e933090 --- /dev/null +++ b/scripts/openai_server_demo/openai_api_protocol_vllm.py @@ -0,0 +1,171 @@ +import time +from typing import Dict, List, Literal, Optional, Union + +from pydantic import BaseModel, Field + +from vllm.utils import random_uuid + + +class ErrorResponse(BaseModel): + object: str = "error" + message: str + type: str + param: Optional[str] = None + code: Optional[str] = None + + +class ModelPermission(BaseModel): + id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}") + object: str = "model_permission" + created: int = Field(default_factory=lambda: int(time.time())) + allow_create_engine: bool = False + allow_sampling: bool = True + allow_logprobs: bool = True + allow_search_indices: bool = False + allow_view: bool = True + allow_fine_tuning: bool = False + organization: str = "*" + group: Optional[str] = None + is_blocking: str = False + + +class ModelCard(BaseModel): + id: str + object: str = "model" + created: int = Field(default_factory=lambda: int(time.time())) + owned_by: str = "vllm" + root: Optional[str] = None + parent: Optional[str] = None + permission: List[ModelPermission] = Field(default_factory=list) + + +class ModelList(BaseModel): + object: str = "list" + data: List[ModelCard] = Field(default_factory=list) + + +class UsageInfo(BaseModel): + prompt_tokens: int = 0 + total_tokens: int = 0 + completion_tokens: Optional[int] = 0 + + +class ChatCompletionRequest(BaseModel): + model: str + messages: Union[str, List[Dict[str, str]]] + temperature: Optional[float] = 0.7 + top_p: Optional[float] = 1.0 + n: Optional[int] = 1 + max_tokens: Optional[int] = 16 + stop: Optional[Union[str, List[str]]] = Field(default_factory=list) + stream: Optional[bool] = False + presence_penalty: Optional[float] = 0.0 + frequency_penalty: Optional[float] = 0.0 + logit_bias: Optional[Dict[str, float]] = None + user: Optional[str] = None + # Additional parameters supported by vLLM + best_of: Optional[int] = None + top_k: Optional[int] = -1 + ignore_eos: Optional[bool] = False + use_beam_search: Optional[bool] = False + + +class CompletionRequest(BaseModel): + model: str + prompt: Union[str, List[str]] + suffix: Optional[str] = None + max_tokens: Optional[int] = 16 + temperature: Optional[float] = 1.0 + top_p: Optional[float] = 1.0 + n: Optional[int] = 1 + stream: Optional[bool] = False + logprobs: Optional[int] = None + echo: Optional[bool] = False + stop: Optional[Union[str, List[str]]] = Field(default_factory=list) + presence_penalty: Optional[float] = 0.0 + frequency_penalty: Optional[float] = 0.0 + best_of: Optional[int] = None + logit_bias: Optional[Dict[str, float]] = None + user: Optional[str] = None + # Additional parameters supported by vLLM + top_k: Optional[int] = -1 + ignore_eos: Optional[bool] = False + use_beam_search: Optional[bool] = False + + +class LogProbs(BaseModel): + text_offset: List[int] = Field(default_factory=list) + token_logprobs: List[Optional[float]] = Field(default_factory=list) + tokens: List[str] = Field(default_factory=list) + top_logprobs: List[Optional[Dict[str, + float]]] = Field(default_factory=list) + + +class CompletionResponseChoice(BaseModel): + index: int + text: str + logprobs: Optional[LogProbs] = None + finish_reason: Optional[Literal["stop", "length"]] = None + + +class CompletionResponse(BaseModel): + id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") + object: str = "text_completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[CompletionResponseChoice] + usage: UsageInfo + + +class CompletionResponseStreamChoice(BaseModel): + index: int + text: str + logprobs: Optional[LogProbs] = None + finish_reason: Optional[Literal["stop", "length"]] = None + + +class CompletionStreamResponse(BaseModel): + id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") + object: str = "text_completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[CompletionResponseStreamChoice] + + +class ChatMessage(BaseModel): + role: str + content: str + + +class ChatCompletionResponseChoice(BaseModel): + index: int + message: ChatMessage + finish_reason: Optional[Literal["stop", "length"]] = None + + +class ChatCompletionResponse(BaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}") + object: str = "chat.completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[ChatCompletionResponseChoice] + usage: UsageInfo + + +class DeltaMessage(BaseModel): + role: Optional[str] = None + content: Optional[str] = None + + +class ChatCompletionResponseStreamChoice(BaseModel): + index: int + delta: DeltaMessage + finish_reason: Optional[Literal["stop", "length"]] = None + + +class ChatCompletionStreamResponse(BaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}") + object: str = "chat.completion.chunk" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[ChatCompletionResponseStreamChoice] diff --git a/scripts/openai_server_demo/openai_api_server.py b/scripts/openai_server_demo/openai_api_server.py index 2f78280..ae125ae 100644 --- a/scripts/openai_server_demo/openai_api_server.py +++ b/scripts/openai_server_demo/openai_api_server.py @@ -61,12 +61,12 @@ ) model_vocab_size = base_model.get_input_embeddings().weight.size(0) -tokenzier_vocab_size = len(tokenizer) +tokenizer_vocab_size = len(tokenizer) print(f"Vocab of the base model: {model_vocab_size}") -print(f"Vocab of the tokenizer: {tokenzier_vocab_size}") -if model_vocab_size!=tokenzier_vocab_size: +print(f"Vocab of the tokenizer: {tokenizer_vocab_size}") +if model_vocab_size!=tokenizer_vocab_size: print("Resize model embeddings to fit tokenizer") - base_model.resize_token_embeddings(tokenzier_vocab_size) + base_model.resize_token_embeddings(tokenizer_vocab_size) if args.lora_model is not None: print("loading peft model") model = PeftModel.from_pretrained(base_model, args.lora_model,torch_dtype=load_type,device_map='auto',) diff --git a/scripts/openai_server_demo/openai_api_server_vllm.py b/scripts/openai_server_demo/openai_api_server_vllm.py new file mode 100644 index 0000000..0900e98 --- /dev/null +++ b/scripts/openai_server_demo/openai_api_server_vllm.py @@ -0,0 +1,685 @@ +import argparse +import asyncio +from http import HTTPStatus +import json +import time +from typing import AsyncGenerator, Dict, List, Optional + +import fastapi +from fastapi import BackgroundTasks, Request +from fastapi.exceptions import RequestValidationError +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse, StreamingResponse +from fastchat.conversation import Conversation, SeparatorStyle +from fastchat.model.model_adapter import get_conversation_template + +import uvicorn + +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.async_llm_engine import AsyncLLMEngine +import sys +import os +parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.append(parent_dir) +from openai_api_protocol_vllm import ( + CompletionRequest, CompletionResponse, CompletionResponseChoice, + CompletionResponseStreamChoice, CompletionStreamResponse, + ChatCompletionRequest, ChatCompletionResponse, + ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, + ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse, + LogProbs, ModelCard, ModelList, ModelPermission, UsageInfo) +from vllm.logger import init_logger +from vllm.outputs import RequestOutput +from vllm.sampling_params import SamplingParams +from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.utils import random_uuid + +TIMEOUT_KEEP_ALIVE = 5 # seconds + +logger = init_logger(__name__) +served_model = None +app = fastapi.FastAPI() + +from fastchat.conversation import register_conv_template, get_conv_template +from fastchat.model.model_adapter import BaseModelAdapter, model_adapters +import fastchat + +# Chinese LLaMA Alpaca default template +register_conv_template( + Conversation( + name="chinese-llama-alpaca", + system="Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n", + roles=("### Instruction:\n", "### Response:"), + messages=(), + offset=0, + sep_style=SeparatorStyle.NO_COLON_SINGLE, + sep="\n\n", + sep2="", + ) +) + +# Chinese LLaMA Alpaca 2 default template +register_conv_template( + Conversation( + name="chinese-llama-alpaca-2", + system="[INST] <>\nYou are a helpful assistant. 你是一个乐于助人的助手。\n<>\n\n", + roles=("[INST]", "[/INST]"), + messages=(), + offset=0, + sep_style=SeparatorStyle.LLAMA2, + sep=" ", + sep2=" ", + stop_token_ids=[2], + ) +) + +class ChineseLLaMAAlpacaAdapter(BaseModelAdapter): + """The model adapter for Chinese-LLaMA-Alpaca""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "chinese-llama-alpaca" in model_path.lower() + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("chinese-llama-alpaca") + +class ChineseLLaMAAlpaca2Adapter(BaseModelAdapter): + """The model adapter for Chinese-LLaMA-Alpaca-2""" + + def match(self, model_path: str): + return "chinese-llama-alpaca-2" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + model, tokenizer = super().load_model(model_path, from_pretrained_kwargs) + model.config.eos_token_id = tokenizer.eos_token_id + model.config.pad_token_id = tokenizer.pad_token_id + return model, tokenizer + + def get_default_conv_template(self, model_path: str) -> Conversation: + return get_conv_template("chinese-llama-alpaca-2") + + +# add model adapters to head of List model_adapters +model_adapters = [ChineseLLaMAAlpacaAdapter()] + model_adapters +model_adapters = [ChineseLLaMAAlpaca2Adapter()] + model_adapters +fastchat.model.model_adapter.model_adapters = model_adapters + +def create_error_response(status_code: HTTPStatus, + message: str) -> JSONResponse: + return JSONResponse(ErrorResponse(message=message, + type="invalid_request_error").dict(), + status_code=status_code.value) + + +@app.exception_handler(RequestValidationError) +async def validation_exception_handler(request, exc): # pylint: disable=unused-argument + return create_error_response(HTTPStatus.BAD_REQUEST, str(exc)) + + +async def check_model(request) -> Optional[JSONResponse]: + if request.model == served_model: + return + ret = create_error_response( + HTTPStatus.NOT_FOUND, + f"The model `{request.model}` does not exist.", + ) + return ret + + +async def get_gen_prompt(request) -> str: + conv = get_conversation_template(request.model) + conv = Conversation( + name=conv.name, + system=conv.system, + roles=conv.roles, + messages=list(conv.messages), # prevent in-place modification + offset=conv.offset, + sep_style=SeparatorStyle(conv.sep_style), + sep=conv.sep, + sep2=conv.sep2, + stop_str=conv.stop_str, + stop_token_ids=conv.stop_token_ids, + ) + + if isinstance(request.messages, str): + prompt = request.messages + else: + for message in request.messages: + msg_role = message["role"] + if msg_role == "system": + conv.system = message["content"] + elif msg_role == "user": + conv.append_message(conv.roles[0], message["content"]) + elif msg_role == "assistant": + conv.append_message(conv.roles[1], message["content"]) + else: + raise ValueError(f"Unknown role: {msg_role}") + + # Add a blank message for the assistant. + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + + return prompt + + +async def get_gen_prompt_nochat(request) -> str: + conv = get_conversation_template(request.model) + conv = Conversation( + name=conv.name, + system=conv.system, + roles=conv.roles, + messages=list(conv.messages), # prevent in-place modification + offset=conv.offset, + sep_style=SeparatorStyle(conv.sep_style), + sep=conv.sep, + sep2=conv.sep2, + stop_str=conv.stop_str, + stop_token_ids=conv.stop_token_ids, + ) + + prompt = request.prompt + conv.append_message(conv.roles[0], prompt) + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + + return prompt + + +async def check_length(request, prompt, model_config): + if hasattr(model_config.hf_config, "max_sequence_length"): + context_len = model_config.hf_config.max_sequence_length + elif hasattr(model_config.hf_config, "seq_length"): + context_len = model_config.hf_config.seq_length + elif hasattr(model_config.hf_config, "max_position_embeddings"): + context_len = model_config.hf_config.max_position_embeddings + elif hasattr(model_config.hf_config, "seq_length"): + context_len = model_config.hf_config.seq_length + else: + context_len = 2048 + + input_ids = tokenizer(prompt).input_ids + token_num = len(input_ids) + + if token_num + request.max_tokens > context_len: + return create_error_response( + HTTPStatus.BAD_REQUEST, + f"This model's maximum context length is {context_len} tokens. " + f"However, you requested {request.max_tokens + token_num} tokens " + f"({token_num} in the messages, " + f"{request.max_tokens} in the completion). " + f"Please reduce the length of the messages or completion.", + ) + else: + return None + + +@app.get("/v1/models") +async def show_available_models(): + """Show available models. Right now we only have one model.""" + model_cards = [ + ModelCard(id=served_model, + root=served_model, + permission=[ModelPermission()]) + ] + return ModelList(data=model_cards) + + +def create_logprobs(token_ids: List[int], + id_logprobs: List[Dict[int, float]], + initial_text_offset: int = 0) -> LogProbs: + """Create OpenAI-style logprobs.""" + logprobs = LogProbs() + last_token_len = 0 + for token_id, id_logprob in zip(token_ids, id_logprobs): + token = tokenizer.convert_ids_to_tokens(token_id) + logprobs.tokens.append(token) + logprobs.token_logprobs.append(id_logprob[token_id]) + if len(logprobs.text_offset) == 0: + logprobs.text_offset.append(initial_text_offset) + else: + logprobs.text_offset.append(logprobs.text_offset[-1] + + last_token_len) + last_token_len = len(token) + + logprobs.top_logprobs.append({ + tokenizer.convert_ids_to_tokens(i): p + for i, p in id_logprob.items() + }) + return logprobs + + +@app.post("/v1/chat/completions") +async def create_chat_completion(raw_request: Request): + """Completion API similar to OpenAI's API. + + See https://platform.openai.com/docs/api-reference/chat/create + for the API specification. This API mimics the OpenAI ChatCompletion API. + + NOTE: Currently we do not support the following features: + - function_call (Users should implement this by themselves) + - logit_bias (to be supported by vLLM engine) + """ + request = ChatCompletionRequest(**await raw_request.json()) + logger.info(f"Received chat completion request: {request}") + + error_check_ret = await check_model(request) + if error_check_ret is not None: + return error_check_ret + + if request.logit_bias is not None: + # TODO: support logit_bias in vLLM engine. + return create_error_response(HTTPStatus.BAD_REQUEST, + "logit_bias is not currently supported") + + prompt = await get_gen_prompt(request) + error_check_ret = await check_length(request, prompt, engine_model_config) + if error_check_ret is not None: + return error_check_ret + + model_name = request.model + request_id = f"cmpl-{random_uuid()}" + created_time = int(time.time()) + try: + sampling_params = SamplingParams( + n=request.n, + presence_penalty=request.presence_penalty, + frequency_penalty=request.frequency_penalty, + temperature=request.temperature, + top_p=request.top_p, + stop=request.stop, + max_tokens=request.max_tokens, + best_of=request.best_of, + top_k=request.top_k, + ignore_eos=request.ignore_eos, + use_beam_search=request.use_beam_search, + ) + except ValueError as e: + return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) + + result_generator = engine.generate(prompt, sampling_params, request_id) + + async def abort_request() -> None: + await engine.abort(request_id) + + def create_stream_response_json( + index: int, + text: str, + finish_reason: Optional[str] = None, + ) -> str: + choice_data = ChatCompletionResponseStreamChoice( + index=index, + delta=DeltaMessage(content=text), + finish_reason=finish_reason, + ) + response = ChatCompletionStreamResponse( + id=request_id, + created=created_time, + model=model_name, + choices=[choice_data], + ) + response_json = response.json(ensure_ascii=False) + + return response_json + + async def completion_stream_generator() -> AsyncGenerator[str, None]: + # First chunk with role + for i in range(request.n): + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage(role="assistant"), + finish_reason=None, + ) + chunk = ChatCompletionStreamResponse(id=request_id, + choices=[choice_data], + model=model_name) + data = chunk.json(exclude_unset=True, ensure_ascii=False) + yield f"data: {data}\n\n" + + previous_texts = [""] * request.n + previous_num_tokens = [0] * request.n + async for res in result_generator: + res: RequestOutput + for output in res.outputs: + i = output.index + delta_text = output.text[len(previous_texts[i]):] + previous_texts[i] = output.text + previous_num_tokens[i] = len(output.token_ids) + response_json = create_stream_response_json( + index=i, + text=delta_text, + ) + yield f"data: {response_json}\n\n" + if output.finish_reason is not None: + response_json = create_stream_response_json( + index=i, + text="", + finish_reason=output.finish_reason, + ) + yield f"data: {response_json}\n\n" + yield "data: [DONE]\n\n" + + # Streaming response + if request.stream: + background_tasks = BackgroundTasks() + # Abort the request if the client disconnects. + background_tasks.add_task(abort_request) + return StreamingResponse(completion_stream_generator(), + media_type="text/event-stream", + background=background_tasks) + + # Non-streaming response + final_res: RequestOutput = None + async for res in result_generator: + if await raw_request.is_disconnected(): + # Abort the request if the client disconnects. + await abort_request() + return create_error_response(HTTPStatus.BAD_REQUEST, + "Client disconnected") + final_res = res + assert final_res is not None + choices = [] + for output in final_res.outputs: + choice_data = ChatCompletionResponseChoice( + index=output.index, + message=ChatMessage(role="assistant", content=output.text), + finish_reason=output.finish_reason, + ) + choices.append(choice_data) + + num_prompt_tokens = len(final_res.prompt_token_ids) + num_generated_tokens = sum( + len(output.token_ids) for output in final_res.outputs) + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=num_generated_tokens, + total_tokens=num_prompt_tokens + num_generated_tokens, + ) + response = ChatCompletionResponse( + id=request_id, + created=created_time, + model=model_name, + choices=choices, + usage=usage, + ) + + if request.stream: + # When user requests streaming but we don't stream, we still need to + # return a streaming response with a single event. + response_json = response.json(ensure_ascii=False) + + async def fake_stream_generator() -> AsyncGenerator[str, None]: + yield f"data: {response_json}\n\n" + yield "data: [DONE]\n\n" + + return StreamingResponse(fake_stream_generator(), + media_type="text/event-stream") + + return response + + +@app.post("/v1/completions") +async def create_completion(raw_request: Request): + """Completion API similar to OpenAI's API. + + See https://platform.openai.com/docs/api-reference/completions/create + for the API specification. This API mimics the OpenAI Completion API. + + NOTE: Currently we do not support the following features: + - echo (since the vLLM engine does not currently support + getting the logprobs of prompt tokens) + - suffix (the language models we currently support do not support + suffix) + - logit_bias (to be supported by vLLM engine) + """ + request = CompletionRequest(**await raw_request.json()) + logger.info(f"Received completion request: {request}") + + error_check_ret = await check_model(request) + if error_check_ret is not None: + return error_check_ret + + if request.echo: + # We do not support echo since the vLLM engine does not + # currently support getting the logprobs of prompt tokens. + return create_error_response(HTTPStatus.BAD_REQUEST, + "echo is not currently supported") + + if request.suffix is not None: + # The language models we currently support do not support suffix. + return create_error_response(HTTPStatus.BAD_REQUEST, + "suffix is not currently supported") + + if request.logit_bias is not None: + # TODO: support logit_bias in vLLM engine. + return create_error_response(HTTPStatus.BAD_REQUEST, + "logit_bias is not currently supported") + + model_name = request.model + request_id = f"cmpl-{random_uuid()}" + if isinstance(request.prompt, list): + if len(request.prompt) == 0: + return create_error_response(HTTPStatus.BAD_REQUEST, + "please provide at least one prompt") + if len(request.prompt) > 1: + return create_error_response( + HTTPStatus.BAD_REQUEST, + "multiple prompts in a batch is not currently supported") + prompt = request.prompt[0] + else: + prompt = request.prompt + request.prompt = prompt + prompt = await get_gen_prompt_nochat(request) + created_time = int(time.time()) + try: + sampling_params = SamplingParams( + n=request.n, + best_of=request.best_of, + presence_penalty=request.presence_penalty, + frequency_penalty=request.frequency_penalty, + temperature=request.temperature, + top_p=request.top_p, + top_k=request.top_k, + stop=request.stop, + ignore_eos=request.ignore_eos, + max_tokens=request.max_tokens, + logprobs=request.logprobs, + use_beam_search=request.use_beam_search, + ) + except ValueError as e: + return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) + + result_generator = engine.generate(prompt, sampling_params, request_id) + + # Similar to the OpenAI API, when n != best_of, we do not stream the + # results. In addition, we do not stream the results when use beam search. + stream = (request.stream + and (request.best_of is None or request.n == request.best_of) + and not request.use_beam_search) + + async def abort_request() -> None: + await engine.abort(request_id) + + def create_stream_response_json( + index: int, + text: str, + logprobs: Optional[LogProbs] = None, + finish_reason: Optional[str] = None, + ) -> str: + choice_data = CompletionResponseStreamChoice( + index=index, + text=text, + logprobs=logprobs, + finish_reason=finish_reason, + ) + response = CompletionStreamResponse( + id=request_id, + created=created_time, + model=model_name, + choices=[choice_data], + ) + response_json = response.json(ensure_ascii=False) + + return response_json + + async def completion_stream_generator() -> AsyncGenerator[str, None]: + previous_texts = [""] * request.n + previous_num_tokens = [0] * request.n + async for res in result_generator: + res: RequestOutput + for output in res.outputs: + i = output.index + delta_text = output.text[len(previous_texts[i]):] + if request.logprobs is not None: + logprobs = create_logprobs( + output.token_ids[previous_num_tokens[i]:], + output.logprobs[previous_num_tokens[i]:], + len(previous_texts[i])) + else: + logprobs = None + previous_texts[i] = output.text + previous_num_tokens[i] = len(output.token_ids) + response_json = create_stream_response_json( + index=i, + text=delta_text, + logprobs=logprobs, + ) + yield f"data: {response_json}\n\n" + if output.finish_reason is not None: + logprobs = (LogProbs() + if request.logprobs is not None else None) + response_json = create_stream_response_json( + index=i, + text="", + logprobs=logprobs, + finish_reason=output.finish_reason, + ) + yield f"data: {response_json}\n\n" + yield "data: [DONE]\n\n" + + # Streaming response + if stream: + background_tasks = BackgroundTasks() + # Abort the request if the client disconnects. + background_tasks.add_task(abort_request) + return StreamingResponse(completion_stream_generator(), + media_type="text/event-stream", + background=background_tasks) + + # Non-streaming response + final_res: RequestOutput = None + async for res in result_generator: + if await raw_request.is_disconnected(): + # Abort the request if the client disconnects. + await abort_request() + return create_error_response(HTTPStatus.BAD_REQUEST, + "Client disconnected") + final_res = res + assert final_res is not None + choices = [] + for output in final_res.outputs: + if request.logprobs is not None: + logprobs = create_logprobs(output.token_ids, output.logprobs) + else: + logprobs = None + choice_data = CompletionResponseChoice( + index=output.index, + text=output.text, + logprobs=logprobs, + finish_reason=output.finish_reason, + ) + choices.append(choice_data) + + num_prompt_tokens = len(final_res.prompt_token_ids) + num_generated_tokens = sum( + len(output.token_ids) for output in final_res.outputs) + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=num_generated_tokens, + total_tokens=num_prompt_tokens + num_generated_tokens, + ) + response = CompletionResponse( + id=request_id, + created=created_time, + model=model_name, + choices=choices, + usage=usage, + ) + + if request.stream: + # When user requests streaming but we don't stream, we still need to + # return a streaming response with a single event. + response_json = response.json(ensure_ascii=False) + + async def fake_stream_generator() -> AsyncGenerator[str, None]: + yield f"data: {response_json}\n\n" + yield "data: [DONE]\n\n" + + return StreamingResponse(fake_stream_generator(), + media_type="text/event-stream") + + return response + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="vLLM OpenAI-Compatible RESTful API server.") + parser.add_argument("--host", + type=str, + default="localhost", + help="host name") + parser.add_argument("--port", type=int, default=8000, help="port number") + parser.add_argument("--allow-credentials", + action="store_true", + help="allow credentials") + parser.add_argument("--allowed-origins", + type=json.loads, + default=["*"], + help="allowed origins") + parser.add_argument("--allowed-methods", + type=json.loads, + default=["*"], + help="allowed methods") + parser.add_argument("--allowed-headers", + type=json.loads, + default=["*"], + help="allowed headers") + parser.add_argument("--served-model-name", + type=str, + default=None, + help="The model name used in the API. If not " + "specified, the model name will be the same as " + "the huggingface name.") + + parser = AsyncEngineArgs.add_cli_args(parser) + args = parser.parse_args() + + app.add_middleware( + CORSMiddleware, + allow_origins=args.allowed_origins, + allow_credentials=args.allow_credentials, + allow_methods=args.allowed_methods, + allow_headers=args.allowed_headers, + ) + + logger.info(f"args: {args}") + + if args.served_model_name is not None: + served_model = args.served_model_name + else: + served_model = args.model + + engine_args = AsyncEngineArgs.from_cli_args(args) + engine = AsyncLLMEngine.from_engine_args(engine_args) + engine_model_config = asyncio.run(engine.get_model_config()) + + # A separate tokenizer to map token IDs to strings. + tokenizer = get_tokenizer(engine_args.tokenizer, + tokenizer_mode=engine_args.tokenizer_mode, + trust_remote_code=engine_args.trust_remote_code) + + uvicorn.run(app, + host=args.host, + port=args.port, + log_level="info", + timeout_keep_alive=TIMEOUT_KEEP_ALIVE)