-
Notifications
You must be signed in to change notification settings - Fork 894
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Enhancement]create huggingface_gptneox_convert.py #569
[Enhancement]create huggingface_gptneox_convert.py #569
Conversation
Signed-off-by: AkiyamaYummy <[email protected]>
Examples:Get HF model files:
Convert to the 1-gpu model files:
Convert to the 2-gpu model files(to use tensor parallel):
Run and validate 1-GPU model files:
Run and validate tensor parallel model files:
Results:
|
Hi, looking forward to you taking a look. I add several cases and hope to make it an easy-to-validate PR. 🤗🤗🤗 |
Signed-off-by: AkiyamaYummy <[email protected]>
Great work! If you’re interested, welcome to also send out a PR to tabby as well |
Thank you for the great work. Can you help adding some examples into the |
Signed-off-by: AkiyamaYummy <[email protected]>
@byshiue Updated. |
will this works for https://huggingface.co/stabilityai/stablelm-tuned-alpha-7b?, which seems same format as gptneox |
if the model can be loaded by GPTNeoXForCausalLM, the script can theoretically convert it. if not, I will consider adding this support when I have time. |
Hi, looking forward to you taking a look, again. 🤗🤗🤗 |
If there is anything else I need to change, I will update it as soon as possible. |
Thank you for the work. We are waiting the internal unit tests and it looks well now. |
@AkiyamaYummy
hello, thank you for the great work! I have a question about this convert script: why do we need scale bias here when tp > 1? |
FT will respectfully add the multi bias in multi GPUs when tp > 1, and all-reduce them afterward. If not scale the biases, it's equal to they will be added multi times. |
* Update beam_search_topk_kernels.cu fix: fix bug of beam search * fix: change int of some kernels to int64_t to prevent overflow * fix: gpt tensor shapes inconsistency (NVIDIA#505) Signed-off-by: AkiyamaYummy <[email protected]> * Update gpt_guide.md (NVIDIA#529) * fix: fix bug of gpt buffer and gpt gemm overflow * Update T5DecodingWeight.cc fix: fix loading bug of t5 * [Enhancement]add pytorch backend support for gptneox (NVIDIA#550) * add pytorch backend support for gptneox Signed-off-by: AkiyamaYummy <[email protected]> * fix early stopping invalid * 1) Some unused parameters and logic have been removed. 2) Revisions that would affect pipeline parallelism have been reverted. 3) The code has been made capable of direct validation on TabbyML/NeoX-1.3B. Signed-off-by: AkiyamaYummy <[email protected]> * Change the names of classes, removing 'parallel' from their names Signed-off-by: AkiyamaYummy <[email protected]> * Format the code. Signed-off-by: AkiyamaYummy <[email protected]> * Only print results when rank is 0. Signed-off-by: AkiyamaYummy <[email protected]> * Add dist.init_process_group(). Signed-off-by: AkiyamaYummy <[email protected]> * update docs Signed-off-by: AkiyamaYummy <[email protected]> --------- Signed-off-by: AkiyamaYummy <[email protected]> * Update cublasMMWrapper.cc Fix the CUBLAS_VERSION checking of cublasMMWrapper * Update cublasMMWrapper.cc * fix overflow in softmax_kernel when process long seqlen and big batch_size (NVIDIA#524) * Update unfused_attention_kernels.cu fix bug of softmax kernel * [Enhancement]create huggingface_gptneox_convert.py (NVIDIA#569) * create huggingface_gptneox_convert.py Signed-off-by: AkiyamaYummy <[email protected]> * adjust HF's multi bin files Signed-off-by: AkiyamaYummy <[email protected]> * update gptneox_guide.md Signed-off-by: AkiyamaYummy <[email protected]> --------- Signed-off-by: AkiyamaYummy <[email protected]> * perf(bloom): improve performance of huggingface_bloom_convert.py, decrease the time cost and the mem using (NVIDIA#568) Co-authored-by: r.yang <[email protected]> * Fix/gpt early stop (NVIDIA#584) * fix: fix bug of early stopping of gpt * [bugfix] Fix 2-shot All Reduce correctness issue (indexing bug). (NVIDIA#672) FasterTransformer 2-shot all reduce is implemented as a reduce-scatter + all-gather. There is an indexing bug in the all-gather step. Prior to this change, 2-shot all reduce was only producing correct results on device 0. Now, all devices have the correct results. * fix: swap tensor bug (NVIDIA#683) * Support size_per_head=112 (NVIDIA#660) * fix multi-gpu build * add support for size_per_head=112 for gpt decoder * remove mpi_cxx from multi-gpu build for now (NVIDIA#705) --------- Signed-off-by: AkiyamaYummy <[email protected]> Co-authored-by: byshiue <[email protected]> Co-authored-by: _yummy_ <[email protected]> Co-authored-by: Ying Sheng <[email protected]> Co-authored-by: zhangxin81 <[email protected]> Co-authored-by: 杨睿 <[email protected]> Co-authored-by: r.yang <[email protected]> Co-authored-by: Rahul Kindi <[email protected]> Co-authored-by: Perkz Zheng <[email protected]> Co-authored-by: Daya Khudia <[email protected]> Co-authored-by: Dean Wyatte <[email protected]>
* Merge with main (#1) * Update beam_search_topk_kernels.cu fix: fix bug of beam search * fix: change int of some kernels to int64_t to prevent overflow * fix: gpt tensor shapes inconsistency (NVIDIA#505) Signed-off-by: AkiyamaYummy <[email protected]> * Update gpt_guide.md (NVIDIA#529) * fix: fix bug of gpt buffer and gpt gemm overflow * Update T5DecodingWeight.cc fix: fix loading bug of t5 * [Enhancement]add pytorch backend support for gptneox (NVIDIA#550) * add pytorch backend support for gptneox Signed-off-by: AkiyamaYummy <[email protected]> * fix early stopping invalid * 1) Some unused parameters and logic have been removed. 2) Revisions that would affect pipeline parallelism have been reverted. 3) The code has been made capable of direct validation on TabbyML/NeoX-1.3B. Signed-off-by: AkiyamaYummy <[email protected]> * Change the names of classes, removing 'parallel' from their names Signed-off-by: AkiyamaYummy <[email protected]> * Format the code. Signed-off-by: AkiyamaYummy <[email protected]> * Only print results when rank is 0. Signed-off-by: AkiyamaYummy <[email protected]> * Add dist.init_process_group(). Signed-off-by: AkiyamaYummy <[email protected]> * update docs Signed-off-by: AkiyamaYummy <[email protected]> --------- Signed-off-by: AkiyamaYummy <[email protected]> * Update cublasMMWrapper.cc Fix the CUBLAS_VERSION checking of cublasMMWrapper * Update cublasMMWrapper.cc * fix overflow in softmax_kernel when process long seqlen and big batch_size (NVIDIA#524) * Update unfused_attention_kernels.cu fix bug of softmax kernel * [Enhancement]create huggingface_gptneox_convert.py (NVIDIA#569) * create huggingface_gptneox_convert.py Signed-off-by: AkiyamaYummy <[email protected]> * adjust HF's multi bin files Signed-off-by: AkiyamaYummy <[email protected]> * update gptneox_guide.md Signed-off-by: AkiyamaYummy <[email protected]> --------- Signed-off-by: AkiyamaYummy <[email protected]> * perf(bloom): improve performance of huggingface_bloom_convert.py, decrease the time cost and the mem using (NVIDIA#568) Co-authored-by: r.yang <[email protected]> * Fix/gpt early stop (NVIDIA#584) * fix: fix bug of early stopping of gpt * [bugfix] Fix 2-shot All Reduce correctness issue (indexing bug). (NVIDIA#672) FasterTransformer 2-shot all reduce is implemented as a reduce-scatter + all-gather. There is an indexing bug in the all-gather step. Prior to this change, 2-shot all reduce was only producing correct results on device 0. Now, all devices have the correct results. * fix: swap tensor bug (NVIDIA#683) * Support size_per_head=112 (NVIDIA#660) * fix multi-gpu build * add support for size_per_head=112 for gpt decoder * remove mpi_cxx from multi-gpu build for now (NVIDIA#705) --------- Signed-off-by: AkiyamaYummy <[email protected]> Co-authored-by: byshiue <[email protected]> Co-authored-by: _yummy_ <[email protected]> Co-authored-by: Ying Sheng <[email protected]> Co-authored-by: zhangxin81 <[email protected]> Co-authored-by: 杨睿 <[email protected]> Co-authored-by: r.yang <[email protected]> Co-authored-by: Rahul Kindi <[email protected]> Co-authored-by: Perkz Zheng <[email protected]> Co-authored-by: Daya Khudia <[email protected]> Co-authored-by: Dean Wyatte <[email protected]> * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit --------- Signed-off-by: AkiyamaYummy <[email protected]> Co-authored-by: Asim Shankar <[email protected]> Co-authored-by: byshiue <[email protected]> Co-authored-by: _yummy_ <[email protected]> Co-authored-by: Ying Sheng <[email protected]> Co-authored-by: zhangxin81 <[email protected]> Co-authored-by: 杨睿 <[email protected]> Co-authored-by: r.yang <[email protected]> Co-authored-by: Rahul Kindi <[email protected]> Co-authored-by: Perkz Zheng <[email protected]> Co-authored-by: Daya Khudia <[email protected]> Co-authored-by: Dean Wyatte <[email protected]>
Let the gptneox HF model file can convert to FT model.
Mainly based on tabby's script, thanks to @wsxiaoys.
But original script got error using tensor parallel.(link)
I fixed the problem and hope to open this script for everyone to use.