Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[train_engine] support fsdp #2412

Merged
merged 26 commits into from
Apr 7, 2024
Merged

[train_engine] support fsdp #2412

merged 26 commits into from
Apr 7, 2024

Conversation

Mddct
Copy link
Collaborator

@Mddct Mddct commented Mar 15, 2024

torch fsdp to support large speech model or llm

  • make it works

    • no shard (== ddp)
    • wrap enc dec
    • zero2
    • zero3
    • uneven data
    • fix clip_grad_norm_ in no shard env
    • activation checkpoint
  • fsdp memory usage analysis

  • aishell benchmark

    • no shard
    • wrap enc dec
    • zero2
    • zero3
  • whisper benchmark

    • fp32 (v100)
    • bf16 (A100)
  • v100 zero1 and zero2 fp32 对比

  • a100 zero1 and zero2 fp32. bf16

与fsdp相关的pr

后续pr需要考虑

  • 对于requried_grad=False 的params, adam weight decay 需要置0
  • 对于shared embeding 有些限制: 必须yao要在同一个wrap的unit内(FSDP unit)
截屏2024-03-31 23 47 17

@Mddct Mddct mentioned this pull request Mar 15, 2024
2 tasks
wenet/utils/train_utils.py Outdated Show resolved Hide resolved
wenet/utils/train_utils.py Show resolved Hide resolved
@Mddct Mddct force-pushed the Mddct-fsdp branch 2 times, most recently from 2254b80 to b3e3ff6 Compare March 17, 2024 04:57
@Mddct
Copy link
Collaborator Author

Mddct commented Mar 24, 2024

#2363 (comment)

batch size data type 训练时间 att/rescore/ctc greedy/ctc beam wer steps/sec
noshard step 模式 avg 20 step 1000 save interval (stage1 shuffle) bucket_boundaries: [500, 1000, 1500]
bucket_batch_sizes: [128, 64, 32, 16]
raw 11h30min (忽略机器io出问题了) 5.59/5.22/5.73/5.74 1.9-2.2
zero1 step 模式 avg 20 step 1000 save interval (stage1 shuffle) bucket_boundaries: [500, 1000, 1500]
bucket_batch_sizes: [128, 64, 32, 16]
raw 10h24min 5.63/5.32/5.96/5.96 2.2
zero2 但是只wrap enc dec(非layer级) step 模式 avg 20 step 1000 save interval (stage1 shuffle) bucket_boundaries: [500, 1000, 1500]
bucket_batch_sizes: [128, 64, 32, 16]
raw 10h28min 5.66/5.39/6.16/6.16 2.2
zero2 step 模式 avg 20 step 1000 save interval (stage1 shuffle) bucket_boundaries: [500, 1000, 1500]
bucket_batch_sizes: [128, 64, 32, 16]
raw 15h7min 5.72/5.38/5.90/5.90 1.5
zero3 step 模式 avg 20 step 1000 save interval (stage1 shuffle) bucket_boundaries: [500, 1000, 1500]
bucket_batch_sizes: [128, 64, 32, 16]
raw 17h43min 5.59/5.30/83/5.83 1.3

@Mddct
Copy link
Collaborator Author

Mddct commented Mar 27, 2024

A100: step mode

batch size data type 训练时间 att/rescore/ctc greedy/ctc beam wer steps/sec
noshard step 模式 avg 20 step 1000 save interval bucket_boundaries: [500, 1000, 1500]
bucket_batch_sizes: [128, 64, 32, 16]
raw 7.5 5.69/5.36/5.92/5.92/ 3.1
model step 模式 avg 20 step 1000 save interval bucket_boundaries: [500, 1000, 1500]
bucket_batch_sizes: [128, 64, 32, 16]
raw 7.5 5.60/5.30/5.98/5.98 3.1
zero2 step 模式 avg 20 step 1000 save interval bucket_boundaries: [500, 1000, 1500]
bucket_batch_sizes: [128, 64, 32, 16]
raw 7.5 5.61/5.30/5.82/5.82 3.1
zero3 step 模式 avg 20 step 1000 save interval bucket_boundaries: [500, 1000, 1500]
bucket_batch_sizes: [128, 64, 32, 16]
raw 7.6 5.61/5.28/5.85/5.85 3.0

@Mddct
Copy link
Collaborator Author

Mddct commented Mar 27, 2024

#2173 (comment)

