Skip to content

Commit

Permalink
fix trust remote for llm examples (#1537)
Browse files Browse the repository at this point in the history
  • Loading branch information
mengfei25 authored Jan 12, 2024
1 parent 49ab28d commit 2f2c9a2
Showing 1 changed file with 6 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def get_user_model():
trust_remote_code=args.trust_remote_code,
revision=args.revision,
)
tokenizer = AutoTokenizer.from_pretrained(args.model)
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=args.trust_remote_code)
if args.approach == 'weight_only':
user_model = user_model.float()

Expand Down Expand Up @@ -380,7 +380,7 @@ def eval_func(model):
if args.code_generation:
from intel_extension_for_transformers.llm.evaluation.lm_code_eval import evaluate
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(args.model)
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=args.trust_remote_code)
results = evaluate(
model=user_model,
tokenizer=tokenizer,
Expand Down Expand Up @@ -419,7 +419,8 @@ def eval_func(model):
start = time.time()
results = evaluate(
model="hf-causal",
model_args='pretrained=' + args.model + ',tokenizer=' + args.model + ',dtype=float32',
model_args='pretrained=' + args.model + ',tokenizer=' + args.model \
+ ',dtype=float32' + ",trust_remote_code=" + str(args.trust_remote_code),
user_model=user_model,
batch_size=args.batch_size,
tasks=args.tasks,
Expand All @@ -429,6 +430,8 @@ def eval_func(model):
for task_name in args.tasks:
if task_name == "wikitext":
acc = results["results"][task_name]["word_perplexity"]
elif task_name == "truthfulqa_mc":
acc = results["results"][task_name]["mc1"]
else:
acc = results["results"][task_name]["acc"]
print("Accuracy: %.5f" % acc)
Expand Down

0 comments on commit 2f2c9a2

Please sign in to comment.