From aa2d08e5c6999e5010d813b030e17f090ce34d64 Mon Sep 17 00:00:00 2001 From: harborn Date: Mon, 25 Mar 2024 17:54:16 +0800 Subject: [PATCH] [Finetune] Support fine-tuning on Gaudi (#155) * [Fine-tuning] Enable fine-tuning on Gaudi * update * upate * update * update * update --- docs/finetune_parameters.md | 5 + .../common/dataprocesser/general_processer.py | 53 ++++--- llm_on_ray/common/torch_config.py | 15 +- llm_on_ray/common/trainer/default_trainer.py | 20 ++- llm_on_ray/finetune/finetune.py | 140 +++++++++++------- llm_on_ray/finetune/finetune_config.py | 9 +- pyproject.toml | 2 +- 7 files changed, 161 insertions(+), 83 deletions(-) diff --git a/docs/finetune_parameters.md b/docs/finetune_parameters.md index 69a906e86..2fc2a2f23 100644 --- a/docs/finetune_parameters.md +++ b/docs/finetune_parameters.md @@ -23,6 +23,11 @@ The following are the parameters supported in the finetuning workflow. |validation_file|None|A json file containing the validation data.| |validation_split_percentage|5|The percentage of the train set used as validation set in case there's no validation split| |preprocessing_num_workers|None|The number of processes to use for the preprocessing.| +|max_length|512|Padding sequential data to max length of a batch| +|group|True|Whether to concatenate the sentence for more efficient training| +|block_size|512|The block size of concatenated sentence| +|shuffle|False|Whether shuffle the data at every epoch| + ## Training Parameters |Configuration Name| Default|Meaning| diff --git a/llm_on_ray/common/dataprocesser/general_processer.py b/llm_on_ray/common/dataprocesser/general_processer.py index cd09064a6..63d6225c9 100644 --- a/llm_on_ray/common/dataprocesser/general_processer.py +++ b/llm_on_ray/common/dataprocesser/general_processer.py @@ -84,10 +84,12 @@ def torch_call(self, examples): class GeneralProcesser(DataProcesser): def prepare(self, tokenizer, dataset): - per_device_train_batch_size = self.config.get("per_device_train_batch_size", 1) - per_device_eval_batch_size = self.config.get("per_device_eval_batch_size", 1) - group = self.config.get("group", False) - self.config.get("shuffle", False) + per_device_train_batch_size = self.config.get("per_device_train_batch_size") + per_device_eval_batch_size = self.config.get("per_device_eval_batch_size") + max_length = self.config.get("max_length") + group = self.config.get("group") + block_size = self.config.get("block_size") + shuffle = self.config.get("shuffle") tokenizer.pad_token = tokenizer.eos_token if isinstance(dataset, datasets.Dataset): @@ -123,8 +125,6 @@ def prompt(rec): ) column_names += [TEXT_COLUMN_NAME] - max_length = self.config.get("max_length", 1024) - def tokenize_function(examples): return tokenizer(examples[TEXT_COLUMN_NAME], max_length=max_length) @@ -136,7 +136,6 @@ def tokenize_function(examples): ) if group: - block_size = self.config.get("block_size", 1024) def group_texts(examples): # Concatenate all texts. @@ -160,30 +159,30 @@ def group_texts(examples): load_from_cache_file=False, desc=f"Grouping texts in chunks of {block_size}", ) - default_data_collator = transformers.default_data_collator - - else: - default_data_collator = DataCollatorForCompletionOnlyLM( - tokenizer=tokenizer, - mlm=False, - return_tensors="pt", - pad_to_multiple_of=8, - ) - train_dataset = tokenized_datasets["train"] - train_dataloader = torch.utils.data.DataLoader( - train_dataset, - shuffle=True, - collate_fn=default_data_collator, - batch_size=per_device_train_batch_size, + data_collator = DataCollatorForCompletionOnlyLM( + tokenizer=tokenizer, + mlm=False, + return_tensors="pt", + pad_to_multiple_of=8, ) + train_dataset = tokenized_datasets["train"] + train_dataloader_params = { + "shuffle": shuffle, + "collate_fn": data_collator, + "batch_size": per_device_train_batch_size, + "pin_memory": True, + } + train_dataloader = torch.utils.data.DataLoader(train_dataset, **train_dataloader_params) + eval_dataloader = None if "validation" in tokenized_datasets: eval_dataset = tokenized_datasets["validation"] - eval_dataloader = torch.utils.data.DataLoader( - eval_dataset, - collate_fn=default_data_collator, - batch_size=per_device_eval_batch_size, - ) + eval_dataloader_params = { + "shuffle": shuffle, + "collate_fn": data_collator, + "batch_size": per_device_eval_batch_size, + } + eval_dataloader = torch.utils.data.DataLoader(eval_dataset, **eval_dataloader_params) return train_dataloader, eval_dataloader diff --git a/llm_on_ray/common/torch_config.py b/llm_on_ray/common/torch_config.py index a051de56f..40f7e6125 100644 --- a/llm_on_ray/common/torch_config.py +++ b/llm_on_ray/common/torch_config.py @@ -23,7 +23,7 @@ def backend_cls(self): return EnableCCLBackend -def libs_import(): +def xpu_libs_import(): """try to import IPEX and oneCCL.""" try: import intel_extension_for_pytorch @@ -39,6 +39,14 @@ def libs_import(): raise ImportError("Please install torch-ccl") from ccl_not_exist +def hpu_libs_import(): + """try to import habana frameworkfs for torch""" + try: + import habana_frameworks.torch # noqa: F401 + except ImportError as habana_not_exist: + raise ImportError("Please install habana_frameworks") from habana_not_exist + + def _set_torch_distributed_env_vars(device): if device is not None: os.environ["ACCELERATE_TORCH_DEVICE"] = device @@ -48,6 +56,11 @@ class EnableCCLBackend(_TorchBackend): device: Optional[str] = None def on_start(self, worker_group: WorkerGroup, backend_config: RayTorchConfig): + libs_import = ( + hpu_libs_import + if self.device is not None and self.device.startswith("hpu") + else xpu_libs_import + ) for i in range(len(worker_group)): worker_group.execute_single_async(i, libs_import) super().on_start(worker_group, backend_config) diff --git a/llm_on_ray/common/trainer/default_trainer.py b/llm_on_ray/common/trainer/default_trainer.py index 8cceecbf9..51fc8a4be 100644 --- a/llm_on_ray/common/trainer/default_trainer.py +++ b/llm_on_ray/common/trainer/default_trainer.py @@ -135,8 +135,8 @@ def prepare(self, model, tokenizer, dataset, optimizer, accelerator): # self.model, self.optimizer, self.lr_scheduler, ..., are prepared with 2 steps # because it is recommended way to prepare model and optimizer while using FSDP. # https://huggingface.co/docs/accelerate/usage_guides/fsdp#a-few-caveats-to-be-aware-of - accelerate_mode = self.config.get("accelerate_mode") - if accelerate_mode in ["GPU_DEEPSPEED"]: + self.accelerate_mode = self.config.get("accelerate_mode") + if self.accelerate_mode in ["GPU_DEEPSPEED"]: lr = lr_scheduler_config.get("learning_rate", 0.001) weight_decay = lr_scheduler_config.get("weight_decay", 0) from accelerate.utils import DummyOptim, DummyScheduler @@ -163,6 +163,14 @@ def prepare(self, model, tokenizer, dataset, optimizer, accelerator): self.lr_scheduler, ) = accelerator.prepare(optimizer, train_dataloader, eval_dataloader, lr_scheduler) + if self.accelerate_mode in ["HPU_DDP"]: + import habana_frameworks.torch.core as htcore + from habana_frameworks.torch.utils.internal import is_lazy + + self.htcore = htcore + else: + self.htcore = None + checkpoint = self.config.get("checkpoint") if checkpoint is not None: self.recovery(checkpoint) @@ -180,12 +188,20 @@ def train(self): logger.info(f"Start training epoch {idx}, total_steps {total_steps}") for step, batch in enumerate(self.train_dataloader): with self.accelerator.accumulate(self.model): + self.model.train() + batch = batch.to(device=self.accelerator.device) outputs = self.model(**batch) loss = outputs.loss self.accelerator.backward(loss) + if self.htcore is not None: + self.htcore.mark_step() self.optimizer.step() + if self.htcore is not None: + self.htcore.mark_step() if self.lr_scheduler is not None: self.lr_scheduler.step() + if self.htcore is not None: + self.htcore.mark_step() self.optimizer.zero_grad() if step % logging_steps == 0: diff --git a/llm_on_ray/finetune/finetune.py b/llm_on_ray/finetune/finetune.py index e38596915..e0ba92b74 100644 --- a/llm_on_ray/finetune/finetune.py +++ b/llm_on_ray/finetune/finetune.py @@ -5,8 +5,6 @@ from typing import Any, Dict, Union, Optional import torch -import accelerate -from accelerate.utils import is_xpu_available import ray from ray.train.torch import TorchTrainer @@ -15,7 +13,7 @@ from pydantic_yaml import parse_yaml_raw_as -from accelerate import FullyShardedDataParallelPlugin, DeepSpeedPlugin +from accelerate import DeepSpeedPlugin from torch.distributed.fsdp.fully_sharded_data_parallel import ( FullOptimStateDictConfig, FullStateDictConfig, @@ -23,6 +21,22 @@ from llm_on_ray import common from llm_on_ray.finetune.finetune_config import FinetuneConfig +from importlib import util + +use_habana = False +if util.find_spec("habana_frameworks") is not None: + from optimum.habana.accelerate import GaudiAccelerator as Accelerator + from optimum.habana.accelerate.utils import ( + GaudiFullyShardedDataParallelPlugin as FullyShardedDataParallelPlugin, + ) + from optimum.habana.utils import set_seed + + use_habana = True +else: + from accelerate import Accelerator, FullyShardedDataParallelPlugin + from accelerate.utils import set_seed, is_xpu_available + + use_habana = False def get_accelerate_environment_variable(mode: str, config: Union[Dict[str, Any], None]) -> dict: @@ -61,12 +75,27 @@ def get_accelerate_environment_variable(mode: str, config: Union[Dict[str, Any], "ACCELERATE_USE_DEEPSPEED": "true", "ACCELERATE_MIXED_PRECISION": mixed_precision, }, + "HPU_DDP": { + "ACCELERATE_USE_CPU": "false", + "ACCELERATE_USE_XPU": "false", + "ACCELERATE_USE_IPEX": "false", + "ACCELERATE_MIXED_PRECISION": mixed_precision, + }, } if mode not in mode_env_vars: raise ValueError(f"accelerate mode must be one of {list(mode_env_vars.keys())}") return mode_env_vars[mode] +def get_device_environment_variable(device): + if device == "hpu": + return { + "HABANA_VISIBLE_DEVICES": "all", + "RAY_EXPERIMENTAL_NOSET_HABANA_VISIBLE_MODULES": "true", + } + return {} + + def convert_dtype(dtype: str) -> Optional[torch.dtype]: supported_dtypes = { "fp16": torch.float16, @@ -82,43 +111,12 @@ def train_func(config: Dict[str, Any]): os.chdir(cwd) gradient_accumulation_steps = config["Training"].get("gradient_accumulation_steps", 1) - - accelerate_mode = config["Training"]["accelerate_mode"] - if accelerate_mode in ["GPU_FSDP"]: - fsdp_plugin = FullyShardedDataParallelPlugin( - state_dict_config=FullStateDictConfig(offload_to_cpu=False, rank0_only=False), - optim_state_dict_config=FullOptimStateDictConfig( - offload_to_cpu=False, rank0_only=False - ), - ) - deepspeed_plugin = None - - elif accelerate_mode in ["GPU_DEEPSPEED"]: - fsdp_plugin = None - hf_ds_config = config["Training"]["deepspeed_config_file"] - deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=hf_ds_config) - - else: - fsdp_plugin = None - deepspeed_plugin = None - - output_dir = config["General"]["output_dir"] - accelerator = accelerate.Accelerator( - gradient_accumulation_steps=gradient_accumulation_steps, - fsdp_plugin=fsdp_plugin, - deepspeed_plugin=deepspeed_plugin, - ) - epochs = config["Training"]["epochs"] base_model = config["General"]["base_model"] dataset_file = config["Dataset"]["train_file"] - common.logger.info( - f"accelerator generate finish, accelerator device type = {accelerator.device}" - ) - seed = config["Training"].get("seed") if seed is not None: - accelerate.utils.set_seed(seed) + set_seed(seed) datasets = common.dataset.Dataset.registory.get("HuggingfaceDataset")()( config={ @@ -157,6 +155,36 @@ def train_func(config: Dict[str, Any]): }, ) + accelerate_mode = config["Training"]["accelerate_mode"] + if accelerate_mode in ["GPU_FSDP"]: + fsdp_plugin = FullyShardedDataParallelPlugin( + state_dict_config=FullStateDictConfig(offload_to_cpu=False, rank0_only=False), + optim_state_dict_config=FullOptimStateDictConfig( + offload_to_cpu=False, rank0_only=False + ), + ) + deepspeed_plugin = None + elif accelerate_mode in ["GPU_DEEPSPEED"]: + fsdp_plugin = None + hf_ds_config = config["Training"]["deepspeed_config_file"] + deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=hf_ds_config) + else: + fsdp_plugin = None + deepspeed_plugin = None + + output_dir = config["General"]["output_dir"] + accelerator = Accelerator( + device_placement=False, + gradient_accumulation_steps=gradient_accumulation_steps, + fsdp_plugin=fsdp_plugin, + deepspeed_plugin=deepspeed_plugin, + ) + epochs = config["Training"]["epochs"] + + common.logger.info( + f"accelerator generate finish, accelerator device type = {accelerator.device}" + ) + trainer = common.trainer.Trainer.registory.get("DefaultTrainer")( config={ "accelerate_mode": config["Training"]["accelerate_mode"], @@ -169,7 +197,10 @@ def train_func(config: Dict[str, Any]): "per_device_train_batch_size": config["Training"]["batch_size"], "per_device_eval_batch_size": config["Training"]["batch_size"], "preprocessing_num_workers": config["Dataset"].get("preprocessing_num_workers", 1), - "shuffle": True, + "max_length": config["Dataset"].get("max_length", 512), + "group": config["Dataset"].get("group", True), + "block_size": config["Dataset"].get("block_size", 512), + "shuffle": config["Dataset"].get("shuffle", False), }, "lr_scheduler": { "enable": True, @@ -181,9 +212,7 @@ def train_func(config: Dict[str, Any]): }, "checkpoint": { "root_path": config["General"]["checkpoint_dir"], - } - if config["General"].get("checkpoint_dir") - else None, + }, } ) @@ -231,6 +260,7 @@ def main(external_config=None): config = external_config config["cwd"] = os.getcwd() + num_training_workers = config["Training"].get("num_training_workers") resources_per_worker = config["Training"].get("resources_per_worker") @@ -241,6 +271,7 @@ def main(external_config=None): use_cpu = True if accelerate_mode.startswith("CPU") else False use_gpu = True if accelerate_mode.startswith("GPU") else False ccl_worker_count = 1 if use_cpu is True else num_training_workers + device = config["Training"]["device"].lower() if not ray.is_initialized(): runtime_env = { @@ -257,18 +288,24 @@ def main(external_config=None): accelerate_env_vars = get_accelerate_environment_variable(accelerate_mode, config) runtime_env["env_vars"].update(accelerate_env_vars) + device_env_vars = get_device_environment_variable(device) + runtime_env["env_vars"].update(device_env_vars) + if config["General"]["gpt_base_model"] is True: runtime_env["pip"] = ["transformers==4.26.0"] - import intel_extension_for_pytorch as ipex - - if "xpu" in ipex.__version__: - num_cpus = ( - resources_per_worker["CPU"] * num_training_workers + 1 - ) # additional 1 for head worker - ray.init(num_cpus=num_cpus, runtime_env=runtime_env) - else: + if use_habana: ray.init(runtime_env=runtime_env) + else: + import intel_extension_for_pytorch as ipex + + if "xpu" in ipex.__version__: + num_cpus = ( + resources_per_worker["CPU"] * num_training_workers + 1 + ) # additional 1 for head worker + ray.init(num_cpus=num_cpus, runtime_env=runtime_env) + else: + ray.init(runtime_env=runtime_env) common.logger.info(f"ray available resources = {ray.available_resources()}") @@ -279,12 +316,15 @@ def main(external_config=None): placement_strategy="SPREAD", ) - device = config["Training"]["device"].lower() - if device == "gpu" and is_xpu_available(): + if not use_habana and device == "gpu" and is_xpu_available(): device = "xpu" if config.get("torch_config", None) is None: - backend = "ccl" if device == "cpu" or device == "xpu" else None + backend = None + if device == "cpu" or device == "xpu": + backend = "ccl" + elif device == "hpu": + backend = "hccl" torch_config = common.TorchConfig(backend=backend, device=device) else: customer_torch_config = config.get("torch_config") diff --git a/llm_on_ray/finetune/finetune_config.py b/llm_on_ray/finetune/finetune_config.py index d640829cf..8ff799d1b 100644 --- a/llm_on_ray/finetune/finetune_config.py +++ b/llm_on_ray/finetune/finetune_config.py @@ -36,11 +36,16 @@ class Dataset(BaseModel): train_file: str validation_file: Optional[str] validation_split_percentage: int + max_length: int = 512 + group: bool = True + block_size: int = 512 + shuffle: bool = False class RayResourceConfig(BaseModel): CPU: int GPU: int = 0 + HPU: int = 0 class Training(BaseModel): @@ -61,14 +66,14 @@ class Training(BaseModel): @validator("device") def check_device(cls, v: str): - devices = ["CPU", "GPU"] + devices = ["CPU", "GPU", "HPU"] if v not in devices: raise ValueError(f"device must be one of {devices}") return v @validator("accelerate_mode") def check_accelerate_mode(cls, v: str): - modes = ["CPU_DDP", "GPU_DDP", "GPU_FSDP", "GPU_DEEPSPEED"] + modes = ["CPU_DDP", "GPU_DDP", "GPU_FSDP", "HPU_DDP", "GPU_DEEPSPEED"] if v not in modes: raise ValueError(f"accelerate_mode must be one of {modes}") return v diff --git a/pyproject.toml b/pyproject.toml index 95996773e..332775d5b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ dependencies = [ "accelerate", "datasets>=2.14.6", "numpy", - "ray>=2.9", + "ray<2.10", "typing>=3.7.4.3", "tabulate", "ray[tune]",