-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Deepspeed Zero3 QLoRA Fine-tuning #11048
Conversation
@@ -524,7 +536,10 @@ class MatMulLowBit(torch.autograd.Function): | |||
def forward(ctx, A, weight, input_seq_size): | |||
ctx.is_empty = False | |||
import linear_q4_0 | |||
result = linear_q4_0.forward_new(A, weight.data, weight.qtype, input_seq_size) | |||
if hasattr(weight, "enable_deepspeed_zero3") and weight.enable_deepspeed_zero3: | |||
result = linear_q4_0.forward_new(A, weight.data.byte(), NF4, input_seq_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@cyita Please check if this packing can get correct result.
Please resolve the conflict with rebase. |
tokenizer = LlamaTokenizer.from_pretrained(base_model, trust_remote_code=True) | ||
print(f"Tokenizer loaded on rank {os.environ.get('LOCAL_RANK')}") | ||
|
||
tokenizer.pad_token_id = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tokenizer.pad_token_id = (
0 # unk. we want this to be different from the eos token
)
tokenizer.padding_side = "left" # Allow batched inference
This code is not necessary anymore.
dst_tensor = torch.empty(dst_size, dtype=torch.uint8, | ||
device=device) | ||
if enable_deepspeed_zero3: | ||
dst_tensor = torch.empty(dst_size, dtype=torch.bfloat16, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If using torch.bfloat16
, I think we only need half the dst_size
, right? (assuming dst_size
is a even number)
import accelerate | ||
import transformers | ||
|
||
from transformers import AutoTokenizer, BitsAndBytesConfig, AutoConfig, AutoModelForCausalLM |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not using the AutoModelForCausalLM
from ipex-llm
?
|
||
model_config = model_config = AutoConfig.from_pretrained(base_model) | ||
with ds.zero.Init(config_dict_or_path=deepspeed): | ||
model = AutoModelForCausalLM.from_pretrained( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not setting load_in_low_bit
?
@@ -524,7 +536,10 @@ class MatMulLowBit(torch.autograd.Function): | |||
def forward(ctx, A, weight, input_seq_size): | |||
ctx.is_empty = False | |||
import linear_q4_0 | |||
result = linear_q4_0.forward_new(A, weight.data, weight.qtype, input_seq_size) | |||
if hasattr(weight, "enable_deepspeed_zero3") and weight.enable_deepspeed_zero3: | |||
result = linear_q4_0.forward_new(A, weight.data.byte(), NF4, input_seq_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
weight.data.byte()
means converting every element to torch.uint8
. Do you mean something like weight.data.view(torch.uint8)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And shouldn't the qtype
be the same as the load_in_low_bit
in from_pretrained
, instead of hard-coded to NF4
?
deprecated |
Description
Use Deepspeed Zero3 to split and distribute layers of a large model to multiple XPUs and executes QLoRA fine-tuning.
1. Why the change?
as above
2. User API changes
append enable_deepspeed_zero3 in from_pretrained and qlora.py
3. Summary of the change
as above
4. How to test?