Skip to content

Commit

Permalink
Add ChatGLM C-Eval Evaluator (#10095)
Browse files Browse the repository at this point in the history
* Add ChatGLM ceval evaluator

* Modify ChatGLM Evaluator Reference
  • Loading branch information
NovTi authored Feb 7, 2024
1 parent e2f2376 commit 156c232
Show file tree
Hide file tree
Showing 2 changed files with 232 additions and 4 deletions.
25 changes: 21 additions & 4 deletions python/llm/dev/benchmark/ceval/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from bigdl.llm.utils.common.log4Error import invalidInputError
from evaluators.qwen import QwenEvaluator
from evaluators.llama import LlamaEvaluator
from evaluators.chatglm import ChatGLMEvaluator


TASK_NAME_MAPPING = {
Expand Down Expand Up @@ -280,7 +281,6 @@ def main(args, evaluator):

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_family", type=str, default="llama")
parser.add_argument("--model_path", type=str, default="meta-llama/Llama-2-7b-chat-hf")
parser.add_argument("--eval_type", type=str, default="validation")
parser.add_argument("--device", type=str, default="xpu")
Expand All @@ -289,22 +289,39 @@ def main(args, evaluator):

args = parser.parse_args()

if args.model_family == "llama":
# decide the model family
model_families = ['llama', 'qwen', 'chatglm']

model_family = None
for family in model_families:
if family in args.model_path.lower():
model_family = family

assert model_family is not None, f"Model {args.model_path}'s model family is not implemented"

if model_family == "llama":
evaluator = LlamaEvaluator(
choices=choices,
model_path=args.model_path,
device=args.device,
qtype=args.qtype
)
elif args.model_family == "qwen":
elif model_family == "qwen":
evaluator = QwenEvaluator(
choices=choices,
model_path=args.model_path,
device=args.device,
qtype=args.qtype
)
elif model_family == "chatglm":
evaluator = ChatGLMEvaluator(
choices=choices,
model_path=args.model_path,
device=args.device,
qtype=args.qtype
)
else:
invalidInputError(
False,
"Invalid model_family, currently support llama and qwen only.")
"Invalid model_family, currently support llama, qwen, and chatglm only.")
main(args, evaluator=evaluator)
211 changes: 211 additions & 0 deletions python/llm/dev/benchmark/ceval/evaluators/chatglm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# refer to https://github.com/THUDM/ChatGLM2-6B/blob/main/evaluation/evaluate_ceval.py

import re
import torch
from tqdm import tqdm
from thefuzz import process
from transformers import AutoTokenizer

from evaluators.evaluator import Evaluator
from bigdl.llm.transformers import AutoModelForCausalLM
from transformers.generation.utils import LogitsProcessorList
from transformers.generation.logits_process import LogitsProcessor


class InvalidScoreLogitsProcessor(LogitsProcessor):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if torch.isnan(scores).any() or torch.isinf(scores).any():
scores.zero_()
scores[..., 5] = 5e4
return scores


class ChatGLMEvaluator(Evaluator):
def __init__(self, choices, model_path="THUDM/chatglm-6b", device="xpu", qtype="sym_int4"):
super(ChatGLMEvaluator, self).__init__(choices, model_path, device, qtype)
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_path,
trust_remote_code=True
)
self.model = AutoModelForCausalLM.from_pretrained(
self.model_path,
load_in_low_bit=self.qtype,
optimize_model=True,
use_cache=True,
trust_remote_code=True
).eval().to(self.device)

def generate_few_shot_prompt(self, subject, dev_df, cot=False):
message = []
k = self.k
if self.k == -1:
k = dev_df.shape[0]
message.append(self.format_example(dev_df.iloc[0, :], cot=cot, add_prompt=f"以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n"))
for i in range(1, k):
message.append(self.format_example(dev_df.iloc[i, :], cot=cot))
return message

def format_example(self, line, include_answer=True, cot=False, add_prompt=''):
example = add_prompt + line['question']
# print(example)
for choice in self.choices:
example += f'\n{choice}. {line[f"{choice}"]}'
example += '\n答案:'
if include_answer:
if cot:
ans = "让我们一步一步思考,\n" + line["explanation"] + f"\n所以答案是{line['answer']}。"
else:
ans = line["answer"]
m = (example, ans)
return m
return example

