From 0a3d4bd43f69c29e2f8a3b07ac13036e41c6579c Mon Sep 17 00:00:00 2001 From: xinhe Date: Mon, 4 Mar 2024 16:30:53 +0800 Subject: [PATCH] improve HPU usage (#1643) refine example add hpu in auto accelerator and fix bug --------- Signed-off-by: xin3he --- .../quantization/habana_fp8/requirement.txt | 6 +- .../quantization/habana_fp8/run_llm.py | 115 +---------------- .../quantization/habana_fp8/utils.py | 122 ++++++++++++++++++ .../torch/algorithms/weight_only/gptq.py | 3 +- .../torch/utils/auto_accelerator.py | 49 ++++++- 5 files changed, 179 insertions(+), 116 deletions(-) diff --git a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/requirement.txt b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/requirement.txt index 2b39bcbeadf..d3655acd742 100644 --- a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/requirement.txt +++ b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/requirement.txt @@ -1,5 +1,7 @@ transformers datasets +accelerate SentencePiece -intel_extension_for_transformers -lm_eval +lm_eval==0.3.0 +openpyxl +einops diff --git a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/run_llm.py b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/run_llm.py index d7fb14c89d2..0198a888b86 100644 --- a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/run_llm.py +++ b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/run_llm.py @@ -17,12 +17,8 @@ import transformers from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig import habana_frameworks.torch.core as htcore -import numpy as np -import lm_eval -import lm_eval.tasks -import lm_eval.evaluator from accelerate import init_empty_weights -from utils import itrex_bootstrap_stderr, show_msg, save_to_excel +from utils import show_msg, eval_func torch.set_grad_enabled(False) @@ -30,8 +26,6 @@ torch.device('hpu') -# to avoid out-of-memory caused by Popen for large language models. -lm_eval.metrics.bootstrap_stderr = itrex_bootstrap_stderr parser = argparse.ArgumentParser() @@ -52,6 +46,7 @@ parser.add_argument("--precision", type=str, default='fp8_e4m3', help="Select from ['fp8_e4m3', 'fp8_e5m2', 'bf16', 'fp16'], \ ['bf16', 'fp16'] only work with cast approach") +parser.add_argument("--autotune", action="store_true") parser.add_argument("--accuracy", action="store_true") parser.add_argument("--performance", action="store_true") parser.add_argument("--generate", action="store_true") @@ -182,8 +177,9 @@ ### dynamic & static quantization ### if args.approach in ["dynamic", "static"] and not args.load: print("device:", next(user_model.parameters()).device) - from neural_compressor.torch.quantization.config import FP8Config, get_default_fp8_config - from neural_compressor.torch.quantization import quantize + from neural_compressor.torch.quantization import ( + quantize, autotune, FP8Config, get_default_fp8_config, TuningConfig, get_default_fp8_config_set + ) dtype = args.precision if args.approach == "dynamic": from neural_compressor.torch.algorithms.habana_fp8 import quantize_dynamic @@ -300,106 +296,7 @@ def replace_torch_mm_bmm(): if args.accuracy: - - class HabanaModelAdapter(lm_eval.base.BaseLM): - def __init__(self, tokenizer, model, args, options): - super().__init__() - self.tokenizer = tokenizer - self.model = model.eval() - self._batch_size = args.batch_size - self.buckets = list(sorted(args.buckets)) - self.options = options - self._device = "hpu" - torch.set_grad_enabled(False) - - @property - def eot_token_id(self): - return self.model.config.eos_token_id - - @property - def max_length(self): - return self.buckets[-1] - - @property - def max_gen_toks(self): - raise NotImplementedError() - - @property - def batch_size(self): - return self._batch_size - - @property - def device(self): - # We need to do padding ourselves, otherwise we'll end up with recompilations - # Returning 'cpu' to keep tensors on CPU in lm_eval code - return 'cpu' # 'hpu' - - def tok_encode(self, string): - if re.search("chatglm3", args.model.lower()) or re.search("llama", args.model.lower()) : - string = string.lstrip() - return self.tokenizer.encode(string, add_special_tokens=False) - - def tok_decode(self, tokens): - return self.tokenizer.decode(tokens, skip_special_tokens=True) - - def _model_generate(self, context, max_length, eos_token_id): - raise NotImplementedError() - - def find_bucket(self, length): - return [b for b in self.buckets if b >= length][0] - - def _model_call(self, inps): - seq_length = inps.shape[-1] - padding_length = 0 - bucket_length = self.find_bucket(seq_length) - padding_length = bucket_length - seq_length - inps = F.pad(inps, (0, padding_length), value=self.model.config.pad_token_id) - logits = self.model(inps.to(self._device))["logits"].cpu() - - if padding_length > 0: - logits = logits[:, :-padding_length, :] - logits = logits.to(torch.float32) - return logits - - lm_tasks = lm_eval.tasks.get_task_dict(args.tasks) - options = None - lm = HabanaModelAdapter(tokenizer, user_model, args, options) - - eval_start = time.perf_counter() - if args.approach == "cast": - from neural_compressor.torch.amp import autocast - if args.precision == "fp8_e4m3": - dtype = torch.float8_e4m3fn - elif args.precision == "fp8_e5m2": - dtype = torch.float8_e5m2 - elif args.precision == "fp16": - dtype = torch.float16 - elif args.precision == "bf16": - dtype = torch.bfloat16 - with autocast('hpu', dtype=dtype): - results = lm_eval.evaluator.evaluate(lm, lm_tasks, limit=args.limit) - else: - results = lm_eval.evaluator.evaluate(lm, lm_tasks, limit=args.limit) - print(lm_eval.evaluator.make_table(results)) - eval_end = time.perf_counter() - print("Duration:", eval_end - eval_start) - results['args'] = vars(args) - results['duration'] = eval_end - eval_start - - - dumped = json.dumps(results, indent=2) - accu_dict = {} - case_name = args.approach + "-" + args.precision - for task_name in args.tasks: - if task_name == "wikitext": - print("Accuracy for %s is: %s" % (task_name, results["results"][task_name]["word_perplexity"]), flush=True) - accu_dict[task_name] = [args.model, case_name, results["results"][task_name]["word_perplexity"]] - else: - print("Accuracy for %s is: %s" % (task_name, results["results"][task_name]["acc"]), flush=True) - accu_dict[task_name] = [args.model, case_name, results["results"][task_name]["acc"]] - if args.dump_to_excel and local_rank in [-1, 0]: - save_to_excel(accu_dict) - + eval_func(user_model, tokenizer=tokenizer, args=args) # dump final message of HPU show_msg() diff --git a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/utils.py b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/utils.py index 7eac0e0bdf7..34ac214b44e 100644 --- a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/utils.py +++ b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/utils.py @@ -33,3 +33,125 @@ def save_to_excel(dict): df_existing = pd.DataFrame() df_combined = pd.concat([df_existing, df_new], axis=0, ignore_index=True) df_combined.to_excel('output.xlsx', index=False, engine='openpyxl', header=True) + + +def eval_func(user_model, tokenizer, args): + import os + import re + import time + import json + import torch + import habana_frameworks.torch.hpex + import torch.nn.functional as F + import lm_eval + import lm_eval.tasks + import lm_eval.evaluator + + # to avoid out-of-memory caused by Popen for large language models. + lm_eval.metrics.bootstrap_stderr = itrex_bootstrap_stderr + + class HabanaModelAdapter(lm_eval.base.BaseLM): + def __init__(self, tokenizer, model, args, options): + super().__init__() + self.tokenizer = tokenizer + self.model = model.eval() + self._batch_size = args.batch_size + self.buckets = list(sorted(args.buckets)) + self.options = options + self._device = "hpu" + torch.set_grad_enabled(False) + + @property + def eot_token_id(self): + return self.model.config.eos_token_id + + @property + def max_length(self): + return self.buckets[-1] + + @property + def max_gen_toks(self): + raise NotImplementedError() + + @property + def batch_size(self): + return self._batch_size + + @property + def device(self): + # We need to do padding ourselves, otherwise we'll end up with recompilations + # Returning 'cpu' to keep tensors on CPU in lm_eval code + return 'cpu' # 'hpu' + + def tok_encode(self, string): + if ( + re.search("chatglm3", args.model.lower()) or + re.search("llama", args.model.lower()) or + re.search("mistral", args.model.lower()) + ): + string = string.lstrip() + return self.tokenizer.encode(string, add_special_tokens=False) + + def tok_decode(self, tokens): + return self.tokenizer.decode(tokens, skip_special_tokens=True) + + def _model_generate(self, context, max_length, eos_token_id): + raise NotImplementedError() + + def find_bucket(self, length): + return [b for b in self.buckets if b >= length][0] + + def _model_call(self, inputs): + seq_length = inputs.shape[-1] + padding_length = 0 + bucket_length = self.find_bucket(seq_length) + padding_length = bucket_length - seq_length + inputs = F.pad(inputs, (0, padding_length), value=self.model.config.pad_token_id) + logits = self.model(inputs.to(self._device))["logits"].cpu() + + if padding_length > 0: + logits = logits[:, :-padding_length, :] + logits = logits.to(torch.float32) + return logits + + lm_tasks = lm_eval.tasks.get_task_dict(args.tasks) + options = None + lm = HabanaModelAdapter(tokenizer, user_model, args, options) + + eval_start = time.perf_counter() + if args.approach == "cast": + from neural_compressor.torch.amp import autocast + if args.precision == "fp8_e4m3": + dtype = torch.float8_e4m3fn + elif args.precision == "fp8_e5m2": + dtype = torch.float8_e5m2 + elif args.precision == "fp16": + dtype = torch.float16 + elif args.precision == "bf16": + dtype = torch.bfloat16 + with autocast('hpu', dtype=dtype): + results = lm_eval.evaluator.evaluate(lm, lm_tasks, limit=args.limit) + else: + results = lm_eval.evaluator.evaluate(lm, lm_tasks, limit=args.limit) + print(lm_eval.evaluator.make_table(results)) + eval_end = time.perf_counter() + print("Duration:", eval_end - eval_start) + results['args'] = vars(args) + results['duration'] = eval_end - eval_start + + # make sure that result is dumped only once during multi-cards evaluation + local_rank = int(os.getenv('LOCAL_RANK', '-1')) + if local_rank in [-1, 0]: + dumped = json.dumps(results, indent=2) + accu_dict = {} + case_name = args.approach + "-" + args.precision + for task_name in args.tasks: + if task_name == "wikitext": + print("Accuracy for %s is: %s" % (task_name, results["results"][task_name]["word_perplexity"]), flush=True) + accu_dict[task_name] = [args.model, case_name, results["results"][task_name]["word_perplexity"]] + else: + print("Accuracy for %s is: %s" % (task_name, results["results"][task_name]["acc"]), flush=True) + accu_dict[task_name] = [args.model, case_name, results["results"][task_name]["acc"]] + if args.dump_to_excel: + save_to_excel(accu_dict) + return results["results"][task_name]["acc"] diff --git a/neural_compressor/torch/algorithms/weight_only/gptq.py b/neural_compressor/torch/algorithms/weight_only/gptq.py index 3752ffc46c8..5c5d68a4f72 100644 --- a/neural_compressor/torch/algorithms/weight_only/gptq.py +++ b/neural_compressor/torch/algorithms/weight_only/gptq.py @@ -257,8 +257,7 @@ def __init__( # device self.device = get_device(kwargs.pop("device", "auto")) - if str(self.model.device).startswith("cuda"): - self.device = self.model.device + self.model.to(self.device) self.is_ready = False self.export_compressed_model = export_compressed_model diff --git a/neural_compressor/torch/utils/auto_accelerator.py b/neural_compressor/torch/utils/auto_accelerator.py index 57af493b738..2887f9166ec 100644 --- a/neural_compressor/torch/utils/auto_accelerator.py +++ b/neural_compressor/torch/utils/auto_accelerator.py @@ -31,7 +31,8 @@ from neural_compressor.torch.utils import logger -PRIORITY_CUDA = 100 +PRIORITY_HPU = 100 +PRIORITY_CUDA = 95 PRIORITY_CPU = 90 @@ -53,8 +54,9 @@ class CPU_Accelerator: """ def decorator(accelerator_cls): - cls.registered_accelerators.setdefault(name, {}) - cls.registered_accelerators[name] = (accelerator_cls, priority) + if accelerator_cls.is_available(): + cls.registered_accelerators.setdefault(name, {}) + cls.registered_accelerators[name] = (accelerator_cls, priority) return accelerator_cls return decorator @@ -202,6 +204,47 @@ def empty_cache(self): return torch.cuda.empty_cache() +@register_accelerator(name="hpu", priority=PRIORITY_HPU) +class HPU_Accelerator(Auto_Accelerator): + def __init__(self) -> None: + self._name = "hpu" + + def name(self) -> str: + return self._name + + @classmethod + def is_available(cls) -> bool: + from .environ import is_hpex_available + + if is_hpex_available(): + return torch.hpu.is_available() + else: + return False + + def device_name(self, device_indx) -> str: + if device_indx is None: + return "hpu" + return f"hpu:{device_indx}" + + def synchronize(self): + return torch.hpu.synchronize() + + def set_device(self, device_index): + return torch.hpu.set_device(device_index) + + def current_device(self): + return torch.hpu.current_device() + + def current_device_name(self): + return "hpu:{}".format(torch.hpu.current_device()) + + def device(self, device_index=None): + return torch.hpu.device(device_index) + + def empty_cache(self): + return torch.hpu.empty_cache() + + def auto_detect_accelerator(device_name="auto") -> Auto_Accelerator: # Force use the cpu on node has both cpu and gpu: `FORCE_DEVICE=cpu` python main.py ... # The `FORCE_DEVICE` is case insensitive.