Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[enhancement] improve the performace of bloom model conversion, reduce the memory and time cost #568

Merged
merged 2 commits into from
Apr 24, 2023

Conversation

Yangruipis
Copy link
Contributor

what can i do

  1. Load model from pytorch bin / safetensors file directory, instead of call from_pretrained function of transformers, cause the from_pretrained method may take a lot of time for weight initialization, and auto covnert bf16 weights to fp32 which doubles the memory.
  2. Convert every single model file in subprocess with python multiprocessing. The origin code loads entire model in main process, and pass each parameters to subprocess, this may take a lot of memory and cost more time.

conversion benchmarks

code model format by-shard nproc elapsed(s) memory(G)
before bloom-175b safetensors x 72 / 8 3390.0 [1] 700
after bloom-175b safetensors x 72 NO 8 1516.66 350
after bloom-175b safetensors x 72 YES 8 1165.03 50
after bloom-175b safetensors x 72 YES 24 494.81 150

[1]: from_pretrained: 1910.47, convert: 1479.53

some screenshots

image

image

image

@Yangruipis
Copy link
Contributor Author

output alignment

I compared the md5 of each wegiths of model bigscience/bloom-560m, including:

  1. code before this pr
  2. code after this pr, disable --by-shard
  3. code after this pr, enable --by-shard

And the weight files md5 are ALL strictly alignmented

screenshots

image

@jaedeok-nvidia
Copy link

Thank @Yangruipis to propose a significant improvement of the bloom converter. It looks very good to me.
Related to the bloom converter there is a multiprocessing stuck issue, can you please help us to fix the issue in this PR (explained in the comments)?

@Yangruipis
Copy link
Contributor Author

sure, I'll check it later

@Yangruipis
Copy link
Contributor Author

@kjaedeok I couldn't locate the issue or any related comments. Could you please direct me to them?

@jaedeok-nvidia
Copy link

The issue was reported directly to us, and the above workaround solved it.

@Yangruipis
Copy link
Contributor Author

glad to hear that

@byshiue byshiue merged commit 19b2956 into NVIDIA:main Apr 24, 2023
sfc-gh-ashankar added a commit to neevaco/FasterTransformer that referenced this pull request Jul 11, 2023
* Update beam_search_topk_kernels.cu

fix: fix bug of beam search

* fix: change int of some kernels to int64_t to prevent overflow

* fix: gpt tensor shapes inconsistency (NVIDIA#505)

Signed-off-by: AkiyamaYummy <[email protected]>

* Update gpt_guide.md (NVIDIA#529)

* fix: fix bug of gpt buffer and gpt gemm overflow

* Update T5DecodingWeight.cc

fix: fix loading bug of t5

* [Enhancement]add pytorch backend support for gptneox (NVIDIA#550)

* add pytorch backend support for gptneox

Signed-off-by: AkiyamaYummy <[email protected]>

* fix early stopping invalid

* 1) Some unused parameters and logic have been removed. 2) Revisions that would affect pipeline parallelism have been reverted. 3) The code has been made capable of direct validation on TabbyML/NeoX-1.3B.

Signed-off-by: AkiyamaYummy <[email protected]>

* Change the names of classes, removing 'parallel' from their names

Signed-off-by: AkiyamaYummy <[email protected]>

* Format the code.

Signed-off-by: AkiyamaYummy <[email protected]>

* Only print results when rank is 0.

Signed-off-by: AkiyamaYummy <[email protected]>

* Add dist.init_process_group().

Signed-off-by: AkiyamaYummy <[email protected]>

* update docs

Signed-off-by: AkiyamaYummy <[email protected]>

---------

Signed-off-by: AkiyamaYummy <[email protected]>

* Update cublasMMWrapper.cc

Fix the CUBLAS_VERSION checking of cublasMMWrapper

* Update cublasMMWrapper.cc