def extract_cot_answer(self, line, gen_ans):
m = re.findall(r'所以答案是(.+?)。', gen_ans, re.M)
if len(m) > 0 and m[-1] in self.choices:
return m[-1], True
answer_patterns = [
r'([ABCD])是正确的',
r'选项([ABCD])正确',
r'答案为([ABCD])',
r'答案是([ABCD])',
r'答案([ABCD])',
r'选择([ABCD])',
r'答案:([ABCD])',
r'选择答案([ABCD])'
]
# RE extraction
for answer_pattern in answer_patterns:
m = re.search(answer_pattern, gen_ans, re.M)
if m:
answer = m.group(1)
return answer, False
# only containing one choice-character
m = re.findall(r'[ABCD]', gen_ans, re.M)
if len(m) == 1:
answer = m[0]
return answer, False
answer_word_counter = 0
# only containing one choice-context
for c in self.choices:
if str(line[f'{c}']) in gen_ans:
answer = c
answer_word_counter += 1
if answer_word_counter == 1:
return answer, False
return '-', False

def build_prompt(self, text):
return "[Round {}]\n\n问:{}\n\n答:".format(1, text)

def generate_dist(self, model, tokenizer, query, history, max_length=2048,
do_sample=False, logits_processor=None):

if history is None:
history = []
if logits_processor is None:
logits_processor = LogitsProcessorList()
logits_processor.append(InvalidScoreLogitsProcessor())

if not history:
prompt = query
else:
prompt = ""
for i, (old_query, response) in enumerate(history):
prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response)
prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)

# first round prompt
inputs = tokenizer([prompt], padding=True, return_tensors="pt",
truncation=True, max_length=max_length).to(model.device)

# first round generation
outputs = model.generate(**inputs, do_sample=do_sample, max_new_tokens=512)

# organize intermediate_outputs
intermediate_outputs = []
for idx in range(len(outputs)):
output = outputs.tolist()[idx][len(inputs["input_ids"][idx]):]
response = tokenizer.decode(output)
intermediate_outputs.append(response)

# prepare second round prompt
extraction_prompt = '综上所述,ABCD中正确的选项是:'
answer_texts = [query + intermediate + "\n" + extraction_prompt for intermediate in intermediate_outputs]
input_tokens = [self.build_prompt(answer_text) for answer_text in answer_texts]
inputs = tokenizer(input_tokens, padding=True, return_tensors="pt", truncation=True, max_length=2048).to(model.device)

# second round generation
outputs = model(**inputs, return_last_logit=True)

logits = outputs.logits[:, -1]
choice_tokens = [tokenizer.encode(choice, add_special_tokens=False)[0] for choice in self.choices]
logits = logits[:, choice_tokens]
preds = logits.argmax(dim=-1)

return self.choices[preds]

@torch.no_grad()
def eval_subject(
self,
subject_name,
test_df,
eval_type="validation", # "test","validation",
dev_df=None,
few_shot=False,
cot=False,
):
if eval_type == "validation":
correct_num = 0

if few_shot:
history = self.generate_few_shot_prompt(subject_name, dev_df, cot=cot)
else:
history = []

answers = list(test_df['answer'])

for row_index, row in tqdm(test_df.iterrows(), total=len(test_df)):
question = self.format_example(row, include_answer=False, cot=cot)

if few_shot:
response, _ = self.model.chat(self.tokenizer, question, do_sample=False, history=history)
response = response.strip()
# For ChatGLM, we use answer extraction in answer-only mode too.
ans, direct_extract = self.extract_cot_answer(row, response)
else: # zero-shot by extracting answer from distribution
ans = self.generate_dist(self.model, self.tokenizer, question, do_sample=False, max_length=2048, history=history)

if ans == answers[row_index]:
correct_num += 1

correct_ratio = 100*correct_num/len(answers)

return correct_ratio, None
elif eval_type == "test":
answers = {}
for i, row in tqdm(test_df.iterrows(), total=len(test_df)):
question = self.format_example(row)
response, _ = self.model.chat(
self.tokenizer,
question,
history=None,
)
pred = self.extract_answer(response, row)
answers[str(i)] = pred
return None, answers

0 comments on commit 156c232

Please sign in to comment.