Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LoRA support #2049

Merged
merged 4 commits into from
Apr 7, 2024
Merged

LoRA support #2049

merged 4 commits into from
Apr 7, 2024

Conversation

fclearner
Copy link
Contributor

@fclearner fclearner commented Oct 12, 2023

LoRA support

references:
https://github.com/microsoft/LoRA/tree/bugfix_MergedLinear
https://kexue.fm/archives/9590
https://github.com/huggingface/peft
Qlora(better for memory cost): https://arxiv.org/pdf/2305.14314.pdf
(I think PEFT methods used in NLP all have the potential to tune the ASR model)

LoRA experiment:

gpus: 4*3090

lora_list in encoder

base_model finetune_mathod lora_rank lora_alpha lora_dropout lora_list decoding mode test-aishell CER wenetspeech-test-net CER num of parameters time/per-epoch gpu mem
wenet_u2pp_conformer(wenetspeech) baseline / / / / attention rescoring 5.83% 10.82 % / / /
wenet_u2pp_conformer(wenetspeech) aishell finetune (20 epochs) / / / / attention rescoring 3.16% 17.02% 122553574 645 seconds 21170MB
wenet_u2pp_conformer(wenetspeech) aishell LoRA (20 epochs) 8 8 0.1 q,k,v,linear_out attention rescoring 5.08% 14.32% 393216 350 seconds 17268MB
wenet_u2pp_conformer(wenetspeech) aishell LoRA (20 epochs) 16 16 0.1 q,k,v,linear_out attention rescoring 4.98% 14.51 % 786432 351 seconds 17321MB

@xingchensong
Copy link
Member

thx ! any numbers?

@fclearner
Copy link
Contributor Author

thx ! any numbers?

I conducted experiments on my own data and will soon update the experiment results on the open-source data.

@Mddct
Copy link
Collaborator

Mddct commented Oct 20, 2023

可以通过继承的方式, 比如loara_conformer_encoder, lora_attention, 重写encoder 和attention, 目录如下:

  • wenet/fintune/lora/encoder.py
  • wenet/fintune/lora/attention.py

然后在init_model.py 里边初始化,这样对原始代码几乎无侵入,并且fintue的方式还有lora变种 adapter等方式,可以方便后续扩展

@fclearner
Copy link
Contributor Author

可以通过继承的方式, 比如loara_conformer_encoder, lora_attention, 重写encoder 和attention, 目录如下:

  • wenet/fintune/lora/encoder.py
  • wenet/fintune/lora/attention.py

然后在init_model.py 里边初始化,这样对原始代码几乎无侵入,并且fintue的方式还有lora变种 adapter等方式,可以方便后续扩展

ok, it's a good idea

@xingchensong xingchensong added the enhancement New feature or request label Nov 2, 2023
@xingchensong
Copy link
Member

xingchensong commented Nov 2, 2023

可以通过继承的方式, 比如loara_conformer_encoder, lora_attention, 重写encoder 和attention, 目录如下:

  • wenet/fintune/lora/encoder.py
  • wenet/fintune/lora/attention.py

然后在init_model.py 里边初始化,这样对原始代码几乎无侵入,并且fintue的方式还有lora变种 adapter等方式,可以方便后续扩展

实验很赞,另外同意周哥看法,

  1. 可以按照这种方式(wenet/fintune/lora/encoder.py 继承 wenet/transformer/encoder.py, attention同理)吗?现在loralib里的文件也可以统一放进 wenet/fintune/lora, 好处之一是lora相关代码没有侵入修改,好处之二是后面可以追加 wenet/fintune/adapter 实现
  2. rebase一下代码,将train.py的相关修改转移到其他文件中(具体修改已评论)

然后可以合并了

wenet/bin/train.py Outdated Show resolved Hide resolved
wenet/bin/train.py Outdated Show resolved Hide resolved
wenet/bin/train.py Outdated Show resolved Hide resolved
@Mddct
Copy link
Collaborator

Mddct commented Apr 2, 2024

赞, 过两天看一下

xingchensong
xingchensong previously approved these changes Apr 7, 2024
@xingchensong
Copy link
Member

辛苦了!great job !

@xingchensong xingchensong merged commit 01ada04 into wenet-e2e:main Apr 7, 2024
6 checks passed
@srdfjy
Copy link
Contributor

srdfjy commented Jul 16, 2024

hi @fclearner ,之前说LoRA还有个补丁要打,请问这个什么时候merge呢?

@fclearner
Copy link
Contributor Author

hi @fclearner ,之前说LoRA还有个补丁要打,请问这个什么时候merge呢?

其实就是encoder的参数传岔了。。。因为之前的一个commit,你可以自己修一下,我还没来得及验证代码。。周末看看

@fclearner
Copy link
Contributor Author

hi @fclearner ,之前说LoRA还有个补丁要打,请问这个什么时候merge呢?

hello,我大概改了一版,但我这边训练有点问题,晚上看下:
https://github.com/fclearner/wenet/tree/LoRA_fix_args

@fclearner
Copy link
Contributor Author

