forked from MeetKai/functionary
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathserver.py
90 lines (80 loc) · 3.14 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import argparse
import json
import uuid
from typing import Union
import torch
import uvicorn
from fastapi import FastAPI
from fastapi.responses import JSONResponse, StreamingResponse
from transformers import (AutoModelForCausalLM, LlamaForCausalLM,
LlamaTokenizerFast)
from functionary.inference import generate_message
from functionary.inference_stream import generate_stream
from functionary.openai_types import (ChatCompletion, ChatCompletionChunk,
ChatInput, Choice, StreamChoice)
app = FastAPI(title="Functionary API")
@app.post("/v1/chat/completions")
async def chat_endpoint(chat_input: ChatInput):
request_id = str(uuid.uuid4())
if not chat_input.stream:
response_message = generate_message(
messages=chat_input.messages,
functions=chat_input.functions,
tools=chat_input.tools,
temperature=chat_input.temperature,
model=model, # type: ignore
tokenizer=tokenizer,
device=model.device,
)
finish_reason = "stop"
if response_message.function_call is not None:
finish_reason = "function_call" # need to add this to follow the format of openAI function calling
result = ChatCompletion(
id=request_id,
choices=[Choice.from_message(response_message, finish_reason)],
)
return result.dict(exclude_none=True)
else:
response_generator = generate_stream(
messages=chat_input.messages,
functions=chat_input.functions,
tools=chat_input.tools,
temperature=chat_input.temperature,
model=model, # type: ignore
tokenizer=tokenizer,
)
def get_response_stream():
for response in response_generator:
chunk = StreamChoice(**response)
result = ChatCompletionChunk(id=request_id, choices=[chunk])
chunk_dic = result.dict(exclude_unset=True)
chunk_data = json.dumps(chunk_dic, ensure_ascii=False)
yield f"data: {chunk_data}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(get_response_stream(), media_type="text/event-stream")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Functionary API Server")
parser.add_argument(
"--model",
type=str,
default="musabgultekin/functionary-7b-v1",
help="Model name",
)
parser.add_argument(
"--device",
type=str,
default="auto",
help="choose which device to host the model: cpu, cuda, cuda:xxx, or auto",
)
parser.add_argument("--load_in_8bit", type=bool, default=False)
args = parser.parse_args()
model = AutoModelForCausalLM.from_pretrained(
args.model,
low_cpu_mem_usage=True,
device_map=args.device,
torch_dtype=torch.bfloat16 if args.device == "cpu" else torch.float16,
load_in_8bit=args.load_in_8bit,
)
tokenizer = LlamaTokenizerFast.from_pretrained(args.model, legacy=True)
print(tokenizer)
uvicorn.run(app, host="0.0.0.0", port=8000)