From 128eba41e6e18ccf6783de3d6e2f8ea8cc27911f Mon Sep 17 00:00:00 2001 From: Moshe Island Date: Mon, 11 Sep 2023 14:29:23 +0300 Subject: [PATCH] deepspeed-chat: support any model in chatbot Currently, chatbot assumes OPTForCausalLM model. Modify it to use the required model from the checkpoint. Change-Id: I04cbc28f87c7be4fc89a3fac39a3e5634b151b32 Signed-off-by: Moshe Island --- applications/DeepSpeed-Chat/inference/chatbot.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/applications/DeepSpeed-Chat/inference/chatbot.py b/applications/DeepSpeed-Chat/inference/chatbot.py index 38b900d7d..5a4e36895 100644 --- a/applications/DeepSpeed-Chat/inference/chatbot.py +++ b/applications/DeepSpeed-Chat/inference/chatbot.py @@ -10,7 +10,7 @@ import os import json from transformers import pipeline, set_seed -from transformers import AutoConfig, OPTForCausalLM, AutoTokenizer +from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM def parse_args(): @@ -43,9 +43,10 @@ def get_generator(path): tokenizer.pad_token = tokenizer.eos_token model_config = AutoConfig.from_pretrained(path) - model = OPTForCausalLM.from_pretrained(path, - from_tf=bool(".ckpt" in path), - config=model_config).half() + model_class = AutoModelForCausalLM.from_config(model_config) + model = model_class.from_pretrained(path, + from_tf=bool(".ckpt" in path), + config=model_config).half() model.config.end_token_id = tokenizer.eos_token_id model.config.pad_token_id = model.config.eos_token_id