From e8384260b709b10459b138b22af42c3ff82d5c9e Mon Sep 17 00:00:00 2001 From: _yummy_ <842720660@qq.com> Date: Fri, 17 Mar 2023 11:37:12 +0800 Subject: [PATCH] fix: gpt tensor shapes inconsistency (#505) Signed-off-by: AkiyamaYummy <842720660@qq.com> --- .../models/multi_gpu_gpt/ParallelGpt.cc | 44 +++++++++---------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.cc b/src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.cc index 2b9e4f3c4..ad9c3527b 100644 --- a/src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.cc +++ b/src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.cc @@ -184,9 +184,9 @@ void ParallelGpt::allocateBuffer(size_t batch_size, lp_logprob_buf_ = (float*)allocator_->reMalloc(lp_logprob_buf_, sizeof(float) * batchxbeam * max_input_len); } if (shared_contexts_ratio_ > 0.0f) { - shared_contexts_idx_ = (int*)allocator_->reMalloc(shared_contexts_idx_, 3 * batch_size * sizeof(int), false); - batch_to_compact_idx_ = shared_contexts_idx_ + batch_size; - compact_idx_ = shared_contexts_idx_ + 2 * batch_size; + shared_contexts_idx_ = (int*)allocator_->reMalloc(shared_contexts_idx_, 3 * batchxbeam * sizeof(int), false); + batch_to_compact_idx_ = shared_contexts_idx_ + batchxbeam; + compact_idx_ = shared_contexts_idx_ + 2 * batchxbeam; compact_size_ = (int*)allocator_->reMalloc(compact_size_, sizeof(int), false); } generation_should_stop_ = (bool*)allocator_->reMalloc(generation_should_stop_, sizeof(bool), true, true); @@ -879,24 +879,6 @@ void ParallelGpt::forward(std::unordered_map* outp } POP_RANGE; - int compact_size; - bool use_shared_contexts = (shared_contexts_ratio_ > 0.0f) && (max_input_length >= 1) && (batch_size > 1); - PUSH_RANGE("find context dups"); - if (use_shared_contexts) { - invokeFindContextDups(shared_contexts_idx_, - batch_to_compact_idx_, - compact_idx_, - compact_size_, - input_tensors->at("input_ids").getPtr(), - batch_size, - max_input_length, - stream_); - cudaD2Hcpy(&compact_size, compact_size_, 1); - use_shared_contexts = compact_size <= shared_contexts_ratio_ * batch_size; - sync_check_cuda_error(); - } - POP_RANGE; - // NOTE: p/prompt-tuning process here (lookup prompt embedding tables by task name ids) // get p/prompt-tuning weight for each batch --> shape [batch, beam_width] // --> ptrs with shape [prompt_len, hidden_size] @@ -1038,6 +1020,24 @@ void ParallelGpt::forward(std::unordered_map* outp sync_check_cuda_error(); POP_RANGE; + int compact_size; + bool use_shared_contexts = (shared_contexts_ratio_ > 0.0f) && (max_input_length >= 1) && (batch_size > 1); + PUSH_RANGE("find context dups"); + if (use_shared_contexts) { + invokeFindContextDups(shared_contexts_idx_, + batch_to_compact_idx_, + compact_idx_, + compact_size_, + tiled_input_ids_buf_, + batch_size * beam_width, + max_input_length, + stream_); + cudaD2Hcpy(&compact_size, compact_size_, 1); + use_shared_contexts = compact_size <= shared_contexts_ratio_ * batch_size * beam_width; + sync_check_cuda_error(); + } + POP_RANGE; + TensorMap decoder_input_tensors( {{"decoder_input", Tensor(MEMORY_GPU, @@ -1057,7 +1057,7 @@ void ParallelGpt::forward(std::unordered_map* outp decoder_input_tensors.insert("compact_idx", Tensor(MEMORY_GPU, TYPE_INT32, {(size_t)compact_size}, compact_idx_)); decoder_input_tensors.insert("batch_to_compact_idx", - Tensor(MEMORY_GPU, TYPE_INT32, {batch_size}, batch_to_compact_idx_)); + Tensor(MEMORY_GPU, TYPE_INT32, {batch_size * beam_width}, batch_to_compact_idx_)); } if (gpt_variant_params_.use_attention_linear_bias) { decoder_input_tensors.insert("linear_bias_slopes",