Skip to content

Commit

Permalink
fix a bug of fine-tuning on HPU (#239)
Browse files Browse the repository at this point in the history
  • Loading branch information
harborn authored May 30, 2024
1 parent 309bb63 commit e7bcec0
Showing 1 changed file with 24 additions and 2 deletions.
26 changes: 24 additions & 2 deletions llm_on_ray/finetune/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,16 @@
from importlib import util


def set_seed(config):
def adapt_transformers_to_device(config: Dict):
device = config["Training"]["device"]
if device in ["hpu"]:
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi

# adapt transformers to gaudi
adapt_transformers_to_gaudi()


def set_seed(config: Dict):
seed = config["Training"].get("seed", None)
if seed is None:
return
Expand All @@ -57,7 +66,7 @@ def set_seed(config):
_set_seed(seed)


def convert_to_training_args(cls, config):
def convert_to_training_args(cls, config: Dict):
device = config["Training"]["device"]
accelerate_mode = config["Training"]["accelerate_mode"]
save_strategy = config["General"]["save_strategy"]
Expand Down Expand Up @@ -312,11 +321,22 @@ def get_trainer(config: Dict, model, tokenizer, tokenized_dataset, data_collator
elif device in ["hpu"]:
from optimum.habana.transformers import GaudiTrainer
from optimum.habana.transformers import GaudiTrainingArguments
from optimum.habana import GaudiConfig

# If gaudi_config_name is provided, load gaudi_config from huggingface model hub(https://huggingface.co/Habana), otherwise use default gaudi_config
gaudi_config_name = config["General"].get("gaudi_config_name", None)
if gaudi_config_name is not None:
gaudi_config = GaudiConfig.from_pretrained(gaudi_config_name)
else:
gaudi_config = GaudiConfig()
gaudi_config.use_fused_adam = True
gaudi_config.use_fused_clip_norm = True

training_args = convert_to_training_args(GaudiTrainingArguments, config)
trainer = GaudiTrainer(
model=model,
args=training_args,
gaudi_config=gaudi_config,
train_dataset=tokenized_dataset["train"],
eval_dataset=tokenized_dataset["validation"]
if tokenized_dataset.get("validation") is not None
Expand All @@ -331,6 +351,8 @@ def get_trainer(config: Dict, model, tokenizer, tokenized_dataset, data_collator
def train_func(config: Dict[str, Any]):
os.chdir(config["cwd"])

adapt_transformers_to_device(config)

set_seed(config)

tokenizer = load_tokenizer(config)
Expand Down

0 comments on commit e7bcec0

Please sign in to comment.