forked from deep-diver/LLM-As-Chatbot
-
Notifications
You must be signed in to change notification settings - Fork 0
/
global_vars.py
60 lines (48 loc) · 1.92 KB
/
global_vars.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import yaml
from transformers import GenerationConfig
from models import alpaca, stablelm, koalpaca
def initialize_globals(args):
global model, model_type, stream_model, tokenizer
global gen_config, gen_config_raw
global gen_config_summarization
model_type = "alpaca"
if "stablelm" in args.base_url:
model_type = "stablelm"
elif "KoAlpaca-Polyglot" in args.base_url:
model_type = "koalpaca-polyglot"
elif "gpt4-alpaca" in args.ft_ckpt_url:
model_type = "alpaca-gpt4"
elif "alpaca" in args.ft_ckpt_url:
model_type = "alpaca"
else:
print("unsupported model type")
quit()
print(f"determined model type: {model_type}")
load_model = get_load_model(model_type)
model, tokenizer = load_model(
base=args.base_url,
finetuned=args.ft_ckpt_url,
multi_gpu=args.multi_gpu,
force_download_ckpt=args.force_download_ckpt
)
gen_config, gen_config_raw = get_generation_config(args.gen_config_path)
gen_config_summarization, _ = get_generation_config(args.gen_config_summarization_path)
stream_model = model
def get_load_model(model_type):
if model_type == "alpaca" or model_type == "alpaca-gpt4":
return alpaca.load_model
elif model_type == "stablelm":
return stablelm.load_model
elif model_type == "koalpaca-polyglot":
return koalpaca.load_model
else:
return None
def get_generation_config(path):
with open(path, 'rb') as f:
generation_config = yaml.safe_load(f.read())
generation_config = generation_config["generation_config"]
return GenerationConfig(**generation_config), generation_config
def get_constraints_config(path):
with open(path, 'rb') as f:
constraints_config = yaml.safe_load(f.read())
return ConstraintsConfig(**constraints_config), constraints_config["constraints"]