Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
zrr1999 committed Feb 3, 2024
1 parent 09ad99a commit 3779df7
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 48 deletions.
76 changes: 39 additions & 37 deletions paddle/phi/kernels/fusion/gpu/fused_multi_transformer_int8_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -76,16 +76,15 @@ void FusedMultiTransformerINT8Kernel(
AttnLayerNorm<T, T, int8_t>(dev_ctx, epsilon, bsz_seq, dim_embed);
phi::DenseTensor ln_mean, ln_var;
ln_mean.Resize({{bsz_seq}});
auto *ln_mean_data = dev_ctx.Alloc<U>(&ln_mean, ln_mean.numel() * sizeof(U));
auto *ln_mean_data =
dev_ctx.template Alloc<U>(&ln_mean, ln_mean.numel() * sizeof(U));
ln_var.Resize({{bsz_seq}});
auto *ln_var_data = dev_ctx.Alloc<U>(&ln_var, ln_var.numel() * sizeof(U));
auto *ln_var_data =
dev_ctx.template Alloc<U>(&ln_var, ln_var.numel() * sizeof(U));

// 2. qkv
// x: qkv's input [batch_size, seq_len, dim_embed]
// y: qkv's weight: [3, num_head, dim_head, dim_embed]
auto qkv_weights = ctx.MultiInput<phi::DenseTensor>("QKVW");
auto qkv_biases = ctx.MultiInput<phi::DenseTensor>("QKVBias");
const bool trans_qkvw = ctx.Attr<bool>("trans_qkvw");
const auto qkv_w_dims = qkv_weights[0]->dims();
int num_head = trans_qkvw ? qkv_w_dims[1] : qkv_w_dims[2];
int dim_head = trans_qkvw ? qkv_w_dims[2] : qkv_w_dims[3];
Expand All @@ -99,7 +98,8 @@ void FusedMultiTransformerINT8Kernel(
dev_ctx, bsz_seq, output_size, input_size, compute_bias);
phi::DenseTensor qkv_out;
qkv_out.Resize({{bsz, seq_len, 3, num_head, dim_head}});
auto *qkv_out_data = dev_ctx.Alloc<T>(&qkv_out, qkv_out.numel() * sizeof(T));
auto *qkv_out_data =
dev_ctx.template Alloc<T>(&qkv_out, qkv_out.numel() * sizeof(T));

// 3. fmha
AttnDropoutParam attn_param(
Expand Down Expand Up @@ -135,32 +135,33 @@ void FusedMultiTransformerINT8Kernel(

phi::DenseTensor transpose_out_2, qk_out;
transpose_out_2.Resize({{3, bsz, num_head, seq_len, dim_head}});
auto *transpose_out_2_data =
dev_ctx.Alloc<T>(&transpose_out_2, transpose_out_2.numel() * sizeof(T));
auto *transpose_out_2_data = dev_ctx.template Alloc<T>(
&transpose_out_2, transpose_out_2.numel() * sizeof(T));

qk_out.Resize({{bsz, num_head, seq_len, out_seq_len}});
auto *qk_out_data = dev_ctx.Alloc<T>(&qk_out, qk_out.numel() * sizeof(T));
auto *qk_out_data =
dev_ctx.template Alloc<T>(&qk_out, qk_out.numel() * sizeof(T));

phi::DenseTensor softmax_out;
phi::DenseTensor attn_dropout_mask_out, attn_dropout_out;
phi::DenseTensor qktv_out, fmha_out;
softmax_out.Resize({{bsz, num_head, seq_len, out_seq_len}});
auto *softmax_out_data =
dev_ctx.Alloc<T>(&softmax_out, softmax_out.numel() * sizeof(T));
dev_ctx.template Alloc<T>(&softmax_out, softmax_out.numel() * sizeof(T));

attn_dropout_mask_out.Resize({{bsz, num_head, seq_len, out_seq_len}});
auto *attn_dropout_mask_out_data = dev_ctx.Alloc<T>(
auto *attn_dropout_mask_out_data = dev_ctx.template Alloc<T>(
&attn_dropout_mask_out, attn_dropout_mask_out.numel() * sizeof(T));
attn_dropout_out.Resize({{bsz, num_head, seq_len, out_seq_len}});
auto *attn_dropout_data_data =
dev_ctx.Alloc<T>(&attn_dropout_out, attn_dropout_out.numel() * sizeof(T));
auto *attn_dropout_data_data = dev_ctx.template Alloc<T>(
&attn_dropout_out, attn_dropout_out.numel() * sizeof(T));

qktv_out.Resize({{bsz, num_head, seq_len, dim_head}});
auto *qktv_out_data =
dev_ctx.Alloc<T>(&qktv_out, qktv_out.numel() * sizeof(T));
dev_ctx.template Alloc<T>(&qktv_out, qktv_out.numel() * sizeof(T));
fmha_out.Resize({{bsz, seq_len, num_head, dim_head}});
auto *fmha_out_data =
dev_ctx.Alloc<T>(&fmha_out, fmha_out.numel() * sizeof(T));
dev_ctx.template Alloc<T>(&fmha_out, fmha_out.numel() * sizeof(T));

// 4. out_linear
auto out_linear_weights = ctx.MultiInput<phi::DenseTensor>("OutLinearW");
Expand All @@ -184,12 +185,12 @@ void FusedMultiTransformerINT8Kernel(
T *bias_dropout_residual_out_data = nullptr;
if (pre_layer_norm) {
bias_dropout_residual_out.Resize({{bsz, seq_len, dim_embed}});
bias_dropout_residual_out_data =
dev_ctx.Alloc<T>(&bias_dropout_residual_out,
bias_dropout_residual_out.numel() * sizeof(T));
bias_dropout_residual_out_data = dev_ctx.template Alloc<T>(
&bias_dropout_residual_out,
bias_dropout_residual_out.numel() * sizeof(T));
}
dropout_mask_out.Resize({{bsz, seq_len, dim_embed}});
auto *dropout_mask_out_data = dev_ctx.Alloc<uint8_t>(
auto *dropout_mask_out_data = dev_ctx.template Alloc<uint8_t>(
&dropout_mask_out, dropout_mask_out.numel() * sizeof(uint8_t));

// 6. ffn matmul1
Expand All @@ -203,7 +204,7 @@ void FusedMultiTransformerINT8Kernel(
phi::DenseTensor ffn1_out;
ffn1_out.Resize({{bsz_seq, dim_ffn}});
auto *ffn1_out_data =
dev_ctx.Alloc<T>(&ffn1_out, ffn1_out.numel() * sizeof(T));
dev_ctx.template Alloc<T>(&ffn1_out, ffn1_out.numel() * sizeof(T));

// 7. ffn act + bias
DropoutParam ffn1_dropout_param(true, 0, true, true, 0.0, nullptr, 0);
Expand All @@ -213,10 +214,10 @@ void FusedMultiTransformerINT8Kernel(
dev_ctx, bsz_seq, dim_ffn, ffn1_dropout_param);
phi::DenseTensor ffn1_dropout_out, ffn1_dropout_mask;
ffn1_dropout_out.Resize({{bsz_seq, dim_ffn}});
auto *ffn1_dropout_out_data =
dev_ctx.Alloc<T>(&ffn1_dropout_out, ffn1_dropout_out.numel() * sizeof(T));
auto *ffn1_dropout_out_data = dev_ctx.template Alloc<T>(
&ffn1_dropout_out, ffn1_dropout_out.numel() * sizeof(T));
ffn1_dropout_mask.Resize({{bsz_seq, dim_ffn}});
auto *ffn1_dropout_mask_data = dev_ctx.Alloc<uint8_t>(
auto *ffn1_dropout_mask_data = dev_ctx.template Alloc<uint8_t>(
&ffn1_dropout_mask, ffn1_dropout_mask.numel() * sizeof(uint8_t));

// 8. ffn2 matmul
Expand Down Expand Up @@ -244,24 +245,25 @@ void FusedMultiTransformerINT8Kernel(
n_max = std::max({output_size, dim_embed, dim_ffn});

input_workspace.Resize({{(m_max * k_max + 31) / 32 * 32}});
dev_ctx.Alloc<int8_t>(&input_workspace,
input_workspace.numel() * sizeof(int8_t));
dev_ctx.template Alloc<int8_t>(&input_workspace,
input_workspace.numel() * sizeof(int8_t));

output_workspace.Resize({{(n_max * m_max + 31) / 32 * 32}});
dev_ctx.Alloc<int32_t>(&output_workspace,
output_workspace.numel() * sizeof(int32_t));
dev_ctx.template Alloc<int32_t>(&output_workspace,
output_workspace.numel() * sizeof(int32_t));

cublaslt_workspace.Resize({{3000000}});
dev_ctx.Alloc<int8_t>(&cublaslt_workspace,
cublaslt_workspace.numel() * sizeof(int8_t));
dev_ctx.template Alloc<int8_t>(&cublaslt_workspace,
cublaslt_workspace.numel() * sizeof(int8_t));

// calc
auto *out = ctx.Output<phi::DenseTensor>("Out");
auto *from_data = dev_ctx.Alloc<T>(out, out->numel() * sizeof(T));
auto *from_data = dev_ctx.template Alloc<T>(out, out->numel() * sizeof(T));
phi::DenseTensor *from_tensor = out;
phi::DenseTensor tmp_out;
tmp_out.Resize({{bsz, seq_len, dim_embed}});
auto *tmp_out_data = dev_ctx.Alloc<T>(&tmp_out, tmp_out.numel() * sizeof(T));
auto *tmp_out_data =
dev_ctx.template Alloc<T>(&tmp_out, tmp_out.numel() * sizeof(T));

auto *x_data = input_x->data<T>();
phi::DenseTensor *buf0 = nullptr;
Expand Down Expand Up @@ -667,9 +669,9 @@ void FusedMultiTransformerINT8Kernel(
} // namespace fusion
} // namespace phi

// PD_REGISTER_KERNEL(fused_multi_transformer_int8,
// GPU,
// ALL_LAYOUT,
// phi::fusion::FusedMultiTransformerINT8Kernel,
// float,
// plat::dtype::float16) {}
PD_REGISTER_KERNEL(fused_multi_transformer_int8,
GPU,
ALL_LAYOUT,
phi::fusion::FusedMultiTransformerINT8Kernel,
float,
plat::dtype::float16) {}
12 changes: 6 additions & 6 deletions paddle/phi/kernels/fusion/gpu/fused_multi_transformer_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1394,9 +1394,9 @@ void FusedMultiTransformerKernel(
} // namespace fusion
} // namespace phi

// PD_REGISTER_KERNEL(fused_multi_transformer,
// GPU,
// ALL_LAYOUT,
// phi::fusion::FusedMultiTransformerKernel,
// float,
// plat::dtype::float16) {}
PD_REGISTER_KERNEL(fused_multi_transformer,
GPU,
ALL_LAYOUT,
phi::fusion::FusedMultiTransformerKernel,
float,
plat::dtype::float16) {}
8 changes: 4 additions & 4 deletions paddle/phi/kernels/fusion/gpu/fused_multi_transformer_op.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -745,9 +745,9 @@ inline __device__ void convert_from_float(float4 &dst, float4 src) { // NOLINT
dst = src;
}

inline __device__ void convert_from_float(phi::phi::float16 &dst, // NOLINT
inline __device__ void convert_from_float(phi::float16 &dst, // NOLINT
float src) {
dst = static_cast<phi::phi::float16>(src);
dst = static_cast<phi::float16>(src);
}

inline __device__ void convert_from_float(uint4 &dst, Float8_ src) { // NOLINT
Expand Down Expand Up @@ -1782,7 +1782,7 @@ class CublasFusedMLP {
cudaDataType_t mat_type = CUDA_R_32F;
cudaDataType_t scale_type = CUDA_R_32F;
cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F;
if (std::is_same<T, phi::phi::float16>::value) {
if (std::is_same<T, phi::float16>::value) {
mat_type = CUDA_R_16F;
if (FLAGS_gemm_use_half_precision_compute_type) {
// This option default value is true, it tends to result NaN, but get
Expand Down Expand Up @@ -1979,7 +1979,7 @@ class CublasFusedMLP {
const uint64_t cublas_row,
const uint64_t cublas_col) {
cudaDataType_t mat_type = CUDA_R_32F;
if (std::is_same<T, phi::phi::float16>::value) {
if (std::is_same<T, phi::float16>::value) {
mat_type = CUDA_R_16F;
}
if (std::is_same<T, phi::bfloat16>::value) {
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/incubate/nn/functional/fused_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ def fused_bias_dropout_residual_layer_norm(
x.shape[len(x.shape) - 1] == ln_bias.shape[0]
), "The dim of ln_bias must equal to the last dim of x."

if in_dynamic_mode():
if in_dynamic_or_pir_mode():
if default_main_program().random_seed != 0:
seed = default_main_program().random_seed
(
Expand Down

0 comments on commit 3779df7

Please sign in to comment.