手头机器无法使用nvitop之类的工具,先简单初步看下)
4卡a100 ,限制显存24G, https://pytorch.org/docs/stable/generated/torch.cuda.set_per_process_memory_fraction.html
zero2:
截屏2024-03-27 16 25 44

zero3

截屏2024-03-27 16 43 34

当前并没有把embedding wrap成一个FSDP,以后有需求再做,后者吧wrap的逻辑设计成user-defined

模拟8卡 2080, zero3 fp32

截屏2024-03-27 17 08 06

@Mddct
Copy link
Collaborator Author

Mddct commented Apr 6, 2024

该commit 2f1f4df 修复了A100 bf16 精度问题, 当前性能和deepspeed bf16 严格对齐

FSDP zero2 bf16

  • Feature info: using log_mel_spectrogram feature, no cmvn, no speed perturb
  • Training info: bf16, fsdp zero2, activation checkpointing, batch dynamic12000, acc_grad 1, 8 *A00 gpu, 80 epoch (step模式) conf/finetune_whisper_largev3_conv2d4.yaml (10 hours)
  • Decoding info: ctc_weight 0.3, average_num 3
  • Git hash: TBD
decoding mode CER
attention decoder 3.37 % N=104765 C=101361 S=3283 D=121 I=127
ctc greedy search 7.43 % N=104765 C=97819 S=6776 D=170 I=835
ctc prefix beam search 7.41 % N=104765 C=97831 S=6766 D=168 I=834
attention rescoring 5.64 % N=104765 C=99135 S=5521 D=109 I=274

@Mddct
Copy link
Collaborator Author

Mddct commented Apr 7, 2024

A100 fp32 zero2

FSDP zero2 fp32

  • Feature info: using log_mel_spectrogram feature, no cmvn, no speed perturb
  • Training info: fp32, fsdp zero2, activation checkpointing, batch dynamic12000, acc_grad 1, 8 *A00 gpu, 80 epoch (step模式) conf/finetune_whisper_largev3_conv2d4.yaml (1D10 hours)
  • Decoding info: ctc_weight 0.3, average_num 3
  • Git hash: TBD
decoding mode CER
attention decoder 3.32 % N=104765 C=101390 S=3245 D=130 I=104
ctc greedy search 7.08 % N=104765 C=98104 S=6489 D=172 I=754
ctc prefix beam search 7.07 % N=104765 C=98108 S=6486 D=171 I=754
attention rescoring 5.43 % N=104765 C=99305 S=5348 D=112 I=227

xingchensong
xingchensong previously approved these changes Apr 7, 2024
Comment on lines +558 to +566
"deepspeed":
torch.cuda.amp.autocast(enabled=dtype is not None,
dtype=dtype,
cache_enabled=False),
"torch_ddp":
torch.cuda.amp.autocast(enabled=scaler is not None),
"torch_fsdp":
torch.cuda.amp.autocast(enabled=True, dtype=dtype)
if dtype is not None else nullcontext()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我觉得ddp应该也可以set dtype

Copy link
Member

@xingchensong xingchensong Apr 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果dtype这个arg挪到add fsdp args函数里,那么这里ddp的amp就让他写死fp16也无妨了

Copy link
Member

@xingchensong xingchensong Apr 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

周哥,感觉fsdp的autocast可以和deepspeed合并哎,dtype是none的时候,enable是false(也就相当于nullcontext),这个语义两者应该是一样的(有无试过直接复用deepspeed的autocast配置?)

Copy link
Collaborator Author

@Mddct Mddct Apr 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

按道理是可以, 不过fsdp的在torch.cuda.amp.autocast(enable=False, dtype=fp32)下性能就是会差点,只有nullcontext才会性能好点 这里不明觉历,
这里也用了nullcontext, https://github.com/meta-llama/llama-recipes/blob/main/src/llama_recipes/utils/train_utils.py#L64

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有对比数据吗

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有的 今天贴一下

