-
Notifications
You must be signed in to change notification settings - Fork 895
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[fix] fix overflow in softmax_kernel when process long seqlen and big batch… #524
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Thank you for the fixing. We are checking where this kernel may affect and need more time. |
sfc-gh-ashankar
added a commit
to neevaco/FasterTransformer
that referenced
this pull request
Jul 11, 2023
* Update beam_search_topk_kernels.cu fix: fix bug of beam search * fix: change int of some kernels to int64_t to prevent overflow * fix: gpt tensor shapes inconsistency (NVIDIA#505) Signed-off-by: AkiyamaYummy <[email protected]> * Update gpt_guide.md (NVIDIA#529) * fix: fix bug of gpt buffer and gpt gemm overflow * Update T5DecodingWeight.cc fix: fix loading bug of t5 * [Enhancement]add pytorch backend support for gptneox (NVIDIA#550) * add pytorch backend support for gptneox Signed-off-by: AkiyamaYummy <[email protected]> * fix early stopping invalid * 1) Some unused parameters and logic have been removed. 2) Revisions that would affect pipeline parallelism have been reverted. 3) The code has been made capable of direct validation on TabbyML/NeoX-1.3B. Signed-off-by: AkiyamaYummy <[email protected]> * Change the names of classes, removing 'parallel' from their names Signed-off-by: AkiyamaYummy <[email protected]> * Format the code. Signed-off-by: AkiyamaYummy <[email protected]> * Only print results when rank is 0. Signed-off-by: AkiyamaYummy <[email protected]> * Add dist.init_process_group(). Signed-off-by: AkiyamaYummy <[email protected]> * update docs Signed-off-by: AkiyamaYummy <[email protected]> --------- Signed-off-by: AkiyamaYummy <[email protected]> * Update cublasMMWrapper.cc Fix the CUBLAS_VERSION checking of cublasMMWrapper * Update cublasMMWrapper.cc * fix overflow in softmax_kernel when process long seqlen and big batch_size (NVIDIA#524) * Update unfused_attention_kernels.cu fix bug of softmax kernel * [Enhancement]create huggingface_gptneox_convert.py (NVIDIA#569) * create huggingface_gptneox_convert.py Signed-off-by: AkiyamaYummy <[email protected]> * adjust HF's multi bin files Signed-off-by: AkiyamaYummy <[email protected]> * update gptneox_guide.md Signed-off-by: AkiyamaYummy <[email protected]> --------- Signed-off-by: AkiyamaYummy <[email protected]> * perf(bloom): improve performance of huggingface_bloom_convert.py, decrease the time cost and the mem using (NVIDIA#568) Co-authored-by: r.yang <[email protected]> * Fix/gpt early stop (NVIDIA#584) * fix: fix bug of early stopping of gpt * [bugfix] Fix 2-shot All Reduce correctness issue (indexing bug). (NVIDIA#672) FasterTransformer 2-shot all reduce is implemented as a reduce-scatter + all-gather. There is an indexing bug in the all-gather step. Prior to this change, 2-shot all reduce was only producing correct results on device 0. Now, all devices have the correct results. * fix: swap tensor bug (NVIDIA#683) * Support size_per_head=112 (NVIDIA#660) * fix multi-gpu build * add support for size_per_head=112 for gpt decoder * remove mpi_cxx from multi-gpu build for now (NVIDIA#705) --------- Signed-off-by: AkiyamaYummy <[email protected]> Co-authored-by: byshiue <[email protected]> Co-authored-by: _yummy_ <[email protected]> Co-authored-by: Ying Sheng <[email protected]> Co-authored-by: zhangxin81 <[email protected]> Co-authored-by: 杨睿 <[email protected]> Co-authored-by: r.yang <[email protected]> Co-authored-by: Rahul Kindi <[email protected]> Co-authored-by: Perkz Zheng <[email protected]> Co-authored-by: Daya Khudia <[email protected]> Co-authored-by: Dean Wyatte <[email protected]>
sfc-gh-zhwang
added a commit
to neevaco/FasterTransformer
that referenced
this pull request
Oct 5, 2023
* Merge with main (#1) * Update beam_search_topk_kernels.cu fix: fix bug of beam search * fix: change int of some kernels to int64_t to prevent overflow * fix: gpt tensor shapes inconsistency (NVIDIA#505) Signed-off-by: AkiyamaYummy <[email protected]> * Update gpt_guide.md (NVIDIA#529) * fix: fix bug of gpt buffer and gpt gemm overflow * Update T5DecodingWeight.cc fix: fix loading bug of t5 * [Enhancement]add pytorch backend support for gptneox (NVIDIA#550) * add pytorch backend support for gptneox Signed-off-by: AkiyamaYummy <[email protected]> * fix early stopping invalid * 1) Some unused parameters and logic have been removed. 2) Revisions that would affect pipeline parallelism have been reverted. 3) The code has been made capable of direct validation on TabbyML/NeoX-1.3B. Signed-off-by: AkiyamaYummy <[email protected]> * Change the names of classes, removing 'parallel' from their names Signed-off-by: AkiyamaYummy <[email protected]> * Format the code. Signed-off-by: AkiyamaYummy <[email protected]> * Only print results when rank is 0. Signed-off-by: AkiyamaYummy <[email protected]> * Add dist.init_process_group(). Signed-off-by: AkiyamaYummy <[email protected]> * update docs Signed-off-by: AkiyamaYummy <[email protected]> --------- Signed-off-by: AkiyamaYummy <[email protected]> * Update cublasMMWrapper.cc Fix the CUBLAS_VERSION checking of cublasMMWrapper * Update cublasMMWrapper.cc * fix overflow in softmax_kernel when process long seqlen and big batch_size (NVIDIA#524) * Update unfused_attention_kernels.cu fix bug of softmax kernel * [Enhancement]create huggingface_gptneox_convert.py (NVIDIA#569) * create huggingface_gptneox_convert.py Signed-off-by: AkiyamaYummy <[email protected]> * adjust HF's multi bin files Signed-off-by: AkiyamaYummy <[email protected]> * update gptneox_guide.md Signed-off-by: AkiyamaYummy <[email protected]> --------- Signed-off-by: AkiyamaYummy <[email protected]> * perf(bloom): improve performance of huggingface_bloom_convert.py, decrease the time cost and the mem using (NVIDIA#568) Co-authored-by: r.yang <[email protected]> * Fix/gpt early stop (NVIDIA#584) * fix: fix bug of early stopping of gpt * [bugfix] Fix 2-shot All Reduce correctness issue (indexing bug). (NVIDIA#672) FasterTransformer 2-shot all reduce is implemented as a reduce-scatter + all-gather. There is an indexing bug in the all-gather step. Prior to this change, 2-shot all reduce was only producing correct results on device 0. Now, all devices have the correct results. * fix: swap tensor bug (NVIDIA#683) * Support size_per_head=112 (NVIDIA#660) * fix multi-gpu build * add support for size_per_head=112 for gpt decoder * remove mpi_cxx from multi-gpu build for now (NVIDIA#705) --------- Signed-off-by: AkiyamaYummy <[email protected]> Co-authored-by: byshiue <[email protected]> Co-authored-by: _yummy_ <[email protected]> Co-authored-by: Ying Sheng <[email protected]> Co-authored-by: zhangxin81 <[email protected]> Co-authored-by: 杨睿 <[email protected]> Co-authored-by: r.yang <[email protected]> Co-authored-by: Rahul Kindi <[email protected]> Co-authored-by: Perkz Zheng <[email protected]> Co-authored-by: Daya Khudia <[email protected]> Co-authored-by: Dean Wyatte <[email protected]> * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit * commit --------- Signed-off-by: AkiyamaYummy <[email protected]> Co-authored-by: Asim Shankar <[email protected]> Co-authored-by: byshiue <[email protected]> Co-authored-by: _yummy_ <[email protected]> Co-authored-by: Ying Sheng <[email protected]> Co-authored-by: zhangxin81 <[email protected]> Co-authored-by: 杨睿 <[email protected]> Co-authored-by: r.yang <[email protected]> Co-authored-by: Rahul Kindi <[email protected]> Co-authored-by: Perkz Zheng <[email protected]> Co-authored-by: Daya Khudia <[email protected]> Co-authored-by: Dean Wyatte <[email protected]>
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
when process query of 2000 seqlen and 20 batch_size with model of 32 head num,
qkoffset is up to 32 * 20 * 2000 * 2000 which int32 can not handle.
so, int is replaced by int64_t to avoid overflow.