-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[train_engine] support fsdp #2412
Conversation
2254b80
to
b3e3ff6
Compare
|
A100: step mode
|
手头机器无法使用nvitop之类的工具,先简单初步看下) zero3 当前并没有把embedding wrap成一个FSDP,以后有需求再做,后者吧wrap的逻辑设计成user-defined 模拟8卡 2080, zero3 fp32 |
该commit 2f1f4df 修复了A100 bf16 精度问题, 当前性能和deepspeed bf16 严格对齐 FSDP zero2 bf16
|
A100 fp32 zero2 FSDP zero2 fp32
|
"deepspeed": | ||
torch.cuda.amp.autocast(enabled=dtype is not None, | ||
dtype=dtype, | ||
cache_enabled=False), | ||
"torch_ddp": | ||
torch.cuda.amp.autocast(enabled=scaler is not None), | ||
"torch_fsdp": | ||
torch.cuda.amp.autocast(enabled=True, dtype=dtype) | ||
if dtype is not None else nullcontext() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我觉得ddp应该也可以set dtype
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果dtype这个arg挪到add fsdp args函数里,那么这里ddp的amp就让他写死fp16也无妨了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
周哥,感觉fsdp的autocast可以和deepspeed合并哎,dtype是none的时候,enable是false(也就相当于nullcontext),这个语义两者应该是一样的(有无试过直接复用deepspeed的autocast配置?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
按道理是可以, 不过fsdp的在torch.cuda.amp.autocast(enable=False, dtype=fp32)下性能就是会差点,只有nullcontext才会性能好点 这里不明觉历,
这里也用了nullcontext, https://github.com/meta-llama/llama-recipes/blob/main/src/llama_recipes/utils/train_utils.py#L64
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
有对比数据吗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
有的 今天贴一下
# TODO(Mddct): Support user customization | ||
# see more wrap methods: | ||
# https://github.com/meta-llama/llama-recipes/blob/main/src/llama_recipes/utils/fsdp_utils.py#L13 # noqa | ||
if mode == 'model': | ||
enc_dec_wrap_policy = partial( | ||
lambda_auto_wrap_policy, | ||
lambda_fn=lambda module: isinstance( | ||
module, | ||
tuple(WENET_ENCODER_CLASSES.values()) + tuple( | ||
WENET_DECODER_CLASSES.values()))) | ||
return enc_dec_wrap_policy | ||
else: | ||
to_wrap_class = set() | ||
to_wrap_class.update(set(WENET_ENCODER_LAYERS_CLASSES.values())) | ||
to_wrap_class.update(set(WENET_DECODER_LAYERS_CLASSES.values())) | ||
layers_wrap_policy = partial(transformer_auto_wrap_policy, | ||
transformer_layer_cls=to_wrap_class) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
要是能贴个link解释下lambda wrap和transformer wrap的区别就好了,方便学习 (感谢周哥:))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
从使用方式上看,lambda是用于整个encoder级别,transformer是用于layer级别,为啥要有这两种不同的划分呢
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
一个wrap 意味着在这个wrap的forward上in out 的梯度 和 optimizer的切分会进行一次all gather 的通性。
fsdp 比较灵活,通过wrap的方式控制“切分”的力度, 所以在 enc dec的力度上相当于只有optimzier的切分,没有梯度的切分,(内存优化相当于zero1) 在每一个layer上的wrap就有了layer级别的切分相当于zero2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wenet/utils/fsdp_utils.py
Outdated
def check_gradient_checkpoint(model): | ||
ckpt_laye_types = [] | ||
if hasattr(model, 'encoder') and hasattr(model.encoder, | ||
'gradient_checkpointing'): | ||
if model.encoder.gradient_checkpointing: | ||
model.encoder.gradient_checkpointing = False | ||
ckpt_laye_types += list(WENET_ENCODER_LAYERS_CLASSES.values()) | ||
if hasattr(model, 'decoder') and hasattr(model.decoder, | ||
'gradient_checkpointing'): | ||
if model.decoder.gradient_checkpointing: | ||
model.decoder.gradient_checkpointing = False | ||
ckpt_laye_types += list(WENET_DECODER_LAYERS_CLASSES.values()) | ||
return tuple(ckpt_laye_types) | ||
|
||
|
||
def apply_fsdp_checkpointing(model, ckpt_layer_types: tuple): | ||
if len(ckpt_layer_types) == 0: | ||
return | ||
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( | ||
checkpoint_wrapper, | ||
CheckpointImpl, | ||
apply_activation_checkpointing, | ||
) | ||
non_reentrant_wrapper = partial( | ||
checkpoint_wrapper, | ||
checkpoint_impl=CheckpointImpl.NO_REENTRANT, | ||
) | ||
apply_activation_checkpointing( | ||
model, | ||
checkpoint_wrapper_fn=non_reentrant_wrapper, | ||
check_fn=lambda submodule: isinstance(submodule, ckpt_layer_types)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这套grad ckpt感觉挺奇怪的,deepspeed都能无缝和ddp公用一套grad ckpt的配置和使用逻辑,fsdp这里要这么搞是为啥捏?(原因建议也补充成note写到代码里嘿嘿,方便大家学习)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fsdp (zero2 zero3) 也可以用wenet的这种方式,model 方式不太行 (应该是目前fsdp的一个bug),
所以这里以统一的方式写了, 而且这种方式在LLM里边是一种标配的写法了 参考:https://github.com/meta-llama/llama-recipes/blob/main/src/llama_recipes/policies/activation_checkpointing_functions.py#L21
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这样有个好处, 如果model代码本身没有ckpt的调用,通过这种方式可以不用改model里边的代码,在最外层apply,零侵入
wenet/utils/fsdp_utils.py
Outdated
from wenet.utils.init_model import WENET_DECODER_CLASSES, WENET_ENCODER_CLASSES | ||
|
||
WENET_ENCODER_LAYERS_CLASSES = { | ||
'transformer_encoder_laer': TransformerEncoderLayer, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
laer->layer,另外也没太看懂为啥要用dict,我看后面全都只取了values(),keys似乎没用上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里只是为了和class_utils.py 里边命名一致
great job! ❤️ |
请问如果data_type用shard的话,速度会慢很多吗,我在wenetspeech上的step/sec只能到0.18 单机多卡不影响, 但是目前fsdp/dppespeed scale到多机 @xingchensong 以及机器通信的带宽, 后续pr会尝试修复多机fsdp 多机的问题 请问是多机多卡吗? |
我用的是单机4*A100,那不是data_type的问题? |
那不是 |
add casual model fix typo rm ckpt add topk topp sampler fix positoin [train_engine] support fsdp (wenet-e2e#2412) * [train_engine] support fsdp * [train_engine] support fsdp * unify scaler and amp * fp32&&fp16 works in fsdp env * fix fsdp in cv auto cast * try to fix wenet.join fsdp * implementing zero1 under fsdp is almost equivalent to deepspeed's zero1 * fix clip_and_grad_ * fix train summary * all wenet xxxformer works (-paraformer -transducer) * try to fix nan * add barrier for cv * add destroy group for end of all train * refactor wrap methods and ckpt works * fix ckpt * fix cv in dtype != float32 * fix ckpt in model mode * fix bf16 amp * refactor scaler and autocast, fix fp32 fp16 bf16 for fsdp * fix fp32 nullcontext to nullcontext() * modify after review * fix lint * fix lint LoRA support (wenet-e2e#2049) * support lora for v3.0.1 * format code and update lora attention && encoder * fix bug when lora_list is None --------- Co-authored-by: Xingchen Song(宋星辰) <[email protected]> [env] update python version and deepspeed version (wenet-e2e#2462) * [env] update python version and deepspeed version * [env] fix lint fix rope pos embdining (wenet-e2e#2463) * fix rope pos embdining * fix dropout * fix comment [transformer] add multi warmup and learning rate for different modules (wenet-e2e#2449) * [transformer] add multi warmup and learning rate for different modules * fix typo * it works in warmuplr * fix lr in tensorboard in step mode * fix cv log * cv works * refactor cv log * add helper lrs_to_string * fix lrstr * fix ddp multiple lr * fix initial step * revert to -1 * fix sub params dup * fix step * fix step * fix log * add assert for scheduler * add comment for log --------- Co-authored-by: Xingchen Song(宋星辰) <[email protected]> add generate add toto support sft & pretrain training forward gemm conversion works support init casual model [whisper] limit language to Chinese (wenet-e2e#2470) [train] convert tensor to scalar (wenet-e2e#2471) [workflow] upgrad python version to 3.10 (wenet-e2e#2472) * [workflow] upgrad python version to 3.10 * [workflow] try to pass refactor cache behaviour in training mode (reduce compute cost and memory) (wenet-e2e#2473) all gemma model works fix ut fix ut (wenet-e2e#2477) * fix ut * fix py version [transformer] Make MoE runnable (wenet-e2e#2474) [transformer] fix mqa (wenet-e2e#2478) enable mmap in torch.load (wenet-e2e#2479) [example] Add deespeed configs of different stages for illustration (wenet-e2e#2485) [example] Fix prefetch and step_save (wenet-e2e#2486) [ctl] simplified ctl (wenet-e2e#2483) * [ctl] simplified ctl * [ctl] unify [branchformer] simplified branchformer (wenet-e2e#2482) * [transformer] simplified branchformer * fix yaml * support mqa gradiengt ckpt sdpa * fix gradient checkponit * add deepspeed comment in layer dropout * fix comment [e_branchformer] simplified e_branchformer (wenet-e2e#2484) * [e_branchformer] simplified ctl * try to fix ut * try to fix ut * fix activation * fix att args * e-branformer works [transformer] refactor cache (wenet-e2e#2481) * [transformer] refactor cache * fix ut * unify cache type in branchformer and ebranchformer fix cache fix gradient ckpt in branchformer/ebranformer (wenet-e2e#2488) fix search after refactor cache (wenet-e2e#2490) generate works! unify chat pattern convert llama3 works [transformer] set use_reentrant=False for gradient ckpt (wenet-e2e#2491) [transformer] fix warning: ignore(True) has been deprecated (wenet-e2e#2492) * [transformer] fix warning: ignore(True) has been deprecated * [transformer] fix warning: ignore(True) has been deprecated [log] avoid reduntant logging (wenet-e2e#2493) fix w1 w2 w3 in feedforward add 70b temporarily mv LLM to wenet support llm dataset unify config add dataset yaml in script support llm dataset dynamic static bucket works [transformer] refacgtor mqa repeat (wenet-e2e#2497) [transformer] fix mqa in cross att (wenet-e2e#2498) [deepspeed] update json config (wenet-e2e#2499) training works pretrain works refactor covert fix flash att in generate llama works fix llama3 fix speed try fix ut support stop tokens in gen and support ppl support stop tokens in gen and support ppl
torch fsdp to support large speech model or llm
make it works
fsdp memory usage analysis
aishell benchmark
whisper benchmark
v100 zero1 and zero2 fp32 对比
a100 zero1 and zero2 fp32. bf16
与fsdp相关的pr
后续pr需要考虑