diff --git a/paddle/fluid/operators/fused/attention_layer_norm.h b/paddle/fluid/operators/fused/attention_layer_norm.h deleted file mode 100644 index 92cbc37059eb14..00000000000000 --- a/paddle/fluid/operators/fused/attention_layer_norm.h +++ /dev/null @@ -1,113 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/phi/kernels/funcs/layer_norm_impl.cu.h" - -namespace paddle { -namespace operators { - -// NOTE: T must be the same as OutType in ComputeBackward -template -class AttnLayerNorm { - public: - AttnLayerNorm(const phi::GPUContext& dev_ctx, - float epsilon, - int64_t batch_size, - int64_t feature_size) - : dev_ctx_(dev_ctx), - epsilon_(epsilon), - batch_size_(batch_size), - feature_size_(feature_size) {} - - ~AttnLayerNorm() {} - - void ComputeForward(const InType* x_data, - const phi::funcs::LayerNormParamType* scale_data, - const phi::funcs::LayerNormParamType* bias_data, - OutType* y_data, - phi::funcs::LayerNormParamType* mean_data, - phi::funcs::LayerNormParamType* var_data, - const float* dequant_out_scale_data = nullptr, - const int quant_out_scale_offset = 0, - const float quant_in_scale = 1.0, - const int quant_round_type = 1, - const float quant_max_bound = 127.0, - const float quant_min_bound = -127.0) { - auto stream = dev_ctx_.stream(); - - switch (phi::funcs::GetDesiredBlockDim(feature_size_)) { - FIXED_BLOCK_DIM_CASE( - phi::funcs::LayerNormForward, - kBlockDim, - false, - InType, - OutType> - <<>>(x_data, - scale_data, - bias_data, - y_data, - mean_data, - var_data, - epsilon_, - feature_size_, - dequant_out_scale_data, - quant_out_scale_offset, - quant_in_scale, - quant_round_type, - quant_max_bound, - quant_min_bound)); - default: - PADDLE_THROW( - phi::errors::InvalidArgument("Feature_size must be larger than 1")); - break; - } - } - - void ComputeBackward(const T* x_data, - const T* d_y_data, - const phi::funcs::LayerNormParamType* scale_data, - const phi::funcs::LayerNormParamType* mean_data, - const phi::funcs::LayerNormParamType* var_data, - T* d_x_data, - phi::funcs::LayerNormParamType* d_scale_data, - phi::funcs::LayerNormParamType* d_bias_data) { - phi::funcs::LayerNormBackward>( - x_data, - d_y_data, - scale_data, - mean_data, - var_data, - d_x_data, - d_scale_data, - d_bias_data, - epsilon_, - batch_size_, - feature_size_, - dev_ctx_); - } - - private: - const phi::GPUContext& dev_ctx_; - - int64_t batch_size_; - int64_t feature_size_; - - float epsilon_; -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/fused/fmha_ref.h b/paddle/fluid/operators/fused/fmha_ref.h deleted file mode 100644 index 2a43eea07535ab..00000000000000 --- a/paddle/fluid/operators/fused/fmha_ref.h +++ /dev/null @@ -1,750 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/fluid/operators/fused/fused_softmax_mask.cu.h" -#include "paddle/phi/kernels/funcs/blas/blas.h" -#include "paddle/phi/kernels/funcs/broadcast_function.h" -#include "paddle/phi/kernels/funcs/concat_and_split_functor.h" -#include "paddle/phi/kernels/funcs/dropout_impl.cu.h" -#include "paddle/phi/kernels/funcs/elementwise_base.h" -#include "paddle/phi/kernels/funcs/elementwise_functor.h" -#include "paddle/phi/kernels/funcs/functors.h" -#include "paddle/phi/kernels/funcs/transpose_function.cu.h" -#include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h" - -namespace paddle { -namespace operators { - -class AttnDropoutParam { - public: - AttnDropoutParam() { - is_test_ = false; - dropout_implementation_ = "downgrade_in_infer"; - dropout_prob_ = 0.5; - is_upscale_in_train_ = false; - is_fix_seed_ = false; - seed_val_ = 0; - seed_ = nullptr; - } - AttnDropoutParam(bool is_test, - const std::string dropout_implementation, - float dropout_prob, - bool is_upscale_in_train, - bool is_fix_seed, - int seed_val, - const phi::DenseTensor* seed) { - is_test_ = is_test; - dropout_implementation_ = dropout_implementation; - dropout_prob_ = dropout_prob; - is_upscale_in_train_ = is_upscale_in_train; - is_fix_seed_ = is_fix_seed; - seed_val_ = seed_val; - seed_ = seed; - } - bool is_test_; - std::string dropout_implementation_; - float dropout_prob_; - bool is_upscale_in_train_; - bool is_fix_seed_; - int seed_val_; - const phi::DenseTensor* seed_; -}; - -template -__global__ void TransposeRemovingPadding(const T* input_data, - T* output_data, - const int batch_size, - const int num_head, - const int seq_len, - const int head_dim, - const int token_num, - const int elem_cnt, - const int* padding_offset) { - // transpose and remove padding - // [batch_size, num_head, seq_len, head_dim] -> [token_num, num_head, - // head_dim] - int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; - const int dim_embed = num_head * head_dim; - using LoadT = phi::AlignedVector; - LoadT src_vec; - - for (int32_t linear_index = idx * VecSize, - step = gridDim.x * blockDim.x * VecSize; - linear_index < elem_cnt; - linear_index += step) { - const int token_idx = linear_index / dim_embed; - const int ori_token_idx = - token_idx + (padding_offset == nullptr ? 0 : padding_offset[token_idx]); - const int ori_batch_id = ori_token_idx / seq_len; - const int ori_seq_id = ori_token_idx % seq_len; - const int ori_head_id = (linear_index % dim_embed) / head_dim; - const int ori_head_lane = (linear_index % dim_embed) % head_dim; - const int ori_idx = ori_batch_id * num_head * seq_len * head_dim + - ori_head_id * seq_len * head_dim + - ori_seq_id * head_dim + ori_head_lane; - phi::Load(&input_data[ori_idx], &src_vec); - phi::Store(src_vec, &output_data[linear_index]); - } -} - -template -void InvokeTransposeRemovePadding(const phi::GPUContext& dev_ctx, - const T* input_data, - T* output_data, - const int batch_size, - const int num_head, - const int seq_len, - const int head_dim, - const int token_num, - const int* padding_offset) { - // [batch_size, num_head, seq_len, head_dim] -> [token_num, num_head, - // head_dim] - constexpr int VEC_16B = 16; - const int elem_cnt = token_num * num_head * head_dim; - constexpr int PackSize = VEC_16B / sizeof(T); - PADDLE_ENFORCE_EQ( - head_dim % PackSize, - 0, - phi::errors::PreconditionNotMet( - "dim_head=%d must be divisible by vec_size=%d", head_dim, PackSize)); - const int32_t pack_num = elem_cnt / PackSize; - const int32_t block_size = 128; - int32_t grid_size = (pack_num + block_size - 1) / block_size; - TransposeRemovingPadding - <<>>(input_data, - output_data, - batch_size, - num_head, - seq_len, - head_dim, - token_num, - elem_cnt, - padding_offset); -} - -template -class FMHARef { - public: - FMHARef(const phi::GPUContext& dev_ctx, - int64_t batch_size, - int64_t seq_len, - int64_t num_head, - int64_t head_dim, - AttnDropoutParam param) - : dev_ctx_(dev_ctx), - batch_size_(batch_size), - seq_len_(seq_len), - num_head_(num_head), - head_dim_(head_dim), - dropout_param_(param) {} - - ~FMHARef() {} - - void ComputeForward(const phi::DenseTensor& qkv_input_tensor, - const phi::DenseTensor* cache_kv_tensor, - const phi::DenseTensor* src_mask_tensor, - phi::DenseTensor* transpose_2_out_tensor, - phi::DenseTensor* cache_kv_out_tensor, - phi::DenseTensor* qk_out_tensor, - phi::DenseTensor* src_mask_out_tensor, - phi::DenseTensor* softmax_out_tensor, - phi::DenseTensor* dropout_mask_out_tensor, - phi::DenseTensor* dropout_out_tensor, - phi::DenseTensor* qktv_out_tensor, - phi::DenseTensor* fmha_out_tensor) { - // input shape: [bs, seq_len, 3, num_head, head_dim] - // transpose with perm [2, 0, 3, 1, 4], - // output_shape: [3, bs, num_head, seq_len, head_dim] - std::vector perm_1 = {2, 0, 3, 1, 4}; - phi::funcs::TransposeGPUKernelDriver( - dev_ctx_, qkv_input_tensor, perm_1, transpose_2_out_tensor); - T* qkv_data = transpose_2_out_tensor->data(); - T* qk_out_data = qk_out_tensor->data(); - T* qktv_out_data = qktv_out_tensor->data(); - T* softmax_out_data = softmax_out_tensor->data(); - T* fmha_out_data = fmha_out_tensor->data(); - - auto out_seq_len = seq_len_; - if (cache_kv_tensor) { - // kv [2, bs, num_head, seq_len, head_dim] - auto kv_tensor = transpose_2_out_tensor->Slice(1, 3); - phi::funcs::ConcatFunctor concat; - // out [2, bs, num_head, cache_seq_len + seq_len, head_dim] - concat(dev_ctx_, {*cache_kv_tensor, kv_tensor}, 3, cache_kv_out_tensor); - out_seq_len = cache_kv_out_tensor->dims()[3]; - } - - int64_t q_size = batch_size_ * seq_len_ * num_head_ * head_dim_; - T* q_ptr = qkv_data; - T* k_ptr = nullptr; - T* v_ptr = nullptr; - - if (cache_kv_tensor) { - int64_t k_size = cache_kv_out_tensor->numel() / 2; - k_ptr = cache_kv_out_tensor->data(); - v_ptr = k_ptr + k_size; - } else { - int64_t k_size = q_size; - k_ptr = q_ptr + q_size; - v_ptr = k_ptr + k_size; - } - - { - // NOTE(wangxi): We scale Q with 1/sqrt(Dh) before QK^T, because for - // float16 calculation, INF may appear in QK^T if we do not scale before. - float alpha = 1.0 / sqrt(head_dim_); - auto q_tensor = transpose_2_out_tensor->Slice(0, 1); - auto functor = phi::funcs::ScaleFunctor(alpha); - std::vector ins = {&q_tensor}; - std::vector outs = {&q_tensor}; - phi::funcs::ElementwiseKernel(dev_ctx_, ins, &outs, functor); - } - - // q*k^t, batched_gemm - CBLAS_TRANSPOSE transA = CblasNoTrans; - CBLAS_TRANSPOSE transB = CblasTrans; - auto blas = phi::funcs::GetBlas(dev_ctx_); - int gemm_batch_size = batch_size_ * num_head_; - int gemm_m = seq_len_; - int gemm_n = out_seq_len; - int gemm_k = head_dim_; - T alpha = static_cast(1.0); - T beta = static_cast(0.0); - int64_t stride_a = gemm_m * gemm_k; - int64_t stride_b = gemm_k * gemm_n; - blas.BatchedGEMM(transA, - transB, - gemm_m, - gemm_n, - gemm_k, - alpha, - q_ptr, - k_ptr, - beta, - qk_out_data, - gemm_batch_size, - stride_a, - stride_b); - int softmax_axis = -1; - if (src_mask_tensor != nullptr) { - if (src_mask_out_tensor == nullptr && seq_len_ == out_seq_len) { - LaunchFusedSoftmaxMaskKernel(qk_out_data, - src_mask_tensor->data(), - softmax_out_data, - batch_size_, - num_head_, - seq_len_, - dev_ctx_.stream()); - } else { - std::vector ins; - std::vector outs; - ins.emplace_back(qk_out_tensor); - ins.emplace_back(src_mask_tensor); - outs.emplace_back(src_mask_out_tensor); - int elewise_add_axis = -1; - phi::funcs::BroadcastKernel(dev_ctx_, - ins, - &outs, - phi::funcs::AddFunctor(), - elewise_add_axis); - - phi::SoftmaxForwardCUDAKernelDriver( - dev_ctx_, *src_mask_out_tensor, softmax_axis, softmax_out_tensor); - } - } else { - phi::SoftmaxForwardCUDAKernelDriver( - dev_ctx_, *qk_out_tensor, softmax_axis, softmax_out_tensor); - } - - transB = CblasNoTrans; - gemm_m = seq_len_; - gemm_n = head_dim_; - gemm_k = out_seq_len; - alpha = static_cast(1.0); - stride_a = gemm_m * gemm_k; - stride_b = gemm_k * gemm_n; - - if (dropout_param_.dropout_prob_) { - phi::funcs::DropoutFwGPUKernelDriver( - static_cast(dev_ctx_), - dropout_param_.is_test_, - dropout_param_.dropout_prob_, - dropout_param_.is_upscale_in_train_, - dropout_param_.is_fix_seed_, - dropout_param_.seed_val_, - static_cast(*softmax_out_tensor), - dropout_param_.seed_, - dropout_mask_out_tensor, - dropout_out_tensor, - false); - T* dropout_out_data = dropout_out_tensor->data(); - blas.BatchedGEMM(transA, - transB, - gemm_m, - gemm_n, - gemm_k, - alpha, - dropout_out_data, - v_ptr, - beta, - qktv_out_data, - gemm_batch_size, - stride_a, - stride_b); - } else { - // softmax_out * v, batched_gemm - // output shape: [batch_size, num_heads, seq_len, head_dim] - blas.BatchedGEMM(transA, - transB, - gemm_m, - gemm_n, - gemm_k, - alpha, - softmax_out_data, - v_ptr, - beta, - qktv_out_data, - gemm_batch_size, - stride_a, - stride_b); - } - // transpose: [0, 2, 1, 3] - // output shape: [batch_size, seq_len, num_heads, head_dim] - std::vector perm_3 = {0, 2, 1, 3}; - phi::funcs::TransposeGPUKernelDriver( - dev_ctx_, *qktv_out_tensor, perm_3, fmha_out_tensor); - } - - void ComputeForwardWithoutTranspose( - const phi::DenseTensor* cache_kv_tensor, - const phi::DenseTensor* src_mask_tensor, - const phi::DenseTensor* padding_offset_tensor, - phi::DenseTensor* q_transpose_out_tensor, - phi::DenseTensor* kv_transpose_out_tensor, - phi::DenseTensor* cache_kv_out_tensor, - phi::DenseTensor* qk_out_tensor, - phi::DenseTensor* src_mask_out_tensor, - phi::DenseTensor* softmax_out_tensor, - phi::DenseTensor* dropout_mask_out_tensor, - phi::DenseTensor* dropout_out_tensor, - phi::DenseTensor* qktv_out_tensor, - phi::DenseTensor* fmha_out_tensor, - const int token_num) { - // input shape: [bs, seq_len, 3, num_head, head_dim] - // transpose with perm [2, 0, 3, 1, 4], - // output_shape: [3, bs, num_head, seq_len, head_dim] - T* qk_out_data = qk_out_tensor->data(); - T* qktv_out_data = qktv_out_tensor->data(); - T* softmax_out_data = softmax_out_tensor->data(); - T* dropout_out_data = dropout_out_tensor->data(); - T* fmha_out_data = fmha_out_tensor->data(); - - auto out_seq_len = seq_len_; - if (cache_kv_tensor) { - // kv [2, bs, num_head, seq_len, head_dim] - phi::funcs::ConcatFunctor concat; - // out [2, bs, num_head, cache_seq_len + seq_len, head_dim] - concat(dev_ctx_, - {*cache_kv_tensor, *kv_transpose_out_tensor}, - 3, - cache_kv_out_tensor); - out_seq_len = cache_kv_out_tensor->dims()[3]; - } - - int64_t q_size = batch_size_ * seq_len_ * num_head_ * head_dim_; - T* q_ptr = q_transpose_out_tensor->data(); - T* k_ptr = nullptr; - T* v_ptr = nullptr; - - if (cache_kv_tensor) { - int64_t k_size = cache_kv_out_tensor->numel() / 2; - k_ptr = cache_kv_out_tensor->data(); - v_ptr = k_ptr + k_size; - } else { - int64_t k_size = q_size; - k_ptr = kv_transpose_out_tensor->data(); - v_ptr = k_ptr + k_size; - } - - { - // NOTE(wangxi): We scale Q with 1/sqrt(Dh) before QK^T, because for - // float16 calculation, INF may appear in QK^T if we do not scale before. - float alpha = 1.0 / sqrt(head_dim_); - auto functor = phi::funcs::ScaleFunctor(alpha); - std::vector ins = {q_transpose_out_tensor}; - std::vector outs = {q_transpose_out_tensor}; - phi::funcs::ElementwiseKernel(dev_ctx_, ins, &outs, functor); - } - - // q*k^t, batched_gemm - CBLAS_TRANSPOSE transA = CblasNoTrans; - CBLAS_TRANSPOSE transB = CblasTrans; - auto blas = phi::funcs::GetBlas(dev_ctx_); - int gemm_batch_size = batch_size_ * num_head_; - int gemm_m = seq_len_; - int gemm_n = out_seq_len; - int gemm_k = head_dim_; - T alpha = static_cast(1.0); - T beta = static_cast(0.0); - int64_t stride_a = gemm_m * gemm_k; - int64_t stride_b = gemm_k * gemm_n; - blas.BatchedGEMM(transA, - transB, - gemm_m, - gemm_n, - gemm_k, - alpha, - q_ptr, - k_ptr, - beta, - qk_out_data, - gemm_batch_size, - stride_a, - stride_b); - int softmax_axis = -1; - if (src_mask_tensor != nullptr) { - if (src_mask_out_tensor == nullptr && seq_len_ == out_seq_len) { - LaunchFusedSoftmaxMaskKernel(qk_out_data, - src_mask_tensor->data(), - softmax_out_data, - batch_size_, - num_head_, - seq_len_, - dev_ctx_.stream()); - } else { - std::vector ins; - std::vector outs; - ins.emplace_back(qk_out_tensor); - ins.emplace_back(src_mask_tensor); - outs.emplace_back(src_mask_out_tensor); - int elewise_add_axis = -1; - phi::funcs::BroadcastKernel(dev_ctx_, - ins, - &outs, - phi::funcs::AddFunctor(), - elewise_add_axis); - - phi::SoftmaxForwardCUDAKernelDriver( - dev_ctx_, *src_mask_out_tensor, softmax_axis, softmax_out_tensor); - } - } else { - phi::SoftmaxForwardCUDAKernelDriver( - dev_ctx_, *qk_out_tensor, softmax_axis, softmax_out_tensor); - } - - transB = CblasNoTrans; - gemm_m = seq_len_; - gemm_n = head_dim_; - gemm_k = out_seq_len; - alpha = static_cast(1.0); - stride_a = gemm_m * gemm_k; - stride_b = gemm_k * gemm_n; - - if (dropout_param_.dropout_prob_) { - phi::funcs::DropoutFwGPUKernelDriver( - static_cast(dev_ctx_), - dropout_param_.is_test_, - dropout_param_.dropout_prob_, - dropout_param_.is_upscale_in_train_, - dropout_param_.is_fix_seed_, - dropout_param_.seed_val_, - static_cast(*softmax_out_tensor), - dropout_param_.seed_, - dropout_mask_out_tensor, - dropout_out_tensor, - false); - blas.BatchedGEMM(transA, - transB, - gemm_m, - gemm_n, - gemm_k, - alpha, - dropout_out_data, - v_ptr, - beta, - qktv_out_data, - gemm_batch_size, - stride_a, - stride_b); - } else { - // softmax_out * v, batched_gemm - // output shape: [batch_size, num_heads, seq_len, head_dim] - blas.BatchedGEMM(transA, - transB, - gemm_m, - gemm_n, - gemm_k, - alpha, - softmax_out_data, - v_ptr, - beta, - qktv_out_data, - gemm_batch_size, - stride_a, - stride_b); - } - // transpose: [0, 2, 1, 3] - // output shape: [batch_size, seq_len, num_heads, head_dim] - if (!padding_offset_tensor) { - std::vector perm_3 = {0, 2, 1, 3}; - phi::funcs::TransposeGPUKernelDriver( - dev_ctx_, *qktv_out_tensor, perm_3, fmha_out_tensor); - } else { - InvokeTransposeRemovePadding(dev_ctx_, - qktv_out_data, - fmha_out_data, - batch_size_, - num_head_, - seq_len_, - head_dim_, - token_num, - padding_offset_tensor->data()); - } - } - - void ComputeBackward(const phi::DenseTensor& transpose_2_out_tensor, - const phi::DenseTensor* src_mask_tensor, - const phi::DenseTensor& softmax_out_tensor, - const phi::DenseTensor& dropout_mask_out_tensor, - const phi::DenseTensor& dropout_out_tensor, - const phi::DenseTensor& qk_out_tensor, - const phi::DenseTensor& src_mask_out_tensor, - const phi::DenseTensor& fmha_out_grad_tensor, - phi::DenseTensor* qktv_out_grad_tensor, - phi::DenseTensor* dropout_out_grad_tensor, - phi::DenseTensor* softmax_out_grad_tensor, - phi::DenseTensor* src_mask_out_grad_tensor, - phi::DenseTensor* qk_out_grad_tensor, - phi::DenseTensor* transpose_2_out_grad_tensor, - phi::DenseTensor* src_mask_grad_tensor, - phi::DenseTensor* qkv_input_grad_tensor) { - auto blas = phi::funcs::GetBlas(dev_ctx_); - int q_size = batch_size_ * seq_len_ * num_head_ * head_dim_; - int k_size = q_size; - int softmax_axis = -1; - - T* qkv_grad_data = transpose_2_out_grad_tensor->data(); - T* q_grad_ptr = qkv_grad_data; - T* k_grad_ptr = q_grad_ptr + q_size; - T* v_grad_ptr = k_grad_ptr + k_size; - const T* qkv_data = transpose_2_out_tensor.data(); - const T* q_ptr = qkv_data; - const T* k_ptr = q_ptr + q_size; - const T* v_ptr = k_ptr + k_size; - - const T* softmax_out_data = softmax_out_tensor.data(); - T* softmax_out_grad_data = softmax_out_grad_tensor->data(); - T* qktv_out_grad_data = qktv_out_grad_tensor->data(); - - // transpose bw - std::vector perm_3 = {0, 2, 1, 3}; - phi::funcs::TransposeGPUKernelDriver( - dev_ctx_, fmha_out_grad_tensor, perm_3, qktv_out_grad_tensor); - - // recall batchedgemm(nn) fw: softmax_out_data(x) * v_ptr(y) = - // qktv_out_data(out) - CBLAS_TRANSPOSE transA = CblasTrans; - CBLAS_TRANSPOSE transB = CblasNoTrans; - int gemm_batch_size = batch_size_ * num_head_; - int gemm_m = seq_len_; - int gemm_n = head_dim_; - int gemm_k = seq_len_; - T alpha = static_cast(1.0); - T beta = static_cast(0.0); - int64_t stride_a = gemm_m * gemm_k; - int64_t stride_b = gemm_k * gemm_n; - // bw: dy = x^t * dout - if (dropout_param_.dropout_prob_) { - const T* dropout_out_data = dropout_out_tensor.data(); - blas.BatchedGEMM(transA, - transB, - gemm_m, - gemm_n, - gemm_k, - alpha, - dropout_out_data, - qktv_out_grad_data, - beta, - v_grad_ptr, - gemm_batch_size, - stride_a, - stride_b); - } else { - blas.BatchedGEMM(transA, - transB, - gemm_m, - gemm_n, - gemm_k, - alpha, - softmax_out_data, - qktv_out_grad_data, - beta, - v_grad_ptr, - gemm_batch_size, - stride_a, - stride_b); - } - // bw: dx = dout * y^t - transA = CblasNoTrans; - transB = CblasTrans; - gemm_m = seq_len_; - gemm_n = seq_len_; - gemm_k = head_dim_; - stride_a = gemm_m * gemm_k; - stride_b = gemm_k * gemm_n; - if (dropout_param_.dropout_prob_) { - T* dropout_out_grad_data = dropout_out_grad_tensor->data(); - blas.BatchedGEMM(transA, - transB, - gemm_m, - gemm_n, - gemm_k, - alpha, - qktv_out_grad_data, - v_ptr, - beta, - dropout_out_grad_data, - gemm_batch_size, - stride_a, - stride_b); - } else { - blas.BatchedGEMM(transA, - transB, - gemm_m, - gemm_n, - gemm_k, - alpha, - qktv_out_grad_data, - v_ptr, - beta, - softmax_out_grad_data, - gemm_batch_size, - stride_a, - stride_b); - } - // dropout bw - if (dropout_param_.dropout_prob_) { - phi::funcs::DropoutGradGPUKernelDriver( - static_cast(dev_ctx_), - false, - dropout_param_.dropout_prob_, - dropout_param_.is_upscale_in_train_, - static_cast(*dropout_out_grad_tensor), - dropout_mask_out_tensor, - softmax_out_grad_tensor, - false); - } - - if (src_mask_tensor != nullptr) { - phi::SoftmaxBackwardCUDAKernelDriver(dev_ctx_, - softmax_out_tensor, - *softmax_out_grad_tensor, - softmax_axis, - src_mask_out_grad_tensor); - // recall LaunchElementwiseCudaKernel fw: src_mask_out = qk_out + - // src_mask - // Special case when dy is not needed and dx doesn't reduce - if (qk_out_grad_tensor != nullptr && src_mask_grad_tensor == nullptr && - qk_out_tensor.dims() == src_mask_out_tensor.dims()) { - VLOG(4) << "Special case when dy is not needed and dx doesn't " - "reduce"; - framework::TensorCopy(*src_mask_out_grad_tensor, - dev_ctx_.GetPlace(), - dev_ctx_, - qk_out_grad_tensor); - } else { - PADDLE_THROW(phi::errors::InvalidArgument( - "Only used for the backward elementwise_add op when" - "dy is not needed and dx is not reduce")); - return; - } - - } else { - phi::SoftmaxBackwardCUDAKernelDriver(dev_ctx_, - softmax_out_tensor, - *softmax_out_grad_tensor, - softmax_axis, - qk_out_grad_tensor); - } - - T* qk_out_grad_data = qk_out_grad_tensor->data(); - // NOTE(wangxi): For we scale Q with 1/sqrt(Dh) in forward, so we set - // alpha = 1.0 in backward. - alpha = static_cast(1.0); - // recall batchedgemm(nt) fw: q_ptr * (k_ptr)^t = qk_out - // bw: dy (seq_len * head_dim) = (dout)^t * x - transA = CblasTrans; - transB = CblasNoTrans; - gemm_m = seq_len_; - gemm_n = head_dim_; - gemm_k = seq_len_; - stride_a = gemm_m * gemm_k; - stride_b = gemm_k * gemm_n; - blas.BatchedGEMM(transA, - transB, - gemm_m, - gemm_n, - gemm_k, - alpha, - qk_out_grad_data, - q_ptr, - beta, - k_grad_ptr, - gemm_batch_size, - stride_a, - stride_b); - // dx (seq_len * head_dim) = dout * y - alpha = static_cast(1.0 / sqrt(head_dim_)); - transA = CblasNoTrans; - transB = CblasNoTrans; - gemm_m = seq_len_; - gemm_n = head_dim_; - gemm_k = seq_len_; - stride_a = gemm_m * gemm_k; - stride_b = gemm_k * gemm_n; - blas.BatchedGEMM(transA, - transB, - gemm_m, - gemm_n, - gemm_k, - alpha, - qk_out_grad_data, - k_ptr, - beta, - q_grad_ptr, - gemm_batch_size, - stride_a, - stride_b); - - // transpose bw - std::vector perm_1 = {1, 3, 0, 2, 4}; - phi::funcs::TransposeGPUKernelDriver( - dev_ctx_, *transpose_2_out_grad_tensor, perm_1, qkv_input_grad_tensor); - } - - private: - const phi::GPUContext& dev_ctx_; - - int64_t batch_size_; - int64_t seq_len_; - int64_t num_head_; - int64_t head_dim_; - - AttnDropoutParam dropout_param_; -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cu b/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cu index b696a183170c33..11614d70165d3a 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cu +++ b/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cu @@ -61,8 +61,8 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { auto ln_scales = ctx.MultiInput("LnScale"); auto ln_biases = ctx.MultiInput("LnBias"); - auto ln_compute = - AttnLayerNorm(dev_ctx, epsilon, bsz_seq, dim_embed); + auto ln_compute = phi::fusion::AttnLayerNorm( + dev_ctx, epsilon, bsz_seq, dim_embed); phi::DenseTensor ln_mean, ln_var; ln_mean.Resize({{bsz_seq}}); auto *ln_mean_data = @@ -93,10 +93,10 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { dev_ctx.Alloc(&qkv_out, qkv_out.numel() * sizeof(T)); // 3. fmha - AttnDropoutParam attn_param( + phi::fusion::AttnDropoutParam attn_param( true, "upscale_in_train", 0.0, true, true, 0, nullptr); - auto fmha_compute = - FMHARef(dev_ctx, bsz, seq_len, num_head, dim_head, attn_param); + auto fmha_compute = phi::fusion::FMHARef( + dev_ctx, bsz, seq_len, num_head, dim_head, attn_param); auto *src_mask = ctx.Input("SrcMask"); auto cache_kvs = ctx.MultiInput("CacheKV"); auto cache_kv_outs = ctx.MultiOutput("CacheKVOut"); diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h index 8590738297edf7..0a57fb9e873414 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h +++ b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h @@ -27,8 +27,6 @@ limitations under the License. */ #include "paddle/common/flags.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/operators/fused/attention_layer_norm.h" -#include "paddle/fluid/operators/fused/fmha_ref.h" #include "paddle/fluid/operators/fused/fused_dropout_helper.h" #include "paddle/phi/api/include/tensor.h" #include "paddle/phi/backends/dynload/cublasLt.h" @@ -38,6 +36,7 @@ limitations under the License. */ #include "paddle/phi/kernels/funcs/fused_gemm_epilogue.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/fusion/gpu/attn_gemm.h" +#include "paddle/phi/kernels/fusion/gpu/fmha_ref.h" #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/distributed/collective/process_group.h" @@ -711,13 +710,13 @@ struct Qk_dot { } }; -template +template inline __device__ float block_sum(float *red_smem, float sum) { - int warp = threadIdx.x / WARP_SIZE; - int lane = threadIdx.x % WARP_SIZE; + int warp = threadIdx.x / WARP_SIZE_T; + int lane = threadIdx.x % WARP_SIZE_T; #pragma unroll - for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + for (int mask = WARP_SIZE_T / 2; mask >= 1; mask /= 2) { sum += __shfl_xor_sync(uint32_t(-1), sum, mask); } @@ -789,8 +788,8 @@ __global__ void masked_multihead_attention_kernel( static_assert(Dh_MAX % THREADS_PER_KEY == 0, ""); static_assert(Dh_MAX % THREADS_PER_VALUE == 0, ""); - constexpr int WARP_SIZE = 32; - constexpr int WARPS_PER_BLOCK = THREADS_PER_BLOCK / WARP_SIZE; + constexpr int WARP_SIZE_TMP = 32; + constexpr int WARPS_PER_BLOCK = THREADS_PER_BLOCK / WARP_SIZE_TMP; extern __shared__ char smem_[]; @@ -824,7 +823,7 @@ __global__ void masked_multihead_attention_kernel( constexpr int QK_VEC_SIZE = sizeof(Qk_vec) / sizeof(T); static_assert(Dh_MAX % QK_VEC_SIZE == 0, ""); // Use block reduction if needed - // static_assert(Dh_MAX / QK_VEC_SIZE <= WARP_SIZE, ""); + // static_assert(Dh_MAX / QK_VEC_SIZE <= WARP_SIZE_TMP, ""); constexpr int QK_VECS_PER_WARP = Dh_MAX / QK_VEC_SIZE; // cache_k, [B, num_head, head_dim / x, max_seq_len, x] @@ -944,16 +943,16 @@ __global__ void masked_multihead_attention_kernel( qk = dot(q, k); - if (QK_VECS_PER_WARP <= WARP_SIZE) { + if (QK_VECS_PER_WARP <= WARP_SIZE_TMP) { #pragma unroll for (int mask = QK_VECS_PER_WARP / 2; mask >= 1; mask /= 2) { qk += __shfl_xor_sync(shfl_mask(QK_VECS_PER_WARP), qk, mask); } } } - if (QK_VECS_PER_WARP > WARP_SIZE) { + if (QK_VECS_PER_WARP > WARP_SIZE_TMP) { constexpr int WARPS_PER_RED = - (QK_VECS_PER_WARP + WARP_SIZE - 1) / WARP_SIZE; + (QK_VECS_PER_WARP + WARP_SIZE_TMP - 1) / WARP_SIZE_TMP; qk = block_sum(&red_smem[WARPS_PER_RED], qk); } if (tid == 0) { @@ -994,7 +993,7 @@ __global__ void masked_multihead_attention_kernel( } constexpr int K_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_KEY; - constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; + constexpr int K_PER_WARP = WARP_SIZE_TMP / THREADS_PER_KEY; T *k_cache = ¶ms.cache_kv[bhi * params.max_seq_length * Dh + ki]; int ti_end = div_up(act_time_step, K_PER_WARP) * K_PER_WARP; @@ -1031,12 +1030,12 @@ __global__ void masked_multihead_attention_kernel( } #pragma unroll - for (int mask = WARP_SIZE / 2; mask >= THREADS_PER_KEY; mask /= 2) { + for (int mask = WARP_SIZE_TMP / 2; mask >= THREADS_PER_KEY; mask /= 2) { qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); } - const int warp = tid / WARP_SIZE; - const int lane = tid % WARP_SIZE; + const int warp = tid / WARP_SIZE_TMP; + const int lane = tid % WARP_SIZE_TMP; if (lane == 0) { red_smem[warp] = qk_max; diff --git a/paddle/fluid/operators/fused/xpu_fused_common_function.h b/paddle/fluid/operators/fused/xpu_fused_common_function.h deleted file mode 100644 index 63a22838e8c35e..00000000000000 --- a/paddle/fluid/operators/fused/xpu_fused_common_function.h +++ /dev/null @@ -1,225 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#ifdef PADDLE_WITH_XPU -#include "paddle/fluid/platform/device/device_wrapper.h" - -namespace paddle { -namespace operators { - -struct XPUDropoutParam { - float dropout_prob; - bool is_upscale_in_train; - bool is_test; - bool fix_seed; - const phi::DenseTensor *tensor_seed; - int seed_val; - - XPUDropoutParam() { - fix_seed = false; - is_test = false; - is_upscale_in_train = false; - dropout_prob = 0.5; - tensor_seed = nullptr; - seed_val = 0; - } - - XPUDropoutParam(const framework::ExecutionContext &context, - const int dropout_index) { - std::string pre_fix = "dropout"; - std::string str_index = std::to_string(dropout_index); - if (dropout_index > 0) { - pre_fix = pre_fix + str_index + "_"; - } else { - pre_fix = pre_fix + "_"; - } - dropout_prob = context.Attr(pre_fix + "rate"); - auto &dropout_implementation = - context.Attr(pre_fix + "implementation"); - is_upscale_in_train = (dropout_implementation == "upscale_in_train"); - is_test = context.Attr("is_test"); - fix_seed = context.Attr(pre_fix + "fix_seed"); - - std::string str_seed = "Dropout"; - if (dropout_index > 0) { - str_seed = str_seed + str_index + "Seed"; - } else { - str_seed = str_seed + "Seed"; - } - - tensor_seed = context.HasInput(str_seed) - ? context.Input(str_seed) - : nullptr; - if (tensor_seed) { - seed_val = *(tensor_seed->data()); - } else { - seed_val = fix_seed ? context.Attr(pre_fix + "seed") : 0; - } - } - - void initXPUDropoutParam(float dropout_prob_, - bool is_upscale_in_train_, - bool is_test_, - bool fix_seed_, - const phi::DenseTensor *tensor_seed, - int seed_val_) { - dropout_prob = dropout_prob_; - is_upscale_in_train = is_upscale_in_train_; - is_test = is_test_; - fix_seed = fix_seed_; - if (tensor_seed) { - seed_val = *(tensor_seed->data()); - } else { - seed_val = fix_seed ? seed_val_ : 0; - } - } - - void initXPUDropoutParam(const framework::ExecutionContext &context, - int dropout_index) { - std::string pre_fix = "dropout"; - std::string str_index = std::to_string(dropout_index); - if (dropout_index > 0) { - pre_fix = pre_fix + str_index + "_"; - } else { - pre_fix = pre_fix + "_"; - } - dropout_prob = context.Attr(pre_fix + "rate"); - auto &dropout_implementation = - context.Attr(pre_fix + "implementation"); - is_upscale_in_train = (dropout_implementation == "upscale_in_train"); - is_test = context.Attr("is_test"); - fix_seed = context.Attr(pre_fix + "fix_seed"); - std::string str_seed = "Dropout"; - if (dropout_index > 0) { - str_seed = str_seed + str_index + "Seed"; - } else { - str_seed = str_seed + "Seed"; - } - tensor_seed = context.HasInput(str_seed) - ? context.Input(str_seed) - : nullptr; - - if (tensor_seed) { - seed_val = *(tensor_seed->data()); - } else { - seed_val = fix_seed ? context.Attr(pre_fix + "seed") : 0; - } - } -}; - -/****************** - * check is l3 - *******************/ - -static bool is_in_l3(const void *addr) { - int64_t addr_int = (int64_t)addr; - int addr_int_high = addr_int >> 32; - return (addr_int_high == 0); -} - -/************************* - * dropout - *************************/ - -template -void Dropout(xpu::Context *xpu_ctx, - const T *x, - T *mask, - T *y, - const XPUDropoutParam ¶m, - int len) { - using XPUType = typename XPUTypeTrait::Type; - int r = XPU_SUCCESS; - if (param.dropout_prob == 0.0f) { - r = xpu::copy(xpu_ctx, - reinterpret_cast(x), - reinterpret_cast(y), - len); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy"); - return; - } - if (!param.is_test) { - if (param.dropout_prob == 1.0f) { - r = xpu::constant( - xpu_ctx, reinterpret_cast(y), len, XPUType(0)); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant"); - r = xpu::constant( - xpu_ctx, reinterpret_cast(mask), len, XPUType(0)); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant"); - } else { - r = xpu::dropout(xpu_ctx, - reinterpret_cast(x), - reinterpret_cast(y), - reinterpret_cast(mask), - param.seed_val, - len, - param.is_upscale_in_train, - param.dropout_prob); - - PADDLE_ENFORCE_XDNN_SUCCESS(r, "dropout"); - } - } else { - float scale = (param.is_upscale_in_train) - ? (1.0) - : (static_cast(1.0f - param.dropout_prob)); - r = xpu::scale(xpu_ctx, - reinterpret_cast(x), - reinterpret_cast(y), - len, - false, - scale, - 0.0f); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale"); - } -} - -template -void DropoutGrad(xpu::Context *xpu_ctx, - const T *dy, - const T *mask, - T *dx, - const XPUDropoutParam ¶m, - int len) { - using XPUType = typename XPUTypeTrait::Type; - if (param.dropout_prob == 0.0f) { - int r = xpu::copy(xpu_ctx, - reinterpret_cast(dy), - reinterpret_cast(dx), - len); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy"); - return; - } - if (!param.is_upscale_in_train) { - int r = xpu::mul(xpu_ctx, - reinterpret_cast(dy), - reinterpret_cast(mask), - reinterpret_cast(dx), - len); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "mul"); - } else { - int r = xpu::dropout_grad(xpu_ctx, - reinterpret_cast(mask), - reinterpret_cast(dy), - reinterpret_cast(dx), - param.dropout_prob, - len); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "dropout_grad"); - } -} - -} // namespace operators -} // namespace paddle -#endif