Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add rescale sp loss #917

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 5 additions & 9 deletions docs/zh_cn/acceleration/train_extreme_long_sequence.rst
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ XTuner 中的序列并行设计思路参考了 DeepSpeed 的工作 `DeepSpeed Ul
- 适配序列并行的 Data Sampler (SequenceParallelSampler)
- 数据 Pad 与切分 (pad_for_sequence_parallel, split_for_sequence_parallel)
- 适配序列并行的 Attention (dispatch_modules)
- reduce loss 以正确打印训练损失 (reduce_sequence_parallel_loss)
- rescale loss 使得在使用序列并行时 backward 梯度与数据并行 (DP) 保持一致 (rescale_sp_loss)

分布式环境初始化
-------------------
Expand Down Expand Up @@ -303,20 +303,16 @@ XTuner 提供了 dispatch_modules 接口以支持修改模型 Attention 的计
.. tip::
上述过程在 ``xtuner/model/sft.py`` 中实现。

Reduce Loss
Rescale Loss
-------------

这个 API 对于保证训练的正确性不是必须的,但对于观测模型训练状态,打印训练 loss 是非常有用的
由于不同的 sp rank 上计算 loss 的 tokens 数量各不相同,因此在数据并行 (DP) 梯度同步过程中,简单的不同 rank 的梯度取平均对于序列并行 (SP) 是不合理的。XTuner 提供 `rescale_sp_loss` API 来确保序列并行场景与数据并行场景的参数梯度保持一致

.. code-block:: python

from xtuner.parallel.sequence import reduce_sequence_parallel_loss
from xtuner.parallel.sequence import rescale_sp_loss, get_sequence_parallel_group
outputs = llm(input_ids=input_ids, labels=labels, **kwargs)
num_tokens_per_rank = (labels != -100).sum()
# Suppose sequence parallel world size equals to 4,
# losses on rank0, rank1, rank2, rank3 are different.
loss = reduce_sequence_parallel_loss(outputs.loss, num_tokens_per_rank)
# After loss reduction, losses on rank0, rank1, rank2, rank3 are the same.
rescaled_loss = rescale_sp_loss(outputs.loss, labels, sp_group)

.. tip::
上述过程在 ``xtuner/model/sft.py`` 中实现。
12 changes: 4 additions & 8 deletions docs/zh_cn/user_guides/sequence_parallel.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ model = dict(
- 数据 Pad (pad_for_sequence_parallel)
- 数据切分 (split_for_sequence_parallel)
- 适配序列并行的 Attention (dispatch_modules)
- reduce loss 以正确打印训练损失 (reduce_sequence_parallel_loss)
- rescale loss 使得在使用序列并行时 backward 梯度与数据并行 (DP) 保持一致 (rescale_sp_loss)

### 序列并行分布式环境初始化

Expand Down Expand Up @@ -176,16 +176,12 @@ dispatch_modules(model)

### Reduce Loss 以正确打印训练损失

这个 API 对于保证训练的正确性不是必须的,但对于观测模型训练状态,打印训练 loss 是非常有用的
由于不同的 sp rank 上计算 loss 的 tokens 数量各不相同,因此在数据并行 (DP) 梯度同步过程中,简单的不同 rank 的梯度取平均对于序列并行 (SP) 是不合理的。XTuner 提供 `rescale_sp_loss` API 来确保序列并行场景与数据并行场景的参数梯度保持一致

```python
from xtuner.parallel.sequence import reduce_sequence_parallel_loss
from xtuner.parallel.sequence import rescale_sp_loss, get_sequence_parallel_group
outputs = llm(input_ids=input_ids, labels=labels, **kwargs)
num_tokens_per_rank = (labels != -100).sum()
# Suppose sequence parallel world size equals to 4,
# losses on rank0, rank1, rank2, rank3 are different.
loss = reduce_sequence_parallel_loss(outputs.loss, num_tokens_per_rank)
# After loss reduction, losses on rank0, rank1, rank2, rank3 are the same.
rescaled_loss = rescale_sp_loss(outputs.loss, labels, sp_group)
```

上述过程在 xtuner/model/sft.py 中实现。
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from datasets import load_dataset
from mmengine.dataset import DefaultSampler
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
LoggerHook, ParamSchedulerHook)
from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
from peft import LoraConfig
from torch.optim import AdamW
from transformers import (AutoModelForCausalLM, AutoTokenizer,
BitsAndBytesConfig)
from transformers import AutoModelForCausalLM, AutoTokenizer

from xtuner.dataset import process_hf_dataset
from xtuner.dataset.collate_fns import default_collate_fn
Expand Down
21 changes: 14 additions & 7 deletions xtuner/model/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@

from xtuner.parallel.sequence import (get_sequence_parallel_group,
get_sequence_parallel_world_size,
reduce_sequence_parallel_loss,
reduce_sp_loss_for_debug,
rescale_sp_loss,
split_for_sequence_parallel)
from xtuner.registry import BUILDER
from .modules import dispatch_modules
Expand Down Expand Up @@ -79,7 +80,6 @@ def __init__(self,
tokenizer=None,
max_position_embeddings=None):
super().__init__()

self.llm = self.build_llm_from_cfg(llm, use_varlen_attn,
max_position_embeddings)

Expand All @@ -88,7 +88,6 @@ def __init__(self,
tokenizer = BUILDER.build(tokenizer)
smart_tokenizer_and_embedding_resize(tokenizer, self.llm)

self.llm.config.use_cache = False
if use_activation_checkpointing:
# For backward compatibility
if hasattr(self.llm, 'enable_input_require_grads'):
Expand Down Expand Up @@ -116,6 +115,8 @@ def __init__(self,
# the sequence.
self.use_varlen_attn = use_varlen_attn

self.debug_sp = False

def build_llm_from_cfg(self, llm_cfg, use_varlen_attn,
max_position_embeddings):
# For forward
Expand Down Expand Up @@ -288,11 +289,17 @@ def _compute_sequence_parallel_loss(self, data):
data = self._split_for_sequence_parallel(data)
outputs = self.llm(**data)
labels = data['labels']
num_tokens = (labels != -100).sum()

sp_group = get_sequence_parallel_group()
loss = reduce_sequence_parallel_loss(outputs.loss, num_tokens,
sp_group)
return {'loss': loss}
loss = rescale_sp_loss(outputs.loss, labels, sp_group)
output = {'loss': loss}
if self.debug_sp:
reduced_loss = reduce_sp_loss_for_debug(outputs.loss, labels,
sp_group)
# string `loss` can not be a part of the key in output dict
# https://github.com/open-mmlab/mmengine/blob/main/mmengine/model/base_model/base_model.py#L174 # noqa: E501
output['reduced_l'] = reduced_loss
return output

def compute_loss(self, data, data_samples=None):
if get_sequence_parallel_world_size() > 1:
Expand Down
7 changes: 4 additions & 3 deletions xtuner/parallel/sequence/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
split_forward_gather_backward)
from .data_collate import (pad_cumulative_len_for_sequence_parallel,
pad_for_sequence_parallel)
from .reduce_loss import reduce_sequence_parallel_loss
from .reduce_loss import reduce_sp_loss_for_debug, rescale_sp_loss
from .sampler import SequenceParallelSampler
from .setup_distributed import (get_data_parallel_group,
get_data_parallel_rank,
Expand All @@ -31,11 +31,12 @@
'init_sequence_parallel', 'get_sequence_parallel_group',
'get_sequence_parallel_world_size', 'get_sequence_parallel_rank',
'get_data_parallel_group', 'get_data_parallel_world_size',
'get_data_parallel_rank', 'reduce_sequence_parallel_loss', 'init_dist',
'get_data_parallel_rank', 'init_dist',
'all_to_all', 'gather_for_sequence_parallel',
'split_forward_gather_backward', 'gather_forward_split_backward',
'get_inner_sequence_parallel_group', 'get_inner_sequence_parallel_rank',
'get_inner_sequence_parallel_world_size', 'init_inner_sequence_parallel',
'is_inner_sequence_parallel_initialized',
'pad_cumulative_len_for_sequence_parallel'
'pad_cumulative_len_for_sequence_parallel', 'rescale_sp_loss',
'reduce_sp_loss_for_debug'
]
65 changes: 43 additions & 22 deletions xtuner/parallel/sequence/reduce_loss.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,55 @@
import copy

import torch
import torch.distributed as dist

from .setup_distributed import get_sequence_parallel_group


class _ReduceLoss(torch.autograd.Function):
def rescale_sp_loss(loss_per_sp_rank,
labels_per_sp_rank,
sp_group: dist.ProcessGroup = None,
ignore_index=-100):
if sp_group is None:
sp_group = get_sequence_parallel_group()

if (sp_group is None) or (dist.get_world_size(sp_group) == 1):
return loss_per_sp_rank

shift_labels = labels_per_sp_rank[..., 1:].view(-1)
active_tokens = (shift_labels != ignore_index).long().sum()
global_active_tokens = copy.deepcopy(active_tokens)
dist.all_reduce(global_active_tokens, group=sp_group)
loss_weight = active_tokens / global_active_tokens * dist.get_world_size(
group=sp_group)

@staticmethod
def forward(ctx, mean_loss, loss_scale, process_group):
ctx.mode = process_group
if loss_scale == 0:
# convert nan to 0 just for logging
mean_loss = torch.nan_to_num(mean_loss)
loss_sum = mean_loss * loss_scale
dist.all_reduce(loss_sum, group=process_group)
dist.all_reduce(loss_scale, group=process_group)
loss = loss_sum / loss_scale
return loss
if active_tokens == 0:
# convert nan to 0 just for logging
loss_per_sp_rank = torch.nan_to_num(loss_per_sp_rank)

@staticmethod
def backward(ctx, grad_output):
return grad_output, None, None
return loss_per_sp_rank * loss_weight


def reduce_sequence_parallel_loss(mean_loss,
loss_scale,
sp_group: dist.ProcessGroup = None):
if dist.get_world_size(sp_group) == 1:
return mean_loss
def reduce_sp_loss_for_debug(loss_per_sp_rank,
labels_per_sp_rank,
sp_group: dist.ProcessGroup = None,
ignore_index=-100):
# Reduce loss to check whether the training losses is different
# when using sp. This function is only used for debugging
if sp_group is None:
# avoid bc breaking
sp_group = get_sequence_parallel_group()
return _ReduceLoss.apply(mean_loss, loss_scale, sp_group)

if (sp_group is None) or (dist.get_world_size(sp_group) == 1):
return loss_per_sp_rank

shift_labels = labels_per_sp_rank[..., 1:].view(-1)
active_tokens = (shift_labels != ignore_index).long().sum()
if active_tokens == 0:
# convert nan to 0 just for logging
loss_per_sp_rank = torch.nan_to_num(loss_per_sp_rank)

loss_sum = loss_per_sp_rank * active_tokens
global_active_tokens = copy.deepcopy(active_tokens)
dist.all_reduce(loss_sum, group=sp_group)
dist.all_reduce(global_active_tokens, group=sp_group)
return loss_sum / global_active_tokens
Loading