Skip to content

Commit

Permalink
[Finetune] Support fine-tuning on Gaudi (#155)
Browse files Browse the repository at this point in the history
* [Fine-tuning] Enable fine-tuning on Gaudi

* update

* upate

* update

* update

* update
  • Loading branch information
harborn authored Mar 25, 2024
1 parent a51fd46 commit aa2d08e
Show file tree
Hide file tree
Showing 7 changed files with 161 additions and 83 deletions.
5 changes: 5 additions & 0 deletions docs/finetune_parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ The following are the parameters supported in the finetuning workflow.
|validation_file|None|A json file containing the validation data.|
|validation_split_percentage|5|The percentage of the train set used as validation set in case there's no validation split|
|preprocessing_num_workers|None|The number of processes to use for the preprocessing.|
|max_length|512|Padding sequential data to max length of a batch|
|group|True|Whether to concatenate the sentence for more efficient training|
|block_size|512|The block size of concatenated sentence|
|shuffle|False|Whether shuffle the data at every epoch|


## Training Parameters
|Configuration Name| Default|Meaning|
Expand Down
53 changes: 26 additions & 27 deletions llm_on_ray/common/dataprocesser/general_processer.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,12 @@ def torch_call(self, examples):

class GeneralProcesser(DataProcesser):
def prepare(self, tokenizer, dataset):
per_device_train_batch_size = self.config.get("per_device_train_batch_size", 1)
per_device_eval_batch_size = self.config.get("per_device_eval_batch_size", 1)
group = self.config.get("group", False)
self.config.get("shuffle", False)
per_device_train_batch_size = self.config.get("per_device_train_batch_size")
per_device_eval_batch_size = self.config.get("per_device_eval_batch_size")
max_length = self.config.get("max_length")
group = self.config.get("group")
block_size = self.config.get("block_size")
shuffle = self.config.get("shuffle")
tokenizer.pad_token = tokenizer.eos_token

if isinstance(dataset, datasets.Dataset):
Expand Down Expand Up @@ -123,8 +125,6 @@ def prompt(rec):
)
column_names += [TEXT_COLUMN_NAME]

max_length = self.config.get("max_length", 1024)

def tokenize_function(examples):
return tokenizer(examples[TEXT_COLUMN_NAME], max_length=max_length)

Expand All @@ -136,7 +136,6 @@ def tokenize_function(examples):
)

if group:
block_size = self.config.get("block_size", 1024)

def group_texts(examples):
# Concatenate all texts.
Expand All @@ -160,30 +159,30 @@ def group_texts(examples):
load_from_cache_file=False,
desc=f"Grouping texts in chunks of {block_size}",
)
default_data_collator = transformers.default_data_collator

else:
default_data_collator = DataCollatorForCompletionOnlyLM(
tokenizer=tokenizer,
mlm=False,
return_tensors="pt",
pad_to_multiple_of=8,
)

train_dataset = tokenized_datasets["train"]
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
shuffle=True,
collate_fn=default_data_collator,
batch_size=per_device_train_batch_size,
data_collator = DataCollatorForCompletionOnlyLM(
tokenizer=tokenizer,
mlm=False,
return_tensors="pt",
pad_to_multiple_of=8,
)

train_dataset = tokenized_datasets["train"]
train_dataloader_params = {
"shuffle": shuffle,
"collate_fn": data_collator,
"batch_size": per_device_train_batch_size,
"pin_memory": True,
}
train_dataloader = torch.utils.data.DataLoader(train_dataset, **train_dataloader_params)

eval_dataloader = None
if "validation" in tokenized_datasets:
eval_dataset = tokenized_datasets["validation"]
eval_dataloader = torch.utils.data.DataLoader(
eval_dataset,
collate_fn=default_data_collator,
batch_size=per_device_eval_batch_size,
)
eval_dataloader_params = {
"shuffle": shuffle,
"collate_fn": data_collator,
"batch_size": per_device_eval_batch_size,
}
eval_dataloader = torch.utils.data.DataLoader(eval_dataset, **eval_dataloader_params)
return train_dataloader, eval_dataloader
15 changes: 14 additions & 1 deletion llm_on_ray/common/torch_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def backend_cls(self):
return EnableCCLBackend


def libs_import():
def xpu_libs_import():
"""try to import IPEX and oneCCL."""
try:
import intel_extension_for_pytorch
Expand All @@ -39,6 +39,14 @@ def libs_import():
raise ImportError("Please install torch-ccl") from ccl_not_exist


def hpu_libs_import():
"""try to import habana frameworkfs for torch"""
try:
import habana_frameworks.torch # noqa: F401
except ImportError as habana_not_exist:
raise ImportError("Please install habana_frameworks") from habana_not_exist


def _set_torch_distributed_env_vars(device):
if device is not None:
os.environ["ACCELERATE_TORCH_DEVICE"] = device
Expand All @@ -48,6 +56,11 @@ class EnableCCLBackend(_TorchBackend):
device: Optional[str] = None

def on_start(self, worker_group: WorkerGroup, backend_config: RayTorchConfig):
libs_import = (
hpu_libs_import
if self.device is not None and self.device.startswith("hpu")
else xpu_libs_import
)
for i in range(len(worker_group)):
worker_group.execute_single_async(i, libs_import)
super().on_start(worker_group, backend_config)
Expand Down
20 changes: 18 additions & 2 deletions llm_on_ray/common/trainer/default_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ def prepare(self, model, tokenizer, dataset, optimizer, accelerator):
# self.model, self.optimizer, self.lr_scheduler, ..., are prepared with 2 steps
# because it is recommended way to prepare model and optimizer while using FSDP.
# https://huggingface.co/docs/accelerate/usage_guides/fsdp#a-few-caveats-to-be-aware-of
accelerate_mode = self.config.get("accelerate_mode")
if accelerate_mode in ["GPU_DEEPSPEED"]:
self.accelerate_mode = self.config.get("accelerate_mode")
if self.accelerate_mode in ["GPU_DEEPSPEED"]:
lr = lr_scheduler_config.get("learning_rate", 0.001)
weight_decay = lr_scheduler_config.get("weight_decay", 0)
from accelerate.utils import DummyOptim, DummyScheduler
Expand All @@ -163,6 +163,14 @@ def prepare(self, model, tokenizer, dataset, optimizer, accelerator):
self.lr_scheduler,
) = accelerator.prepare(optimizer, train_dataloader, eval_dataloader, lr_scheduler)

if self.accelerate_mode in ["HPU_DDP"]:
import habana_frameworks.torch.core as htcore
from habana_frameworks.torch.utils.internal import is_lazy

self.htcore = htcore
else:
self.htcore = None

checkpoint = self.config.get("checkpoint")
if checkpoint is not None:
self.recovery(checkpoint)
Expand All @@ -180,12 +188,20 @@ def train(self):
logger.info(f"Start training epoch {idx}, total_steps {total_steps}")
for step, batch in enumerate(self.train_dataloader):
with self.accelerator.accumulate(self.model):
self.model.train()
batch = batch.to(device=self.accelerator.device)
outputs = self.model(**batch)
loss = outputs.loss
self.accelerator.backward(loss)
if self.htcore is not None:
self.htcore.mark_step()
self.optimizer.step()
if self.htcore is not None:
self.htcore.mark_step()
if self.lr_scheduler is not None:
self.lr_scheduler.step()
if self.htcore is not None:
self.htcore.mark_step()
self.optimizer.zero_grad()

if step % logging_steps == 0:
Expand Down
Loading

0 comments on commit aa2d08e

Please sign in to comment.