-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
deepspeed zero3 QLoRA finetuning #11625
Changes from 15 commits
526aa20
ce04901
e8c083a
baec9e9
a329756
2f7ba16
65d5403
6bd5811
876266a
154a110
dc2bb4d
ccd53ee
8b3e9e4
1f53ba8
3f4b35b
a69d038
dad4684
6df300c
4d243ff
3f3d612
2ab7220
5118595
13884f5
486da9c
2d8550f
95c252b
369c369
ce99bca
17dbf80
1c8cd6c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
{ | ||
"zero_optimization": { | ||
"stage": 3, | ||
"contiguous_gradients": true, | ||
"overlap_comm": true, | ||
"offload_optimizer": {"device": "cpu"} | ||
}, | ||
"bf16": { | ||
"enabled": true | ||
}, | ||
"world_size": 2, | ||
"train_batch_size": 32, | ||
"train_micro_batch_size_per_gpu": 2, | ||
"gradient_accumulation_steps": 8 | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# | ||
# Copyright 2016 The BigDL Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
export MASTER_ADDR=127.0.0.1 | ||
export MASTER_PORT=29503 | ||
export FI_PROVIDER=tcp | ||
export CCL_ATL_TRANSPORT=ofi | ||
export CCL_ZE_IPC_EXCHANGE=sockets | ||
export UR_L0_IN_ORDER_BARRIER_BY_SIGNAL=0 | ||
basekit_root=/opt/intel/oneapi | ||
source $basekit_root/setvars.sh --force | ||
source $basekit_root/ccl/latest/env/vars.sh --force | ||
|
||
NUM_GPUS=2 # number of used GPU | ||
export USE_XETLA=OFF | ||
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=2 | ||
export TORCH_LLM_ALLREDUCE=0 # Different from PVC | ||
export DS_SKIP_CUDA_CHECK=1 | ||
export IPEX_LLM_ENABLE_DEEPSPEED_ZERO3=1 | ||
|
||
mpirun -n $NUM_GPUS \ | ||
python -u ./alpaca_qlora_finetuning.py \ | ||
--base_model "meta-llama/Llama-2-13b-hf" \ | ||
--data_path "yahma/alpaca-cleaned" \ | ||
--output_dir "./ipex-llm-qlora-alpaca" \ | ||
--gradient_checkpointing True \ | ||
--micro_batch_size 2 \ | ||
--batch_size 32 \ | ||
--deepspeed ./deepspeed_zero3.json |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -208,7 +208,8 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int, | |
device=None, convert_shape_only=False, | ||
imatrix: torch.Tensor=None, | ||
in_features: int=None, | ||
enable_scale_search: bool=False): | ||
enable_scale_search: bool=False, | ||
enable_deepspeed_zero3: bool=False): | ||
QK = ggml.ggml_qk_size(qtype) | ||
block_size_in_bytes = ggml.ggml_type_size(qtype) | ||
|
||
|
@@ -229,8 +230,12 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int, | |
scale = torch.empty(n // k, dtype=torch.float32, | ||
device=device) | ||
else: | ||
dst_tensor = torch.empty(dst_size, dtype=torch.uint8, | ||
device=device) | ||
if enable_deepspeed_zero3: | ||
dst_tensor = torch.empty(dst_size // 2, dtype=torch.bfloat16, | ||
device=device) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should always do that for NF4 (only)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Other NF4s are packed in torch.uint8, which do not make the buffer length redundant. |
||
else: | ||
dst_tensor = torch.empty(dst_size, dtype=torch.uint8, | ||
device=device) | ||
|
||
if not convert_shape_only and device != 'meta': | ||
dst = ctypes.c_void_p(dst_tensor.data.data_ptr()) | ||
|
@@ -259,9 +264,12 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int, | |
|
||
|
||
def ggml_q_format_convet_cpu2xpu(tensor: torch.Tensor, num_elem: int, qtype: int): | ||
import os | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Move os import to top, because other module may share this import. |
||
enable_deepspeed_zero3 = (os.getenv("IPEX_LLM_ENABLE_DEEPSPEED_ZERO3", "0") == "1") | ||
|
||
invalidInputError(tensor.dtype == torch.uint8, | ||
"Input tensor must be uint8") | ||
if not enable_deepspeed_zero3: | ||
invalidInputError(tensor.dtype == torch.uint8, | ||
"Input tensor must be uint8") | ||
|
||
invalidInputError(tensor.device == torch.device('cpu'), | ||
"Input tensor must be uint8") | ||
|
@@ -381,7 +389,8 @@ def __new__(cls, | |
imatrix=None, | ||
in_features=None, | ||
enable_xetla=False, | ||
enable_scale_search=False): | ||
enable_scale_search=False, | ||
enable_deepspeed_zero3=False): | ||
if data is None: | ||
data = torch.empty(0) | ||
|
||
|
@@ -395,6 +404,7 @@ def __new__(cls, | |
self.in_features = in_features | ||
self.enable_xetla = enable_xetla | ||
self.enable_scale_search = enable_scale_search | ||
self.enable_deepspeed_zero3 = enable_deepspeed_zero3 | ||
return self | ||
|
||
def ggml_mse(self, w, ggml_qtype, device): | ||
|
@@ -453,7 +463,8 @@ def quantize(self, device=None): | |
convert_shape_only=self.convert_shape_only, | ||
imatrix=self.imatrix, | ||
in_features=self.in_features, | ||
enable_scale_search=self.enable_scale_search) | ||
enable_scale_search=self.enable_scale_search, | ||
enable_deepspeed_zero3=self.enable_deepspeed_zero3) | ||
self.data = w_quantized | ||
self.quantized = True | ||
self._shape = w.shape | ||
|
@@ -581,7 +592,13 @@ class MatMulLowBit(torch.autograd.Function): | |
def forward(ctx, A, weight, input_seq_size): | ||
ctx.is_empty = False | ||
import xe_linear | ||
result = xe_linear.forward_new(A, weight.data, weight.qtype, input_seq_size) | ||
if hasattr(weight, "enable_deepspeed_zero3") and weight.enable_deepspeed_zero3: | ||
result = xe_linear.forward_new(A, | ||
weight.data.view(torch.uint8), | ||
weight.qtype, | ||
input_seq_size) | ||
else: | ||
result = xe_linear.forward_new(A, weight.data, weight.qtype, input_seq_size) | ||
if any(ctx.needs_input_grad[:2]): | ||
ctx.tensors = (A, weight) | ||
else: | ||
|
@@ -601,7 +618,12 @@ def backward(ctx, grad_output): | |
if req_gradA: | ||
if torch.xpu.is_autocast_xpu_enabled(): | ||
grad_output = grad_output.to(torch.xpu.get_autocast_xpu_dtype()) | ||
dequant_weight = xe_linear.dequant(A, weight.data, weight.qtype) | ||
if hasattr(weight, "enable_deepspeed_zero3") and weight.enable_deepspeed_zero3: | ||
dequant_weight = xe_linear.dequant(A, | ||
weight.data.view(torch.uint8), | ||
weight.qtype) | ||
else: | ||
dequant_weight = xe_linear.dequant(A, weight.data, weight.qtype) | ||
grad_A = torch.matmul(grad_output, dequant_weight.reshape(weight._shape)) | ||
|
||
return grad_A, grad_weight, None | ||
|
@@ -640,13 +662,15 @@ class LowBitLinear(nn.Linear): | |
def __init__(self, input_features, output_features, qtype, bias=True, | ||
conver_to_half=True, mp_group=None, enable_xetla=False, | ||
optimize_lm_head=False, act_order=False, | ||
enable_scale_search=False): | ||
enable_scale_search=False, | ||
enable_deepspeed_zero3=False): | ||
super().__init__(input_features, output_features, bias) | ||
self.weight = FP4Params(self.weight.data, | ||
requires_grad=False, | ||
quantized=False, _shape=None, qtype=qtype, | ||
enable_xetla=enable_xetla, | ||
enable_scale_search=enable_scale_search) | ||
enable_scale_search=enable_scale_search, | ||
enable_deepspeed_zero3=enable_deepspeed_zero3) | ||
self.in_len = input_features | ||
self.out_len = output_features | ||
self.weight_shape = (self.out_len, self.in_len) | ||
|
@@ -666,6 +690,7 @@ def __init__(self, input_features, output_features, qtype, bias=True, | |
self.is_lm_head = self.in_len * self.out_len >= 32000 * 4096 and self.bias is None | ||
self.low_memory_mode = self.is_lm_head | ||
self.act_order = act_order | ||
self.enable_deepspeed_zero3 = enable_deepspeed_zero3 | ||
if act_order: | ||
self.register_buffer( | ||
"g_idx_map", | ||
|
@@ -736,9 +761,17 @@ def forward(self, x: torch.Tensor): | |
if x_2d.requires_grad: | ||
result = MatMulLowBit.apply(x_2d, self.weight, input_seq_size) | ||
else: | ||
result = xe_linear.forward_new(x_2d, self.weight.data, | ||
self.weight.qtype, | ||
input_seq_size) | ||
if hasattr(self.weight, "enable_deepspeed_zero3") \ | ||
and self.weight.enable_deepspeed_zero3: | ||
result = xe_linear.forward_new(x_2d, | ||
self.weight.data.view(torch.uint8), | ||
self.weight.qtype, | ||
input_seq_size) | ||
else: | ||
result = xe_linear.forward_new(x_2d, | ||
self.weight.data, | ||
self.weight.qtype, | ||
input_seq_size) | ||
elif self.enable_xetla: | ||
x_2d = x_2d.half() | ||
result = xe_linear.mm_xetla(x_2d, self.weight.data, self.qtype) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -454,6 +454,7 @@ def load_convert(cls, q_k, optimize_model, *args, **kwargs): | |
if embedding_qtype is not None: | ||
embedding_qtype = ggml_tensor_qtype[embedding_qtype] | ||
enable_xetla = kwargs.pop("enable_xetla", False) | ||
enable_deepspeed_zero3 = kwargs.pop("enable_deepspeed_zero3", False) | ||
_args = copy.deepcopy(args) | ||
_kwargs = copy.deepcopy(kwargs) | ||
awq_config = None | ||
|
@@ -524,7 +525,8 @@ def load_convert(cls, q_k, optimize_model, *args, **kwargs): | |
imatrix_data=imatrix_data, | ||
embedding_qtype=embedding_qtype, | ||
enable_xetla=enable_xetla, | ||
mixed_precision=mixed_precision) | ||
mixed_precision=mixed_precision, | ||
enable_deepspeed_zero3=enable_deepspeed_zero3) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we want to introduce this user-level parameter; we should either change all NF4 to BF16, or all training (QLoRA) NF4 to BF16, instead of doing something special for zero3 only. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. pls take a look again @jason-dai @qiyuangong |
||
|
||
if disk_embedding: | ||
from ipex_llm.transformers.embedding import DiskEmbedding | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -361,3 +361,15 @@ def get_modelscope_hf_config(model_id_or_path: str, | |
def is_torch_bf16_gpu_available(): | ||
# always true for XPU and CPU | ||
return True | ||
|
||
|
||
# Arc platfrom does not support FP64, | ||
# Disable FP64 in DeepSpeedZeroOptimizer_Stage3's _constant_buffered_norm2 method | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's different between our implementation and ds's one? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ds is double(), fp64 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK |
||
def _constant_buffered_norm2(self, input, buffer_size=250000000): | ||
norm = None | ||
for part in input.view(-1).split(buffer_size): | ||
if norm is None: | ||
norm = part.data.norm(2)**2.0 | ||
else: | ||
norm += part.data.norm(2)**2.0 | ||
return norm**0.5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add comments for magic value 2 and hard-coded type bfloat16.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done