hi @fclearner ,之前说LoRA还有个补丁要打,请问这个什么时候merge呢?

调model.eval()的时候会报错,不用deepspeed能跑,感觉是初始化的时候deepspeed有点问题,应该是类似这个issue的问题:huggingface/alignment-handbook#57

Zth9730 pushed a commit to Zth9730/wenet that referenced this pull request Aug 7, 2024
add casual model

fix typo

rm ckpt

add topk topp sampler

fix positoin

[train_engine] support fsdp (wenet-e2e#2412)

* [train_engine] support fsdp

* [train_engine] support fsdp

* unify scaler and amp

* fp32&&fp16 works in fsdp env

* fix fsdp in cv auto cast

* try to fix wenet.join fsdp

* implementing zero1 under fsdp is almost equivalent to deepspeed's zero1

* fix clip_and_grad_

* fix train summary

* all wenet xxxformer works (-paraformer -transducer)

* try to fix nan

* add barrier for cv

* add destroy group for end of all train

* refactor wrap methods and ckpt works

* fix ckpt

* fix cv in dtype != float32

* fix ckpt in model mode

* fix bf16 amp

* refactor scaler and autocast, fix fp32 fp16 bf16 for fsdp

* fix fp32 nullcontext to nullcontext()

* modify after review

* fix lint

* fix lint

LoRA support (wenet-e2e#2049)

* support lora for v3.0.1

* format code and update lora attention && encoder

* fix bug when lora_list is None

---------

Co-authored-by: Xingchen Song(宋星辰) <[email protected]>

[env] update python version and deepspeed version (wenet-e2e#2462)

* [env] update python version and deepspeed version

* [env] fix lint

fix rope pos embdining (wenet-e2e#2463)

* fix rope pos embdining

* fix dropout

* fix comment

[transformer] add multi warmup and learning rate for different modules (wenet-e2e#2449)

* [transformer] add multi warmup and learning rate for different modules

* fix typo

* it works in warmuplr

* fix lr in tensorboard in step mode

* fix cv log

* cv works

* refactor cv log

* add helper lrs_to_string

* fix lrstr

* fix ddp multiple lr

* fix initial step

* revert to -1

* fix sub params dup

* fix step

* fix step

* fix log

* add assert for scheduler

* add comment for log

---------

Co-authored-by: Xingchen Song(宋星辰) <[email protected]>

add generate

add toto

support sft & pretrain training forward

gemm conversion works

support init casual model

[whisper] limit language to Chinese (wenet-e2e#2470)

[train] convert tensor to scalar (wenet-e2e#2471)

[workflow] upgrad python version to 3.10 (wenet-e2e#2472)

* [workflow] upgrad python version to 3.10

* [workflow] try to pass

refactor cache behaviour in training mode (reduce compute cost and memory) (wenet-e2e#2473)

all gemma model works

fix ut

fix ut (wenet-e2e#2477)

* fix ut

* fix py version

[transformer] Make MoE runnable (wenet-e2e#2474)

[transformer] fix mqa (wenet-e2e#2478)

enable mmap in torch.load (wenet-e2e#2479)

[example] Add deespeed configs of different stages for illustration (wenet-e2e#2485)

[example] Fix prefetch and step_save (wenet-e2e#2486)

[ctl] simplified ctl (wenet-e2e#2483)

* [ctl] simplified ctl

* [ctl] unify

[branchformer] simplified branchformer (wenet-e2e#2482)

* [transformer] simplified branchformer

* fix yaml

* support mqa  gradiengt ckpt sdpa

* fix gradient checkponit

* add deepspeed comment in layer dropout

* fix comment

[e_branchformer] simplified e_branchformer (wenet-e2e#2484)

* [e_branchformer] simplified ctl

* try to fix ut

* try to fix ut

* fix activation

* fix att args

* e-branformer works

[transformer] refactor cache (wenet-e2e#2481)

* [transformer] refactor cache

* fix ut

* unify cache type in branchformer and ebranchformer

fix cache

fix gradient ckpt in branchformer/ebranformer (wenet-e2e#2488)

fix search after refactor cache (wenet-e2e#2490)

generate works!

unify chat pattern

convert llama3 works

[transformer] set use_reentrant=False for gradient ckpt (wenet-e2e#2491)

[transformer] fix warning: ignore(True) has been deprecated (wenet-e2e#2492)

* [transformer] fix warning: ignore(True) has been deprecated

* [transformer] fix warning: ignore(True) has been deprecated

[log] avoid reduntant logging (wenet-e2e#2493)

fix w1 w2 w3 in feedforward

add 70b temporarily

mv LLM to wenet

support llm dataset

unify config

add dataset yaml in script

support llm dataset

dynamic static bucket works

[transformer] refacgtor mqa repeat (wenet-e2e#2497)

[transformer] fix mqa in cross att (wenet-e2e#2498)

[deepspeed] update json config (wenet-e2e#2499)

training works

pretrain works

refactor covert

fix flash att in generate

llama works

fix llama3

fix speed

try fix ut

support stop tokens in gen and support ppl

support stop tokens in gen and support ppl
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants