-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
deepspeed zero3 QLoRA finetuning #11625
Conversation
if enable_deepspeed_zero3: | ||
dst_tensor = torch.empty(dst_size // 2, dtype=torch.bfloat16, | ||
device=device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should always do that for NF4 (only)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Other NF4s are packed in torch.uint8, which do not make the buffer length redundant.
Only deepspeed zero3 needs NF4 to be packed in torch.bfloat16, which needs to halve the buffer.
@@ -259,9 +264,12 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int, | |||
|
|||
|
|||
def ggml_q_format_convet_cpu2xpu(tensor: torch.Tensor, num_elem: int, qtype: int): | |||
import os |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move os import to top, because other module may share this import.
dst_tensor = torch.empty(dst_size, dtype=torch.uint8, | ||
device=device) | ||
if enable_deepspeed_zero3: | ||
dst_tensor = torch.empty(dst_size // 2, dtype=torch.bfloat16, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add comments for magic value 2 and hard-coded type bfloat16.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
|
||
|
||
# Arc platfrom does not support FP64, | ||
# Disable FP64 in DeepSpeedZeroOptimizer_Stage3's _constant_buffered_norm2 method |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's different between our implementation and ds's one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ds is double(), fp64
here removes double(), as Arc does not support fp64
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK
Any more comment or approve? @qiyuangong |
@@ -524,7 +525,8 @@ def load_convert(cls, q_k, optimize_model, *args, **kwargs): | |||
imatrix_data=imatrix_data, | |||
embedding_qtype=embedding_qtype, | |||
enable_xetla=enable_xetla, | |||
mixed_precision=mixed_precision) | |||
mixed_precision=mixed_precision, | |||
enable_deepspeed_zero3=enable_deepspeed_zero3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we want to introduce this user-level parameter; we should either change all NF4 to BF16, or all training (QLoRA) NF4 to BF16, instead of doing something special for zero3 only.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pls take a look again @jason-dai @qiyuangong
|
||
invalidInputError(tensor.dtype == torch.uint8, | ||
"Input tensor must be uint8") | ||
invalidInputError(tensor.dtype == torch.bfloat16, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will this change impact other features?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NF4 applications e.g. QLoRA (zero2) will not be influenced. While maybe better add judgement qtype == NF4
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Passed PR validation. |
Description
transferred from #11048
1. Why the change?
2. User API changes
3. Summary of the change
4. How to test?
1234
). And paste your action link here once it has been successfully finished.5. New dependencies
- Dependency1
- Dependency2
- ...
- Dependency1 and license1
- Dependency2 and license2
- ...