Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement]add pytorch backend support for gptneox #550

Merged
merged 8 commits into from
Apr 18, 2023

Conversation

zhang-ge-hao
Copy link
Contributor

GPT-Neox is good. Hope all PyTorch users can easily use it.

@zhang-ge-hao
Copy link
Contributor Author

@byshiue

Looking forward to your review of this PR. Thank you 😋😋

@@ -150,7 +150,7 @@ void invokeLengthCriterion(bool* finished,

length_criterion<<<grid, block, 0, stream>>>(
finished, should_stop, h_pinned_finished_sum_, sequence_limit_length, batch_size, beam_width, step);
while (((volatile size_t*)h_pinned_finished_sum_)[0] == -1) {};
while (((volatile int*)h_pinned_finished_sum_)[0] == -1) {};
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please don't change this code because it would lead to dead lock under pipeline parallelism.

@@ -0,0 +1,17 @@
# Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please rename the folder to gptneox directly, don't add multi_gpu.
GPT has such prefix because we have separate implementation for single gpu and multi gpu in the past.

@@ -0,0 +1,171 @@
/*
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same to above, don't use Parallelxxx because we fuse multi gpu and single gpu in one implementation now.

@byshiue
Copy link
Collaborator

byshiue commented Apr 10, 2023

Please check the format by the .clang-format in root path.

if (return_cum_log_probs == 2) {
return_context_cum_log_probs = true;
input_tensors.insert(
{"is_return_context_cum_log_probs",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check the GptNeoXTritonModelInstance to make sure what inputs are supported. For example, is_return_context_cum_log_probs is only supported in GPT and is not supported in Gpt Neox.

@byshiue
Copy link
Collaborator

byshiue commented Apr 10, 2023

Please help sharing a result of running example, making sure it works well.

@@ -0,0 +1,198 @@
# Copyright (c) 2021-2023, NVIDIA CORPORATION. All rights reserved.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please rename to gptneox_example.py.

@@ -28,7 +28,7 @@ namespace fastertransformer {
template<typename T>
struct GptNeoXDecoderLayerWeight {
public:
GptNeoXDecoderLayerWeight() = delete;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if any assumption (around the initialization status) breaks here however it increases the complexity of the ways construct GptNeoXDecoderLayerWeight.

…hat would affect pipeline parallelism have been reverted. 3) The code has been made capable of direct validation on TabbyML/NeoX-1.3B.

Signed-off-by: AkiyamaYummy <[email protected]>
Signed-off-by: AkiyamaYummy <[email protected]>
@zhang-ge-hao
Copy link
Contributor Author

zhang-ge-hao commented Apr 15, 2023

@byshiue

Hi, I updated my code.

  1. Some unused parameters and logic have been removed.
  2. Revisions that would affect pipeline parallelism have been reverted.
  3. The code has been made capable of direct validation on TabbyML/NeoX-1.3B.
  4. Rename the classes and format my code.
  5. GptNeoX model in FT does not support bfloat16 originally. And I did not add support for bf16 on this basis, because my development environment seems unable to test bf16 code.

ps: The TabbyML/NeoX-1.3B model can be downloaded at https://huggingface.co/TabbyML/NeoX-1.3B (also contains the FT model files).

Some results of gptneox_example:

Demo prompts in the file "gptneox_input" can also be seen in the output logs.

(base) root@f0305219ab2b:/workspace/FasterTransformer/build# python ../examples/pytorch/gptneox/gptneox_example.py --sample_input_file gptneox_input --time

=============== Arguments ===============
output_len: 32
beam_width: 1
top_k: 1
top_p: 0.0
temperature: 1.0
len_penalty: 0.0
beam_search_diversity_rate: 0.0
tensor_para_size: 1
pipeline_para_size: 1
ckpt_path: ../models/gptneox/c-model/NeoX-1.3B/1-gpu
tokenizer_path: ../models/gptneox/model/NeoX-1.3B
lib_path: ./lib/libth_transformer.so
sample_input_file: gptneox_input
max_batch_size: 8
repetition_penalty: 1.0
max_seq_len: 1024
inference_data_type: fp16
time: True
enable_random_seed: False
=========================================

[INFO] batch size: 2
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
        - Avoid using `tokenizers` before the fork if possible
        - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
[WARNING] gemm_config.in is not found; using default GEMM algo
[FT][WARNING] Skip NCCL initialization since requested tensor/pipeline parallel sizes are equals to 1.
[INFO] batch 0, beam 0:
[Context]
Hello,

[Output]
 I'm a newbie. I'm trying to install Ubuntu on my laptop. I have a Dell Inspiron 1525 with Windows 7. I have a USB<|endoftext|><|endoftext|>

[INFO] batch 1, beam 0:
[Context]
Gama start,

[Output]
 and the first of the two-day event.

The first day of the Gama start was a bit of a letdown. The first few laps

[INFO] FT-GPT generates 10 batches, taking 1.804 secs to generate 660 tokens, 365.939 tokens/sec.
root@f0305219ab2b:/workspace/FasterTransformer/build# python ../examples/pytorch/gptneox/gptneox_example.py --top_k 50 --enable_random_seed --time --sample_input_file gptneox_input

=============== Arguments ===============
output_len: 32
beam_width: 1
top_k: 50
top_p: 0.0
temperature: 1.0
len_penalty: 0.0
beam_search_diversity_rate: 0.0
tensor_para_size: 1
pipeline_para_size: 1
ckpt_path: ../models/gptneox/c-model/NeoX-1.3B/1-gpu
tokenizer_path: ../models/gptneox/model/NeoX-1.3B
lib_path: ./lib/libth_transformer.so
sample_input_file: gptneox_input
max_batch_size: 8
repetition_penalty: 1.0
max_seq_len: 1024
inference_data_type: fp16
time: True
enable_random_seed: True
=========================================

[INFO] batch size: 2
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
        - Avoid using `tokenizers` before the fork if possible
        - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
[WARNING] gemm_config.in is not found; using default GEMM algo
[FT][WARNING] Skip NCCL initialization since requested tensor/pipeline parallel sizes are equals to 1.
[INFO] batch 0, beam 0:
[Context]
Hello,

[Output]
 there." "This is just in the mail." "I just opened it." "Yeah." "Just in the mail." "Just in the mail." "<|endoftext|><|endoftext|>

[INFO] batch 1, beam 0:
[Context]
Gama start,

[Output]
 or maybe it could be a new version. Any idea on this?

ToyBox will most probably go live in late August, or early September 2008

[INFO] FT-GPT generates 10 batches, taking 1.791 secs to generate 660 tokens, 368.585 tokens/sec.
root@f0305219ab2b:/workspace/FasterTransformer/build# python ../examples/pytorch/gptneox/gptneox_example.py --beam_width 2 --time --sample_input_file gptneox_input

=============== Arguments ===============
output_len: 32
beam_width: 2
top_k: 1
top_p: 0.0
temperature: 1.0
len_penalty: 0.0
beam_search_diversity_rate: 0.0
tensor_para_size: 1
pipeline_para_size: 1
ckpt_path: ../models/gptneox/c-model/NeoX-1.3B/1-gpu
tokenizer_path: ../models/gptneox/model/NeoX-1.3B
lib_path: ./lib/libth_transformer.so
sample_input_file: gptneox_input
max_batch_size: 8
repetition_penalty: 1.0
max_seq_len: 1024
inference_data_type: fp16
time: True
enable_random_seed: False
=========================================

[INFO] batch size: 2
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
        - Avoid using `tokenizers` before the fork if possible
        - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
[WARNING] gemm_config.in is not found; using default GEMM algo
[FT][WARNING] Skip NCCL initialization since requested tensor/pipeline parallel sizes are equals to 1.
[INFO] batch 0, beam 0:
[Context]
Hello,

[Output]
 I'm not sure if this is the right place to ask this, but I'm looking for a way to get a list of all the files in a directory<|endoftext|><|endoftext|>

[INFO] batch 0, beam 1:
[Context]
Hello,

[Output]
 I'm not sure if this is the right place to ask this, but I'm looking for a way to get a list of all the files in a folder<|endoftext|><|endoftext|>

[INFO] batch 1, beam 0:
[Context]
Gama start,

[Output]
 and then the rest of the world.

I’m not sure if it’s a good thing or a bad thing, but I’m glad

[INFO] batch 1, beam 1:
[Context]
Gama start,

[Output]
 and then the rest of the world.

I’m not sure if it’s a good thing or a bad thing, but I’ve been

[INFO] FT-GPT generates 10 batches, taking 1.840 secs to generate 660 tokens, 358.770 tokens/sec.
root@f0305219ab2b:/workspace/FasterTransformer/build# mpirun -n 2 --allow-run-as-root python ../examples/pytorch/gptneox/gptneox_example.py --pipeline_para_size 2 --time --sample_input_file gptneox_input

=============== Arguments ===============
output_len: 32
beam_width: 1
top_k: 1
top_p: 0.0
temperature: 1.0
len_penalty: 0.0
beam_search_diversity_rate: 0.0
tensor_para_size: 1
pipeline_para_size: 2
ckpt_path: ../models/gptneox/c-model/NeoX-1.3B/1-gpu
tokenizer_path: ../models/gptneox/model/NeoX-1.3B
lib_path: ./lib/libth_transformer.so
sample_input_file: gptneox_input
max_batch_size: 8
repetition_penalty: 1.0
max_seq_len: 1024
inference_data_type: fp16
time: True
enable_random_seed: False
=========================================


=============== Arguments ===============
output_len: 32
beam_width: 1
top_k: 1
top_p: 0.0
temperature: 1.0
len_penalty: 0.0
beam_search_diversity_rate: 0.0
tensor_para_size: 1
pipeline_para_size: 2
ckpt_path: ../models/gptneox/c-model/NeoX-1.3B/1-gpu
tokenizer_path: ../models/gptneox/model/NeoX-1.3B
lib_path: ./lib/libth_transformer.so
sample_input_file: gptneox_input
max_batch_size: 8
repetition_penalty: 1.0
max_seq_len: 1024
inference_data_type: fp16
time: True
enable_random_seed: False
=========================================

[INFO] batch size: 2
[INFO] batch size: 2
[INFO] WARNING: Have initialized the process group
[INFO] WARNING: Have initialized the process group
[WARNING] gemm_config.in is not found; using default GEMM algo
[WARNING] gemm_config.in is not found; using default GEMM algo
[FT][INFO] NCCL initialized rank=1 world_size=2 tensor_para=NcclParam[rank=0, world_size=1, nccl_comm=0x5650969d38b0] pipeline_para=NcclParam[rank=1, world_size=2, nccl_comm=0x565096d1ce60]
[FT][INFO] NCCL initialized rank=0 world_size=2 tensor_para=NcclParam[rank=0, world_size=1, nccl_comm=0x5603073f7890] pipeline_para=NcclParam[rank=0, world_size=2, nccl_comm=0x560307740c90]
[INFO] batch 0, beam 0:
[Context]
Hello,

[Output]
 I'm a newbie. I'm trying to install Ubuntu on my laptop. I have a Dell Inspiron 1525 with Windows 7. I have a USB<|endoftext|><|endoftext|>

[INFO] batch 1, beam 0:
[Context]
Gama start,

[Output]
 and the first of the two-day event.

The first day of the Gama start was a bit of a letdown. The first few laps

[INFO] FT-GPT generates 10 batches, taking 3.703 secs to generate 660 tokens, 178.245 tokens/sec.
[INFO] FT-GPT generates 10 batches, taking 3.703 secs to generate 660 tokens, 178.239 tokens/sec.

@zhang-ge-hao
Copy link
Contributor Author

@byshiue

Hi, looking forward to your review of this PR, again. Thank you 😋😋

Last week I was too busy with work, so I modified the code on the weekend.

If anything else needs to be added, you can contact me at any time.

@byshiue
Copy link
Collaborator

byshiue commented Apr 17, 2023

Thank you. I will take a look. Besides, please update the document about the pytorch example usage in gptneox_guide.md.

Signed-off-by: AkiyamaYummy <[email protected]>
@zhang-ge-hao
Copy link
Contributor Author

Thank you. I will take a look. Besides, please update the document about the pytorch example usage in gptneox_guide.md.

Updated.

@byshiue byshiue merged commit 169b8df into NVIDIA:main Apr 18, 2023
sfc-gh-ashankar added a commit to neevaco/FasterTransformer that referenced this pull request Jul 11, 2023
* Update beam_search_topk_kernels.cu

fix: fix bug of beam search

* fix: change int of some kernels to int64_t to prevent overflow

* fix: gpt tensor shapes inconsistency (NVIDIA#505)

Signed-off-by: AkiyamaYummy <[email protected]>

* Update gpt_guide.md (NVIDIA#529)

* fix: fix bug of gpt buffer and gpt gemm overflow

* Update T5DecodingWeight.cc

fix: fix loading bug of t5

* [Enhancement]add pytorch backend support for gptneox (NVIDIA#550)

* add pytorch backend support for gptneox

Signed-off-by: AkiyamaYummy <[email protected]>

* fix early stopping invalid

* 1) Some unused parameters and logic have been removed. 2) Revisions that would affect pipeline parallelism have been reverted. 3) The code has been made capable of direct validation on TabbyML/NeoX-1.3B.

Signed-off-by: AkiyamaYummy <[email protected]>

* Change the names of classes, removing 'parallel' from their names

Signed-off-by: AkiyamaYummy <[email protected]>

* Format the code.

Signed-off-by: AkiyamaYummy <[email protected]>

* Only print results when rank is 0.

Signed-off-by: AkiyamaYummy <[email protected]>

* Add dist.init_process_group().

Signed-off-by: AkiyamaYummy <[email protected]>

* update docs

Signed-off-by: AkiyamaYummy <[email protected]>

---------

Signed-off-by: AkiyamaYummy <[email protected]>

* Update cublasMMWrapper.cc

Fix the CUBLAS_VERSION checking of cublasMMWrapper

* Update cublasMMWrapper.cc

* fix overflow in softmax_kernel when process long seqlen and big batch_size (NVIDIA#524)

* Update unfused_attention_kernels.cu

fix bug of softmax kernel

* [Enhancement]create huggingface_gptneox_convert.py (NVIDIA#569)

* create huggingface_gptneox_convert.py

Signed-off-by: AkiyamaYummy <[email protected]>

* adjust HF's multi bin files

Signed-off-by: AkiyamaYummy <[email protected]>

* update gptneox_guide.md

Signed-off-by: AkiyamaYummy <[email protected]>

---------

Signed-off-by: AkiyamaYummy <[email protected]>

* perf(bloom): improve performance of huggingface_bloom_convert.py, decrease the time cost and the mem using (NVIDIA#568)

Co-authored-by: r.yang <[email protected]>

* Fix/gpt early stop (NVIDIA#584)

* fix: fix bug of early stopping of gpt

* [bugfix] Fix 2-shot All Reduce correctness issue (indexing bug). (NVIDIA#672)

FasterTransformer 2-shot all reduce is implemented as a reduce-scatter + all-gather. There is an indexing bug in the all-gather step. Prior to this change, 2-shot all reduce was only producing correct results on device 0. Now, all devices have the correct results.

* fix: swap tensor bug (NVIDIA#683)

* Support size_per_head=112 (NVIDIA#660)

* fix multi-gpu build

* add support for size_per_head=112 for gpt decoder

* remove mpi_cxx from multi-gpu build for now (NVIDIA#705)

---------

Signed-off-by: AkiyamaYummy <[email protected]>
Co-authored-by: byshiue <[email protected]>
Co-authored-by: _yummy_ <[email protected]>
Co-authored-by: Ying Sheng <[email protected]>
Co-authored-by: zhangxin81 <[email protected]>
Co-authored-by: 杨睿 <[email protected]>
Co-authored-by: r.yang <[email protected]>
Co-authored-by: Rahul Kindi <[email protected]>
Co-authored-by: Perkz Zheng <[email protected]>
Co-authored-by: Daya Khudia <[email protected]>
Co-authored-by: Dean Wyatte <[email protected]>
sfc-gh-zhwang added a commit to neevaco/FasterTransformer that referenced this pull request Oct 5, 2023
* Merge with main (#1)

* Update beam_search_topk_kernels.cu

fix: fix bug of beam search

* fix: change int of some kernels to int64_t to prevent overflow

* fix: gpt tensor shapes inconsistency (NVIDIA#505)

Signed-off-by: AkiyamaYummy <[email protected]>

* Update gpt_guide.md (NVIDIA#529)

* fix: fix bug of gpt buffer and gpt gemm overflow

* Update T5DecodingWeight.cc

fix: fix loading bug of t5

* [Enhancement]add pytorch backend support for gptneox (NVIDIA#550)

* add pytorch backend support for gptneox

Signed-off-by: AkiyamaYummy <[email protected]>

* fix early stopping invalid

* 1) Some unused parameters and logic have been removed. 2) Revisions that would affect pipeline parallelism have been reverted. 3) The code has been made capable of direct validation on TabbyML/NeoX-1.3B.

Signed-off-by: AkiyamaYummy <[email protected]>

* Change the names of classes, removing 'parallel' from their names

Signed-off-by: AkiyamaYummy <[email protected]>

* Format the code.

Signed-off-by: AkiyamaYummy <[email protected]>

* Only print results when rank is 0.

Signed-off-by: AkiyamaYummy <[email protected]>

* Add dist.init_process_group().

Signed-off-by: AkiyamaYummy <[email protected]>

* update docs

Signed-off-by: AkiyamaYummy <[email protected]>

---------

Signed-off-by: AkiyamaYummy <[email protected]>

* Update cublasMMWrapper.cc

Fix the CUBLAS_VERSION checking of cublasMMWrapper

* Update cublasMMWrapper.cc

* fix overflow in softmax_kernel when process long seqlen and big batch_size (NVIDIA#524)

* Update unfused_attention_kernels.cu

fix bug of softmax kernel

* [Enhancement]create huggingface_gptneox_convert.py (NVIDIA#569)

* create huggingface_gptneox_convert.py

Signed-off-by: AkiyamaYummy <[email protected]>

* adjust HF's multi bin files

Signed-off-by: AkiyamaYummy <[email protected]>

* update gptneox_guide.md

Signed-off-by: AkiyamaYummy <[email protected]>

---------

Signed-off-by: AkiyamaYummy <[email protected]>

* perf(bloom): improve performance of huggingface_bloom_convert.py, decrease the time cost and the mem using (NVIDIA#568)

Co-authored-by: r.yang <[email protected]>

* Fix/gpt early stop (NVIDIA#584)

* fix: fix bug of early stopping of gpt

* [bugfix] Fix 2-shot All Reduce correctness issue (indexing bug). (NVIDIA#672)

FasterTransformer 2-shot all reduce is implemented as a reduce-scatter + all-gather. There is an indexing bug in the all-gather step. Prior to this change, 2-shot all reduce was only producing correct results on device 0. Now, all devices have the correct results.

* fix: swap tensor bug (NVIDIA#683)

* Support size_per_head=112 (NVIDIA#660)

* fix multi-gpu build

* add support for size_per_head=112 for gpt decoder

* remove mpi_cxx from multi-gpu build for now (NVIDIA#705)

---------

Signed-off-by: AkiyamaYummy <[email protected]>
Co-authored-by: byshiue <[email protected]>
Co-authored-by: _yummy_ <[email protected]>
Co-authored-by: Ying Sheng <[email protected]>
Co-authored-by: zhangxin81 <[email protected]>
Co-authored-by: 杨睿 <[email protected]>
Co-authored-by: r.yang <[email protected]>
Co-authored-by: Rahul Kindi <[email protected]>
Co-authored-by: Perkz Zheng <[email protected]>
Co-authored-by: Daya Khudia <[email protected]>
Co-authored-by: Dean Wyatte <[email protected]>

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

* commit

---------

Signed-off-by: AkiyamaYummy <[email protected]>
Co-authored-by: Asim Shankar <[email protected]>
Co-authored-by: byshiue <[email protected]>
Co-authored-by: _yummy_ <[email protected]>
Co-authored-by: Ying Sheng <[email protected]>
Co-authored-by: zhangxin81 <[email protected]>
Co-authored-by: 杨睿 <[email protected]>
Co-authored-by: r.yang <[email protected]>
Co-authored-by: Rahul Kindi <[email protected]>
Co-authored-by: Perkz Zheng <[email protected]>
Co-authored-by: Daya Khudia <[email protected]>
Co-authored-by: Dean Wyatte <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants