Skip to content

Commit

Permalink
moe fmt add remove padding (#114)
Browse files Browse the repository at this point in the history
* add & fix zeus int8

* ptq int8 moe add grouped gemm

* moe fmt add remove padding

* moe fmt add remove padding

* moe fmt add remove padding
  • Loading branch information
tianyan01 authored Feb 22, 2024
1 parent fd736c0 commit c586e29
Show file tree
Hide file tree
Showing 4 changed files with 234 additions and 89 deletions.
88 changes: 80 additions & 8 deletions paddle/fluid/operators/fused/fmha_ref.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class AttnDropoutParam {
const phi::DenseTensor* seed_;
};

template <typename T, int VecSize>
template <typename T, int VecSize, bool do_transpose=true>
__global__ void TransposeRemovingPadding(const T* input_data,
T* output_data,
const int batch_size,
Expand All @@ -96,6 +96,8 @@ __global__ void TransposeRemovingPadding(const T* input_data,
// transpose and remove padding
// [batch_size, num_head, seq_len, head_dim] -> [token_num, num_head,
// head_dim]
//
// if do not transpose, input shape is [bsz, seq_len, num_head, head_dim]
int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
const int dim_embed = num_head * head_dim;
using LoadT = phi::AlignedVector<T, VecSize>;
Expand All @@ -112,15 +114,20 @@ __global__ void TransposeRemovingPadding(const T* input_data,
const int ori_seq_id = ori_token_idx % seq_len;
const int ori_head_id = (linear_index % dim_embed) / head_dim;
const int ori_head_lane = (linear_index % dim_embed) % head_dim;
const int ori_idx = ori_batch_id * num_head * seq_len * head_dim +
ori_head_id * seq_len * head_dim +
ori_seq_id * head_dim + ori_head_lane;

int ori_idx = ori_batch_id * num_head * seq_len * head_dim +
ori_head_id * seq_len * head_dim +
ori_seq_id * head_dim +
ori_head_lane;
if (!do_transpose) {
ori_idx = ori_token_idx * dim_embed + linear_index % dim_embed;
}
phi::Load<T, VecSize>(&input_data[ori_idx], &src_vec);
phi::Store<T, VecSize>(src_vec, &output_data[linear_index]);
}
}

template <typename T>
template <typename T, bool do_transpose=true>
void InvokeTransposeRemovePadding(const phi::GPUContext& dev_ctx,
const T* input_data,
T* output_data,
Expand All @@ -143,7 +150,7 @@ void InvokeTransposeRemovePadding(const phi::GPUContext& dev_ctx,
const int32_t pack_num = elem_cnt / PackSize;
const int32_t block_size = 128;
int32_t grid_size = (pack_num + block_size - 1) / block_size;
TransposeRemovingPadding<T, PackSize>
TransposeRemovingPadding<T, PackSize, do_transpose>
<<<grid_size, block_size, 0, dev_ctx.stream()>>>(input_data,
output_data,
batch_size,
Expand Down Expand Up @@ -845,15 +852,80 @@ class FlashAttnFMHARef {
seed,
mask,
dropout_param_.dropout_prob_,
(seq_len_ != 1 && src_mask_tensor == nullptr),
return_softmax,
(seq_len_ != 1 && src_mask_tensor == nullptr),
return_softmax,
dropout_param_.is_test_,
"",
fmha_out_tensor,
softmax_out_tensor, // softmax
softmax_lse_out_tensor, // softmax_lse
seed_offset); // seed_offset
}

void RemovePaddingComputeForward(const phi::DenseTensor* src_mask_tensor,
const phi::DenseTensor* padding_offset_tensor,
phi::DenseTensor* transpose_2_out_tensor,
phi::DenseTensor* input_tensor,
phi::DenseTensor* softmax_lse_out_tensor,
phi::DenseTensor* seed_offset,
phi::DenseTensor* softmax_out_tensor,
phi::DenseTensor* qktv_out_tensor,
phi::DenseTensor* fmha_out_tensor,
const int token_num) {
// input shape: [3, bs, num_head, seq_len, head_dim]
// transpose with perm [0, 1, 3, 2, 4],
// output_shape: [3, bs, seq_len, num_head, head_dim]
std::vector<int> perm_1 = {0, 1, 3, 2, 4};
TransposeGPUKernelDriver<T>(
dev_ctx_, *transpose_2_out_tensor, perm_1, input_tensor);

phi::DenseTensor q, k, v;
q = input_tensor->Slice(0, 1);
k = input_tensor->Slice(1, 2);
v = input_tensor->Slice(2, 3);
// bs, seq_len, num_head, head_dim
auto dim = phi::make_ddim({batch_size_, seq_len_, num_head_, head_dim_});
q.Resize(dim);
k.Resize(dim);
v.Resize(dim);

auto seed = paddle::make_optional(
(dropout_param_.is_fix_seed_ && dropout_param_.seed_ != nullptr),
*dropout_param_.seed_);
auto mask = paddle::make_optional(*src_mask_tensor);
bool return_softmax =
(!dropout_param_.is_test_ && dropout_param_.dropout_prob_ > 0.0f);
// q.shape[1] != 1,
phi::FlashAttnKernel<T, phi::GPUContext>(
dev_ctx_,
q,
k,
v,
seed,
mask,
dropout_param_.dropout_prob_,
(seq_len_ != 1 && src_mask_tensor == nullptr),
return_softmax,
dropout_param_.is_test_,
"",
qktv_out_tensor, // already transpose
softmax_out_tensor, // softmax
softmax_lse_out_tensor, // softmax_lse
seed_offset); // seed_offset

if (padding_offset_tensor) {
InvokeTransposeRemovePadding<T, false>(dev_ctx_,
qktv_out_tensor->data<T>(), // input
fmha_out_tensor->data<T>(), // output
batch_size_,
num_head_,
seq_len_,
head_dim_,
token_num,
padding_offset_tensor->data<int>());
}
}

void ComputeBackward(const phi::DenseTensor& transpose_2_out_tensor,
const phi::DenseTensor* src_mask_tensor,
const phi::DenseTensor& softmax_out_tensor,
Expand Down
Loading

0 comments on commit c586e29

Please sign in to comment.