Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deepspeed Zero3 QLoRA Fine-tuning #11048

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/tloen/alpaca-lora/blob/main/finetune.py
#
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from typing import List
os.environ["ACCELERATE_USE_XPU"] = "true"
import fire
import torch
from datasets import load_dataset
import accelerate
import transformers

from transformers import AutoTokenizer, BitsAndBytesConfig, AutoConfig, AutoModelForCausalLM
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not using the AutoModelForCausalLM from ipex-llm?

from peft import (
get_peft_model_state_dict,
set_peft_model_state_dict,
)

current_dir = os.path.dirname(os.path.realpath(__file__))
common_util_path = os.path.join(current_dir, '..', '..')
import sys
sys.path.append(common_util_path)
from common.utils import Prompter, get_int_from_env, wandb_check, get_train_val_data

from ipex_llm.transformers.qlora import get_peft_model, prepare_model_for_kbit_training,\
LoraConfig
from ipex_llm.utils.common import invalidInputError
import deepspeed as ds

local_rank = get_int_from_env(["LOCAL_RANK","MPI_LOCALRANKID"], "0")
world_size = get_int_from_env(["WORLD_SIZE","PMI_SIZE"], "1")
port = get_int_from_env(["MASTER_PORT"], 29500)
os.environ["LOCAL_RANK"] = str(local_rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["RANK"] = str(local_rank)
os.environ["MASTER_PORT"] = str(port)

def train(
# model/data params
base_model: str = "meta-llama/Llama-2-7b-hf", # the only required argument, default to be "meta-llama/Llama-2-7b-hf"
data_path: str = "yahma/alpaca-cleaned",
output_dir: str = "./ipex-deepspeed-zero3-qlora-alpaca",
# training hyperparams
bf16: bool = True, # default to bf16
batch_size: int = 128,
micro_batch_size: int = 2, # default to be 2, limited by GPU memory
num_epochs: int = 3,
learning_rate: float = 3e-5, # default to be 3e-5 to avoid divergence
cutoff_len: int = 256,
val_set_size: int = 2000,
# lora hyperparams
lora_r: int = 8,
lora_alpha: int = 16,
lora_dropout: float = 0.05,
lora_target_modules: List[str] = [
"q_proj",
"v_proj",
"k_proj",
"o_proj",
"up_proj",
"down_proj",
"gate_proj"
], # according to the QLoRA paper (https://arxiv.org/pdf/2305.14314.pdf), it's suggested to fine tune all linear layers
# llm hyperparams
train_on_inputs: bool = True, # if False, masks out inputs in loss
add_eos_token: bool = False,
group_by_length: bool = False, # faster, but produces an odd training loss curve
# wandb params
wandb_project: str = "",
wandb_run_name: str = "",
wandb_watch: str = "", # options: false | gradients | all
wandb_log_model: str = "", # options: false | true
resume_from_checkpoint: str = None, # either training checkpoint or final adapter
prompt_template_name: str = "alpaca", # The prompt template to use, will default to alpaca.
gradient_checkpointing: bool = False,
deepspeed: str = None,
training_mode: str = "qlora",
):
invalidInputError(training_mode == "qlora",
f"This example is for qlora training mode, but got training_mode={training_mode}.")
if int(os.environ.get("LOCAL_RANK", 0)) == 0:
print(
f"Training Alpaca-LoRA model with params:\n"
f"base_model: {base_model}\n"
f"data_path: {data_path}\n"
f"output_dir: {output_dir}\n"
f"batch_size: {batch_size}\n"
f"micro_batch_size: {micro_batch_size}\n"
f"num_epochs: {num_epochs}\n"
f"learning_rate: {learning_rate}\n"
f"cutoff_len: {cutoff_len}\n"
f"val_set_size: {val_set_size}\n"
f"lora_r: {lora_r}\n"
f"lora_alpha: {lora_alpha}\n"
f"lora_dropout: {lora_dropout}\n"
f"lora_target_modules: {lora_target_modules}\n"
f"train_on_inputs: {train_on_inputs}\n"
f"add_eos_token: {add_eos_token}\n"
f"group_by_length: {group_by_length}\n"
f"wandb_project: {wandb_project}\n"
f"wandb_run_name: {wandb_run_name}\n"
f"wandb_watch: {wandb_watch}\n"
f"wandb_log_model: {wandb_log_model}\n"
f"resume_from_checkpoint: {resume_from_checkpoint or False}\n"
f"prompt template: {prompt_template_name}\n"
f"training_mode: {training_mode}\n"
)
assert (
base_model
), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'"
gradient_accumulation_steps = batch_size // micro_batch_size

prompter = Prompter(prompt_template_name)

device_map = "auto"
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
if ddp:
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
gradient_accumulation_steps = gradient_accumulation_steps // world_size

# Check if parameter passed or if set within environ
use_wandb = wandb_check(wandb_project, wandb_watch, wandb_log_model)


model_config = model_config = AutoConfig.from_pretrained(base_model)
with ds.zero.Init(config_dict_or_path=deepspeed):
model = AutoModelForCausalLM.from_pretrained(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not setting load_in_low_bit?

base_model,
config=model_config,
torch_dtype=torch.bfloat16,
ignore_mismatched_sizes=True,
)

from transformers import LlamaTokenizer
tokenizer = LlamaTokenizer.from_pretrained(base_model, trust_remote_code=True)
print(f"Tokenizer loaded on rank {os.environ.get('LOCAL_RANK')}")

tokenizer.pad_token_id = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tokenizer.pad_token_id = (
        0  # unk. we want this to be different from the eos token
    )
    tokenizer.padding_side = "left"  # Allow batched inference

This code is not necessary anymore.

0 # unk. we want this to be different from the eos token
)
tokenizer.padding_side = "left" # Allow batched inference

