From a17d93e6fdad7a0bf73c297dd16e66fdcd11b2ee Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Wed, 19 Jul 2023 00:47:53 -0700 Subject: [PATCH 1/2] xpu support --- qlora.py | 41 ++++++++++++++++++++++++++++++++++++++--- 1 file changed, 38 insertions(+), 3 deletions(-) diff --git a/qlora.py b/qlora.py index 554de1ca..be63d6a0 100644 --- a/qlora.py +++ b/qlora.py @@ -14,6 +14,9 @@ import logging import bitsandbytes as bnb import pandas as pd +import importlib +from packaging import version +from packaging.version import parse import torch import transformers @@ -41,7 +44,31 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR -torch.backends.cuda.matmul.allow_tf32 = True +def is_ipex_available(): + def get_major_and_minor_from_version(full_version): + return str(version.parse(full_version).major) + "." + str(version.parse(full_version).minor) + + _torch_version = importlib.metadata.version("torch") + if importlib.util.find_spec("intel_extension_for_pytorch") is None: + return False + _ipex_version = "N/A" + try: + _ipex_version = importlib.metadata.version("intel_extension_for_pytorch") + except importlib.metadata.PackageNotFoundError: + return False + torch_major_and_minor = get_major_and_minor_from_version(_torch_version) + ipex_major_and_minor = get_major_and_minor_from_version(_ipex_version) + if torch_major_and_minor != ipex_major_and_minor: + warnings.warn( + f"Intel Extension for PyTorch {ipex_major_and_minor} needs to work with PyTorch {ipex_major_and_minor}.*," + f" but PyTorch {_torch_version} is found. Please switch to the matching version and run again." + ) + return False + return True + + +if torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True logger = logging.getLogger(__name__) @@ -261,7 +288,11 @@ def touch(fname, times=None): def get_accelerate_model(args, checkpoint_dir): - n_gpus = torch.cuda.device_count() + if torch.cuda.is_available(): + n_gpus = torch.cuda.device_count() + if is_ipex_available() and torch.xpu.is_available(): + n_gpus = torch.xpu.device_count() + max_memory = f'{args.max_memory_MB}MB' max_memory = {i: max_memory for i in range(n_gpus)} device_map = "auto" @@ -303,6 +334,10 @@ def get_accelerate_model(args, checkpoint_dir): print('='*80) print('Your GPU supports bfloat16, you can accelerate training with the argument --bf16') print('='*80) + + if compute_dtype == torch.float16 and (is_ipex_available() and torch.xpu.is_available()): + compute_dtype = torch.bfloat16 + print('Intel XPU does not supports float16 yet, you can accelerate training with the argument --bf16') setattr(model, 'model_parallel', True) setattr(model, 'is_parallelizable', True) @@ -651,7 +686,7 @@ def train(): **vars(model_args), **vars(data_args), **vars(training_args) ) print(args) - + checkpoint_dir, completed_training = get_last_checkpoint(args.output_dir) if completed_training: print('Detected that training was already completed!') From a9644b41540d9ab8b13348fcc68ab93a5381b36a Mon Sep 17 00:00:00 2001 From: abhilash1910 Date: Wed, 19 Jul 2023 04:01:03 -0700 Subject: [PATCH 2/2] xpu support --- qlora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/qlora.py b/qlora.py index be63d6a0..cac08eca 100644 --- a/qlora.py +++ b/qlora.py @@ -337,7 +337,7 @@ def get_accelerate_model(args, checkpoint_dir): if compute_dtype == torch.float16 and (is_ipex_available() and torch.xpu.is_available()): compute_dtype = torch.bfloat16 - print('Intel XPU does not supports float16 yet, you can accelerate training with the argument --bf16') + print('Intel XPU does not support float16 yet, so switching to bfloat16') setattr(model, 'model_parallel', True) setattr(model, 'is_parallelizable', True)