Skip to content

Commit

Permalink
Merge branch 'master' into 2_overflow_bf16
Browse files Browse the repository at this point in the history
  • Loading branch information
tjruwase authored Oct 2, 2023
2 parents ea2467f + d8f3f73 commit 775a6fc
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions applications/DeepSpeed-Chat/inference/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import os
import json
from transformers import pipeline, set_seed
from transformers import AutoConfig, OPTForCausalLM, AutoTokenizer
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM


def parse_args():
Expand Down Expand Up @@ -43,9 +43,10 @@ def get_generator(path):
tokenizer.pad_token = tokenizer.eos_token

model_config = AutoConfig.from_pretrained(path)
model = OPTForCausalLM.from_pretrained(path,
from_tf=bool(".ckpt" in path),
config=model_config).half()
model_class = AutoModelForCausalLM.from_config(model_config)
model = model_class.from_pretrained(path,
from_tf=bool(".ckpt" in path),
config=model_config).half()

model.config.end_token_id = tokenizer.eos_token_id
model.config.pad_token_id = model.config.eos_token_id
Expand Down

0 comments on commit 775a6fc

Please sign in to comment.