diff --git a/python/llm/dev/benchmark/ceval/eval.py b/python/llm/dev/benchmark/ceval/eval.py index ceff8b208c0..61ad6f91eb9 100644 --- a/python/llm/dev/benchmark/ceval/eval.py +++ b/python/llm/dev/benchmark/ceval/eval.py @@ -24,6 +24,7 @@ from bigdl.llm.utils.common.log4Error import invalidInputError from evaluators.qwen import QwenEvaluator from evaluators.llama import LlamaEvaluator +from evaluators.chatglm import ChatGLMEvaluator TASK_NAME_MAPPING = { @@ -280,7 +281,6 @@ def main(args, evaluator): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--model_family", type=str, default="llama") parser.add_argument("--model_path", type=str, default="meta-llama/Llama-2-7b-chat-hf") parser.add_argument("--eval_type", type=str, default="validation") parser.add_argument("--device", type=str, default="xpu") @@ -289,22 +289,39 @@ def main(args, evaluator): args = parser.parse_args() - if args.model_family == "llama": + # decide the model family + model_families = ['llama', 'qwen', 'chatglm'] + + model_family = None + for family in model_families: + if family in args.model_path.lower(): + model_family = family + + assert model_family is not None, f"Model {args.model_path}'s model family is not implemented" + + if model_family == "llama": evaluator = LlamaEvaluator( choices=choices, model_path=args.model_path, device=args.device, qtype=args.qtype ) - elif args.model_family == "qwen": + elif model_family == "qwen": evaluator = QwenEvaluator( choices=choices, model_path=args.model_path, device=args.device, qtype=args.qtype ) + elif model_family == "chatglm": + evaluator = ChatGLMEvaluator( + choices=choices, + model_path=args.model_path, + device=args.device, + qtype=args.qtype + ) else: invalidInputError( False, - "Invalid model_family, currently support llama and qwen only.") + "Invalid model_family, currently support llama, qwen, and chatglm only.") main(args, evaluator=evaluator) diff --git a/python/llm/dev/benchmark/ceval/evaluators/chatglm.py b/python/llm/dev/benchmark/ceval/evaluators/chatglm.py new file mode 100644 index 00000000000..435a43178ce --- /dev/null +++ b/python/llm/dev/benchmark/ceval/evaluators/chatglm.py @@ -0,0 +1,211 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# refer to https://github.com/THUDM/ChatGLM2-6B/blob/main/evaluation/evaluate_ceval.py + +import re +import torch +from tqdm import tqdm +from thefuzz import process +from transformers import AutoTokenizer + +from evaluators.evaluator import Evaluator +from bigdl.llm.transformers import AutoModelForCausalLM +from transformers.generation.utils import LogitsProcessorList +from transformers.generation.logits_process import LogitsProcessor + + +class InvalidScoreLogitsProcessor(LogitsProcessor): + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + if torch.isnan(scores).any() or torch.isinf(scores).any(): + scores.zero_() + scores[..., 5] = 5e4 + return scores + + +class ChatGLMEvaluator(Evaluator): + def __init__(self, choices, model_path="THUDM/chatglm-6b", device="xpu", qtype="sym_int4"): + super(ChatGLMEvaluator, self).__init__(choices, model_path, device, qtype) + self.tokenizer = AutoTokenizer.from_pretrained( + self.model_path, + trust_remote_code=True + ) + self.model = AutoModelForCausalLM.from_pretrained( + self.model_path, + load_in_low_bit=self.qtype, + optimize_model=True, + use_cache=True, + trust_remote_code=True + ).eval().to(self.device) + + def generate_few_shot_prompt(self, subject, dev_df, cot=False): + message = [] + k = self.k + if self.k == -1: + k = dev_df.shape[0] + message.append(self.format_example(dev_df.iloc[0, :], cot=cot, add_prompt=f"以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n")) + for i in range(1, k): + message.append(self.format_example(dev_df.iloc[i, :], cot=cot)) + return message + + def format_example(self, line, include_answer=True, cot=False, add_prompt=''): + example = add_prompt + line['question'] + # print(example) + for choice in self.choices: + example += f'\n{choice}. {line[f"{choice}"]}' + example += '\n答案:' + if include_answer: + if cot: + ans = "让我们一步一步思考,\n" + line["explanation"] + f"\n所以答案是{line['answer']}。" + else: + ans = line["answer"] + m = (example, ans) + return m + return example + + def extract_cot_answer(self, line, gen_ans): + m = re.findall(r'所以答案是(.+?)。', gen_ans, re.M) + if len(m) > 0 and m[-1] in self.choices: + return m[-1], True + answer_patterns = [ + r'([ABCD])是正确的', + r'选项([ABCD])正确', + r'答案为([ABCD])', + r'答案是([ABCD])', + r'答案([ABCD])', + r'选择([ABCD])', + r'答案:([ABCD])', + r'选择答案([ABCD])' + ] + # RE extraction + for answer_pattern in answer_patterns: + m = re.search(answer_pattern, gen_ans, re.M) + if m: + answer = m.group(1) + return answer, False + # only containing one choice-character + m = re.findall(r'[ABCD]', gen_ans, re.M) + if len(m) == 1: + answer = m[0] + return answer, False + answer_word_counter = 0 + # only containing one choice-context + for c in self.choices: + if str(line[f'{c}']) in gen_ans: + answer = c + answer_word_counter += 1 + if answer_word_counter == 1: + return answer, False + return '-', False + + def build_prompt(self, text): + return "[Round {}]\n\n问:{}\n\n答:".format(1, text) + + def generate_dist(self, model, tokenizer, query, history, max_length=2048, + do_sample=False, logits_processor=None): + + if history is None: + history = [] + if logits_processor is None: + logits_processor = LogitsProcessorList() + logits_processor.append(InvalidScoreLogitsProcessor()) + + if not history: + prompt = query + else: + prompt = "" + for i, (old_query, response) in enumerate(history): + prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response) + prompt += "[Round {}]\n问:{}\n答:".format(len(history), query) + + # first round prompt + inputs = tokenizer([prompt], padding=True, return_tensors="pt", + truncation=True, max_length=max_length).to(model.device) + + # first round generation + outputs = model.generate(**inputs, do_sample=do_sample, max_new_tokens=512) + + # organize intermediate_outputs + intermediate_outputs = [] + for idx in range(len(outputs)): + output = outputs.tolist()[idx][len(inputs["input_ids"][idx]):] + response = tokenizer.decode(output) + intermediate_outputs.append(response) + + # prepare second round prompt + extraction_prompt = '综上所述,ABCD中正确的选项是:' + answer_texts = [query + intermediate + "\n" + extraction_prompt for intermediate in intermediate_outputs] + input_tokens = [self.build_prompt(answer_text) for answer_text in answer_texts] + inputs = tokenizer(input_tokens, padding=True, return_tensors="pt", truncation=True, max_length=2048).to(model.device) + + # second round generation + outputs = model(**inputs, return_last_logit=True) + + logits = outputs.logits[:, -1] + choice_tokens = [tokenizer.encode(choice, add_special_tokens=False)[0] for choice in self.choices] + logits = logits[:, choice_tokens] + preds = logits.argmax(dim=-1) + + return self.choices[preds] + + @torch.no_grad() + def eval_subject( + self, + subject_name, + test_df, + eval_type="validation", # "test","validation", + dev_df=None, + few_shot=False, + cot=False, + ): + if eval_type == "validation": + correct_num = 0 + + if few_shot: + history = self.generate_few_shot_prompt(subject_name, dev_df, cot=cot) + else: + history = [] + + answers = list(test_df['answer']) + + for row_index, row in tqdm(test_df.iterrows(), total=len(test_df)): + question = self.format_example(row, include_answer=False, cot=cot) + + if few_shot: + response, _ = self.model.chat(self.tokenizer, question, do_sample=False, history=history) + response = response.strip() + # For ChatGLM, we use answer extraction in answer-only mode too. + ans, direct_extract = self.extract_cot_answer(row, response) + else: # zero-shot by extracting answer from distribution + ans = self.generate_dist(self.model, self.tokenizer, question, do_sample=False, max_length=2048, history=history) + + if ans == answers[row_index]: + correct_num += 1 + + correct_ratio = 100*correct_num/len(answers) + + return correct_ratio, None + elif eval_type == "test": + answers = {} + for i, row in tqdm(test_df.iterrows(), total=len(test_df)): + question = self.format_example(row) + response, _ = self.model.chat( + self.tokenizer, + question, + history=None, + ) + pred = self.extract_answer(response, row) + answers[str(i)] = pred + return None, answers \ No newline at end of file