Skip to content

Commit

Permalink
Support internlm2.5 (#803)
Browse files Browse the repository at this point in the history
* add internlm 2.5 configs

* update readme
  • Loading branch information
HIT-cwh authored Jul 3, 2024
1 parent b98d413 commit 44749c2
Show file tree
Hide file tree
Showing 5 changed files with 688 additions and 36 deletions.
31 changes: 12 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ English | [简体中文](README_zh-CN.md)

## 🎉 News

- **\[2024/07\]** Support [InternLM 2.5](xtuner/configs/internlm/internlm2_5_chat_7b/) models!
- **\[2024/06\]** Support [DeepSeek V2](xtuner/configs/deepseek/deepseek_v2_chat/) models! **2x faster!**
- **\[2024/04\]** [LLaVA-Phi-3-mini](https://huggingface.co/xtuner/llava-phi-3-mini-hf) is released! Click [here](xtuner/configs/llava/phi3_mini_4k_instruct_clip_vit_large_p14_336) for details!
- **\[2024/04\]** [LLaVA-Llama-3-8B](https://huggingface.co/xtuner/llava-llama-3-8b) and [LLaVA-Llama-3-8B-v1.1](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1) are released! Click [here](xtuner/configs/llava/llama3_8b_instruct_clip_vit_large_p14_336) for details!
- **\[2024/04\]** Support [Llama 3](xtuner/configs/llama) models!
Expand Down Expand Up @@ -100,16 +102,15 @@ XTuner is an efficient, flexible and full-featured toolkit for fine-tuning large
<tr valign="top">
<td align="left" valign="top">
<ul>
<li><a href="https://huggingface.co/internlm">InternLM2</a></li>
<li><a href="https://huggingface.co/meta-llama">Llama 3</a></li>
<li><a href="https://huggingface.co/meta-llama">Llama 2</a></li>
<li><a href="https://huggingface.co/internlm">InternLM2 / 2.5</a></li>
<li><a href="https://huggingface.co/meta-llama">Llama 2 / 3</a></li>
<li><a href="https://huggingface.co/collections/microsoft/phi-3-6626e15e9585a200d2d761e3">Phi-3</a></li>
<li><a href="https://huggingface.co/THUDM/chatglm2-6b">ChatGLM2</a></li>
<li><a href="https://huggingface.co/THUDM/chatglm3-6b">ChatGLM3</a></li>
<li><a href="https://huggingface.co/Qwen/Qwen-7B">Qwen</a></li>
<li><a href="https://huggingface.co/baichuan-inc/Baichuan2-7B-Base">Baichuan2</a></li>
<li><a href="https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1">Mixtral 8x7B</a></li>
<li><a href="https://huggingface.co/deepseek-ai/deepseek-moe-16b-chat">DeepSeek MoE</a></li>
<li><a href="https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1">Mixtral</a></li>
<li><a href="https://huggingface.co/deepseek-ai/DeepSeek-V2-Chat">DeepSeek V2</a></li>
<li><a href="https://huggingface.co/google">Gemma</a></li>
<li>...</li>
</ul>
Expand Down Expand Up @@ -203,14 +204,14 @@ XTuner supports the efficient fine-tune (*e.g.*, QLoRA) for LLMs. Dataset prepar
xtuner train ${CONFIG_NAME_OR_PATH}
```

For example, we can start the QLoRA fine-tuning of InternLM2-Chat-7B with oasst1 dataset by
For example, we can start the QLoRA fine-tuning of InternLM2.5-Chat-7B with oasst1 dataset by

```shell
# On a single GPU
xtuner train internlm2_chat_7b_qlora_oasst1_e3 --deepspeed deepspeed_zero2
xtuner train internlm2_5_chat_7b_qlora_oasst1_e3 --deepspeed deepspeed_zero2
# On multiple GPUs
(DIST) NPROC_PER_NODE=${GPU_NUM} xtuner train internlm2_chat_7b_qlora_oasst1_e3 --deepspeed deepspeed_zero2
(SLURM) srun ${SRUN_ARGS} xtuner train internlm2_chat_7b_qlora_oasst1_e3 --launcher slurm --deepspeed deepspeed_zero2
(DIST) NPROC_PER_NODE=${GPU_NUM} xtuner train internlm2_5_chat_7b_qlora_oasst1_e3 --deepspeed deepspeed_zero2
(SLURM) srun ${SRUN_ARGS} xtuner train internlm2_5_chat_7b_qlora_oasst1_e3 --launcher slurm --deepspeed deepspeed_zero2
```

- `--deepspeed` means using [DeepSpeed](https://github.com/microsoft/DeepSpeed) 🚀 to optimize the training. XTuner comes with several integrated strategies including ZeRO-1, ZeRO-2, and ZeRO-3. If you wish to disable this feature, simply remove this argument.
Expand All @@ -231,18 +232,10 @@ XTuner provides tools to chat with pretrained / fine-tuned LLMs.
xtuner chat ${NAME_OR_PATH_TO_LLM} --adapter {NAME_OR_PATH_TO_ADAPTER} [optional arguments]
```

For example, we can start the chat with

InternLM2-Chat-7B with adapter trained from oasst1 dataset:

```shell
xtuner chat internlm/internlm2-chat-7b --adapter xtuner/internlm2-chat-7b-qlora-oasst1 --prompt-template internlm2_chat
```

LLaVA-InternLM2-7B:
For example, we can start the chat with InternLM2.5-Chat-7B :

```shell
xtuner chat internlm/internlm2-chat-7b --visual-encoder openai/clip-vit-large-patch14-336 --llava xtuner/llava-internlm2-7b --prompt-template internlm2_chat --image $IMAGE_PATH
xtuner chat internlm/internlm2_5-chat-7b --prompt-template internlm2_chat
```

For more examples, please see [chat.md](./docs/en/user_guides/chat.md).
Expand Down
29 changes: 12 additions & 17 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@

## 🎉 更新

- **\[2024/07\]** 支持 [InternLM 2.5](xtuner/configs/internlm/internlm2_5_chat_7b/) 模型!
- **\[2024/06\]** 支持 [DeepSeek V2](xtuner/configs/deepseek/deepseek_v2_chat/) models! **训练速度提升一倍!**
- **\[2024/04\]** 多模态大模型 [LLaVA-Phi-3-mini](https://huggingface.co/xtuner/llava-phi-3-mini-hf) 发布!快速开始请查阅此[文档](xtuner/configs/llava/phi3_mini_4k_instruct_clip_vit_large_p14_336)
- **\[2024/04\]** 多模态大模型 [LLaVA-Llama-3-8B](https://huggingface.co/xtuner/llava-llama-3-8b)[LLaVA-Llama-3-8B-v1.1](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1) 发布!快速开始请查阅此[文档](xtuner/configs/llava/llama3_8b_instruct_clip_vit_large_p14_336)
- **\[2024/04\]** 支持 [Llama 3](xtuner/configs/llama) 模型!
Expand Down Expand Up @@ -100,16 +102,15 @@ XTuner 是一个高效、灵活、全能的轻量化大模型微调工具库。
<tr valign="top">
<td align="left" valign="top">
<ul>
<li><a href="https://huggingface.co/internlm">InternLM2</a></li>
<li><a href="https://huggingface.co/meta-llama">Llama 3</a></li>
<li><a href="https://huggingface.co/meta-llama">Llama 2</a></li>
<li><a href="https://huggingface.co/internlm">InternLM 2 / 2.5</a></li>
<li><a href="https://huggingface.co/meta-llama">Llama 2 / 3</a></li>
<li><a href="https://huggingface.co/collections/microsoft/phi-3-6626e15e9585a200d2d761e3">Phi-3</a></li>
<li><a href="https://huggingface.co/THUDM/chatglm2-6b">ChatGLM2</a></li>
<li><a href="https://huggingface.co/THUDM/chatglm3-6b">ChatGLM3</a></li>
<li><a href="https://huggingface.co/Qwen/Qwen-7B">Qwen</a></li>
<li><a href="https://huggingface.co/baichuan-inc/Baichuan2-7B-Base">Baichuan2</a></li>
<li><a href="https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1">Mixtral 8x7B</a></li>
<li><a href="https://huggingface.co/deepseek-ai/deepseek-moe-16b-chat">DeepSeek MoE</a></li>
<li><a href="https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1">Mixtral</a></li>
<li><a href="https://huggingface.co/deepseek-ai/DeepSeek-V2-Chat">DeepSeek V2</a></li>
<li><a href="https://huggingface.co/google">Gemma</a></li>
<li>...</li>
</ul>
Expand Down Expand Up @@ -203,14 +204,14 @@ XTuner 支持微调大语言模型。数据集预处理指南请查阅[文档](.
xtuner train ${CONFIG_NAME_OR_PATH}
```

例如,我们可以利用 QLoRA 算法在 oasst1 数据集上微调 InternLM2-Chat-7B:
例如,我们可以利用 QLoRA 算法在 oasst1 数据集上微调 InternLM2.5-Chat-7B:

```shell
# 单卡
xtuner train internlm2_chat_7b_qlora_oasst1_e3 --deepspeed deepspeed_zero2
xtuner train internlm2_5_chat_7b_qlora_oasst1_e3 --deepspeed deepspeed_zero2
# 多卡
(DIST) NPROC_PER_NODE=${GPU_NUM} xtuner train internlm2_chat_7b_qlora_oasst1_e3 --deepspeed deepspeed_zero2
(SLURM) srun ${SRUN_ARGS} xtuner train internlm2_chat_7b_qlora_oasst1_e3 --launcher slurm --deepspeed deepspeed_zero2
(DIST) NPROC_PER_NODE=${GPU_NUM} xtuner train internlm2_5_chat_7b_qlora_oasst1_e3 --deepspeed deepspeed_zero2
(SLURM) srun ${SRUN_ARGS} xtuner train internlm2_5_chat_7b_qlora_oasst1_e3 --launcher slurm --deepspeed deepspeed_zero2
```

- `--deepspeed` 表示使用 [DeepSpeed](https://github.com/microsoft/DeepSpeed) 🚀 来优化训练过程。XTuner 内置了多种策略,包括 ZeRO-1、ZeRO-2、ZeRO-3 等。如果用户期望关闭此功能,请直接移除此参数。
Expand All @@ -233,16 +234,10 @@ xtuner chat ${NAME_OR_PATH_TO_LLM} --adapter {NAME_OR_PATH_TO_ADAPTER} [optional

例如:

与 InternLM2-Chat-7B, oasst1 adapter 对话:
与 InternLM2.5-Chat-7B 对话:

```shell
xtuner chat internlm/internlm2-chat-7b --adapter xtuner/internlm2-chat-7b-qlora-oasst1 --prompt-template internlm2_chat
```

与 LLaVA-InternLM2-7B 对话:

```shell
xtuner chat internlm/internlm2-chat-7b --visual-encoder openai/clip-vit-large-patch14-336 --llava xtuner/llava-internlm2-7b --prompt-template internlm2_chat --image $IMAGE_PATH
xtuner chat internlm/internlm2-chat-7b --prompt-template internlm2_chat
```

更多示例,请查阅[文档](./docs/zh_cn/user_guides/chat.md)。
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
# Copyright (c) OpenMMLab. All rights reserved.
"""Data format:
[
{
"conversation": [
{
"system": "",
"input": "xxx",
"output": "xxx"
},
{
"input": "xxx",
"output": "xxx"
}
]
},
...
]
Please refer to https://github.com/InternLM/xtuner/blob/main/docs/en/user_guides/dataset_format.md for details.
""" # noqa: E501
from datasets import load_dataset
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
LoggerHook, ParamSchedulerHook)
from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR
from torch.optim import AdamW
from torch.utils.data import BatchSampler
from transformers import AutoModelForCausalLM, AutoTokenizer

from xtuner.dataset import process_hf_dataset
from xtuner.dataset.collate_fns import default_collate_fn
from xtuner.dataset.map_fns import template_map_fn_factory
from xtuner.dataset.samplers import InternRepoSampler
from xtuner.engine import (DatasetInfoHook, EvaluateChatHook, ThroughputHook,
VarlenAttnArgsToMessageHubHook)
from xtuner.engine.runner import TrainLoop
from xtuner.model import SupervisedFinetune
from xtuner.utils import PROMPT_TEMPLATE

#######################################################################
# PART 1 Settings #
#######################################################################
# Model
pretrained_model_name_or_path = 'internlm/internlm2_5-7b-chat'
use_varlen_attn = True

# Data
data_files = ['/path/to/json/file.json']
prompt_template = PROMPT_TEMPLATE.internlm2_chat
max_length = 32768
pack_to_max_length = True

# parallel
sequence_parallel_size = 1

# Scheduler & Optimizer
# batch size per device, set to 1 if `use_varlen_attn` = True
# To clarify, enlarging the batch size essentially enlarges the `max_length`.
# For example, doubling the max length is tantamount to doubling the batch size
batch_size = 1
accumulative_counts = 1 # 1bs * 1acc * 64gpu = 64 batchsize
accumulative_counts *= sequence_parallel_size
dataloader_num_workers = 4
max_epochs = 1
optim_type = AdamW
lr = 4e-5
betas = (0.9, 0.95)
weight_decay = 0.01
max_norm = 1 # grad clip
warm_up_ratio = 0.025

# Save
save_steps = 500
save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited)

# Evaluate the generation performance during the training
evaluation_freq = 500
SYSTEM = ''
evaluation_inputs = [
'请给我介绍五个上海的景点', 'Please tell me five scenic spots in Shanghai'
]

#######################################################################
# PART 2 Model & Tokenizer #
#######################################################################
tokenizer = dict(
type=AutoTokenizer.from_pretrained,
pretrained_model_name_or_path=pretrained_model_name_or_path,
trust_remote_code=True,
padding_side='right')

model = dict(
type=SupervisedFinetune,
use_varlen_attn=use_varlen_attn,
llm=dict(
type=AutoModelForCausalLM.from_pretrained,
pretrained_model_name_or_path=pretrained_model_name_or_path,
trust_remote_code=True))

#######################################################################
# PART 3 Dataset & Dataloader #
#######################################################################
train_dataset = dict(
type=process_hf_dataset,
use_varlen_attn=use_varlen_attn,
dataset=dict(type=load_dataset, path='json', data_files=data_files),
tokenizer=tokenizer,
max_length=max_length,
dataset_map_fn=None,
template_map_fn=dict(
type=template_map_fn_factory, template=prompt_template),
remove_unused_columns=True,
shuffle_before_pack=True,
pack_to_max_length=pack_to_max_length)

train_dataloader = dict(
batch_size=batch_size,
num_workers=dataloader_num_workers,
dataset=train_dataset,
sampler=dict(type=InternRepoSampler, shuffle=True, seed=1024),
batch_sampler=dict(
type=BatchSampler, drop_last=True, batch_size=batch_size),
collate_fn=dict(type=default_collate_fn, use_varlen_attn=use_varlen_attn))

#######################################################################
# PART 4 Scheduler & Optimizer #
#######################################################################
# optimizer
optim_wrapper = dict(
type=AmpOptimWrapper,
optimizer=dict(
type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
accumulative_counts=accumulative_counts,
loss_scale='dynamic',
)

# learning policy
# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
param_scheduler = [
dict(
type='LinearLR',
start_factor=1 / 40,
by_epoch=True,
begin=0,
end=warm_up_ratio * max_epochs,
convert_to_iter_based=True),
dict(
type=CosineAnnealingLR,
eta_min=lr * 0.15,
by_epoch=True,
begin=warm_up_ratio * max_epochs,
end=max_epochs,
convert_to_iter_based=True)
]

# train, val, test setting
train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)

#######################################################################
# PART 5 Runtime #
#######################################################################
# Log the dialogue periodically during the training process, optional
custom_hooks = [
dict(
type=DatasetInfoHook, tokenizer=tokenizer,
is_intern_repo_dataset=True),
dict(
type=EvaluateChatHook,
tokenizer=tokenizer,
every_n_iters=evaluation_freq,
evaluation_inputs=evaluation_inputs,
system=SYSTEM,
prompt_template=prompt_template),
dict(type=ThroughputHook)
]

if use_varlen_attn:
custom_hooks += [dict(type=VarlenAttnArgsToMessageHubHook)]

# configure default hooks
default_hooks = dict(
# record the time of every iteration.
timer=dict(type=IterTimerHook),
# print log every 100 iterations.
logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=1),
# enable the parameter scheduler.
param_scheduler=dict(type=ParamSchedulerHook),
# save checkpoint per `save_steps`.
checkpoint=dict(
type=CheckpointHook,
by_epoch=False,
interval=save_steps,
max_keep_ckpts=save_total_limit),
# set sampler seed in distributed evrionment.
sampler_seed=dict(type=DistSamplerSeedHook),
)

# configure environment
env_cfg = dict(
# whether to enable cudnn benchmark
cudnn_benchmark=False,
# set multi process parameters
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
# set distributed parameters
dist_cfg=dict(backend='nccl'),
)

# set visualizer
visualizer = None

# set log level
log_level = 'INFO'

# load from which checkpoint
load_from = None

# whether to resume training from the loaded checkpoint
resume = False

# Defaults to use random seed and disable `deterministic`
randomness = dict(seed=None, deterministic=False)

log_processor = dict(
by_epoch=False,
window_size=1,
mean_pattern=r'.*(loss|time|data_time|grad_norm|tflops).*')
Loading

0 comments on commit 44749c2

Please sign in to comment.