print(model)

# Prepare a IPEX-LLM compatible Peft model
model = prepare_model_for_kbit_training(model,
use_gradient_checkpointing=gradient_checkpointing,
enable_deepspeed_zero3=True)

config = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
target_modules=lora_target_modules,
lora_dropout=lora_dropout,
bias="none",
task_type="CAUSAL_LM",
training_mode=training_mode,
)
print(f"Lora Config: {config}")
model = get_peft_model(model, config, enable_deepspeed_zero3=True)

if data_path.endswith(".json") or data_path.endswith(".jsonl"):
data = load_dataset("json", data_files=data_path)
else:
data = load_dataset(data_path)

model.print_trainable_parameters() # Be more transparent about the % of trainable params.

train_data, val_data = get_train_val_data(data, tokenizer, prompter, train_on_inputs,
add_eos_token, cutoff_len, val_set_size, seed=42)

trainer = transformers.Trainer(
model=model,
train_dataset=train_data,
eval_dataset=val_data,
args=transformers.TrainingArguments(
per_device_train_batch_size=micro_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
# warmup_ratio=0.03,
# warmup_steps=100,
max_grad_norm=0.3,
num_train_epochs=num_epochs,
learning_rate=learning_rate,
lr_scheduler_type="cosine",
bf16=True, # ensure training more stable
logging_steps=1,
optim="adamw_torch",
evaluation_strategy="steps" if val_set_size > 0 else "no",
save_strategy="steps",
eval_steps=100 if val_set_size > 0 else None,
save_steps=100,
output_dir=output_dir,
save_total_limit=100,
load_best_model_at_end=True if val_set_size > 0 else False,
ddp_find_unused_parameters=False if ddp else None,
group_by_length=group_by_length,
report_to="wandb" if use_wandb else None,
run_name=wandb_run_name if use_wandb else None,
gradient_checkpointing=gradient_checkpointing,
ddp_backend="ccl",
deepspeed=deepspeed,
save_safetensors=False,
),
data_collator=transformers.DataCollatorForSeq2Seq(
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
),
)
model.config.use_cache = False

trainer.train(resume_from_checkpoint=resume_from_checkpoint)

model.save_pretrained(output_dir)

print(
"\n If there's a warning about missing keys above, please disregard :)"
)


