From 7e4170f2f556ae9c959126e6366911912ec31412 Mon Sep 17 00:00:00 2001 From: BUAADreamer <1428195643@qq.com> Date: Wed, 14 Jun 2023 08:59:04 +0800 Subject: [PATCH] add server.py to support stream generator api --- src/server.py | 105 ++++++++++++++++++++++++++++++++++++++++++++++++ src/web_demo.py | 11 +++-- 2 files changed, 110 insertions(+), 6 deletions(-) create mode 100644 src/server.py diff --git a/src/server.py b/src/server.py new file mode 100644 index 0000000000..299e104495 --- /dev/null +++ b/src/server.py @@ -0,0 +1,105 @@ +# coding=utf-8 + +import json +from threading import Thread + +import torch +import uvicorn +import datetime +from fastapi import FastAPI, Request +from starlette.responses import StreamingResponse +from transformers import TextIteratorStreamer + +from utils import ( + Template, + load_pretrained, + prepare_infer_args, + get_logits_processor +) + +app = FastAPI() + + +@app.get("/hello") +def hello(): + return "hello world!" + + +def parse_text(text): # copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT + lines = text.split("\n") + lines = [line for line in lines if line != ""] + count = 0 + for i, line in enumerate(lines): + if "```" in line: + count += 1 + items = line.split("`") + if count % 2 == 1: + lines[i] = "
".format(items[-1])
+            else:
+                lines[i] = "
" + else: + if i > 0: + if count % 2 == 1: + line = line.replace("`", "\`") + line = line.replace("<", "<") + line = line.replace(">", ">") + line = line.replace(" ", " ") + line = line.replace("*", "*") + line = line.replace("_", "_") + line = line.replace("-", "-") + line = line.replace(".", ".") + line = line.replace("!", "!") + line = line.replace("(", "(") + line = line.replace(")", ")") + line = line.replace("$", "$") + lines[i] = "
" + line + text = "".join(lines) + return text + + +def predict(query, max_length, top_p, temperature, history): + input_ids = tokenizer([prompt_template.get_prompt(query, history)], return_tensors="pt")["input_ids"] + input_ids = input_ids.to(model.device) + gen_kwargs = { + "input_ids": input_ids, + "do_sample": True, + "top_p": top_p, + "temperature": temperature, + "num_beams": generating_args.num_beams, + "max_length": max_length, + "repetition_penalty": generating_args.repetition_penalty, + "logits_processor": get_logits_processor(), + "streamer": streamer + } + thread = Thread(target=model.generate, kwargs=gen_kwargs) + thread.start() + response = '' + for new_text in streamer: + response += new_text + print(new_text) + s = parse_text(response) + yield s[-1] + + +@app.post("/chat") +async def chat(request: Request): + json_post_raw = await request.json() + json_post = json.dumps(json_post_raw) + json_post_list = json.loads(json_post) + messages = json_post_list.get("messages")[:-1] + history = [] + if len(messages) > 2: + for i in range(0, len(messages) - 1, 2): + history.append([messages[i]['content'], messages[i + 1]['content']]) + prompt = messages[-1]['content'] + model = json_post_list.get("model") # keep this for future use + return StreamingResponse(predict(query=prompt, max_length=512, top_p=0.7, temperature=0.95, history=history), + media_type="text/event-stream") + + +if __name__ == "__main__": + model_args, data_args, finetuning_args, generating_args = prepare_infer_args() + model, tokenizer = load_pretrained(model_args, finetuning_args) + prompt_template = Template(data_args.prompt_template) + streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) + uvicorn.run(app, host='0.0.0.0', port=8000, workers=1) diff --git a/src/web_demo.py b/src/web_demo.py index 2cceddd388..11ba06b219 100644 --- a/src/web_demo.py +++ b/src/web_demo.py @@ -17,10 +17,8 @@ from transformers import TextIteratorStreamer from transformers.utils.versions import require_version - require_version("gradio>=3.30.0", "To fix: pip install gradio>=3.30.0") - model_args, data_args, finetuning_args, generating_args = prepare_infer_args() model, tokenizer = load_pretrained(model_args, finetuning_args) @@ -45,7 +43,7 @@ def postprocess(self, y): gr.Chatbot.postprocess = postprocess -def parse_text(text): # copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT +def parse_text(text): # copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT lines = text.split("\n") lines = [line for line in lines if line != ""] count = 0 @@ -112,7 +110,6 @@ def reset_state(): with gr.Blocks() as demo: - gr.HTML("""

@@ -134,11 +131,13 @@ def reset_state(): emptyBtn = gr.Button("Clear History") max_length = gr.Slider(0, 2048, value=1024, step=1.0, label="Maximum length", interactive=True) top_p = gr.Slider(0, 1, value=generating_args.top_p, step=0.01, label="Top P", interactive=True) - temperature = gr.Slider(0, 1.5, value=generating_args.temperature, step=0.01, label="Temperature", interactive=True) + temperature = gr.Slider(0, 1.5, value=generating_args.temperature, step=0.01, label="Temperature", + interactive=True) history = gr.State([]) - submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history], show_progress=True) + submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history], + show_progress=True) submitBtn.click(reset_user_input, [], [user_input]) emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)