wenet/bin/train.py Outdated Show resolved Hide resolved
Comment on lines +43 to +59
# TODO(Mddct): Support user customization
# see more wrap methods:
# https://github.com/meta-llama/llama-recipes/blob/main/src/llama_recipes/utils/fsdp_utils.py#L13 # noqa
if mode == 'model':
enc_dec_wrap_policy = partial(
lambda_auto_wrap_policy,
lambda_fn=lambda module: isinstance(
module,
tuple(WENET_ENCODER_CLASSES.values()) + tuple(
WENET_DECODER_CLASSES.values())))
return enc_dec_wrap_policy
else:
to_wrap_class = set()
to_wrap_class.update(set(WENET_ENCODER_LAYERS_CLASSES.values()))
to_wrap_class.update(set(WENET_DECODER_LAYERS_CLASSES.values()))
layers_wrap_policy = partial(transformer_auto_wrap_policy,
transformer_layer_cls=to_wrap_class)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

要是能贴个link解释下lambda wrap和transformer wrap的区别就好了,方便学习 (感谢周哥:))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

从使用方式上看,lambda是用于整个encoder级别,transformer是用于layer级别,为啥要有这两种不同的划分呢

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

一个wrap 意味着在这个wrap的forward上in out 的梯度 和 optimizer的切分会进行一次all gather 的通性。

fsdp 比较灵活,通过wrap的方式控制“切分”的力度, 所以在 enc dec的力度上相当于只有optimzier的切分,没有梯度的切分,(内存优化相当于zero1) 在每一个layer上的wrap就有了layer级别的切分相当于zero2

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines 80 to 110
def check_gradient_checkpoint(model):
ckpt_laye_types = []
if hasattr(model, 'encoder') and hasattr(model.encoder,
'gradient_checkpointing'):
if model.encoder.gradient_checkpointing:
model.encoder.gradient_checkpointing = False
ckpt_laye_types += list(WENET_ENCODER_LAYERS_CLASSES.values())
if hasattr(model, 'decoder') and hasattr(model.decoder,
'gradient_checkpointing'):
if model.decoder.gradient_checkpointing:
model.decoder.gradient_checkpointing = False
ckpt_laye_types += list(WENET_DECODER_LAYERS_CLASSES.values())
return tuple(ckpt_laye_types)


def apply_fsdp_checkpointing(model, ckpt_layer_types: tuple):
if len(ckpt_layer_types) == 0:
return
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
CheckpointImpl,
apply_activation_checkpointing,
)
non_reentrant_wrapper = partial(
checkpoint_wrapper,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
)
apply_activation_checkpointing(
model,
checkpoint_wrapper_fn=non_reentrant_wrapper,
check_fn=lambda submodule: isinstance(submodule, ckpt_layer_types))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这套grad ckpt感觉挺奇怪的,deepspeed都能无缝和ddp公用一套grad ckpt的配置和使用逻辑,fsdp这里要这么搞是为啥捏?(原因建议也补充成note写到代码里嘿嘿,方便大家学习)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fsdp (zero2 zero3) 也可以用wenet的这种方式,model 方式不太行 (应该是目前fsdp的一个bug),
所以这里以统一的方式写了, 而且这种方式在LLM里边是一种标配的写法了 参考:https://github.com/meta-llama/llama-recipes/blob/main/src/llama_recipes/policies/activation_checkpointing_functions.py#L21

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这样有个好处, 如果model代码本身没有ckpt的调用,通过这种方式可以不用改model里边的代码,在最外层apply,零侵入

from wenet.utils.init_model import WENET_DECODER_CLASSES, WENET_ENCODER_CLASSES