if __name__ == "__main__":
fire.Fire(train)
99 changes: 68 additions & 31 deletions python/llm/src/ipex_llm/transformers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
enable_xetla=False,
mixed_precision=False,
act_order=False,
enable_deepspeed_zero3=False,
):
from ipex_llm.transformers.low_bit_linear import LowBitLinear, FP4Params, \
FP16Linear, BF16Linear
Expand Down Expand Up @@ -337,36 +338,70 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
if in_features % 64 != 0:
# now our kernel requires in_features is a multiple of 64
continue
new_linear = LowBitLinear(
in_features,
out_features,
qtype,
module.bias is not None,
mp_group=mp_group,
enable_xetla=enable_xetla,
optimize_lm_head=optimize_lm_head
)
cur_qtype, cur_imatrix = get_cur_qtype_and_imatrix(qtype,
full_module_name,
imatrix_data,
model_config)
# mixed precison for lm_head
if mixed_precision and is_lm_head(name, model_config, out_features):
if cur_qtype in [ggml_tensor_qtype["sym_int4"],
ggml_tensor_qtype["asym_int4"]]:
cur_qtype = ggml_tensor_qtype["sym_int8"]
device = module.weight.data.device
# Copy the weights
paramsLowBit = FP4Params(data=module.weight.data,
requires_grad=False,
quantized=False,
_shape=None,
convert_shape_only=convert_shape_only,
qtype=cur_qtype,
imatrix=cur_imatrix,
in_features=in_features,
enable_xetla=enable_xetla).to(device)
new_linear._parameters['weight'] = paramsLowBit
if enable_deepspeed_zero3:
cur_qtype, cur_imatrix = get_cur_qtype_and_imatrix(qtype,
full_module_name,
imatrix_data,
model_config)
# mixed precison for lm_head
if mixed_precision and is_lm_head(name, model_config, out_features):
if cur_qtype in [ggml_tensor_qtype["sym_int4"],
ggml_tensor_qtype["asym_int4"]]:
cur_qtype = ggml_tensor_qtype["sym_int8"]
device = module.weight.data.device
# Copy the weights
new_weight = FP4Params(data=module.weight.data,
requires_grad=False,
quantized=False,
_shape=None,
convert_shape_only=convert_shape_only,
qtype=cur_qtype,
imatrix=cur_imatrix,
in_features=in_features,
enable_xetla=enable_xetla,
enable_deepspeed_zero3=enable_deepspeed_zero3).to(device)
new_linear = LowBitLinear(
in_features,
out_features,
qtype,
module.bias is not None,
mp_group=mp_group,
enable_xetla=enable_xetla,
optimize_lm_head=optimize_lm_head,
enable_deepspeed_zero3=enable_deepspeed_zero3,
weight=new_weight
)
else:
new_linear = LowBitLinear(
in_features,
out_features,
qtype,
module.bias is not None,
mp_group=mp_group,
enable_xetla=enable_xetla,
optimize_lm_head=optimize_lm_head
)
cur_qtype, cur_imatrix = get_cur_qtype_and_imatrix(qtype,
full_module_name,
imatrix_data,
model_config)
# mixed precison for lm_head
if mixed_precision and is_lm_head(name, model_config, out_features):
if cur_qtype in [ggml_tensor_qtype["sym_int4"],
ggml_tensor_qtype["asym_int4"]]:
cur_qtype = ggml_tensor_qtype["sym_int8"]
device = module.weight.data.device
# Copy the weights
paramsLowBit = FP4Params(data=module.weight.data,
requires_grad=False,
quantized=False,
_shape=None,
convert_shape_only=convert_shape_only,
qtype=cur_qtype,
imatrix=cur_imatrix,
in_features=in_features,
enable_xetla=enable_xetla).to(device)
new_linear._parameters['weight'] = paramsLowBit
if module.bias is not None:
new_linear._parameters['bias'] = nn.Parameter(module.bias.data)\
.to(device)
Expand Down Expand Up @@ -757,7 +792,8 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
imatrix_data=None,
embedding_qtype=None,
enable_xetla=False,
mixed_precision=False):
mixed_precision=False,
enable_deepspeed_zero3=False):
logger.info(f"Converting the current model to "
f"{list(ggml_tensor_qtype.keys())[list(ggml_tensor_qtype.values()).index(qtype)]} "
f"format......")
Expand Down Expand Up @@ -788,6 +824,7 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
enable_xetla=enable_xetla,
mixed_precision=mixed_precision,
act_order=act_order,
enable_deepspeed_zero3=enable_deepspeed_zero3,
)
if not has_been_replaced:
warnings.warn(
Expand Down
Loading
Loading