-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ChatGLM3-6B LoRA Fine-tuning Demo #11450
Merged
Merged
Changes from 4 commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
38b3bf6
ChatGLM3-6B LoRA Fine-tuning Demo
Uxito-Ada a3d8d60
refine
Uxito-Ada 5d1c83e
refine
Uxito-Ada 5e6ee70
add 2-card deepspeed
Uxito-Ada e088926
refine format
Uxito-Ada d1434bf
add mpi4py and deepspeed install
Uxito-Ada File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
148 changes: 148 additions & 0 deletions
148
python/llm/example/GPU/LLM-Finetuning/LoRA/chatglm_finetune/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
# LoRA Fine-Tuning on ChatGLM3-6B with IPEX-LLM | ||
|
||
This example ports [ChatGLM3-6B lora_finetune](https://github.com/THUDM/ChatGLM3/blob/main/finetune_demo/lora_finetune.ipynb) demo to IPEX-LLM on [Intel Arc GPU](../../README.md). | ||
|
||
### 1. Install | ||
|
||
```bash | ||
conda create -n llm python=3.11 | ||
conda activate llm | ||
pip install "jieba>=0.42.1" | ||
pip install "ruamel_yaml>=0.18.6" | ||
pip install "rouge_chinese>=1.0.3" | ||
pip install "jupyter>=1.0.0" | ||
pip install "datasets>=2.18.0" | ||
pip install "peft>=0.10.0" | ||
pip install typer | ||
pip install sentencepiece | ||
pip install nltk | ||
pip install "numpy<2.0.0" | ||
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default | ||
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ | ||
pip install oneccl_bind_pt==2.1.100 --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ | ||
``` | ||
|
||
### 2. Configures OneAPI Environment Variables | ||
```bash | ||
source /opt/intel/oneapi/setvars.sh | ||
``` | ||
|
||
### 3. LoRA Fine-Tune on ChatGLM3-6B | ||
|
||
First, download the dataset: we use `AdvertiseGen` to finetune ChatGLM3-6B in the following, and please now get it from [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing) or [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1), and unzip it in the current directory. Then, process the dataset with the below script: | ||
|
||
```bash | ||
python process_advertise_gen_dataset.py | ||
``` | ||
|
||
Then, './AdvertiseGen' will be converted to './AdvertiseGen_fix'. Now, we have prepared the dataset, and are going to start LoRA fine-tuning on ChatGLM3-6B. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ' changes to ` will high this work |
||
|
||
#### 3.1. Fine-Tune with a Single Arc Card | ||
|
||
Start the fine-tuning by: | ||
|
||
Uxito-Ada marked this conversation as resolved.
Show resolved
Hide resolved
|
||
```bash | ||
bash lora_finetuning_on_chatglm3_6b_with_1_arc_card.sh | ||
``` | ||
|
||
Then, you will get output are as below: | ||
|
||
```bash | ||
2024-06-27 13:47:02,680 - root - INFO - intel_extension_for_pytorch auto imported | ||
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████| 7/7 [00:01<00:00, 6.47it/s] | ||
2024-06-27 13:47:03,794 - ipex_llm.transformers.utils - INFO - Converting the current model to bf16 format...... | ||
[2024-06-27 13:47:04,105] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to xpu (auto detect) | ||
trainable params: 487,424 || all params: 6,244,071,424 || trainable%: 0.0078 | ||
PeftModelForCausalLM( | ||
(base_model): LoraModel( | ||
(model): ChatGLMForConditionalGeneration( | ||
(transformer): ChatGLMModel( | ||
(embedding): Embedding( | ||
(word_embeddings): Embedding(65024, 4096) | ||
) | ||
(rotary_pos_emb): RotaryEmbedding() | ||
(encoder): GLMTransformer( | ||
(layers): ModuleList( | ||
(0-27): 28 x GLMBlock( | ||
(input_layernorm): RMSNorm() | ||
(self_attention): SelfAttention( | ||
(query_key_value): LoraLowBitLinear( | ||
(base_layer): BF16Linear(in_features=4096, out_features=4608, bias=True) | ||
(lora_dropout): ModuleDict( | ||
(default): Dropout(p=0.1, inplace=False) | ||
) | ||
(lora_A): ModuleDict( | ||
(default): Linear(in_features=4096, out_features=2, bias=False) | ||
) | ||
(lora_B): ModuleDict( | ||
(default): Linear(in_features=2, out_features=4608, bias=False) | ||
) | ||
(lora_embedding_A): ParameterDict() | ||
(lora_embedding_B): ParameterDict() | ||
(qa_pool): Identity() | ||
) | ||
(core_attention): CoreAttention( | ||
(attention_dropout): Dropout(p=0.0, inplace=False) | ||
) | ||
(dense): BF16Linear(in_features=4096, out_features=4096, bias=False) | ||
) | ||
(post_attention_layernorm): RMSNorm() | ||
(mlp): MLP( | ||
(dense_h_to_4h): BF16Linear(in_features=4096, out_features=27392, bias=False) | ||
(dense_4h_to_h): BF16Linear(in_features=13696, out_features=4096, bias=False) | ||
) | ||
) | ||
) | ||
(final_layernorm): RMSNorm() | ||
) | ||
(output_layer): BF16Linear(in_features=4096, out_features=65024, bias=False) | ||
) | ||
) | ||
) | ||
) | ||
--> Model | ||
|
||
--> model has 0.487424M params | ||
|
||
train_dataset: Dataset({ | ||
features: ['input_ids', 'labels'], | ||
num_rows: 114599 | ||
}) | ||
val_dataset: Dataset({ | ||
features: ['input_ids', 'output_ids'], | ||
num_rows: 1070 | ||
}) | ||
test_dataset: Dataset({ | ||
features: ['input_ids', 'output_ids'], | ||
num_rows: 1070 | ||
}) | ||
--> Sanity check | ||
'[gMASK]': 64790 -> -100 | ||
'sop': 64792 -> -100 | ||
'<|user|>': 64795 -> -100 | ||
'': 30910 -> -100 | ||
'\n': 13 -> -100 | ||
...... | ||
|
||
# Here it takes time to finish the whole fine-tuning | ||
|
||
...... | ||
|
||
Training completed. Do not forget to share your model on huggingface.co/models =) | ||
|
||
|
||
{'train_runtime': xxxx.xxxx, 'train_samples_per_second': x.xxx, 'train_steps_per_second': x.xxx, 'train_loss': xx.xx, 'epoch': x.xx} | ||
100%|████████████████████████████████████████████████████████████████████████████████████████████| 3000/3000 [xx:xx<00:00, x.xxit/s] | ||
***** Running Prediction ***** | ||
Num examples = 1070 | ||
Batch size = 4 | ||
100%|██████████████████████████████████████████████████████████████████████████████████████████████| 268/268 [xx:xx<00:00, x.xxs/it] | ||
``` | ||
|
||
#### 3.2. Fine-Tune with 2 Arc Cards | ||
|
||
Start the data-parallel fine-tuning on 2 Intel Arc XPU cards by: | ||
|
||
```bash | ||
bash lora_finetuning_on_chatglm3_6b_with_2_arc_cards.sh | ||
``` |
15 changes: 15 additions & 0 deletions
15
python/llm/example/GPU/LLM-Finetuning/LoRA/chatglm_finetune/deepspeed_config.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
{ | ||
"zero_optimization": { | ||
"stage": 2, | ||
"offload_optimizer": { | ||
"device": "cpu" | ||
}, | ||
"contiguous_gradients": true, | ||
"overlap_comm": true | ||
}, | ||
"bf16": { | ||
"enabled": true | ||
}, | ||
"train_micro_batch_size_per_gpu": "auto", | ||
"gradient_accumulation_steps": "auto" | ||
} |
47 changes: 47 additions & 0 deletions
47
python/llm/example/GPU/LLM-Finetuning/LoRA/chatglm_finetune/lora_config.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# This is ported from https://github.com/THUDM/ChatGLM3/blob/main/finetune_demo/configs/lora.yaml | ||
data_config: | ||
train_file: train.json | ||
val_file: dev.json | ||
test_file: dev.json | ||
num_proc: 16 | ||
max_input_length: 128 | ||
max_output_length: 128 | ||
training_args: | ||
# see `transformers.Seq2SeqTrainingArguments` | ||
output_dir: ./output | ||
max_steps: 3000 | ||
# needed to be fit for the dataset | ||
learning_rate: 5e-5 | ||
# settings for data loading | ||
per_device_train_batch_size: 1 | ||
dataloader_num_workers: 16 | ||
remove_unused_columns: false | ||
# settings for saving checkpoints | ||
save_strategy: steps | ||
save_steps: 500 | ||
# settings for logging | ||
log_level: info | ||
logging_strategy: steps | ||
logging_steps: 10 | ||
# settings for evaluation | ||
per_device_eval_batch_size: 4 | ||
evaluation_strategy: steps | ||
eval_steps: 1000 | ||
# settings for optimizer | ||
# adam_epsilon: 1e-6 | ||
# uncomment the following line to detect nan or inf values | ||
# debug: underflow_overflow | ||
predict_with_generate: true | ||
# see `transformers.GenerationConfig` | ||
generation_config: | ||
max_new_tokens: 128 | ||
# set your absolute deepspeed path here | ||
#deepspeed: ds_zero_2.json | ||
# set to true if train with cpu. | ||
use_cpu: false | ||
peft_config: | ||
peft_type: LORA | ||
task_type: CAUSAL_LM | ||
r: 2 | ||
lora_alpha: 8 | ||
lora_dropout: 0.1 |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Single ARC doesn't need oneccl
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is necessary, as XPU accelerator needs CCL. Without CCL, accelerator will switch to CUDA, and trainer will schedule model to CPU rather than XPU.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK