diff --git a/README.md b/README.md
index e0be695ef..4e729226c 100644
--- a/README.md
+++ b/README.md
@@ -39,6 +39,8 @@ English | [简体中文](README_zh-CN.md)
## 🎉 News
+- **\[2024/07\]** Support [InternLM 2.5](xtuner/configs/internlm/internlm2_5_chat_7b/) models!
+- **\[2024/06\]** Support [DeepSeek V2](xtuner/configs/deepseek/deepseek_v2_chat/) models! **2x faster!**
- **\[2024/04\]** [LLaVA-Phi-3-mini](https://huggingface.co/xtuner/llava-phi-3-mini-hf) is released! Click [here](xtuner/configs/llava/phi3_mini_4k_instruct_clip_vit_large_p14_336) for details!
- **\[2024/04\]** [LLaVA-Llama-3-8B](https://huggingface.co/xtuner/llava-llama-3-8b) and [LLaVA-Llama-3-8B-v1.1](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1) are released! Click [here](xtuner/configs/llava/llama3_8b_instruct_clip_vit_large_p14_336) for details!
- **\[2024/04\]** Support [Llama 3](xtuner/configs/llama) models!
@@ -100,16 +102,15 @@ XTuner is an efficient, flexible and full-featured toolkit for fine-tuning large
@@ -203,14 +204,14 @@ XTuner 支持微调大语言模型。数据集预处理指南请查阅[文档](.
xtuner train ${CONFIG_NAME_OR_PATH}
```
- 例如,我们可以利用 QLoRA 算法在 oasst1 数据集上微调 InternLM2-Chat-7B:
+ 例如,我们可以利用 QLoRA 算法在 oasst1 数据集上微调 InternLM2.5-Chat-7B:
```shell
# 单卡
- xtuner train internlm2_chat_7b_qlora_oasst1_e3 --deepspeed deepspeed_zero2
+ xtuner train internlm2_5_chat_7b_qlora_oasst1_e3 --deepspeed deepspeed_zero2
# 多卡
- (DIST) NPROC_PER_NODE=${GPU_NUM} xtuner train internlm2_chat_7b_qlora_oasst1_e3 --deepspeed deepspeed_zero2
- (SLURM) srun ${SRUN_ARGS} xtuner train internlm2_chat_7b_qlora_oasst1_e3 --launcher slurm --deepspeed deepspeed_zero2
+ (DIST) NPROC_PER_NODE=${GPU_NUM} xtuner train internlm2_5_chat_7b_qlora_oasst1_e3 --deepspeed deepspeed_zero2
+ (SLURM) srun ${SRUN_ARGS} xtuner train internlm2_5_chat_7b_qlora_oasst1_e3 --launcher slurm --deepspeed deepspeed_zero2
```
- `--deepspeed` 表示使用 [DeepSpeed](https://github.com/microsoft/DeepSpeed) 🚀 来优化训练过程。XTuner 内置了多种策略,包括 ZeRO-1、ZeRO-2、ZeRO-3 等。如果用户期望关闭此功能,请直接移除此参数。
@@ -233,16 +234,10 @@ xtuner chat ${NAME_OR_PATH_TO_LLM} --adapter {NAME_OR_PATH_TO_ADAPTER} [optional
例如:
-与 InternLM2-Chat-7B, oasst1 adapter 对话:
+与 InternLM2.5-Chat-7B 对话:
```shell
-xtuner chat internlm/internlm2-chat-7b --adapter xtuner/internlm2-chat-7b-qlora-oasst1 --prompt-template internlm2_chat
-```
-
-与 LLaVA-InternLM2-7B 对话:
-
-```shell
-xtuner chat internlm/internlm2-chat-7b --visual-encoder openai/clip-vit-large-patch14-336 --llava xtuner/llava-internlm2-7b --prompt-template internlm2_chat --image $IMAGE_PATH
+xtuner chat internlm/internlm2-chat-7b --prompt-template internlm2_chat
```
更多示例,请查阅[文档](./docs/zh_cn/user_guides/chat.md)。
diff --git a/xtuner/configs/internlm/internlm2_5_chat_7b/internlm2_5_chat_7b_full_finetune_custom_dataset_e1.py b/xtuner/configs/internlm/internlm2_5_chat_7b/internlm2_5_chat_7b_full_finetune_custom_dataset_e1.py
new file mode 100644
index 000000000..bc8a2816a
--- /dev/null
+++ b/xtuner/configs/internlm/internlm2_5_chat_7b/internlm2_5_chat_7b_full_finetune_custom_dataset_e1.py
@@ -0,0 +1,226 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+"""Data format:
+[
+ {
+ "conversation": [
+ {
+ "system": "",
+ "input": "xxx",
+ "output": "xxx"
+ },
+ {
+ "input": "xxx",
+ "output": "xxx"
+ }
+ ]
+ },
+...
+]
+Please refer to https://github.com/InternLM/xtuner/blob/main/docs/en/user_guides/dataset_format.md for details.
+""" # noqa: E501
+from datasets import load_dataset
+from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
+ LoggerHook, ParamSchedulerHook)
+from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR
+from torch.optim import AdamW
+from torch.utils.data import BatchSampler
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+from xtuner.dataset import process_hf_dataset
+from xtuner.dataset.collate_fns import default_collate_fn
+from xtuner.dataset.map_fns import template_map_fn_factory
+from xtuner.dataset.samplers import InternRepoSampler
+from xtuner.engine import (DatasetInfoHook, EvaluateChatHook, ThroughputHook,
+ VarlenAttnArgsToMessageHubHook)
+from xtuner.engine.runner import TrainLoop
+from xtuner.model import SupervisedFinetune
+from xtuner.utils import PROMPT_TEMPLATE
+
+#######################################################################
+# PART 1 Settings #
+#######################################################################
+# Model
+pretrained_model_name_or_path = 'internlm/internlm2_5-7b-chat'
+use_varlen_attn = True
+
+# Data
+data_files = ['/path/to/json/file.json']
+prompt_template = PROMPT_TEMPLATE.internlm2_chat
+max_length = 32768
+pack_to_max_length = True
+
+# parallel
+sequence_parallel_size = 1
+
+# Scheduler & Optimizer
+# batch size per device, set to 1 if `use_varlen_attn` = True
+# To clarify, enlarging the batch size essentially enlarges the `max_length`.
+# For example, doubling the max length is tantamount to doubling the batch size
+batch_size = 1
+accumulative_counts = 1 # 1bs * 1acc * 64gpu = 64 batchsize
+accumulative_counts *= sequence_parallel_size
+dataloader_num_workers = 4
+max_epochs = 1
+optim_type = AdamW
+lr = 4e-5
+betas = (0.9, 0.95)
+weight_decay = 0.01
+max_norm = 1 # grad clip
+warm_up_ratio = 0.025
+
+# Save
+save_steps = 500
+save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited)
+
+# Evaluate the generation performance during the training
+evaluation_freq = 500
+SYSTEM = ''
+evaluation_inputs = [
+ '请给我介绍五个上海的景点', 'Please tell me five scenic spots in Shanghai'
+]
+
+#######################################################################
+# PART 2 Model & Tokenizer #
+#######################################################################
+tokenizer = dict(
+ type=AutoTokenizer.from_pretrained,
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
+ trust_remote_code=True,
+ padding_side='right')
+
+model = dict(
+ type=SupervisedFinetune,
+ use_varlen_attn=use_varlen_attn,
+ llm=dict(
+ type=AutoModelForCausalLM.from_pretrained,
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
+ trust_remote_code=True))
+
+#######################################################################
+# PART 3 Dataset & Dataloader #
+#######################################################################
+train_dataset = dict(
+ type=process_hf_dataset,
+ use_varlen_attn=use_varlen_attn,
+ dataset=dict(type=load_dataset, path='json', data_files=data_files),
+ tokenizer=tokenizer,
+ max_length=max_length,
+ dataset_map_fn=None,
+ template_map_fn=dict(
+ type=template_map_fn_factory, template=prompt_template),
+ remove_unused_columns=True,
+ shuffle_before_pack=True,
+ pack_to_max_length=pack_to_max_length)
+
+train_dataloader = dict(
+ batch_size=batch_size,
+ num_workers=dataloader_num_workers,
+ dataset=train_dataset,
+ sampler=dict(type=InternRepoSampler, shuffle=True, seed=1024),
+ batch_sampler=dict(
+ type=BatchSampler, drop_last=True, batch_size=batch_size),
+ collate_fn=dict(type=default_collate_fn, use_varlen_attn=use_varlen_attn))
+
+#######################################################################
+# PART 4 Scheduler & Optimizer #
+#######################################################################
+# optimizer
+optim_wrapper = dict(
+ type=AmpOptimWrapper,
+ optimizer=dict(
+ type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
+ clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
+ accumulative_counts=accumulative_counts,
+ loss_scale='dynamic',
+)
+
+# learning policy
+# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
+param_scheduler = [
+ dict(
+ type='LinearLR',
+ start_factor=1 / 40,
+ by_epoch=True,
+ begin=0,
+ end=warm_up_ratio * max_epochs,
+ convert_to_iter_based=True),
+ dict(
+ type=CosineAnnealingLR,
+ eta_min=lr * 0.15,
+ by_epoch=True,
+ begin=warm_up_ratio * max_epochs,
+ end=max_epochs,
+ convert_to_iter_based=True)
+]
+
+# train, val, test setting
+train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
+
+#######################################################################
+# PART 5 Runtime #
+#######################################################################
+# Log the dialogue periodically during the training process, optional
+custom_hooks = [
+ dict(
+ type=DatasetInfoHook, tokenizer=tokenizer,
+ is_intern_repo_dataset=True),
+ dict(
+ type=EvaluateChatHook,
+ tokenizer=tokenizer,
+ every_n_iters=evaluation_freq,
+ evaluation_inputs=evaluation_inputs,
+ system=SYSTEM,
+ prompt_template=prompt_template),
+ dict(type=ThroughputHook)
+]
+
+if use_varlen_attn:
+ custom_hooks += [dict(type=VarlenAttnArgsToMessageHubHook)]
+
+# configure default hooks
+default_hooks = dict(
+ # record the time of every iteration.
+ timer=dict(type=IterTimerHook),
+ # print log every 100 iterations.
+ logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=1),
+ # enable the parameter scheduler.
+ param_scheduler=dict(type=ParamSchedulerHook),
+ # save checkpoint per `save_steps`.
+ checkpoint=dict(
+ type=CheckpointHook,
+ by_epoch=False,
+ interval=save_steps,
+ max_keep_ckpts=save_total_limit),
+ # set sampler seed in distributed evrionment.
+ sampler_seed=dict(type=DistSamplerSeedHook),
+)
+
+# configure environment
+env_cfg = dict(
+ # whether to enable cudnn benchmark
+ cudnn_benchmark=False,
+ # set multi process parameters
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
+ # set distributed parameters
+ dist_cfg=dict(backend='nccl'),
+)
+
+# set visualizer
+visualizer = None
+
+# set log level
+log_level = 'INFO'
+
+# load from which checkpoint
+load_from = None
+
+# whether to resume training from the loaded checkpoint
+resume = False
+
+# Defaults to use random seed and disable `deterministic`
+randomness = dict(seed=None, deterministic=False)
+
+log_processor = dict(
+ by_epoch=False,
+ window_size=1,
+ mean_pattern=r'.*(loss|time|data_time|grad_norm|tflops).*')
diff --git a/xtuner/configs/internlm/internlm2_5_chat_7b/internlm2_5_chat_7b_qlora_alpaca_e3.py b/xtuner/configs/internlm/internlm2_5_chat_7b/internlm2_5_chat_7b_qlora_alpaca_e3.py
new file mode 100644
index 000000000..7dfc92617
--- /dev/null
+++ b/xtuner/configs/internlm/internlm2_5_chat_7b/internlm2_5_chat_7b_qlora_alpaca_e3.py
@@ -0,0 +1,219 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from datasets import load_dataset
+from mmengine.dataset import DefaultSampler
+from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
+ LoggerHook, ParamSchedulerHook)
+from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
+from peft import LoraConfig
+from torch.optim import AdamW
+from transformers import (AutoModelForCausalLM, AutoTokenizer,
+ BitsAndBytesConfig)
+
+from xtuner.dataset import process_hf_dataset
+from xtuner.dataset.collate_fns import default_collate_fn
+from xtuner.dataset.map_fns import alpaca_map_fn, template_map_fn_factory
+from xtuner.engine.hooks import (DatasetInfoHook, EvaluateChatHook,
+ VarlenAttnArgsToMessageHubHook)
+from xtuner.engine.runner import TrainLoop
+from xtuner.model import SupervisedFinetune
+from xtuner.parallel.sequence import SequenceParallelSampler
+from xtuner.utils import PROMPT_TEMPLATE, SYSTEM_TEMPLATE
+
+#######################################################################
+# PART 1 Settings #
+#######################################################################
+# Model
+pretrained_model_name_or_path = 'internlm/internlm2_5-7b-chat'
+use_varlen_attn = False
+
+# Data
+alpaca_en_path = 'tatsu-lab/alpaca'
+prompt_template = PROMPT_TEMPLATE.internlm2_chat
+max_length = 2048
+pack_to_max_length = True
+
+# parallel
+sequence_parallel_size = 1
+
+# Scheduler & Optimizer
+batch_size = 1 # per_device
+accumulative_counts = 1
+accumulative_counts *= sequence_parallel_size
+dataloader_num_workers = 0
+max_epochs = 3
+optim_type = AdamW
+lr = 2e-4
+betas = (0.9, 0.999)
+weight_decay = 0
+max_norm = 1 # grad clip
+warmup_ratio = 0.03
+
+# Save
+save_steps = 500
+save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited)
+
+# Evaluate the generation performance during the training
+evaluation_freq = 500
+SYSTEM = SYSTEM_TEMPLATE.alpaca
+evaluation_inputs = [
+ '请给我介绍五个上海的景点', 'Please tell me five scenic spots in Shanghai'
+]
+
+#######################################################################
+# PART 2 Model & Tokenizer #
+#######################################################################
+tokenizer = dict(
+ type=AutoTokenizer.from_pretrained,
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
+ trust_remote_code=True,
+ padding_side='right')
+
+model = dict(
+ type=SupervisedFinetune,
+ use_varlen_attn=use_varlen_attn,
+ llm=dict(
+ type=AutoModelForCausalLM.from_pretrained,
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
+ trust_remote_code=True,
+ torch_dtype=torch.float16,
+ quantization_config=dict(
+ type=BitsAndBytesConfig,
+ load_in_4bit=True,
+ load_in_8bit=False,
+ llm_int8_threshold=6.0,
+ llm_int8_has_fp16_weight=False,
+ bnb_4bit_compute_dtype=torch.float16,
+ bnb_4bit_use_double_quant=True,
+ bnb_4bit_quant_type='nf4')),
+ lora=dict(
+ type=LoraConfig,
+ r=64,
+ lora_alpha=16,
+ lora_dropout=0.1,
+ bias='none',
+ task_type='CAUSAL_LM'))
+
+#######################################################################
+# PART 3 Dataset & Dataloader #
+#######################################################################
+alpaca_en = dict(
+ type=process_hf_dataset,
+ dataset=dict(type=load_dataset, path=alpaca_en_path),
+ tokenizer=tokenizer,
+ max_length=max_length,
+ dataset_map_fn=alpaca_map_fn,
+ template_map_fn=dict(
+ type=template_map_fn_factory, template=prompt_template),
+ remove_unused_columns=True,
+ shuffle_before_pack=True,
+ pack_to_max_length=pack_to_max_length,
+ use_varlen_attn=use_varlen_attn)
+
+sampler = SequenceParallelSampler \
+ if sequence_parallel_size > 1 else DefaultSampler
+train_dataloader = dict(
+ batch_size=batch_size,
+ num_workers=dataloader_num_workers,
+ dataset=alpaca_en,
+ sampler=dict(type=sampler, shuffle=True),
+ collate_fn=dict(type=default_collate_fn, use_varlen_attn=use_varlen_attn))
+
+#######################################################################
+# PART 4 Scheduler & Optimizer #
+#######################################################################
+# optimizer
+optim_wrapper = dict(
+ type=AmpOptimWrapper,
+ optimizer=dict(
+ type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
+ clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
+ accumulative_counts=accumulative_counts,
+ loss_scale='dynamic',
+ dtype='float16')
+
+# learning policy
+# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
+param_scheduler = [
+ dict(
+ type=LinearLR,
+ start_factor=1e-5,
+ by_epoch=True,
+ begin=0,
+ end=warmup_ratio * max_epochs,
+ convert_to_iter_based=True),
+ dict(
+ type=CosineAnnealingLR,
+ eta_min=0.0,
+ by_epoch=True,
+ begin=warmup_ratio * max_epochs,
+ end=max_epochs,
+ convert_to_iter_based=True)
+]
+
+# train, val, test setting
+train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
+
+#######################################################################
+# PART 5 Runtime #
+#######################################################################
+# Log the dialogue periodically during the training process, optional
+custom_hooks = [
+ dict(type=DatasetInfoHook, tokenizer=tokenizer),
+ dict(
+ type=EvaluateChatHook,
+ tokenizer=tokenizer,
+ every_n_iters=evaluation_freq,
+ evaluation_inputs=evaluation_inputs,
+ system=SYSTEM,
+ prompt_template=prompt_template)
+]
+
+if use_varlen_attn:
+ custom_hooks += [dict(type=VarlenAttnArgsToMessageHubHook)]
+
+# configure default hooks
+default_hooks = dict(
+ # record the time of every iteration.
+ timer=dict(type=IterTimerHook),
+ # print log every 10 iterations.
+ logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
+ # enable the parameter scheduler.
+ param_scheduler=dict(type=ParamSchedulerHook),
+ # save checkpoint per `save_steps`.
+ checkpoint=dict(
+ type=CheckpointHook,
+ by_epoch=False,
+ interval=save_steps,
+ max_keep_ckpts=save_total_limit),
+ # set sampler seed in distributed evrionment.
+ sampler_seed=dict(type=DistSamplerSeedHook),
+)
+
+# configure environment
+env_cfg = dict(
+ # whether to enable cudnn benchmark
+ cudnn_benchmark=False,
+ # set multi process parameters
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
+ # set distributed parameters
+ dist_cfg=dict(backend='nccl'),
+)
+
+# set visualizer
+visualizer = None
+
+# set log level
+log_level = 'INFO'
+
+# load from which checkpoint
+load_from = None
+
+# whether to resume training from the loaded checkpoint
+resume = False
+
+# Defaults to use random seed and disable `deterministic`
+randomness = dict(seed=None, deterministic=False)
+
+# set log processor
+log_processor = dict(by_epoch=False)
diff --git a/xtuner/configs/internlm/internlm2_5_chat_7b/internlm2_5_chat_7b_qlora_oasst1_e3.py b/xtuner/configs/internlm/internlm2_5_chat_7b/internlm2_5_chat_7b_qlora_oasst1_e3.py
new file mode 100644
index 000000000..98b097efb
--- /dev/null
+++ b/xtuner/configs/internlm/internlm2_5_chat_7b/internlm2_5_chat_7b_qlora_oasst1_e3.py
@@ -0,0 +1,219 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from datasets import load_dataset
+from mmengine.dataset import DefaultSampler
+from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
+ LoggerHook, ParamSchedulerHook)
+from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
+from peft import LoraConfig
+from torch.optim import AdamW
+from transformers import (AutoModelForCausalLM, AutoTokenizer,
+ BitsAndBytesConfig)
+
+from xtuner.dataset import process_hf_dataset
+from xtuner.dataset.collate_fns import default_collate_fn
+from xtuner.dataset.map_fns import oasst1_map_fn, template_map_fn_factory
+from xtuner.engine.hooks import (DatasetInfoHook, EvaluateChatHook,
+ VarlenAttnArgsToMessageHubHook)
+from xtuner.engine.runner import TrainLoop
+from xtuner.model import SupervisedFinetune
+from xtuner.parallel.sequence import SequenceParallelSampler
+from xtuner.utils import PROMPT_TEMPLATE
+
+#######################################################################
+# PART 1 Settings #
+#######################################################################
+# Model
+pretrained_model_name_or_path = 'internlm/internlm2_5-7b-chat'
+use_varlen_attn = False
+
+# Data
+data_path = 'timdettmers/openassistant-guanaco'
+prompt_template = PROMPT_TEMPLATE.internlm2_chat
+max_length = 2048
+pack_to_max_length = True
+
+# parallel
+sequence_parallel_size = 1
+
+# Scheduler & Optimizer
+batch_size = 1 # per_device
+accumulative_counts = 16
+accumulative_counts *= sequence_parallel_size
+dataloader_num_workers = 0
+max_epochs = 3
+optim_type = AdamW
+lr = 2e-4
+betas = (0.9, 0.999)
+weight_decay = 0
+max_norm = 1 # grad clip
+warmup_ratio = 0.03
+
+# Save
+save_steps = 500
+save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited)
+
+# Evaluate the generation performance during the training
+evaluation_freq = 500
+SYSTEM = ''
+evaluation_inputs = [
+ '请给我介绍五个上海的景点', 'Please tell me five scenic spots in Shanghai'
+]
+
+#######################################################################
+# PART 2 Model & Tokenizer #
+#######################################################################
+tokenizer = dict(
+ type=AutoTokenizer.from_pretrained,
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
+ trust_remote_code=True,
+ padding_side='right')
+
+model = dict(
+ type=SupervisedFinetune,
+ use_varlen_attn=use_varlen_attn,
+ llm=dict(
+ type=AutoModelForCausalLM.from_pretrained,
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
+ trust_remote_code=True,
+ torch_dtype=torch.float16,
+ quantization_config=dict(
+ type=BitsAndBytesConfig,
+ load_in_4bit=True,
+ load_in_8bit=False,
+ llm_int8_threshold=6.0,
+ llm_int8_has_fp16_weight=False,
+ bnb_4bit_compute_dtype=torch.float16,
+ bnb_4bit_use_double_quant=True,
+ bnb_4bit_quant_type='nf4')),
+ lora=dict(
+ type=LoraConfig,
+ r=64,
+ lora_alpha=16,
+ lora_dropout=0.1,
+ bias='none',
+ task_type='CAUSAL_LM'))
+
+#######################################################################
+# PART 3 Dataset & Dataloader #
+#######################################################################
+train_dataset = dict(
+ type=process_hf_dataset,
+ dataset=dict(type=load_dataset, path=data_path),
+ tokenizer=tokenizer,
+ max_length=max_length,
+ dataset_map_fn=oasst1_map_fn,
+ template_map_fn=dict(
+ type=template_map_fn_factory, template=prompt_template),
+ remove_unused_columns=True,
+ shuffle_before_pack=True,
+ pack_to_max_length=pack_to_max_length,
+ use_varlen_attn=use_varlen_attn)
+
+sampler = SequenceParallelSampler \
+ if sequence_parallel_size > 1 else DefaultSampler
+train_dataloader = dict(
+ batch_size=batch_size,
+ num_workers=dataloader_num_workers,
+ dataset=train_dataset,
+ sampler=dict(type=sampler, shuffle=True),
+ collate_fn=dict(type=default_collate_fn, use_varlen_attn=use_varlen_attn))
+
+#######################################################################
+# PART 4 Scheduler & Optimizer #
+#######################################################################
+# optimizer
+optim_wrapper = dict(
+ type=AmpOptimWrapper,
+ optimizer=dict(
+ type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
+ clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
+ accumulative_counts=accumulative_counts,
+ loss_scale='dynamic',
+ dtype='float16')
+
+# learning policy
+# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
+param_scheduler = [
+ dict(
+ type=LinearLR,
+ start_factor=1e-5,
+ by_epoch=True,
+ begin=0,
+ end=warmup_ratio * max_epochs,
+ convert_to_iter_based=True),
+ dict(
+ type=CosineAnnealingLR,
+ eta_min=0.0,
+ by_epoch=True,
+ begin=warmup_ratio * max_epochs,
+ end=max_epochs,
+ convert_to_iter_based=True)
+]
+
+# train, val, test setting
+train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
+
+#######################################################################
+# PART 5 Runtime #
+#######################################################################
+# Log the dialogue periodically during the training process, optional
+custom_hooks = [
+ dict(type=DatasetInfoHook, tokenizer=tokenizer),
+ dict(
+ type=EvaluateChatHook,
+ tokenizer=tokenizer,
+ every_n_iters=evaluation_freq,
+ evaluation_inputs=evaluation_inputs,
+ system=SYSTEM,
+ prompt_template=prompt_template)
+]
+
+if use_varlen_attn:
+ custom_hooks += [dict(type=VarlenAttnArgsToMessageHubHook)]
+
+# configure default hooks
+default_hooks = dict(
+ # record the time of every iteration.
+ timer=dict(type=IterTimerHook),
+ # print log every 10 iterations.
+ logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
+ # enable the parameter scheduler.
+ param_scheduler=dict(type=ParamSchedulerHook),
+ # save checkpoint per `save_steps`.
+ checkpoint=dict(
+ type=CheckpointHook,
+ by_epoch=False,
+ interval=save_steps,
+ max_keep_ckpts=save_total_limit),
+ # set sampler seed in distributed evrionment.
+ sampler_seed=dict(type=DistSamplerSeedHook),
+)
+
+# configure environment
+env_cfg = dict(
+ # whether to enable cudnn benchmark
+ cudnn_benchmark=False,
+ # set multi process parameters
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
+ # set distributed parameters
+ dist_cfg=dict(backend='nccl'),
+)
+
+# set visualizer
+visualizer = None
+
+# set log level
+log_level = 'INFO'
+
+# load from which checkpoint
+load_from = None
+
+# whether to resume training from the loaded checkpoint
+resume = False
+
+# Defaults to use random seed and disable `deterministic`
+randomness = dict(seed=None, deterministic=False)
+
+# set log processor
+log_processor = dict(by_epoch=False)
|