Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deepspeed Zero3 QLoRA Fine-tuning #11048

Closed
wants to merge 3 commits into from

Conversation

Uxito-Ada
Copy link
Contributor

Description

Use Deepspeed Zero3 to split and distribute layers of a large model to multiple XPUs and executes QLoRA fine-tuning.

1. Why the change?

as above

2. User API changes

append enable_deepspeed_zero3 in from_pretrained and qlora.py

3. Summary of the change

as above

4. How to test?

  • N/A
  • Unit test
  • Application test
  • Document test
  • ...

@@ -524,7 +536,10 @@ class MatMulLowBit(torch.autograd.Function):
def forward(ctx, A, weight, input_seq_size):
ctx.is_empty = False
import linear_q4_0
result = linear_q4_0.forward_new(A, weight.data, weight.qtype, input_seq_size)
if hasattr(weight, "enable_deepspeed_zero3") and weight.enable_deepspeed_zero3:
result = linear_q4_0.forward_new(A, weight.data.byte(), NF4, input_seq_size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cyita Please check if this packing can get correct result.

@qiyuangong qiyuangong requested a review from cyita May 22, 2024 06:36
@qiyuangong
Copy link
Contributor

Please resolve the conflict with rebase.
Other LGBM.

tokenizer = LlamaTokenizer.from_pretrained(base_model, trust_remote_code=True)
print(f"Tokenizer loaded on rank {os.environ.get('LOCAL_RANK')}")

tokenizer.pad_token_id = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tokenizer.pad_token_id = (
        0  # unk. we want this to be different from the eos token
    )
    tokenizer.padding_side = "left"  # Allow batched inference

This code is not necessary anymore.

dst_tensor = torch.empty(dst_size, dtype=torch.uint8,
device=device)
if enable_deepspeed_zero3:
dst_tensor = torch.empty(dst_size, dtype=torch.bfloat16,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If using torch.bfloat16, I think we only need half the dst_size, right? (assuming dst_size is a even number)

import accelerate
import transformers

from transformers import AutoTokenizer, BitsAndBytesConfig, AutoConfig, AutoModelForCausalLM
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not using the AutoModelForCausalLM from ipex-llm?


model_config = model_config = AutoConfig.from_pretrained(base_model)
with ds.zero.Init(config_dict_or_path=deepspeed):
model = AutoModelForCausalLM.from_pretrained(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not setting load_in_low_bit?

@@ -524,7 +536,10 @@ class MatMulLowBit(torch.autograd.Function):
def forward(ctx, A, weight, input_seq_size):
ctx.is_empty = False
import linear_q4_0
result = linear_q4_0.forward_new(A, weight.data, weight.qtype, input_seq_size)
if hasattr(weight, "enable_deepspeed_zero3") and weight.enable_deepspeed_zero3:
result = linear_q4_0.forward_new(A, weight.data.byte(), NF4, input_seq_size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weight.data.byte() means converting every element to torch.uint8. Do you mean something like weight.data.view(torch.uint8)?

Copy link
Contributor

@yangw1234 yangw1234 Jun 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And shouldn't the qtype be the same as the load_in_low_bit in from_pretrained, instead of hard-coded to NF4?

@Uxito-Ada
Copy link
Contributor Author

deprecated

@Uxito-Ada Uxito-Ada closed this Jul 19, 2024
@Uxito-Ada Uxito-Ada mentioned this pull request Jul 19, 2024
7 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants