Skip to content

Commit

Permalink
update mx script (#1838)
Browse files Browse the repository at this point in the history
Signed-off-by: Mengni Wang <[email protected]>
  • Loading branch information
mengniwang95 authored Jun 12, 2024
1 parent a0dee94 commit 6733dab
Showing 1 changed file with 38 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@
help="For accuracy measurement only.")
parser.add_argument("--save_accuracy_path", default=None,
help="Save accuracy results path.")
parser.add_argument("--tasks", type=str, default="lambada_openai",
help="tasks list for accuracy validation")
parser.add_argument("--tasks", nargs="+", default=["lambada_openai"], type=str,
help="tasks list for accuracy validation"
)
parser.add_argument("--peft_model_id", type=str, default=None, help="model_name_or_path of peft model")

args = parser.parse_args()
Expand All @@ -54,57 +55,41 @@ def get_user_model():
return user_model, tokenizer

user_model, tokenizer = get_user_model()
if args.quantize:
from neural_compressor.torch.quantization import MXQuantConfig, quantize
quant_config = MXQuantConfig(w_dtype=args.w_dtype, act_dtype=args.act_dtype, weight_only=args.woq)
user_model = quantize(model=user_model, quant_config=quant_config)

from neural_compressor.torch.quantization import MXQuantConfig, prepare, convert
quant_config = MXQuantConfig(w_dtype=args.w_dtype, act_dtype=args.act_dtype, weight_only=args.woq)
user_model = prepare(model=user_model, quant_config=quant_config)
user_model = convert(model=user_model)
user_model.eval()

if args.accuracy:
user_model.eval()
from intel_extension_for_transformers.transformers.llm.evaluation.lm_eval import evaluate, LMEvalParser
args = LMEvalParser(
model="hf",
user_model=user_model,
tokenizer=tokenizer,
batch_size=args.batch_size,
tasks=args.tasks,
device="cpu",
)
results = evaluate(args)
dumped = json.dumps(results, indent=2)
if args.save_accuracy_path:
with open(args.save_accuracy_path, "w") as f:
f.write(dumped)
for task_name in args.tasks:
if task_name == "wikitext":
acc = results["results"][task_name]["word_perplexity"]
else:
acc = results["results"][task_name]["acc"]
print("Accuracy: %.5f" % acc)
print('Batch size = %d' % args.batch_size)
from intel_extension_for_transformers.transformers.llm.evaluation.lm_eval import evaluate, LMEvalParser
eval_args = LMEvalParser(
model="hf",
user_model=user_model,
tokenizer=tokenizer,
batch_size=args.batch_size,
tasks=','.join(args.tasks),
device="cpu",
)

if args.performance:
user_model.eval()
from intel_extension_for_transformers.llm.evaluation.lm_eval import evaluate
import time
samples = args.iters * args.batch_size
start = time.time()
results = evaluate(
model="hf",
tokenizer=tokenizer,
user_model=user_model,
batch_size=args.batch_size,
tasks=args.tasks,
limit=samples,
)
end = time.time()
for task_name in args.tasks:
if task_name == "wikitext":
acc = results["results"][task_name]["word_perplexity"]
else:
acc = results["results"][task_name]["acc"]
print("Accuracy: %.5f" % acc)
print('Throughput: %.3f samples/sec' % (samples / (end - start)))
print('Latency: %.3f ms' % ((end - start)*1000 / samples))
print('Batch size = %d' % args.batch_size)
results = evaluate(eval_args)
dumped = json.dumps(results, indent=2)
if args.save_accuracy_path:
with open(args.save_accuracy_path, "w") as f:
f.write(dumped)

eval_acc = 0
for task_name in args.tasks:
if task_name == "wikitext":
print("Accuracy for %s is: %s" %
(task_name, results["results"][task_name]["word_perplexity,none"]))
eval_acc += results["results"][task_name]["word_perplexity,none"]
else:
print("Accuracy for %s is: %s" %
(task_name, results["results"][task_name]["acc,none"]))
eval_acc += results["results"][task_name]["acc,none"]

if len(args.tasks) != 0:
eval_acc /= len(args.tasks)
print("Accuracy: %.5f" % eval_acc)
print('Batch size = %d' % args.batch_size)

0 comments on commit 6733dab

Please sign in to comment.