* fix overflow in softmax_kernel when process long seqlen and big batch_size (NVIDIA#524)

* Update unfused_attention_kernels.cu

fix bug of softmax kernel

* [Enhancement]create huggingface_gptneox_convert.py (NVIDIA#569)

* create huggingface_gptneox_convert.py

Signed-off-by: AkiyamaYummy <[email protected]>

* adjust HF's multi bin files

Signed-off-by: AkiyamaYummy <[email protected]>

* update gptneox_guide.md

Signed-off-by: AkiyamaYummy <[email protected]>

---------

Signed-off-by: AkiyamaYummy <[email protected]>

* perf(bloom): improve performance of huggingface_bloom_convert.py, decrease the time cost and the mem using (NVIDIA#568)

Co-authored-by: r.yang <[email protected]>

* Fix/gpt early stop (NVIDIA#584)

* fix: fix bug of early stopping of gpt

* [bugfix] Fix 2-shot All Reduce correctness issue (indexing bug). (NVIDIA#672)

FasterTransformer 2-shot all reduce is implemented as a reduce-scatter + all-gather. There is an indexing bug in the all-gather step. Prior to this change, 2-shot all reduce was only producing correct results on device 0. Now, all devices have the correct results.

* fix: swap tensor bug (NVIDIA#683)

* Support size_per_head=112 (NVIDIA#660)

* fix multi-gpu build

* add support for size_per_head=112 for gpt decoder

* remove mpi_cxx from multi-gpu build for now (NVIDIA#705)

---------

Signed-off-by: AkiyamaYummy <[email protected]>
Co-authored-by: byshiue <[email protected]>
Co-authored-by: _yummy_ <[email protected]>
Co-authored-by: Ying Sheng <[email protected]>
Co-authored-by: zhangxin81 <[email protected]>
Co-authored-by: 杨睿 <[email protected]>
Co-authored-by: r.yang <[email protected]>
Co-authored-by: Rahul Kindi <[email protected]>
Co-authored-by: Perkz Zheng <[email protected]>
Co-authored-by: Daya Khudia <[email protected]>
Co-authored-by: Dean Wyatte <[email protected]>
sfc-gh-zhwang added a commit to neevaco/FasterTransformer that referenced this pull request Oct 5, 2023
* Merge with main (#1)

* Update beam_search_topk_kernels.cu

fix: fix bug of beam search

* fix: change int of some kernels to int64_t to prevent overflow

* fix: gpt tensor shapes inconsistency (NVIDIA#505)

Signed-off-by: AkiyamaYummy <[email protected]>

* Update gpt_guide.md (NVIDIA#529)

* fix: fix bug of gpt buffer and gpt gemm overflow

* Update T5DecodingWeight.cc

fix: fix loading bug of t5

* [Enhancement]add pytorch backend support for gptneox (NVIDIA#550)

* add pytorch backend support for gptneox

Signed-off-by: AkiyamaYummy <[email protected]>

* fix early stopping invalid

* 1) Some unused parameters and logic have been removed. 2) Revisions that would affect pipeline parallelism have been reverted. 3) The code has been made capable of direct validation on TabbyML/NeoX-1.3B.

Signed-off-by: AkiyamaYummy <[email protected]>

* Change the names of classes, removing 'parallel' from their names

Signed-off-by: AkiyamaYummy <[email protected]>

* Format the code.

Signed-off-by: AkiyamaYummy <[email protected]>

* Only print results when rank is 0.

Signed-off-by: AkiyamaYummy <[email protected]>

* Add dist.init_process_group().

Signed-off-by: AkiyamaYummy <[email protected]>

* update docs

Signed-off-by: AkiyamaYummy <[email protected]>

---------

Signed-off-by: AkiyamaYummy <[email protected]>

* Update cublasMMWrapper.cc

Fix the CUBLAS_VERSION checking of cublasMMWrapper

* Update cublasMMWrapper.cc

* fix overflow in softmax_kernel when process long seqlen and big batch_size (NVIDIA#524)

* Update unfused_attention_kernels.cu

fix bug of softmax kernel

* [Enhancement]create huggingface_gptneox_convert.py (NVIDIA#569)

* create huggingface_gptneox_convert.py

Signed-off-by: AkiyamaYummy <[email protected]>

* adjust HF's multi bin files

Signed-off-by: AkiyamaYummy <[email protected]>

* update gptneox_guide.md

Signed-off-by: AkiyamaYummy <[email protected]>

---------

Signed-off-by: AkiyamaYummy <[email protected]>

* perf(bloom): improve performance of huggingface_bloom_convert.py, decrease the time cost and the mem using (NVIDIA#568)

Co-authored-by: r.yang <[email protected]>

* Fix/gpt early stop (NVIDIA#584)

* fix: fix bug of early stopping of gpt

* [bugfix] Fix 2-shot All Reduce correctness issue (indexing bug). (NVIDIA#672)

FasterTransformer 2-shot all reduce is implemented as a reduce-scatter + all-gather. There is an indexing bug in the all-gather step. Prior to this change, 2-shot all reduce was only producing correct results on device 0. Now, all devices have the correct results.

* fix: swap tensor bug (NVIDIA#683)

* Support size_per_head=112 (NVIDIA#660)

* fix multi-gpu build

* add support for size_per_head=112 for gpt decoder

* remove mpi_cxx from multi-gpu build for now (NVIDIA#705)

---------

Signed-off-by: AkiyamaYummy <[email protected]>
Co-authored-by: byshiue <[email protected]>
Co-authored-by: _yummy_ <[email protected]>
Co-authored-by: Ying Sheng <[email protected]>
Co-authored-by: zhangxin81 <[email protected]>
Co-authored-by: 杨睿 <[email protected]>
Co-authored-by: r.yang <[email protected]>
Co-authored-by: Rahul Kindi <[email protected]>
Co-authored-by: Perkz Zheng <[email protected]>
Co-authored-by: Daya Khudia <[email protected]>
Co-authored-by: Dean Wyatte <[email protected]>

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

---------

Signed-off-by: AkiyamaYummy <[email protected]>
Co-authored-by: Asim Shankar <[email protected]>
Co-authored-by: byshiue <[email protected]>
Co-authored-by: _yummy_ <[email protected]>
Co-authored-by: Ying Sheng <[email protected]>
Co-authored-by: zhangxin81 <[email protected]>
Co-authored-by: 杨睿 <[email protected]>
Co-authored-by: r.yang <[email protected]>
Co-authored-by: Rahul Kindi <[email protected]>
Co-authored-by: Perkz Zheng <[email protected]>
Co-authored-by: Daya Khudia <[email protected]>
Co-authored-by: Dean Wyatte <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants