Skip to content

Commit

Permalink
fix: gpt tensor shapes inconsistency (#505)
Browse files Browse the repository at this point in the history
Signed-off-by: AkiyamaYummy <[email protected]>
  • Loading branch information
zhang-ge-hao authored Mar 17, 2023
1 parent bb94e2d commit e838426
Showing 1 changed file with 22 additions and 22 deletions.
44 changes: 22 additions & 22 deletions src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,9 @@ void ParallelGpt<T>::allocateBuffer(size_t batch_size,
lp_logprob_buf_ = (float*)allocator_->reMalloc(lp_logprob_buf_, sizeof(float) * batchxbeam * max_input_len);
}
if (shared_contexts_ratio_ > 0.0f) {
shared_contexts_idx_ = (int*)allocator_->reMalloc(shared_contexts_idx_, 3 * batch_size * sizeof(int), false);
batch_to_compact_idx_ = shared_contexts_idx_ + batch_size;
compact_idx_ = shared_contexts_idx_ + 2 * batch_size;
shared_contexts_idx_ = (int*)allocator_->reMalloc(shared_contexts_idx_, 3 * batchxbeam * sizeof(int), false);
batch_to_compact_idx_ = shared_contexts_idx_ + batchxbeam;
compact_idx_ = shared_contexts_idx_ + 2 * batchxbeam;
compact_size_ = (int*)allocator_->reMalloc(compact_size_, sizeof(int), false);
}
generation_should_stop_ = (bool*)allocator_->reMalloc(generation_should_stop_, sizeof(bool), true, true);
Expand Down Expand Up @@ -879,24 +879,6 @@ void ParallelGpt<T>::forward(std::unordered_map<std::string, Tensor>* outp
}
POP_RANGE;

int compact_size;
bool use_shared_contexts = (shared_contexts_ratio_ > 0.0f) && (max_input_length >= 1) && (batch_size > 1);
PUSH_RANGE("find context dups");
if (use_shared_contexts) {
invokeFindContextDups(shared_contexts_idx_,
batch_to_compact_idx_,
compact_idx_,
compact_size_,
input_tensors->at("input_ids").getPtr<int>(),
batch_size,
max_input_length,
stream_);
cudaD2Hcpy(&compact_size, compact_size_, 1);
use_shared_contexts = compact_size <= shared_contexts_ratio_ * batch_size;
sync_check_cuda_error();
}
POP_RANGE;

// NOTE: p/prompt-tuning process here (lookup prompt embedding tables by task name ids)
// get p/prompt-tuning weight for each batch --> shape [batch, beam_width]
// --> ptrs with shape [prompt_len, hidden_size]
Expand Down Expand Up @@ -1038,6 +1020,24 @@ void ParallelGpt<T>::forward(std::unordered_map<std::string, Tensor>* outp
sync_check_cuda_error();
POP_RANGE;

int compact_size;
bool use_shared_contexts = (shared_contexts_ratio_ > 0.0f) && (max_input_length >= 1) && (batch_size > 1);
PUSH_RANGE("find context dups");
if (use_shared_contexts) {
invokeFindContextDups(shared_contexts_idx_,
batch_to_compact_idx_,
compact_idx_,
compact_size_,
tiled_input_ids_buf_,
batch_size * beam_width,
max_input_length,
stream_);
cudaD2Hcpy(&compact_size, compact_size_, 1);
use_shared_contexts = compact_size <= shared_contexts_ratio_ * batch_size * beam_width;
sync_check_cuda_error();
}
POP_RANGE;

TensorMap decoder_input_tensors(
{{"decoder_input",
Tensor(MEMORY_GPU,
Expand All @@ -1057,7 +1057,7 @@ void ParallelGpt<T>::forward(std::unordered_map<std::string, Tensor>* outp
decoder_input_tensors.insert("compact_idx",
Tensor(MEMORY_GPU, TYPE_INT32, {(size_t)compact_size}, compact_idx_));
decoder_input_tensors.insert("batch_to_compact_idx",
Tensor(MEMORY_GPU, TYPE_INT32, {batch_size}, batch_to_compact_idx_));
Tensor(MEMORY_GPU, TYPE_INT32, {batch_size * beam_width}, batch_to_compact_idx_));
}
if (gpt_variant_params_.use_attention_linear_bias) {
decoder_input_tensors.insert("linear_bias_slopes",
Expand Down

0 comments on commit e838426

Please sign in to comment.