diff --git a/llm_on_ray/finetune/finetune.py b/llm_on_ray/finetune/finetune.py index eb8cea170..eb4996cb5 100644 --- a/llm_on_ray/finetune/finetune.py +++ b/llm_on_ray/finetune/finetune.py @@ -42,7 +42,16 @@ from importlib import util -def set_seed(config): +def adapt_transformers_to_device(config: Dict): + device = config["Training"]["device"] + if device in ["hpu"]: + from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi + + # adapt transformers to gaudi + adapt_transformers_to_gaudi() + + +def set_seed(config: Dict): seed = config["Training"].get("seed", None) if seed is None: return @@ -57,7 +66,7 @@ def set_seed(config): _set_seed(seed) -def convert_to_training_args(cls, config): +def convert_to_training_args(cls, config: Dict): device = config["Training"]["device"] accelerate_mode = config["Training"]["accelerate_mode"] save_strategy = config["General"]["save_strategy"] @@ -312,11 +321,22 @@ def get_trainer(config: Dict, model, tokenizer, tokenized_dataset, data_collator elif device in ["hpu"]: from optimum.habana.transformers import GaudiTrainer from optimum.habana.transformers import GaudiTrainingArguments + from optimum.habana import GaudiConfig + + # If gaudi_config_name is provided, load gaudi_config from huggingface model hub(https://huggingface.co/Habana), otherwise use default gaudi_config + gaudi_config_name = config["General"].get("gaudi_config_name", None) + if gaudi_config_name is not None: + gaudi_config = GaudiConfig.from_pretrained(gaudi_config_name) + else: + gaudi_config = GaudiConfig() + gaudi_config.use_fused_adam = True + gaudi_config.use_fused_clip_norm = True training_args = convert_to_training_args(GaudiTrainingArguments, config) trainer = GaudiTrainer( model=model, args=training_args, + gaudi_config=gaudi_config, train_dataset=tokenized_dataset["train"], eval_dataset=tokenized_dataset["validation"] if tokenized_dataset.get("validation") is not None @@ -331,6 +351,8 @@ def get_trainer(config: Dict, model, tokenizer, tokenized_dataset, data_collator def train_func(config: Dict[str, Any]): os.chdir(config["cwd"]) + adapt_transformers_to_device(config) + set_seed(config) tokenizer = load_tokenizer(config)