Skip to content

Commit

Permalink
beam search support length_penalty
Browse files Browse the repository at this point in the history
  • Loading branch information
wanglipeng committed Nov 28, 2023
1 parent e54fcce commit 456ab5b
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 9 deletions.
2 changes: 1 addition & 1 deletion paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@
backward : flip_grad

- op : beam_search_softmax
args : (Tensor logits, Tensor cum_scores, Tensor sequence_lengths, Tensor stop_flags, Tensor end_ids, Tensor step_ids, Tensor last_cache_ids, Tensor last_beam_offsets, int beam_size, int max_seq_len, int max_dec_len, bool fuse_softmax, bool early_stop)
args : (Tensor logits, Tensor cum_scores, Tensor sequence_lengths, Tensor stop_flags, Tensor end_ids, Tensor step_ids, Tensor last_cache_ids, Tensor last_beam_offsets, int beam_size, int max_seq_len, int max_dec_len, bool fuse_softmax, bool early_stop, float length_penalty=0.0)
output : Tensor(ids_this_time), Tensor(out_cum_scores), Tensor(cache_ids), Tensor(beam_offsets), Tensor(parent_idx), Tensor(stop_flags_out), Tensor(seq_lens_out), Tensor(step_ids_out)
infer_meta :
func : BeamSearchSoftmaxInferMeta
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,7 @@ void BeamSearchSoftmaxInferMeta(const MetaTensor& logits,
int max_dec_len,
bool fuse_softmax,
bool early_stop,
float length_penalty,
MetaTensor* ids_this_time,
MetaTensor* out_cum_scores,
MetaTensor* cache_ids,
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/infermeta/multiary.h
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ void BeamSearchSoftmaxInferMeta(const MetaTensor& logits,
int max_dec_len,
bool fuse_softmax,
bool early_stop,
float length_penalty,
MetaTensor* ids_this_time,
MetaTensor* out_cum_scores,
MetaTensor* cache_ids,
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/kernels/fusion/beam_search_softmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ void BeamSearchSoftmaxKernel(const Context &dev_ctx,
int max_dec_len,
bool fuse_softmax,
bool early_stop,
float length_penalty,
DenseTensor *ids_this_time,
DenseTensor *out_cum_scores,
DenseTensor *cache_ids,
Expand Down
28 changes: 22 additions & 6 deletions paddle/phi/kernels/fusion/gpu/beam_search_softmax.cu
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ namespace fusion {
max_dec_len, \
fuse_softmax, \
early_stop, \
length_penalty, \
stream); \
break

Expand Down Expand Up @@ -202,6 +203,7 @@ __global__ void batch_topk(const int *topk_tmp_id_buf,
}
}

// early stop
template <typename T, int K, int THREADBLOCK_SIZE>
__global__ void batch_topk(const int *topk_tmp_id_buf,
const T *topk_tmp_val_buf,
Expand Down Expand Up @@ -391,7 +393,9 @@ __global__ void beam_search_softmax_topk_stage2(const float *tmp_buffer,
T *tmp_vals,
const int voc_parts,
const int packed_top_kmd_size,
const bool fuse_softmax) {
const bool fuse_softmax,
const float length_penalty,
const int *step_ids) {
const int vector_id = blockIdx.x;
const int thread_id = threadIdx.x;
const int PACKED_TOP_KMD_SIZE = packed_top_kmd_size;
Expand All @@ -402,6 +406,10 @@ __global__ void beam_search_softmax_topk_stage2(const float *tmp_buffer,
float *buf_s = reinterpret_cast<float *>(buf_s_);
tmp_buffer += vector_id * PACKED_TOP_KMD_SIZE * voc_parts;

// Since cum_log_probs is the penalized values, need to be restored before accumulation.
T previous_penalty = static_cast<T>(powf(step_ids[vector_id], length_penalty));
T current_penalty = static_cast<T>(powf(step_ids[vector_id] + 1, length_penalty));

if (fuse_softmax) {
typedef cub::BlockReduce<TopKSoftMax<T, K>, THREADBLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
Expand Down Expand Up @@ -443,7 +451,7 @@ __global__ void beam_search_softmax_topk_stage2(const float *tmp_buffer,
// float val = expf((float)total.topk.vals[i] - total.softmax_md.logit - d_total_log);
float val = total.topk.vals[i] - total.softmax_md.logit - d_total_log;
tmp_ids[i] = total.topk.ids[i];
tmp_vals[i] = val + cum_log_probs[0];
tmp_vals[i] = (val + cum_log_probs[0] * previous_penalty) / current_penalty;
#ifdef DEBUG_BEAM_SEARCH_SOFTMAX
printf("vector_id: %d, vals: %f, logit: %f, d_total_log: %f, id: %d, val: %f, cum_log_probs: %f, res: %f\n", vector_id, total.topk.vals[i], total.softmax_md.logit, d_total_log, tmp_ids[i], val, cum_log_probs[0], tmp_vals[i]);
#endif
Expand Down Expand Up @@ -485,7 +493,7 @@ __global__ void beam_search_softmax_topk_stage2(const float *tmp_buffer,
for (int i = 0; i < K; ++i) {
float val = total.vals[i];
tmp_ids[i] = total.ids[i];
tmp_vals[i] = val + cum_log_probs[0];
tmp_vals[i] = (val + cum_log_probs[0] * previous_penalty) / current_penalty;
}
}
}
Expand All @@ -501,25 +509,27 @@ void invokeBeamSearchSoftmaxTopKStage2(const float *tmp_buffer,
const int voc_parts,
const int packed_top_kmd_size,
const bool fuse_softmax,
const float length_penalty,
const int *step_ids,
cudaStream_t stream) {
int smem_stage2_size = voc_parts * packed_top_kmd_size * sizeof(float);

if (voc_parts <= 32) {
beam_search_softmax_topk_stage2<T, K, 32>
<<<batch_size * beam_size, 32, smem_stage2_size, stream>>>(
tmp_buffer, cum_log_probs, ids, vals, voc_parts, packed_top_kmd_size, fuse_softmax);
tmp_buffer, cum_log_probs, ids, vals, voc_parts, packed_top_kmd_size, fuse_softmax, length_penalty, step_ids);
return;
}
if (voc_parts <= 64) {
beam_search_softmax_topk_stage2<T, K, 64>
<<<batch_size * beam_size, 64, smem_stage2_size, stream>>>(
tmp_buffer, cum_log_probs, ids, vals, voc_parts, packed_top_kmd_size, fuse_softmax);
tmp_buffer, cum_log_probs, ids, vals, voc_parts, packed_top_kmd_size, fuse_softmax, length_penalty, step_ids);
return;
}
if (voc_parts <= 128) {
beam_search_softmax_topk_stage2<T, K, 128>
<<<batch_size * beam_size, 128, smem_stage2_size, stream>>>(
tmp_buffer, cum_log_probs, ids, vals, voc_parts, packed_top_kmd_size, fuse_softmax);
tmp_buffer, cum_log_probs, ids, vals, voc_parts, packed_top_kmd_size, fuse_softmax, length_penalty, step_ids);
return;
}
}
Expand Down Expand Up @@ -681,6 +691,7 @@ void invokeTopKSoftMaxLauncher(const Context &dev_ctx,
const int max_dec_len,
const bool fuse_softmax,
const bool early_stop,
const float length_penalty,
cudaStream_t stream) {
// K = 2 * beam_size
const int block_size = 128;
Expand Down Expand Up @@ -725,6 +736,8 @@ void invokeTopKSoftMaxLauncher(const Context &dev_ctx,
voc_parts,
packed_top_kmd_size,
fuse_softmax,
length_penalty,
step_ids,
stream);
// (bs, bm)
if (early_stop) {
Expand Down Expand Up @@ -808,6 +821,7 @@ void invokeTopkSoftMax(const Context &dev_ctx,
const int max_dec_len,
const bool fuse_softmax,
const bool early_stop,
const float length_penalty,
cudaStream_t stream) {
switch (beam_size) {
CASE_K(1);
Expand Down Expand Up @@ -848,6 +862,7 @@ void BeamSearchSoftmaxKernel(const Context &dev_ctx,
int max_dec_len,
bool fuse_softmax,
bool early_stop,
float length_penalty,
DenseTensor *ids_this_time,
DenseTensor *out_cum_scores,
DenseTensor *cache_ids,
Expand Down Expand Up @@ -913,6 +928,7 @@ void BeamSearchSoftmaxKernel(const Context &dev_ctx,
max_dec_len,
fuse_softmax,
early_stop,
length_penalty,
dev_ctx.stream());
}

Expand Down
6 changes: 4 additions & 2 deletions python/paddle/tensor/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -1122,7 +1122,7 @@ def beam_search_softmax(
max_dec_len,
fuse_softmax,
early_stop,
name=None,
length_penalty=0.0,
):
if in_dygraph_mode():
return _C_ops.beam_search_softmax(
Expand All @@ -1138,7 +1138,8 @@ def beam_search_softmax(
max_seq_len,
max_dec_len,
fuse_softmax,
early_stop
early_stop,
length_penalty
)

inputs = {
Expand All @@ -1157,6 +1158,7 @@ def beam_search_softmax(
attrs['max_dec_len'] = max_dec_len
attrs['fuse_softmax'] = fuse_softmax
attrs['early_stop'] = early_stop
attrs['length_penalty'] = length_penalty

helper = LayerHelper('beam_search_softmax', **locals())
ids_this_time = helper.create_variable_for_type_inference(dtype="int32")
Expand Down

0 comments on commit 456ab5b

Please sign in to comment.