diff --git a/xtuner/model/llava.py b/xtuner/model/llava.py index d588b6cd5..67955d02f 100644 --- a/xtuner/model/llava.py +++ b/xtuner/model/llava.py @@ -204,9 +204,17 @@ def _prepare_for_flash_attn(cfg, llm_cfg): 'Qwen2MoeConfig', 'Starcoder2Config', 'Starcoder2Config', 'Phi3Config') - if SUPPORT_FLASH2 and cls_name in SUPPORT_FLASH_ATTN2: - cfg.torch_dtype = torch.bfloat16 \ - if torch.cuda.is_bf16_supported() else torch.float16 + torch_dtype = torch.bfloat16 if ( + torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \ + else torch.float16 + + if getattr(cfg, 'attn_implementation', None) is not None: + # Flash Attention 2.0 only supports torch.float16 and + # torch.bfloat16 dtypes + if cfg.attn_implementation == 'flash_attention_2': + cfg.torch_dtype = torch_dtype + elif SUPPORT_FLASH2 and cls_name in SUPPORT_FLASH_ATTN2: + cfg.torch_dtype = torch_dtype cfg.attn_implementation = 'flash_attention_2' elif SUPPORT_FLASH1 and cls_name in SUPPORT_SDPA_ATTN: cfg.attn_implementation = 'sdpa' diff --git a/xtuner/model/modules/dispatch/__init__.py b/xtuner/model/modules/dispatch/__init__.py index 97f511ebe..2089a048b 100644 --- a/xtuner/model/modules/dispatch/__init__.py +++ b/xtuner/model/modules/dispatch/__init__.py @@ -7,6 +7,7 @@ import transformers from mmengine import print_log from mmengine.utils import digit_version +from transformers.integrations import is_deepspeed_zero3_enabled from .baichuan import (baichuan2_norm_head_forward, baichuan_7b_attn_forward, baichuan_13b_attn_forward) @@ -60,9 +61,10 @@ def dispatch_llama_attn_forward(model, use_varlen_attn): print_log(NO_ATTN_WEIGHTS_MSG, 'current', logging.WARNING) for module in model.modules(): # Do not need to dispatch if - # type(module).__name__ == 'LlamaSdpaAttention', as flash_attn is - # required when using sequence parallel - if type(module).__name__ in ('LlamaAttention', 'LlamaFlashAttention2'): + # type(module).__name__ in ('LlamaAttention', 'LlamaSdpaAttention'). + # If we set `attn_implementation` to `sdpa` or `eager` in xtuner + # configs, we can not use varlen attn and sequence parallel. + if type(module).__name__ == 'LlamaFlashAttention2': if use_varlen_attn: print_log('dispatch llama varlen attn forward', 'current') if IS_LOW_VERSION_TRANSFORMERS: @@ -174,8 +176,11 @@ def dispatch_internlm2_attn_forward(model, use_varlen_attn): print_log(NO_ATTN_WEIGHTS_MSG, 'current', logging.WARNING) for module in model.modules(): - if type(module).__name__ in ('InternLM2Attention', - 'InternLM2FlashAttention2'): + # Do not need to dispatch if + # type(module).__name__ == 'InternLM2Attention'. + # If we set `attn_implementation` to `eager` in xtuner + # configs, we can not use varlen attn and sequence parallel. + if type(module).__name__ == 'InternLM2FlashAttention2': if use_varlen_attn: print_log('dispatch internlm2 varlen attn forward', 'current') module.forward = types.MethodType( @@ -308,9 +313,12 @@ def dispatch_mistral_attn_forward(model, use_varlen_attn): print_log(NO_ATTN_WEIGHTS_MSG, 'current', logging.WARNING) for module in model.modules(): - if type(module).__name__ in ('MistralAttention', - 'MistralFlashAttention2', - 'MixtralAttention', + # Do not need to dispatch if + # type(module).__name__ in ('MistralAttention', 'MistralSdpaAttention', + # 'MixtralAttention', 'MixtralSdpaAttention') + # If we set `attn_implementation` to `sdpa` or `eager` in xtuner + # configs, we can not use varlen attn and sequence parallel. + if type(module).__name__ in ('MistralFlashAttention2', 'MixtralFlashAttention2'): if use_varlen_attn: print_log('dispatch mistral varlen attn forward', 'current') @@ -373,10 +381,10 @@ def dispatch_cohere_attn_forward(model, use_varlen_attn): print_log(NO_ATTN_WEIGHTS_MSG, 'current', logging.WARNING) for module in model.modules(): # Do not need to dispatch if - # type(module).__name__ == 'CohereSdpaAttention', as flash_attn is - # required when using sequence parallel - if type(module).__name__ in ('CohereAttention', - 'CohereFlashAttention2'): + # type(module).__name__ in ('CohereAttention', 'CohereSdpaAttention'). + # If we set `attn_implementation` to `sdpa` or `eager` in xtuner + # configs, we can not use varlen attn and sequence parallel. + if type(module).__name__ == 'CohereFlashAttention2': print_log('dispatch cohere attn forward', 'current') module.forward = types.MethodType(cohere_attn_forward, module) @@ -401,8 +409,12 @@ def dispatch_qwen2_attn_forward(model, use_varlen_attn): print_log(NO_ATTN_WEIGHTS_MSG, 'current', logging.WARNING) for module in model.modules(): - if type(module).__name__ in ('Qwen2Attention', 'Qwen2FlashAttention2', - 'Qwen2MoeAttention', + # Do not need to dispatch if + # type(module).__name__ in ('Qwen2Attention', 'Qwen2SdpaAttention', + # 'Qwen2MoeAttention', 'Qwen2MoeSdpaAttention') + # If we set `attn_implementation` to `sdpa` or `eager` in xtuner + # configs, we can not use varlen attn and sequence parallel. + if type(module).__name__ in ('Qwen2FlashAttention2', 'Qwen2MoeFlashAttention2'): if use_varlen_attn: print_log('dispatch qwen2 varlen attn forward', 'current') @@ -467,6 +479,8 @@ def dispatch_modules(model, use_varlen_attn=False): if USE_TRITON_KERNEL: dispatch_mistral_rmsnorm_forward(model) replace_mistral_rote(model) + if 'moe' in model_name and is_deepspeed_zero3_enabled(): + set_mixtral_moe_blocks_z3_leaf_modules(model) elif 'cohere' in model_name: dispatch_cohere_attn_forward(model, use_varlen_attn) dispatch_cohere_layernorm_forward(model) @@ -475,6 +489,8 @@ def dispatch_modules(model, use_varlen_attn=False): dispatch_qwen2_attn_forward(model, use_varlen_attn) if USE_TRITON_KERNEL: dispatch_qwen2_rmsnorm_forward(model) + if 'moe' in model_name and is_deepspeed_zero3_enabled(): + set_qwen_moe_blocks_z3_leaf_modules(model) __all__ = ['dispatch_modules'] diff --git a/xtuner/model/sft.py b/xtuner/model/sft.py index 94da26789..5fb4d6930 100644 --- a/xtuner/model/sft.py +++ b/xtuner/model/sft.py @@ -186,10 +186,17 @@ def _prepare_for_flash_attn(cfg, llm_cfg): 'Qwen2MoeConfig', 'Starcoder2Config', 'Starcoder2Config', 'Phi3Config') - if SUPPORT_FLASH2 and cls_name in SUPPORT_FLASH_ATTN2: - cfg.torch_dtype = torch.bfloat16 if ( - torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \ - else torch.float16 + torch_dtype = torch.bfloat16 if ( + torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \ + else torch.float16 + + if getattr(cfg, 'attn_implementation', None) is not None: + # Flash Attention 2.0 only supports torch.float16 and + # torch.bfloat16 dtypes + if cfg.attn_implementation == 'flash_attention_2': + cfg.torch_dtype = torch_dtype + elif SUPPORT_FLASH2 and cls_name in SUPPORT_FLASH_ATTN2: + cfg.torch_dtype = torch_dtype cfg.attn_implementation = 'flash_attention_2' elif SUPPORT_FLASH1 and cls_name in SUPPORT_SDPA_ATTN: cfg.attn_implementation = 'sdpa' diff --git a/xtuner/tools/train.py b/xtuner/tools/train.py index 23e3d2a3f..2696c9a4f 100644 --- a/xtuner/tools/train.py +++ b/xtuner/tools/train.py @@ -104,6 +104,24 @@ def check_cfg(cfg): if getattr(cfg, 'sequence_parallel_size', 1) > 1: assert SUPPORT_FLASH2, ('`flash_attn` is required if you want to use ' 'sequence parallel.') + attn_implementation = getattr(cfg.model.llm, 'attn_implementation', + None) + assert (attn_implementation is None or + attn_implementation == 'flash_attention_2'), \ + ('If you want to use sequence parallel, please set ' + 'attn_implementation to `flash_attention_2` or do not ' + f'set this attribute. Got `{attn_implementation}` .') + + if getattr(cfg, 'use_varlen_attn', False): + assert SUPPORT_FLASH2, ('`flash_attn` is required if you set ' + '`use_varlen_attn` to True.') + attn_implementation = getattr(cfg.model.llm, 'attn_implementation', + None) + assert (attn_implementation is None or + attn_implementation == 'flash_attention_2'), \ + ('If you want to set `use_varlen_attn` to True, please set' + ' attn_implementation to `flash_attention_2` or do not ' + f'set this attribute. Got `{attn_implementation}` .') def main():