Skip to content

Commit

Permalink
deepspeed zero3 QLoRA finetuning (#11625)
Browse files Browse the repository at this point in the history
* deepspeed zero3 QLoRA finetuning

* Update convert.py

* Update low_bit_linear.py

* Update utils.py

* Update qlora_finetune_llama2_13b_arch_2_card.sh

* Update low_bit_linear.py

* Update alpaca_qlora_finetuning.py

* Update low_bit_linear.py

* Update utils.py

* Update convert.py

* Update alpaca_qlora_finetuning.py

* Update alpaca_qlora_finetuning.py

* Update low_bit_linear.py

* Update deepspeed_zero3.json

* Update qlora_finetune_llama2_13b_arch_2_card.sh

* Update low_bit_linear.py

* Update low_bit_linear.py

* Update utils.py

* fix style

* fix style

* Update alpaca_qlora_finetuning.py

* Update qlora_finetune_llama2_13b_arch_2_card.sh

* Update convert.py

* Update low_bit_linear.py

* Update model.py

* Update alpaca_qlora_finetuning.py

* Update low_bit_linear.py

* Update low_bit_linear.py

* Update low_bit_linear.py
  • Loading branch information
Uxito-Ada authored Aug 13, 2024
1 parent a184b12 commit 70c828b
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,14 @@ def train(

prompter = Prompter(prompt_template_name)

if deepspeed is not None and "zero3" in deepspeed:
from ipex_llm.transformers.utils \
import _constant_buffered_norm2
from ipex_llm.llm_patching import replace_attr
import deepspeed as ds
replace_attr(ds.runtime.zero.stage3.DeepSpeedZeroOptimizer_Stage3,
"_constant_buffered_norm2", _constant_buffered_norm2)

device_map = "auto"
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
Expand All @@ -161,7 +169,7 @@ def train(
optimize_model=False,
torch_dtype=torch.bfloat16,
modules_to_not_convert=["lm_head"],
trust_remote_code=True,
trust_remote_code=True
)
else:
# According to the QLoRA paper, using "nf4" could yield better model quality than "int4"
Expand All @@ -186,9 +194,10 @@ def train(
# # device_map=device_map,
# modules_to_not_convert=["lm_head"],
# )
print(f"Model loaded on rank {os.environ.get('LOCAL_RANK')}")
model = model.to(f'xpu:{os.environ.get("LOCAL_RANK", 0)}')
print(f"Model moved to rank {os.environ.get('LOCAL_RANK')}")
if deepspeed is not None and not "zero3" in deepspeed:
print(f"Model loaded on rank {os.environ.get('LOCAL_RANK')}")
model = model.to(f'xpu:{os.environ.get("LOCAL_RANK", 0)}')
print(f"Model moved to rank {os.environ.get('LOCAL_RANK')}")

tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
print(f"Tokenizer loaded on rank {os.environ.get('LOCAL_RANK')}")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
"zero_optimization": {
"stage": 3,
"contiguous_gradients": true,
"overlap_comm": true,
"offload_optimizer": {"device": "cpu"}
},
"bf16": {
"enabled": true
},
"world_size": 2,
"train_batch_size": 32,
"train_micro_batch_size_per_gpu": 2,
"gradient_accumulation_steps": 8
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

export MASTER_ADDR=127.0.0.1
export MASTER_PORT=29503
export FI_PROVIDER=tcp
export CCL_ATL_TRANSPORT=ofi
export CCL_ZE_IPC_EXCHANGE=sockets
export UR_L0_IN_ORDER_BARRIER_BY_SIGNAL=0
basekit_root=/opt/intel/oneapi
source $basekit_root/setvars.sh --force
source $basekit_root/ccl/latest/env/vars.sh --force

NUM_GPUS=2 # number of used GPU
export USE_XETLA=OFF
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=2
export TORCH_LLM_ALLREDUCE=0 # Different from PVC
export DS_SKIP_CUDA_CHECK=1

mpirun -n $NUM_GPUS \
python -u ./alpaca_qlora_finetuning.py \
--base_model "meta-llama/Llama-2-13b-hf" \
--data_path "yahma/alpaca-cleaned" \
--output_dir "./ipex-llm-qlora-alpaca" \
--gradient_checkpointing True \
--micro_batch_size 2 \
--batch_size 32 \
--deepspeed ./deepspeed_zero3.json
47 changes: 37 additions & 10 deletions python/llm/src/ipex_llm/transformers/low_bit_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,13 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
dst_tensor = dst_tensor.reshape(tensor.shape[0], tensor.shape[-1] // QK)
scale = torch.empty(n // k, dtype=torch.float32,
device=device)
elif qtype == NF4:
# Deepspeed zero3 requires unified dtype,
# thus here uses bfloat16 consistent to other layers
# dst_size above is computed based on uint8, and for bfloat16,
# buffer size should be half
dst_tensor = torch.empty(dst_size // 2, dtype=torch.bfloat16,
device=device)
else:
dst_tensor = torch.empty(dst_size, dtype=torch.uint8,
device=device)
Expand Down Expand Up @@ -260,12 +267,15 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,


def ggml_q_format_convet_cpu2xpu(tensor: torch.Tensor, num_elem: int, qtype: int):

invalidInputError(tensor.dtype == torch.uint8,
"Input tensor must be uint8")
if qtype == NF4:
invalidInputError(tensor.dtype == torch.bfloat16,
"NF4 Input tensor must be bfloat16")
else:
invalidInputError(tensor.dtype == torch.uint8,
"Input tensor except NF4 must be uint8")

invalidInputError(tensor.device == torch.device('cpu'),
"Input tensor must be uint8")
"Input tensor must be on cpu")

src = ctypes.c_void_p(tensor.data.data_ptr())

Expand Down Expand Up @@ -370,7 +380,6 @@ def use_batch_forward(x: torch.Tensor, qtype: int, output_len: int):

# Rename to FP4Params to trigger initializing
# the params layer with all parameters on the CPU
# https://github.com/huggingface/accelerate/blob/main/src/accelerate/utils/modeling.py#L333
class FP4Params(torch.nn.Parameter):
def __new__(cls,
data=None,
Expand Down Expand Up @@ -582,7 +591,13 @@ class MatMulLowBit(torch.autograd.Function):
def forward(ctx, A, weight, input_seq_size):
ctx.is_empty = False
import xe_linear
result = xe_linear.forward_new(A, weight.data, weight.qtype, input_seq_size)
if weight.qtype == NF4:
result = xe_linear.forward_new(A,
weight.data.view(torch.uint8),
weight.qtype,
input_seq_size)
else:
result = xe_linear.forward_new(A, weight.data, weight.qtype, input_seq_size)
if any(ctx.needs_input_grad[:2]):
ctx.tensors = (A, weight)
else:
Expand All @@ -602,7 +617,12 @@ def backward(ctx, grad_output):
if req_gradA:
if torch.xpu.is_autocast_xpu_enabled():
grad_output = grad_output.to(torch.xpu.get_autocast_xpu_dtype())
dequant_weight = xe_linear.dequant(A, weight.data, weight.qtype)
if weight.qtype == NF4:
dequant_weight = xe_linear.dequant(A,
weight.data.view(torch.uint8),
weight.qtype)
else:
dequant_weight = xe_linear.dequant(A, weight.data, weight.qtype)
grad_A = torch.matmul(grad_output, dequant_weight.reshape(weight._shape))

return grad_A, grad_weight, None
Expand Down Expand Up @@ -737,9 +757,16 @@ def forward(self, x: torch.Tensor):
if x_2d.requires_grad:
result = MatMulLowBit.apply(x_2d, self.weight, input_seq_size)
else:
result = xe_linear.forward_new(x_2d, self.weight.data,
self.weight.qtype,
input_seq_size)
if self.weight.qtype == NF4:
result = xe_linear.forward_new(x_2d,
self.weight.data.view(torch.uint8),
self.weight.qtype,
input_seq_size)
else:
result = xe_linear.forward_new(x_2d,
self.weight.data,
self.weight.qtype,
input_seq_size)
elif self.enable_xetla:
x_2d = x_2d.half()
result = xe_linear.mm_xetla(x_2d, self.weight.data, self.qtype)
Expand Down
13 changes: 13 additions & 0 deletions python/llm/src/ipex_llm/transformers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,3 +382,16 @@ def check_hidden_size(qtype, hidden_size):
"required for fq6_k - using fallback quantization fp6.")
return ggml_tensor_qtype["fp6"]
return qtype


# Arc platfrom does not support FP64,
# Disable FP64 in DeepSpeedZeroOptimizer_Stage3's _constant_buffered_norm2 method
# https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/zero/stage3.py#L1365
def _constant_buffered_norm2(self, input, buffer_size=250000000):
norm = None
for part in input.view(-1).split(buffer_size):
if norm is None:
norm = part.data.norm(2)**2.0
else:
norm += part.data.norm(2)**2.0
return norm**0.5

0 comments on commit 70c828b

Please sign in to comment.