Skip to content

Commit

Permalink
[Fix] Do not set attn_implementation to flash_attention_2 or sdpa if …
Browse files Browse the repository at this point in the history
…users already set it in XTuner configs. (#609)

* do not set attn_implementation to flash_attention_2 or sdpa if users already set it

* check cfg: If we want to use varlen attn or sequence parallel, we should set attn_implementation to flash_attention_2 or do not set this attribute.
  • Loading branch information
HIT-cwh authored Apr 25, 2024
1 parent fc4225a commit 60e0cc9
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 21 deletions.
14 changes: 11 additions & 3 deletions xtuner/model/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,17 @@ def _prepare_for_flash_attn(cfg, llm_cfg):
'Qwen2MoeConfig', 'Starcoder2Config',
'Starcoder2Config', 'Phi3Config')

if SUPPORT_FLASH2 and cls_name in SUPPORT_FLASH_ATTN2:
cfg.torch_dtype = torch.bfloat16 \
if torch.cuda.is_bf16_supported() else torch.float16
torch_dtype = torch.bfloat16 if (
torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \
else torch.float16

if getattr(cfg, 'attn_implementation', None) is not None:
# Flash Attention 2.0 only supports torch.float16 and
# torch.bfloat16 dtypes
if cfg.attn_implementation == 'flash_attention_2':
cfg.torch_dtype = torch_dtype
elif SUPPORT_FLASH2 and cls_name in SUPPORT_FLASH_ATTN2:
cfg.torch_dtype = torch_dtype
cfg.attn_implementation = 'flash_attention_2'
elif SUPPORT_FLASH1 and cls_name in SUPPORT_SDPA_ATTN:
cfg.attn_implementation = 'sdpa'
Expand Down
44 changes: 30 additions & 14 deletions xtuner/model/modules/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import transformers
from mmengine import print_log
from mmengine.utils import digit_version
from transformers.integrations import is_deepspeed_zero3_enabled

from .baichuan import (baichuan2_norm_head_forward, baichuan_7b_attn_forward,
baichuan_13b_attn_forward)
Expand Down Expand Up @@ -60,9 +61,10 @@ def dispatch_llama_attn_forward(model, use_varlen_attn):
print_log(NO_ATTN_WEIGHTS_MSG, 'current', logging.WARNING)
for module in model.modules():
# Do not need to dispatch if
# type(module).__name__ == 'LlamaSdpaAttention', as flash_attn is
# required when using sequence parallel
if type(module).__name__ in ('LlamaAttention', 'LlamaFlashAttention2'):
# type(module).__name__ in ('LlamaAttention', 'LlamaSdpaAttention').
# If we set `attn_implementation` to `sdpa` or `eager` in xtuner
# configs, we can not use varlen attn and sequence parallel.
if type(module).__name__ == 'LlamaFlashAttention2':
if use_varlen_attn:
print_log('dispatch llama varlen attn forward', 'current')
if IS_LOW_VERSION_TRANSFORMERS:
Expand Down Expand Up @@ -174,8 +176,11 @@ def dispatch_internlm2_attn_forward(model, use_varlen_attn):

print_log(NO_ATTN_WEIGHTS_MSG, 'current', logging.WARNING)
for module in model.modules():
if type(module).__name__ in ('InternLM2Attention',
'InternLM2FlashAttention2'):
# Do not need to dispatch if
# type(module).__name__ == 'InternLM2Attention'.
# If we set `attn_implementation` to `eager` in xtuner
# configs, we can not use varlen attn and sequence parallel.
if type(module).__name__ == 'InternLM2FlashAttention2':
if use_varlen_attn:
print_log('dispatch internlm2 varlen attn forward', 'current')
module.forward = types.MethodType(
Expand Down Expand Up @@ -308,9 +313,12 @@ def dispatch_mistral_attn_forward(model, use_varlen_attn):

print_log(NO_ATTN_WEIGHTS_MSG, 'current', logging.WARNING)
for module in model.modules():
if type(module).__name__ in ('MistralAttention',
'MistralFlashAttention2',
'MixtralAttention',
# Do not need to dispatch if
# type(module).__name__ in ('MistralAttention', 'MistralSdpaAttention',
# 'MixtralAttention', 'MixtralSdpaAttention')
# If we set `attn_implementation` to `sdpa` or `eager` in xtuner
# configs, we can not use varlen attn and sequence parallel.
if type(module).__name__ in ('MistralFlashAttention2',
'MixtralFlashAttention2'):
if use_varlen_attn:
print_log('dispatch mistral varlen attn forward', 'current')
Expand Down Expand Up @@ -373,10 +381,10 @@ def dispatch_cohere_attn_forward(model, use_varlen_attn):
print_log(NO_ATTN_WEIGHTS_MSG, 'current', logging.WARNING)
for module in model.modules():
# Do not need to dispatch if
# type(module).__name__ == 'CohereSdpaAttention', as flash_attn is
# required when using sequence parallel
if type(module).__name__ in ('CohereAttention',
'CohereFlashAttention2'):
# type(module).__name__ in ('CohereAttention', 'CohereSdpaAttention').
# If we set `attn_implementation` to `sdpa` or `eager` in xtuner
# configs, we can not use varlen attn and sequence parallel.
if type(module).__name__ == 'CohereFlashAttention2':
print_log('dispatch cohere attn forward', 'current')
module.forward = types.MethodType(cohere_attn_forward, module)

Expand All @@ -401,8 +409,12 @@ def dispatch_qwen2_attn_forward(model, use_varlen_attn):

print_log(NO_ATTN_WEIGHTS_MSG, 'current', logging.WARNING)
for module in model.modules():
if type(module).__name__ in ('Qwen2Attention', 'Qwen2FlashAttention2',
'Qwen2MoeAttention',
# Do not need to dispatch if
# type(module).__name__ in ('Qwen2Attention', 'Qwen2SdpaAttention',
# 'Qwen2MoeAttention', 'Qwen2MoeSdpaAttention')
# If we set `attn_implementation` to `sdpa` or `eager` in xtuner
# configs, we can not use varlen attn and sequence parallel.
if type(module).__name__ in ('Qwen2FlashAttention2',
'Qwen2MoeFlashAttention2'):
if use_varlen_attn:
print_log('dispatch qwen2 varlen attn forward', 'current')
Expand Down Expand Up @@ -467,6 +479,8 @@ def dispatch_modules(model, use_varlen_attn=False):
if USE_TRITON_KERNEL:
dispatch_mistral_rmsnorm_forward(model)
replace_mistral_rote(model)
if 'moe' in model_name and is_deepspeed_zero3_enabled():
set_mixtral_moe_blocks_z3_leaf_modules(model)
elif 'cohere' in model_name:
dispatch_cohere_attn_forward(model, use_varlen_attn)
dispatch_cohere_layernorm_forward(model)
Expand All @@ -475,6 +489,8 @@ def dispatch_modules(model, use_varlen_attn=False):
dispatch_qwen2_attn_forward(model, use_varlen_attn)
if USE_TRITON_KERNEL:
dispatch_qwen2_rmsnorm_forward(model)
if 'moe' in model_name and is_deepspeed_zero3_enabled():
set_qwen_moe_blocks_z3_leaf_modules(model)


__all__ = ['dispatch_modules']
15 changes: 11 additions & 4 deletions xtuner/model/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,10 +186,17 @@ def _prepare_for_flash_attn(cfg, llm_cfg):
'Qwen2MoeConfig', 'Starcoder2Config',
'Starcoder2Config', 'Phi3Config')

if SUPPORT_FLASH2 and cls_name in SUPPORT_FLASH_ATTN2:
cfg.torch_dtype = torch.bfloat16 if (
torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \
else torch.float16
torch_dtype = torch.bfloat16 if (
torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \
else torch.float16

if getattr(cfg, 'attn_implementation', None) is not None:
# Flash Attention 2.0 only supports torch.float16 and
# torch.bfloat16 dtypes
if cfg.attn_implementation == 'flash_attention_2':
cfg.torch_dtype = torch_dtype
elif SUPPORT_FLASH2 and cls_name in SUPPORT_FLASH_ATTN2:
cfg.torch_dtype = torch_dtype
cfg.attn_implementation = 'flash_attention_2'
elif SUPPORT_FLASH1 and cls_name in SUPPORT_SDPA_ATTN:
cfg.attn_implementation = 'sdpa'
Expand Down
18 changes: 18 additions & 0 deletions xtuner/tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,24 @@ def check_cfg(cfg):
if getattr(cfg, 'sequence_parallel_size', 1) > 1:
assert SUPPORT_FLASH2, ('`flash_attn` is required if you want to use '
'sequence parallel.')
attn_implementation = getattr(cfg.model.llm, 'attn_implementation',
None)
assert (attn_implementation is None or
attn_implementation == 'flash_attention_2'), \
('If you want to use sequence parallel, please set '
'attn_implementation to `flash_attention_2` or do not '
f'set this attribute. Got `{attn_implementation}` .')

if getattr(cfg, 'use_varlen_attn', False):
assert SUPPORT_FLASH2, ('`flash_attn` is required if you set '
'`use_varlen_attn` to True.')
attn_implementation = getattr(cfg.model.llm, 'attn_implementation',
None)
assert (attn_implementation is None or
attn_implementation == 'flash_attention_2'), \
('If you want to set `use_varlen_attn` to True, please set'
' attn_implementation to `flash_attention_2` or do not '
f'set this attribute. Got `{attn_implementation}` .')


def main():
Expand Down

0 comments on commit 60e0cc9

Please sign in to comment.