WENET_ENCODER_LAYERS_CLASSES = {
'transformer_encoder_laer': TransformerEncoderLayer,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

laer->layer,另外也没太看懂为啥要用dict,我看后面全都只取了values(),keys似乎没用上

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里只是为了和class_utils.py 里边命名一致

@Mddct Mddct requested a review from xingchensong April 7, 2024 15:40
@xingchensong xingchensong merged commit b8191ce into main Apr 7, 2024
6 checks passed
@xingchensong xingchensong deleted the Mddct-fsdp branch April 7, 2024 16:47
@xingchensong
Copy link
Member

great job! ❤️

@programYoung
Copy link

programYoung commented May 31, 2024

A100: step mode

batch size data type 训练时间 att/rescore/ctc greedy/ctc beam wer steps/sec
noshard step 模式 avg 20 step 1000 save interval bucket_boundaries: [500, 1000, 1500]
bucket_batch_sizes: [128, 64, 32, 16] raw 7.5 5.69/5.36/5.92/5.92/ 3.1
model step 模式 avg 20 step 1000 save interval bucket_boundaries: [500, 1000, 1500]
bucket_batch_sizes: [128, 64, 32, 16] raw 7.5 5.60/5.30/5.98/5.98 3.1
zero2 step 模式 avg 20 step 1000 save interval bucket_boundaries: [500, 1000, 1500]
bucket_batch_sizes: [128, 64, 32, 16] raw 7.5 5.61/5.30/5.82/5.82 3.1
zero3 step 模式 avg 20 step 1000 save interval bucket_boundaries: [500, 1000, 1500]
bucket_batch_sizes: [128, 64, 32, 16] raw 7.6 5.61/5.28/5.85/5.85 3.0

请问如果data_type用shard的话,速度会慢很多吗,我在wenetspeech上的step/sec只能到0.18

单机多卡不影响, 但是目前fsdp/dppespeed scale到多机 @xingchensong 以及机器通信的带宽, 后续pr会尝试修复多机fsdp 多机的问题

请问是多机多卡吗?

@programYoung
Copy link

A100: step mode
batch size data type 训练时间 att/rescore/ctc greedy/ctc beam wer steps/sec
noshard step 模式 avg 20 step 1000 save interval bucket_boundaries: [500, 1000, 1500]
bucket_batch_sizes: [128, 64, 32, 16] raw 7.5 5.69/5.36/5.92/5.92/ 3.1
model step 模式 avg 20 step 1000 save interval bucket_boundaries: [500, 1000, 1500]
bucket_batch_sizes: [128, 64, 32, 16] raw 7.5 5.60/5.30/5.98/5.98 3.1
zero2 step 模式 avg 20 step 1000 save interval bucket_boundaries: [500, 1000, 1500]
bucket_batch_sizes: [128, 64, 32, 16] raw 7.5 5.61/5.30/5.82/5.82 3.1
zero3 step 模式 avg 20 step 1000 save interval bucket_boundaries: [500, 1000, 1500]
bucket_batch_sizes: [128, 64, 32, 16] raw 7.6 5.61/5.28/5.85/5.85 3.0

请问如果data_type用shard的话,速度会慢很多吗,我在wenetspeech上的step/sec只能到0.18

单机多卡不影响, 但是目前fsdp/dppespeed scale到多机 @xingchensong 以及机器通信的带宽, 后续pr会尝试修复多机fsdp 多机的问题

请问是多机多卡吗?

我用的是单机4*A100,那不是data_type的问题?

@Mddct
Copy link
Collaborator Author

Mddct commented May 31, 2024

A100: step mode
batch size data type 训练时间 att/rescore/ctc greedy/ctc beam wer steps/sec
noshard step 模式 avg 20 step 1000 save interval bucket_boundaries: [500, 1000, 1500]
bucket_batch_sizes: [128, 64, 32, 16] raw 7.5 5.69/5.36/5.92/5.92/ 3.1
model step 模式 avg 20 step 1000 save interval bucket_boundaries: [500, 1000, 1500]
bucket_batch_sizes: [128, 64, 32, 16] raw 7.5 5.60/5.30/5.98/5.98 3.1
zero2 step 模式 avg 20 step 1000 save interval bucket_boundaries: [500, 1000, 1500]
bucket_batch_sizes: [128, 64, 32, 16] raw 7.5 5.61/5.30/5.82/5.82 3.1
zero3 step 模式 avg 20 step 1000 save interval bucket_boundaries: [500, 1000, 1500]
bucket_batch_sizes: [128, 64, 32, 16] raw 7.6 5.61/5.28/5.85/5.85 3.0

请问如果data_type用shard的话,速度会慢很多吗,我在wenetspeech上的step/sec只能到0.18

单机多卡不影响, 但是目前fsdp/dppespeed scale到多机 @xingchensong 以及机器通信的带宽, 后续pr会尝试修复多机fsdp 多机的问题

请问是多机多卡吗?

我用的是单机4*A100,那不是data_type的问题?

那不是

Zth9730 pushed a commit to Zth9730/wenet that referenced this pull request Aug 7, 2024
add casual model

fix typo

rm ckpt

add topk topp sampler

fix positoin

[train_engine] support fsdp (wenet-e2e#2412)

* [train_engine] support fsdp

* [train_engine] support fsdp

* unify scaler and amp

* fp32&&fp16 works in fsdp env

* fix fsdp in cv auto cast

* try to fix wenet.join fsdp

* implementing zero1 under fsdp is almost equivalent to deepspeed's zero1

* fix clip_and_grad_

* fix train summary

* all wenet xxxformer works (-paraformer -transducer)

* try to fix nan

* add barrier for cv

* add destroy group for end of all train

* refactor wrap methods and ckpt works

* fix ckpt

* fix cv in dtype != float32

* fix ckpt in model mode

* fix bf16 amp

* refactor scaler and autocast, fix fp32 fp16 bf16 for fsdp

* fix fp32 nullcontext to nullcontext()

* modify after review

* fix lint

* fix lint

LoRA support (wenet-e2e#2049)

* support lora for v3.0.1

* format code and update lora attention && encoder

* fix bug when lora_list is None

---------

Co-authored-by: Xingchen Song(宋星辰) <[email protected]>

[env] update python version and deepspeed version (wenet-e2e#2462)

* [env] update python version and deepspeed version

* [env] fix lint

fix rope pos embdining (wenet-e2e#2463)

* fix rope pos embdining

* fix dropout

* fix comment

[transformer] add multi warmup and learning rate for different modules (wenet-e2e#2449)

* [transformer] add multi warmup and learning rate for different modules

* fix typo

* it works in warmuplr

* fix lr in tensorboard in step mode

* fix cv log

* cv works

* refactor cv log

* add helper lrs_to_string

* fix lrstr

* fix ddp multiple lr

* fix initial step

* revert to -1

* fix sub params dup

* fix step

* fix step

* fix log

* add assert for scheduler

* add comment for log

---------

Co-authored-by: Xingchen Song(宋星辰) <[email protected]>

add generate

add toto

support sft & pretrain training forward

gemm conversion works

support init casual model

[whisper] limit language to Chinese (wenet-e2e#2470)

[train] convert tensor to scalar (wenet-e2e#2471)

[workflow] upgrad python version to 3.10 (wenet-e2e#2472)

* [workflow] upgrad python version to 3.10

* [workflow] try to pass

refactor cache behaviour in training mode (reduce compute cost and memory) (wenet-e2e#2473)

all gemma model works

fix ut

fix ut (wenet-e2e#2477)

* fix ut

* fix py version

[transformer] Make MoE runnable (wenet-e2e#2474)

[transformer] fix mqa (wenet-e2e#2478)

enable mmap in torch.load (wenet-e2e#2479)

[example] Add deespeed configs of different stages for illustration (wenet-e2e#2485)

[example] Fix prefetch and step_save (wenet-e2e#2486)

[ctl] simplified ctl (wenet-e2e#2483)

* [ctl] simplified ctl

* [ctl] unify

[branchformer] simplified branchformer (wenet-e2e#2482)

* [transformer] simplified branchformer

* fix yaml

* support mqa  gradiengt ckpt sdpa

* fix gradient checkponit

* add deepspeed comment in layer dropout

* fix comment

[e_branchformer] simplified e_branchformer (wenet-e2e#2484)

* [e_branchformer] simplified ctl

* try to fix ut

* try to fix ut

* fix activation

* fix att args

* e-branformer works

[transformer] refactor cache (wenet-e2e#2481)

* [transformer] refactor cache

* fix ut

* unify cache type in branchformer and ebranchformer

fix cache

fix gradient ckpt in branchformer/ebranformer (wenet-e2e#2488)

fix search after refactor cache (wenet-e2e#2490)

generate works!

unify chat pattern

convert llama3 works

[transformer] set use_reentrant=False for gradient ckpt (wenet-e2e#2491)

[transformer] fix warning: ignore(True) has been deprecated (wenet-e2e#2492)

* [transformer] fix warning: ignore(True) has been deprecated

* [transformer] fix warning: ignore(True) has been deprecated

[log] avoid reduntant logging (wenet-e2e#2493)

fix w1 w2 w3 in feedforward

add 70b temporarily

mv LLM to wenet

support llm dataset

unify config

add dataset yaml in script

support llm dataset

dynamic static bucket works

[transformer] refacgtor mqa repeat (wenet-e2e#2497)

[transformer] fix mqa in cross att (wenet-e2e#2498)

[deepspeed] update json config (wenet-e2e#2499)

training works

pretrain works

refactor covert

fix flash att in generate

llama works

fix llama3

fix speed

try fix ut

support stop tokens in gen and support ppl

support stop tokens in gen and support ppl
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants