diff --git a/paddle/fluid/pybind/inference_api.cc b/paddle/fluid/pybind/inference_api.cc index 9975c0c01e0d3..03a95e870b810 100644 --- a/paddle/fluid/pybind/inference_api.cc +++ b/paddle/fluid/pybind/inference_api.cc @@ -345,10 +345,15 @@ void PaddleTensorShareExternalData(paddle_infer::Tensor &tensor, // NOLINT static_cast(paddle_tensor.data()), shape, ToPaddleInferPlace(paddle_tensor.place().GetType())); + } else if (paddle_tensor.dtype() == phi::DataType::UINT8) { + tensor.ShareExternalData( + static_cast(paddle_tensor.data()), + shape, + ToPaddleInferPlace(paddle_tensor.place().GetType())); } else { PADDLE_THROW(platform::errors::Unimplemented( "Unsupported data type. Now share_external_data only supports INT32, " - "INT64, FLOAT32, FLOAT16, BFLOAT16 and BOOL.")); + "INT64, UINT8, FLOAT32, FLOAT16, BFLOAT16 and BOOL.")); } } diff --git a/paddle/phi/api/yaml/fused_ops.yaml b/paddle/phi/api/yaml/fused_ops.yaml index 7e5975b42894e..ebea9f25f85c4 100644 --- a/paddle/phi/api/yaml/fused_ops.yaml +++ b/paddle/phi/api/yaml/fused_ops.yaml @@ -33,14 +33,14 @@ data_type : x - op : block_multihead_attention_ - args : (Tensor qkv, Tensor key_cache, Tensor value_cache, Tensor seq_lens_encoder, Tensor seq_lens_decoder, Tensor seq_lens_this_time, Tensor padding_offsets, Tensor cum_offsets, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor block_tables, Tensor pre_key_cache, Tensor pre_value_cache, Tensor rope_emb, Tensor mask, Tensor tgt_mask, int max_seq_len, int block_size, bool use_neox_style) + args : (Tensor qkv, Tensor key_cache, Tensor value_cache, Tensor seq_lens_encoder, Tensor seq_lens_decoder, Tensor seq_lens_this_time, Tensor padding_offsets, Tensor cum_offsets, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor block_tables, Tensor pre_key_cache, Tensor pre_value_cache, Tensor rope_emb, Tensor mask, Tensor tgt_mask, Tensor cache_k_quant_scales, Tensor cache_v_quant_scales, Tensor cache_k_dequant_scales, Tensor cache_v_dequant_scales, Tensor qkv_out_scale, Tensor qkv_bias, Tensor out_shift, Tensor out_smooth, int max_seq_len, int block_size, bool use_neox_style, bool dynamic_cachekv_quant=false, int quant_round_type=1, float quant_max_bound=127.0, float quant_min_bound=-127.0, float out_scale=-1, str compute_dtype = "default") output : Tensor(fmha_out), Tensor(qkv_out), Tensor(key_cache_out), Tensor(value_cache_out) infer_meta : func : BlockMultiheadAttentionInferMeta kernel : func : block_multihead_attention data_type : qkv - optional : pre_key_cache, pre_value_cache, rope_emb, mask, tgt_mask + optional : pre_key_cache, pre_value_cache, rope_emb, mask, tgt_mask, cache_k_quant_scales, cache_v_quant_scales, cache_k_dequant_scales, cache_v_dequant_scales, qkv_out_scale, qkv_bias, out_shift, out_smooth inplace : (qkv -> qkv_out), (key_cache -> key_cache_out), (value_cache -> value_cache_out) support_dygraph_mode : true diff --git a/paddle/phi/infermeta/fusion.cc b/paddle/phi/infermeta/fusion.cc index 7cce4584ac78d..b9712bfa3e210 100644 --- a/paddle/phi/infermeta/fusion.cc +++ b/paddle/phi/infermeta/fusion.cc @@ -132,9 +132,23 @@ void BlockMultiheadAttentionInferMeta(const MetaTensor& qkv, const MetaTensor& rope_emb, const MetaTensor& mask, const MetaTensor& tgt_mask, + const MetaTensor& cache_k_quant_scales, + const MetaTensor& cache_v_quant_scales, + const MetaTensor& cache_k_dequant_scales, + const MetaTensor& cache_v_dequant_scales, + const MetaTensor& qkv_out_scale, + const MetaTensor& qkv_bias, + const MetaTensor& out_shift, + const MetaTensor& out_smooth, int max_seq_len, int block_size, bool use_neox_style, + bool dynamic_cachekv_quant, + const int quant_round_type, + const float quant_max_bound, + const float quant_min_bound, + const float out_scale, + const std::string& compute_dtype, MetaTensor* fmha_out, MetaTensor* qkv_out, MetaTensor* key_cache_out, @@ -159,13 +173,74 @@ void BlockMultiheadAttentionInferMeta(const MetaTensor& qkv, "The input_dims[1] must be equal to 3 * num_head * dim_head")); fmha_out->set_dims({input_dims[0], num_head * dim_head}); - fmha_out->set_dtype(qkv.dtype()); qkv_out->set_dims(qkv.dims()); - qkv_out->set_dtype(qkv.dtype()); key_cache_out->set_dims(key_cache_dims); key_cache_out->set_dtype(key_cache.dtype()); value_cache_out->set_dims(key_cache_dims); value_cache_out->set_dtype(value_cache.dtype()); + + auto FBADtypeCheck = [](const MetaTensor& check_tensor, + const std::string& tensor_name, + const std::string& compute_dtype) { + if (compute_dtype == "bf16") { + PADDLE_ENFORCE_EQ( + check_tensor.dtype(), + phi::DataType::BFLOAT16, + phi::errors::InvalidArgument( + "Input(%s) dtype must be the same with Attr(compute_dtype)", + tensor_name)); + } else if (compute_dtype == "fp16") { + PADDLE_ENFORCE_EQ( + check_tensor.dtype(), + phi::DataType::FLOAT16, + phi::errors::InvalidArgument( + "Input(%s) dtype must be the same with Attr(compute_dtype)", + tensor_name)); + } else if (compute_dtype == "fp32") { + PADDLE_ENFORCE_EQ( + check_tensor.dtype(), + phi::DataType::FLOAT32, + phi::errors::InvalidArgument( + "Input(%s) dtype must be the same with Attr(compute_dtype)", + tensor_name)); + } + }; + + // In the case of quantization enabled, the dtype for computation is + // determined based on compute_dtype. + if (qkv.dtype() == phi::DataType::INT32) { + PADDLE_ENFORCE_NE( + compute_dtype, + "default", + phi::errors::InvalidArgument( + "If Input(x) dtype is INT32, Attr(compute_dtype) must be set.")); + if (out_scale > 0) { + fmha_out->set_dtype(phi::DataType::INT8); + } else { + if (compute_dtype == "bf16") { + fmha_out->set_dtype(phi::DataType::BFLOAT16); + } else if (compute_dtype == "fp16") { + fmha_out->set_dtype(phi::DataType::FLOAT16); + } else if (compute_dtype == "fp32") { + fmha_out->set_dtype(phi::DataType::FLOAT32); + } else { + PADDLE_THROW(phi::errors::InvalidArgument( + "In the case of quantization enabled with Input(x) INT32, " + "Attr(compute_dtype) must be set in (bf16, fp16, fp32), " + "but get compute_dtype (%s)", + compute_dtype)); + } + } + } else { + if (compute_dtype != "default") { + FBADtypeCheck(qkv, "qkv", compute_dtype); + } + if (out_scale > 0) { + fmha_out->set_dtype(phi::DataType::INT8); + } else { + fmha_out->set_dtype(qkv.dtype()); + } + } } void Conv1dXPUInferMeta(const MetaTensor& x, diff --git a/paddle/phi/infermeta/fusion.h b/paddle/phi/infermeta/fusion.h index 6fa9c5baab384..fe03331e570d8 100644 --- a/paddle/phi/infermeta/fusion.h +++ b/paddle/phi/infermeta/fusion.h @@ -54,9 +54,23 @@ void BlockMultiheadAttentionInferMeta(const MetaTensor& qkv, const MetaTensor& rope_emb, const MetaTensor& mask, const MetaTensor& tgt_mask, + const MetaTensor& cache_k_quant_scales, + const MetaTensor& cache_v_quant_scales, + const MetaTensor& cache_k_dequant_scales, + const MetaTensor& cache_v_dequant_scales, + const MetaTensor& qkv_out_scale, + const MetaTensor& qkv_bias, + const MetaTensor& out_shift, + const MetaTensor& out_smooth, int max_seq_len, int block_size, bool use_neox_style, + bool dynamic_cachekv_quant, + const int quant_round_type, + const float quant_max_bound, + const float quant_min_bound, + const float out_scale, + const std::string& compute_dtype, MetaTensor* fmha_out, MetaTensor* qkv_out, MetaTensor* key_cache_out, diff --git a/paddle/phi/kernels/fusion/gpu/block_attn.h b/paddle/phi/kernels/fusion/gpu/block_attn.h index 55cb234fe14b0..73be0901c6f36 100644 --- a/paddle/phi/kernels/fusion/gpu/block_attn.h +++ b/paddle/phi/kernels/fusion/gpu/block_attn.h @@ -43,6 +43,9 @@ struct Block_AttN_params { T *k_cache; T *v_cache; + uint8_t *k_cache_I; + uint8_t *v_cache_I; + const int *block_tables; const int *sequence_lengths{nullptr}; @@ -66,6 +69,11 @@ struct Block_AttN_params { bool add_qkv_bias; bool neox_rotary_style; + + const float *cache_k_quant_scales = nullptr; + const float *cache_v_quant_scales = nullptr; + const float *cache_k_dequant_scales = nullptr; + const float *cache_v_dequant_scales = nullptr; }; template __global__ __launch_bounds__(THREADS_PER_BLOCK) void block_attention_kernel( @@ -114,11 +123,30 @@ __global__ __launch_bounds__(THREADS_PER_BLOCK) void block_attention_kernel( __shared__ float red_smem[WARPS_PER_BLOCK * 2]; using Qk_vec = typename Qk_vec_::Type; using Qk_vec_RoPE = typename Qk_vec_RoPE_::Type; + using QK_Packed_Int8_t = + typename packed_type::value>::type; __shared__ __align__(sizeof(Qk_vec)) T q_smem[Dh_MAX]; const int tid = threadIdx.x; const int hi = blockIdx.x; + float k_quant_scale; + float v_quant_scale; + float k_dequant_scale; + float v_dequant_scale; + + if (USE_CACHE_INT8 == 1) { // static + k_quant_scale = params.cache_k_quant_scales[hi]; + v_quant_scale = params.cache_v_quant_scales[hi]; + k_dequant_scale = params.cache_k_dequant_scales[hi]; + v_dequant_scale = params.cache_v_dequant_scales[hi]; + } else if (USE_CACHE_INT8 == 2) { // dynamic + k_quant_scale = params.cache_k_quant_scales[bi * params.num_head + hi]; + v_quant_scale = params.cache_v_quant_scales[bi * params.num_head + hi]; + k_dequant_scale = params.cache_k_dequant_scales[bi * params.num_head + hi]; + v_dequant_scale = params.cache_v_dequant_scales[bi * params.num_head + hi]; + } + const int bhi = bi * params.num_head + hi; const int ti = params.cum_offsets ? bi * params.seq_len - params.cum_offsets[bi] : -1; @@ -258,7 +286,14 @@ __global__ __launch_bounds__(THREADS_PER_BLOCK) void block_attention_kernel( if (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) { const int offset = base_cache_offset + tid * QK_VEC_SIZE; - *reinterpret_cast(¶ms.k_cache[offset]) = k; + if (!USE_CACHE_INT8) { + *reinterpret_cast(¶ms.k_cache[offset]) = k; + } else { + QK_Packed_Int8_t k_tmp = round_tmp( + mul(k_quant_scale, k)); + *reinterpret_cast(¶ms.k_cache_I[offset]) = + k_tmp; + } } qk = dot(q, k); @@ -283,6 +318,7 @@ __global__ __launch_bounds__(THREADS_PER_BLOCK) void block_attention_kernel( __syncthreads(); using K_vec = typename K_vec_::Type; + using K_vec_I = typename K_vec_I_::Type; constexpr int K_VEC_SIZE = sizeof(K_vec) / sizeof(T); static_assert(Dh_MAX % K_VEC_SIZE == 0, ""); constexpr int K_ELTS_PER_THREAD = Dh_MAX / THREADS_PER_KEY; @@ -317,11 +353,19 @@ __global__ __launch_bounds__(THREADS_PER_BLOCK) void block_attention_kernel( #pragma unroll for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { if (ti < act_time_step) { - k[ii] = (Dh == Dh_MAX || ii * THREADS_PER_KEY * K_VEC_SIZE < Dh) - ? *reinterpret_cast( - params.k_cache + k_offset + - ii * THREADS_PER_KEY * K_VEC_SIZE) - : k_vec_zero; + if (!USE_CACHE_INT8) { + k[ii] = (Dh == Dh_MAX || ii * THREADS_PER_KEY * K_VEC_SIZE < Dh) + ? *reinterpret_cast( + params.k_cache + k_offset + + ii * THREADS_PER_KEY * K_VEC_SIZE) + : k_vec_zero; + } else { + mul_pointer_v2( + &k[ii], + k_dequant_scale, + reinterpret_cast(params.k_cache_I + k_offset + + ii * THREADS_PER_KEY * K_VEC_SIZE)); + } } else { k[ii] = k_vec_zero; } @@ -375,6 +419,8 @@ __global__ __launch_bounds__(THREADS_PER_BLOCK) void block_attention_kernel( constexpr int V_VEC_SIZE = Dh_MAX / THREADS_PER_VALUE; using V_vec = typename V_vec_::Type; + using V_Packed_Int8_t = + typename packed_type::value>::type; int vo = tid / THREADS_PER_VALUE; int vi = (tid % THREADS_PER_VALUE) * V_VEC_SIZE; @@ -397,7 +443,14 @@ __global__ __launch_bounds__(THREADS_PER_BLOCK) void block_attention_kernel( physical_block_number * params.num_head * BLOCK_SIZE * Dh + hi * BLOCK_SIZE * Dh + block_offset * Dh + vi; V_vec v; - v = *reinterpret_cast(params.v_cache + v_offset); + if (!USE_CACHE_INT8) { + v = *reinterpret_cast(params.v_cache + v_offset); + } else { + mul_pointer_v2( + &v, + v_dequant_scale, + reinterpret_cast(params.v_cache_I + v_offset)); + } #if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) float logit = logits_smem[ti]; out = fma(logit, cast_to_float(v), out); @@ -423,7 +476,14 @@ __global__ __launch_bounds__(THREADS_PER_BLOCK) void block_attention_kernel( v = add(v, v_bias); } - *reinterpret_cast(params.v_cache + base_cache_offset + vi) = v; + if (!USE_CACHE_INT8) { + *reinterpret_cast(params.v_cache + base_cache_offset + vi) = v; + } else { + V_Packed_Int8_t v_tmp = round_tmp( + mul(v_quant_scale, v)); + *reinterpret_cast(params.v_cache_I + + base_cache_offset + vi) = v_tmp; + } #if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) out = fma(logits_smem[act_time_step], cast_to_float(v), out); @@ -487,40 +547,83 @@ inline size_t smem_size_in_bytes(const Block_AttN_params ¶ms, return max(logits_table_sz, red_sz); } -#define BLHA_LAUNCH_KERNEL(T, \ - Dh, \ - Dh_MAX, \ - THDS_PER_KEY, \ - THDS_PER_VALUE, \ - THDS_PER_BLOCK, \ - BLOCK_SIZE, \ - stream, \ - load_func, \ - store_func) \ - size_t smem_sz = \ - smem_size_in_bytes(params, Dh, THDS_PER_VALUE, THDS_PER_BLOCK); \ - constexpr auto kernel_fn = block_attention_kernel; \ - if (smem_sz > 0xc000) { \ - cudaFuncSetAttribute( \ - kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz); \ - } \ - dim3 grid(params.num_head, params.batch_size); \ - kernel_fn<<>>( \ - params, load_func, store_func); +#define BLHA_LAUNCH_KERNEL(T, \ + Dh, \ + Dh_MAX, \ + THDS_PER_KEY, \ + THDS_PER_VALUE, \ + THDS_PER_BLOCK, \ + BLOCK_SIZE, \ + stream, \ + load_func, \ + store_func, \ + use_cachekv_int8) \ + size_t smem_sz = \ + smem_size_in_bytes(params, Dh, THDS_PER_VALUE, THDS_PER_BLOCK); \ + if (params.cache_k_quant_scales) { \ + if (use_cachekv_int8 == 2) { \ + constexpr auto kernel_fn = block_attention_kernel; \ + if (smem_sz > 0xc000) { \ + cudaFuncSetAttribute( \ + kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz); \ + } \ + dim3 grid(params.num_head, params.batch_size); \ + kernel_fn<<>>( \ + params, load_func, store_func); \ + } else if (use_cachekv_int8 == 1) { \ + constexpr auto kernel_fn = block_attention_kernel; \ + if (smem_sz > 0xc000) { \ + cudaFuncSetAttribute( \ + kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz); \ + } \ + dim3 grid(params.num_head, params.batch_size); \ + kernel_fn<<>>( \ + params, load_func, store_func); \ + } \ + } else { \ + constexpr auto kernel_fn = block_attention_kernel; \ + if (smem_sz > 0xc000) { \ + cudaFuncSetAttribute( \ + kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz); \ + } \ + dim3 grid(params.num_head, params.batch_size); \ + kernel_fn<<>>( \ + params, load_func, store_func); \ + } template void dispatch_blha_impl_blocksize(const Block_AttN_params ¶ms, const cudaStream_t &stream, LoadFunc load_func, - StoreFunc store_func) { + StoreFunc store_func, + const int use_cachekv_int8) { constexpr int THREADS_PER_VALUE = Dh_MAX * sizeof(T) / 16; constexpr int BLOCKSIZE = 512; if (params.block_size == 16) { @@ -533,7 +636,8 @@ void dispatch_blha_impl_blocksize(const Block_AttN_params ¶ms, 16, stream, load_func, - store_func) + store_func, + use_cachekv_int8) } else if (params.block_size == 32) { BLHA_LAUNCH_KERNEL(T, Dh, @@ -544,7 +648,8 @@ void dispatch_blha_impl_blocksize(const Block_AttN_params ¶ms, 32, stream, load_func, - store_func) + store_func, + use_cachekv_int8) } else if (params.block_size == 64) { BLHA_LAUNCH_KERNEL(T, Dh, @@ -555,7 +660,8 @@ void dispatch_blha_impl_blocksize(const Block_AttN_params ¶ms, 64, stream, load_func, - store_func) + store_func, + use_cachekv_int8) } else if (params.block_size == 128) { BLHA_LAUNCH_KERNEL(T, Dh, @@ -566,7 +672,8 @@ void dispatch_blha_impl_blocksize(const Block_AttN_params ¶ms, 128, stream, load_func, - store_func) + store_func, + use_cachekv_int8) } else if (params.block_size == 256) { BLHA_LAUNCH_KERNEL(T, Dh, @@ -577,7 +684,8 @@ void dispatch_blha_impl_blocksize(const Block_AttN_params ¶ms, 256, stream, load_func, - store_func) + store_func, + use_cachekv_int8) } else { PADDLE_THROW(phi::errors::Unimplemented("block_size = %d is unsupport!", params.block_size)); @@ -589,19 +697,20 @@ void dispatch_blha_impl_headsize(const phi::GPUContext &dev_ctx, const Block_AttN_params ¶ms, int dim_head, LoadFunc load_func, - StoreFunc store_func) { + StoreFunc store_func, + const int use_cachekv_int8) { switch (dim_head) { case 32: dispatch_blha_impl_blocksize( - params, dev_ctx.stream(), load_func, store_func); + params, dev_ctx.stream(), load_func, store_func, use_cachekv_int8); break; case 64: dispatch_blha_impl_blocksize( - params, dev_ctx.stream(), load_func, store_func); + params, dev_ctx.stream(), load_func, store_func, use_cachekv_int8); break; case 128: dispatch_blha_impl_blocksize( - params, dev_ctx.stream(), load_func, store_func); + params, dev_ctx.stream(), load_func, store_func, use_cachekv_int8); break; default: PADDLE_THROW( @@ -613,12 +722,14 @@ template void DispatchBLHA(const phi::GPUContext &dev_ctx, const phi::DenseTensor &qkv_tensor, const Block_AttN_params ¶ms, + int use_cachekv_int8, int num_head, int dim_head, phi::DenseTensor *out_tensor) { MMHALoad load_func(qkv_tensor.data()); MMHAStore store_func(out_tensor->data()); - dispatch_blha_impl_headsize(dev_ctx, params, dim_head, load_func, store_func); + dispatch_blha_impl_headsize( + dev_ctx, params, dim_head, load_func, store_func, use_cachekv_int8); } template @@ -644,11 +755,34 @@ void blha(const phi::GPUContext &dev_ctx, const int rotary_emb_dims, float inv_sqrt_dh, const bool add_qkv_bias = true, - const bool neox_rotary_style = false) { + const bool neox_rotary_style = false, + const int quant_round_type = 1, + const float quant_max_bound = 127.0f, + const float quant_min_bound = -127.0f, + const phi::DenseTensor *cache_k_quant_scales = nullptr, + const phi::DenseTensor *cache_v_quant_scales = nullptr, + const phi::DenseTensor *cache_k_dequant_scales = nullptr, + const phi::DenseTensor *cache_v_dequant_scales = nullptr, + const phi::DenseTensor *dequant_qkv_scales = nullptr, + const phi::DenseTensor *shift = nullptr, + const phi::DenseTensor *smooth = nullptr, + const float quant_fmha_out_scale = -1, + int use_cachekv_int8 = 0) { Block_AttN_params params; - params.k_cache = k_cache->data(); - params.v_cache = v_cache->data(); + if (cache_k_quant_scales) { + VLOG(1) << "blha quant cachekv"; + params.k_cache_I = k_cache->data(); + params.v_cache_I = v_cache->data(); + params.cache_k_quant_scales = cache_k_quant_scales->data(); + params.cache_v_quant_scales = cache_v_quant_scales->data(); + params.cache_k_dequant_scales = cache_k_dequant_scales->data(); + params.cache_v_dequant_scales = cache_v_dequant_scales->data(); + } else { + VLOG(1) << "blha not quant cachekv"; + params.k_cache = k_cache->data(); + params.v_cache = v_cache->data(); + } params.max_num_blocks_per_seq = max_num_blocks_per_seq; params.neox_rotary_style = neox_rotary_style; @@ -693,7 +827,13 @@ void blha(const phi::GPUContext &dev_ctx, << " block_size: " << block_size << " timestep: " << timestep << " rope_stride: " << params.rope_stride; - DispatchBLHA(dev_ctx, qkv_tensor, params, num_head, dim_head, out_tensor); + DispatchBLHA(dev_ctx, + qkv_tensor, + params, + use_cachekv_int8, + num_head, + dim_head, + out_tensor); } inline cudaError_t GetNumBlocks(int64_t n, int *num_blocks) { @@ -745,6 +885,84 @@ inline cudaError_t GetNumBlocks(Func func, return cudaSuccess; } +template +__global__ void cache_int8_kernel( + const T *__restrict__ qkv, // [num_tokens, 3, num_heads, head_size] + uint8_t *__restrict__ key_cache, // [num_blocks, num_heads, block_size, + // head_size] + uint8_t *__restrict__ value_cache, // [num_blocks, num_heads, block_size, + // head_size] + const int *__restrict__ block_tables, // [bsz, max_blocks_per_seq] + const int *__restrict__ padding_offsets, // [num_tokens] + const int *__restrict__ seq_lens, // [bsz] + const float *cache_k_scales, + const float *cache_v_scales, + const int max_seq_len, + const int max_blocks_per_seq, + const int num_heads, + const int head_size, + const int block_size, + const int pre_cache_length, + const int elem_cnt, + const int round_type, + const float max_bound, + const float min_bound) { + using LoadT = phi::AlignedVector; + using LoadKVT = phi::AlignedVector; + LoadT src_vec; + LoadKVT cache_vec; + + int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x; + const int64_t hidden_size = num_heads * head_size; + const int64_t offset = 2 * hidden_size; + for (int32_t linear_index = global_thread_idx * VecSize, + step = gridDim.x * blockDim.x * VecSize; + linear_index < elem_cnt; + linear_index += step) { + const int token_idx = linear_index / offset; + const int bias = linear_index % offset; + const int qkv_id = bias / hidden_size + 1; // skip q + const int qkv_bias = bias % hidden_size; + const int hi = qkv_bias / head_size; + const int h_bias = qkv_bias % head_size; + const int ori_token_idx = token_idx + padding_offsets[token_idx]; + const int ori_bi = ori_token_idx / max_seq_len; + if (seq_lens[ori_bi] == 0) continue; + const int ori_seq_id = ori_token_idx % max_seq_len + pre_cache_length; + + const int *block_table_now = block_tables + ori_bi * max_blocks_per_seq; + const int block_idx = block_table_now[ori_seq_id / block_size]; + const int block_offset = ori_seq_id % block_size; + + const int tgt_idx = block_idx * num_heads * block_size * head_size + + hi * block_size * head_size + block_offset * head_size + + h_bias; + const int ori_idx = token_idx * 3 * hidden_size + qkv_id * hidden_size + + hi * head_size + h_bias; + phi::Load(&qkv[ori_idx], &src_vec); + + const float scale = qkv_id == 1 ? cache_k_scales[hi] : cache_v_scales[hi]; +#pragma unroll + for (int i = 0; i < VecSize; i++) { + float quant_value = scale * static_cast(src_vec[i]); + if (round_type == 0) { + quant_value = static_cast(roundWithTiesToEven(quant_value)); + } else { + quant_value = static_cast(round(quant_value)); + } + quant_value = quant_value > max_bound ? max_bound : quant_value; + quant_value = quant_value < min_bound ? min_bound : quant_value; + cache_vec[i] = static_cast(quant_value + 128.0f); + } + + if (qkv_id == 1) { + phi::Store(cache_vec, &key_cache[tgt_idx]); + } else { + phi::Store(cache_vec, &value_cache[tgt_idx]); + } + } +} + template __global__ void cache_kernel( const T *__restrict__ qkv, // [num_tokens, 3, num_heads, head_size] @@ -801,6 +1019,91 @@ __global__ void cache_kernel( } } +template +__global__ void write_pre_cache_int8_to_cache( + uint8_t *__restrict__ key_cache, // [num_blocks, num_heads, block_size, + // head_size] + uint8_t *__restrict__ value_cache, + const T *__restrict__ pre_key_cache, // [bsz, pre_cache_len, num_head, + // head_dim] + const T *__restrict__ pre_value_cache, + const int *__restrict__ block_tables, // [bsz, max_blocks_per_seq] + const int *__restrict__ seq_lens, + const float *cache_k_scales, + const float *cache_v_scales, + const int max_seq_len, + const int max_blocks_per_seq, + const int num_heads, + const int head_size, + const int block_size, + const int pre_cache_length, + const int elem_cnt, + const int round_type, + const float max_bound, + const float min_bound) { + using LoadT = phi::AlignedVector; + using LoadKVT = phi::AlignedVector; + LoadT src_vec; + LoadKVT cache_vec; + + int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x; + const int hidden_size = pre_cache_length * head_size; + const int cache_hidden_size = num_heads * hidden_size; + const int offset = 2 * cache_hidden_size; + + for (int32_t linear_index = global_thread_idx * VecSize, + step = gridDim.x * blockDim.x * VecSize; + linear_index < elem_cnt; + linear_index += step) { + const int batch_id = linear_index / offset; + if (seq_lens[batch_id] == 0) continue; + const int *block_table_now = block_tables + batch_id * max_blocks_per_seq; + + const int32_t cache_seq_id = (linear_index % hidden_size) / head_size; + const int32_t head_id = (linear_index % cache_hidden_size) / hidden_size; + const int32_t size_id = linear_index % head_size; + + const int32_t kv_id = (linear_index % offset) / cache_hidden_size; + const int32_t read_id = batch_id * cache_hidden_size + + head_id * hidden_size + cache_seq_id * head_size + + size_id; + if (kv_id == 0) { + phi::Load(&pre_key_cache[read_id], &src_vec); + } else { + phi::Load(&pre_value_cache[read_id], &src_vec); + } + + const int block_idx = block_table_now[cache_seq_id / block_size]; + const int block_offset = cache_seq_id % block_size; + + const int tgt_idx = block_idx * num_heads * block_size * head_size + + head_id * block_size * head_size + + block_offset * head_size + size_id; + + const float scale = + kv_id == 0 ? cache_k_scales[head_id] : cache_v_scales[head_id]; + +#pragma unroll + for (int i = 0; i < VecSize; i++) { + float quant_value = scale * static_cast(src_vec[i]); + if (round_type == 0) { + quant_value = static_cast(roundWithTiesToEven(quant_value)); + } else { + quant_value = static_cast(round(quant_value)); + } + quant_value = quant_value > max_bound ? max_bound : quant_value; + quant_value = quant_value < min_bound ? min_bound : quant_value; + cache_vec[i] = static_cast(quant_value + 128.0f); + } + + if (kv_id == 0) { + phi::Store(cache_vec, &key_cache[tgt_idx]); + } else { + phi::Store(cache_vec, &value_cache[tgt_idx]); + } + } +} + template __global__ void write_pre_cache_to_cache( T *__restrict__ key_cache, // [num_blocks, num_heads, block_size, @@ -872,6 +1175,8 @@ void CacheKernel( const phi::DenseTensor &seq_lens, const paddle::optional &pre_key_cache, const paddle::optional &pre_value_cache, + const paddle::optional &cache_k_scales, + const paddle::optional &cache_v_scales, const int batch_size, const int num_tokens, const int num_heads, @@ -881,8 +1186,8 @@ void CacheKernel( phi::DenseTensor *key_cache_out, phi::DenseTensor *value_cache_out, const int round_type = 0, - const float max_bound = 0.0, - const float min_bound = 0.0) { + const float max_bound = 127.0, + const float min_bound = -127.0) { typedef PDDataTypeTraits traits_; typedef typename traits_::DataType DataType_; @@ -897,44 +1202,311 @@ void CacheKernel( int grid_size = 1; GetNumBlocks(pack_num, &grid_size); - VLOG(3) << "cache kv not quant"; - cache_kernel - <<>>( - reinterpret_cast(const_cast(qkv.data())), - reinterpret_cast(key_cache_out->data()), - reinterpret_cast(value_cache_out->data()), - block_tables.data(), - padding_offsets.data(), - seq_lens.data(), - max_seq_len, - max_blocks_per_seq, - num_heads, - head_size, - block_size, - pre_cache_length, - elem_nums); + if (cache_k_scales) { + VLOG(1) << "cache kv quant"; + cache_int8_kernel + <<>>( + reinterpret_cast(const_cast(qkv.data())), + key_cache_out->data(), + value_cache_out->data(), + block_tables.data(), + padding_offsets.data(), + seq_lens.data(), + cache_k_scales.get().data(), + cache_v_scales.get().data(), + max_seq_len, + max_blocks_per_seq, + num_heads, + head_size, + block_size, + pre_cache_length, + elem_nums, + round_type, + max_bound, + min_bound); + } else { + VLOG(1) << "cache kv not quant"; + cache_kernel + <<>>( + reinterpret_cast(const_cast(qkv.data())), + reinterpret_cast(key_cache_out->data()), + reinterpret_cast(value_cache_out->data()), + block_tables.data(), + padding_offsets.data(), + seq_lens.data(), + max_seq_len, + max_blocks_per_seq, + num_heads, + head_size, + block_size, + pre_cache_length, + elem_nums); + } if (pre_key_cache) { // stage 2: write pre_cache to cache [:pre_cache_length] elem_nums = batch_size * num_heads * pre_cache_length * head_size * 2; pack_num = elem_nums / PackSize; GetNumBlocks(pack_num, &grid_size); - write_pre_cache_to_cache + if (cache_k_scales) { + write_pre_cache_int8_to_cache + <<>>( + key_cache_out->data(), + value_cache_out->data(), + reinterpret_cast( + pre_key_cache.get().data()), + reinterpret_cast( + pre_value_cache.get().data()), + block_tables.data(), + seq_lens.data(), + cache_k_scales->data(), + cache_v_scales->data(), + max_seq_len, + max_blocks_per_seq, + num_heads, + head_size, + block_size, + pre_cache_length, + elem_nums, + round_type, + max_bound, + min_bound); + } else { + write_pre_cache_to_cache + <<>>( + reinterpret_cast(key_cache_out->data()), + reinterpret_cast(value_cache_out->data()), + reinterpret_cast( + pre_key_cache.get().data()), + reinterpret_cast( + pre_value_cache.get().data()), + block_tables.data(), + seq_lens.data(), + max_seq_len, + max_blocks_per_seq, + num_heads, + head_size, + block_size, + pre_cache_length, + elem_nums); + } + } +} + +template +__global__ void quant_write_cache_int8_kernel( + const T *__restrict__ qkv, // [num_tokens, 3, num_heads, head_size] + uint8_t *__restrict__ key_cache, // [num_blocks, num_heads, block_size, + // head_size] + uint8_t *__restrict__ value_cache, // [num_blocks, num_heads, block_size, + // head_size] + const int *__restrict__ block_tables, // [bsz, max_blocks_per_seq] + const int *__restrict__ padding_offsets, // [num_tokens] + const int *__restrict__ seq_lens, // [bsz] + const int max_seq_len, + const int pre_cache_length, + const int max_blocks_per_seq, + const int num_tokens, + const int num_heads, + const int head_size, + const int block_size, + float *k_quant_scales, + float *v_quant_scales, + float *k_dequant_scales, + float *v_dequant_scales) { + const int hi = blockIdx.x; + const int b_id = blockIdx.y; + if (seq_lens[b_id] <= 0) return; + const int qkv_id = blockIdx.z; + + using InVec = phi::AlignedVector; + using OutVec = phi::AlignedVector; + + InVec in_vec; + OutVec out_vec; + InVec abs_max_vec; +#pragma unroll + for (int i = 0; i < VecSize; ++i) { + abs_max_vec[i] = 0.0f; + } + + uint8_t *dst_ptr; + float *quant_scales; + float *dequant_scales; + if (qkv_id == 0) { + dst_ptr = key_cache; + quant_scales = k_quant_scales; + dequant_scales = k_dequant_scales; + } else { + dst_ptr = value_cache; + quant_scales = v_quant_scales; + dequant_scales = v_dequant_scales; + } + + T local_abs_max; + + for (int idx = threadIdx.x * VecSize; idx < num_tokens * head_size; + idx += blockDim.x * VecSize) { + int token_idx = idx / head_size; + int h_offset = idx % head_size; + int linear_idx = token_idx * 3 * num_heads * head_size + + (qkv_id + 1) * num_heads * head_size + hi * head_size + + h_offset; + + Load(qkv + linear_idx, &in_vec); +#pragma unroll + for (int i = 0; i < VecSize; ++i) { + abs_max_vec[i] = MaxFunc()(abs_max_vec[i], AbsFunc()(in_vec[i])); + } + } + + local_abs_max = LocalReduceMax(abs_max_vec); + T abs_max_val = BlockReduceAbsMax(local_abs_max, 0xffffffff); + + __shared__ float quant_scale; + if (threadIdx.x == 0) { + quant_scale = 127.0f / static_cast(abs_max_val); + } + + __syncthreads(); + for (int idx = threadIdx.x * VecSize; idx < num_tokens * head_size; + idx += blockDim.x * VecSize) { + int token_idx = idx / head_size; + int h_offset = idx % head_size; + int linear_idx = token_idx * 3 * num_heads * head_size + + (qkv_id + 1) * num_heads * head_size + hi * head_size + + h_offset; + + Load(qkv + linear_idx, &in_vec); +#pragma unroll + for (int i = 0; i < VecSize; ++i) { + out_vec[i] = QuantFunc()(in_vec[i], quant_scale); + } + + const int ori_token_idx = token_idx + padding_offsets[token_idx]; + const int ori_bi = ori_token_idx / max_seq_len; + if (ori_bi != b_id) continue; + const int ori_seq_id = ori_token_idx % max_seq_len + pre_cache_length; + + const int *block_table_now = block_tables + ori_bi * max_blocks_per_seq; + const int block_idx = block_table_now[ori_seq_id / block_size]; + const int block_offset = ori_seq_id % block_size; + // [max_block_num, num_head, block_size, head_dim/x, x] + Store(out_vec, + dst_ptr + block_idx * num_heads * block_size * head_size + + hi * block_size * head_size + block_offset * head_size + + h_offset); + } + + if (threadIdx.x == 0) { + quant_scales[b_id * num_heads + hi] = quant_scale; + dequant_scales[b_id * num_heads + hi] = 1.0f / quant_scale; + } +} + +template +void DynamicQuantCacheKernel( + const phi::GPUContext &dev_ctx, + const phi::DenseTensor &qkv, // [token_num, 3, num_head, head_dim] + const phi::DenseTensor &block_tables, + const phi::DenseTensor &padding_offsets, + const phi::DenseTensor &seq_lens, + const phi::DenseTensor &k_quant_scales, + const phi::DenseTensor &v_quant_scales, + const phi::DenseTensor &k_dequant_scales, + const phi::DenseTensor &v_dequant_scales, + const paddle::optional &pre_key_cache, + const paddle::optional &pre_value_cache, + const int batch_size, + const int num_heads, + const int head_size, + const int max_seq_len, + const int pre_cache_length, + phi::DenseTensor *key_cache_out, + phi::DenseTensor *value_cache_out) { + typedef PDDataTypeTraits traits_; + typedef typename traits_::DataType DataType_; + + const int num_tokens = padding_offsets.dims()[0]; + const int max_blocks_per_seq = block_tables.dims()[1]; + const int32_t block_size = key_cache_out->dims()[2]; + constexpr int PackSize = 16 / sizeof(T); + + assert(head_size % PackSize == 0); + + const DataType_ *qkv_ptr = reinterpret_cast(qkv.data()); + + // [max_block_num, num_head, block_size, head_dim] + + uint8_t *cache_k_ptr = key_cache_out->data(); + uint8_t *cache_v_ptr = value_cache_out->data(); + + float *k_quant_scales_data = + const_cast(k_quant_scales.data()); + float *k_dequant_scales_data = + const_cast(k_dequant_scales.data()); + + float *v_quant_scales_data = + const_cast(v_quant_scales.data()); + float *v_dequant_scales_data = + const_cast(v_dequant_scales.data()); + + constexpr int block_sz = 1024; + + const int bsz = seq_lens.dims()[0]; + + dim3 grid(num_heads, bsz, 2); + + // [token_num, 3, num_head, head_dim/x, x]->[max_block_num, num_head, + // block_size, head_dim/x, x] Quant and Write kv + quant_write_cache_int8_kernel + <<>>(qkv_ptr, + cache_k_ptr, + cache_v_ptr, + block_tables.data(), + padding_offsets.data(), + seq_lens.data(), + max_seq_len, + pre_cache_length, + max_blocks_per_seq, + num_tokens, + num_heads, + head_size, + block_size, + k_quant_scales_data, + v_quant_scales_data, + k_dequant_scales_data, + v_dequant_scales_data); + + if (pre_key_cache) { + // stage 2: write pre_cache to cache [:pre_cache_length] + const int elem_nums = + batch_size * num_heads * pre_cache_length * head_size * 2; + const int pack_num = elem_nums / PackSize; + const int blocksize = 128; + int grid_size = 1; + GetNumBlocks(pack_num, &grid_size); + write_pre_cache_int8_to_cache <<>>( - reinterpret_cast(key_cache_out->data()), - reinterpret_cast(value_cache_out->data()), + key_cache_out->data(), + value_cache_out->data(), reinterpret_cast(pre_key_cache.get().data()), reinterpret_cast( pre_value_cache.get().data()), block_tables.data(), seq_lens.data(), + k_quant_scales.data(), + v_quant_scales.data(), max_seq_len, max_blocks_per_seq, num_heads, head_size, block_size, pre_cache_length, - elem_nums); + elem_nums, + 1, + 127.0f, + -127.0f); } } diff --git a/paddle/phi/kernels/fusion/gpu/block_multi_head_attention_kernel.cu b/paddle/phi/kernels/fusion/gpu/block_multi_head_attention_kernel.cu index 366e1789c829a..57754fd3b82aa 100644 --- a/paddle/phi/kernels/fusion/gpu/block_multi_head_attention_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/block_multi_head_attention_kernel.cu @@ -15,6 +15,8 @@ #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/flash_attn_kernel.h" +#include "paddle/phi/kernels/funcs/broadcast_function.h" +#include "paddle/phi/kernels/funcs/elementwise_functor.h" #include "paddle/phi/kernels/fusion/cutlass/variable_length_memory_efficient_attention.h" #include "paddle/phi/kernels/fusion/gpu/block_attn.h" #include "paddle/phi/kernels/gpu/flash_attn_utils.h" @@ -23,8 +25,186 @@ namespace phi { namespace fusion { +template +inline HOSTDEVICE data_t RoundWithTiesToEven(data_t x) { + data_t xLower = floor(x); + data_t xUpper = ceil(x); + // x is in interval [xl,xu]. Choose closest of two bounds, breaking ties to + // even. + data_t dLower = x - xLower; + data_t dUpper = xUpper - x; + return static_cast( + (dLower == dUpper ? fmod(xLower, 2.0F) == 0.0F : dLower < dUpper) + ? xLower + : xUpper); +} + +template +__forceinline__ __device__ T add_mul(T a, T b, T c) { + return (a + b) * c; +} + +template <> +__forceinline__ __device__ half add_mul(half a, half b, half c) { + return __hmul(__hadd(a, b), c); +} + +#if CUDA_VERSION >= 11000 && defined(CUDA_BFLOAT16_AVALIABLE) +template <> +__forceinline__ __device__ __nv_bfloat16 +add_mul<__nv_bfloat16>(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) { + return __hmul(__hadd(a, b), c); +} +#endif + +template +__forceinline__ __device__ int8_t quant_helper(const data_t input, + const float scale, + const int round_type, + const float max_bound, + const float min_bound) { + float quant_value = max_bound * scale * static_cast(input); + + if (round_type == 0) { + quant_value = static_cast(RoundWithTiesToEven(quant_value)); + } else { + quant_value = static_cast(round(quant_value)); + } + quant_value = quant_value > max_bound ? max_bound : quant_value; + quant_value = quant_value < min_bound ? min_bound : quant_value; + return static_cast(quant_value); +} + +template +__forceinline__ __device__ int8_t quant_helper(const data_t input, + const data_t shift, + const data_t smooth, + const float scale, + const int round_type, + const float max_bound, + const float min_bound) { + auto smooth_out = add_mul(input, shift, smooth); + float quant_value = max_bound * scale * static_cast(smooth_out); + + if (round_type == 0) { + quant_value = static_cast(RoundWithTiesToEven(quant_value)); + } else { + quant_value = static_cast(round(quant_value)); + } + quant_value = quant_value > max_bound ? max_bound : quant_value; + quant_value = quant_value < min_bound ? min_bound : quant_value; + return static_cast(quant_value); +} + +template +__global__ void QuantKernel(const data_t* input, + char4* output, + const float scale, + const int m, + const int n, + const int round_type, + const float max_bound, + const float min_bound) { + int n_id = (blockIdx.x * blockDim.x + threadIdx.x) << 2; + int m_id = blockIdx.y * blockDim.y + threadIdx.y; + bool check = ((m_id < m) && (n_id < n)); + + if (check) { + char4 tmp; + tmp.x = quant_helper( + input[m_id * n + n_id], scale, round_type, max_bound, min_bound); + tmp.y = quant_helper( + input[m_id * n + n_id + 1], scale, round_type, max_bound, min_bound); + tmp.z = quant_helper( + input[m_id * n + n_id + 2], scale, round_type, max_bound, min_bound); + tmp.w = quant_helper( + input[m_id * n + n_id + 3], scale, round_type, max_bound, min_bound); + + output[(m_id * n + n_id) >> 2] = tmp; + } +} + +template +__global__ void QuantKernel(const data_t* input, + const data_t* shift, + const data_t* smooth, + char4* output, + const float scale, + const int m, + const int n, + const int round_type, + const float max_bound, + const float min_bound) { + int n_id = (blockIdx.x * blockDim.x + threadIdx.x) << 2; + int m_id = blockIdx.y * blockDim.y + threadIdx.y; + bool check = ((m_id < m) && (n_id < n)); + + if (check) { + char4 tmp; + tmp.x = quant_helper(input[m_id * n + n_id], + shift[n_id], + smooth[n_id], + scale, + round_type, + max_bound, + min_bound); + tmp.y = quant_helper(input[m_id * n + n_id + 1], + shift[n_id + 1], + smooth[n_id + 1], + scale, + round_type, + max_bound, + min_bound); + tmp.z = quant_helper(input[m_id * n + n_id + 2], + shift[n_id + 2], + smooth[n_id + 2], + scale, + round_type, + max_bound, + min_bound); + tmp.w = quant_helper(input[m_id * n + n_id + 3], + shift[n_id + 3], + smooth[n_id + 3], + scale, + round_type, + max_bound, + min_bound); + + output[(m_id * n + n_id) >> 2] = tmp; + } +} + +template +__global__ void DequantKernel(T* output, + const int32_t* input, + const int64_t m, // batch size + const int64_t n, // hidden + const float* dequant_out_scale_data) { + int64_t numel = m * n; + int64_t stride = blockDim.x * gridDim.x * VecSize; + int64_t idx = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize; + int64_t col_id = idx % n; + + phi::AlignedVector in_vec; + phi::AlignedVector out_scale_vec; + phi::AlignedVector out_vec; + + for (; idx < numel; idx += stride) { + phi::Load(input + idx, &in_vec); + phi::Load(dequant_out_scale_data + col_id, &out_scale_vec); + +#pragma unroll + for (int i = 0; i < VecSize; ++i) { + out_vec[i] = + static_cast(static_cast(in_vec[i]) * out_scale_vec[i]); + } + + phi::Store(out_vec, output + idx); + } +} + template -void BlockMultiheadAttentionKernel( +void DispatchWithDtype( const Context& dev_ctx, const DenseTensor& qkv, const DenseTensor& key_cache, @@ -42,16 +222,41 @@ void BlockMultiheadAttentionKernel( const paddle::optional& rope_emb, const paddle::optional& mask, const paddle::optional& tgt_mask, + const paddle::optional& cache_k_quant_scales, + const paddle::optional& cache_v_quant_scales, + const paddle::optional& cache_k_dequant_scales, + const paddle::optional& cache_v_dequant_scales, + const paddle::optional& qkv_out_scale, + const paddle::optional& qkv_bias, + const paddle::optional& out_shift, + const paddle::optional& out_smooth, int max_seq_len, int block_size, bool use_neox_style, + const bool dynamic_cachekv_quant, + const int quant_round_type, + const float quant_max_bound, + const float quant_min_bound, + const float out_scale, + const std::string& compute_dtype, DenseTensor* fmha_out, DenseTensor* qkv_out, DenseTensor* key_cache_out, DenseTensor* value_cache_out) { - dev_ctx.template Alloc(fmha_out); - InitValue( - dev_ctx, fmha_out->data(), fmha_out->numel(), static_cast(0.)); + phi::DenseTensor qkv_buf; + phi::DenseTensor fmha_buf; + + VLOG(1) << "fmha_out " << fmha_out->dims(); + if (out_scale <= 0) { + dev_ctx.template Alloc(fmha_out); + fmha_buf = *fmha_out; + } else { + fmha_buf.Resize(fmha_out->dims()); + dev_ctx.template Alloc(&fmha_buf); + dev_ctx.template Alloc(fmha_out); + } + + InitValue(dev_ctx, fmha_buf.data(), fmha_buf.numel(), static_cast(0.)); const auto& input_dims = qkv.dims(); const auto& key_cache_dims = key_cache.dims(); const int token_num = input_dims[0]; @@ -131,12 +336,44 @@ void BlockMultiheadAttentionKernel( } VLOG(3) << "encoder"; VLOG(3) << "max_enc_len_this_time: " << max_enc_len_this_time; + + if (qkv_out_scale) { + VLOG(1) << "qkv_out_scale: " << qkv_out_scale.get_ptr()->dims(); + qkv_buf.Resize(qkv.dims()); + dev_ctx.template Alloc(&qkv_buf, qkv_buf.numel() * sizeof(T)); + + int64_t numel = qkv.numel(); + constexpr int64_t thread_per_block = 512; + constexpr int DequantKernelVecSize = 4; + int64_t block_per_grid = + (numel / DequantKernelVecSize + thread_per_block - 1) / + thread_per_block; + DequantKernel + <<>>( + qkv_buf.data(), + qkv.data(), + input_dims[0], + input_dims[1], + qkv_out_scale.get_ptr()->data()); + } else { + VLOG(1) << "qkv_out_scale is none"; + qkv_buf = qkv; + } + + if (qkv_bias) { + VLOG(1) << "has bias"; + std::vector ins = {&qkv_buf, qkv_bias.get_ptr()}; + std::vector outs = {&qkv_buf}; + phi::funcs::BroadcastKernel( + dev_ctx, ins, &outs, phi::funcs::AddFunctor()); + } + if (max_enc_len_this_time > 0) { const int* sequence_lengths_data = seq_lens_encoder.data(); if (rope_emb) { rotary_qk_variable(dev_ctx, - qkv_out->data(), - qkv.data(), + qkv_buf.data(), + qkv_buf.data(), rope_emb.get().data(), padding_offsets.data(), sequence_lengths_data, @@ -154,7 +391,7 @@ void BlockMultiheadAttentionKernel( unpadding_q.data(), unpadding_k.data(), unpadding_v.data(), - qkv.data(), + qkv_buf.data(), padding_offsets.data(), sequence_lengths_data, token_num, @@ -179,7 +416,7 @@ void BlockMultiheadAttentionKernel( false, true /* is_test*/, "" /*rng_name*/, - fmha_out, + &fmha_buf, &softmax_out, &softmax_lse, &seed_offset); @@ -221,7 +458,7 @@ void BlockMultiheadAttentionKernel( InvokeTransposeRemovePadding(dev_ctx, qktv_out.data(), sequence_lengths_data, - fmha_out->data(), + fmha_buf.data(), bsz, num_head, max_enc_len_this_time, @@ -230,29 +467,56 @@ void BlockMultiheadAttentionKernel( token_num, padding_offsets.data()); } + VLOG(3) << "flash end"; - CacheKernel(dev_ctx, - qkv, - block_tables, - padding_offsets, - seq_lens_encoder, - pre_key_cache, - pre_value_cache, - bsz, - token_num, - num_head, - dim_head, - max_seq_len, - pre_cache_length, - key_cache_out, - value_cache_out); + if (cache_k_quant_scales && dynamic_cachekv_quant) { + DynamicQuantCacheKernel(dev_ctx, + qkv_buf, + block_tables, + padding_offsets, + seq_lens_encoder, + *(cache_k_quant_scales.get_ptr()), + *(cache_v_quant_scales.get_ptr()), + *(cache_k_dequant_scales.get_ptr()), + *(cache_v_dequant_scales.get_ptr()), + pre_key_cache, + pre_value_cache, + bsz, + num_head, + dim_head, + max_seq_len, + pre_cache_length, + key_cache_out, + value_cache_out); + } else { + CacheKernel(dev_ctx, + qkv_buf, + block_tables, + padding_offsets, + seq_lens_encoder, + pre_key_cache, + pre_value_cache, + cache_k_quant_scales, + cache_v_quant_scales, + bsz, + token_num, + num_head, + dim_head, + max_seq_len, + pre_cache_length, + key_cache_out, + value_cache_out, + quant_round_type, + quant_max_bound, + quant_min_bound); + } VLOG(3) << "cache end"; } VLOG(3) << "encoder done"; VLOG(3) << "max_dec_len_this_time: " << max_dec_len_this_time; if (max_dec_len_this_time > 0) { GetDecoderTensor(dev_ctx, - qkv, + qkv_buf, nullptr, cum_offsets.data(), &qkv_out_decoder, @@ -263,6 +527,15 @@ void BlockMultiheadAttentionKernel( max_seq_len, dim_head); VLOG(3) << "qkv_out_decoder: " << qkv_out_decoder.dims(); + int cachekv_quant_mode = 0; + if (cache_k_quant_scales) { + if (dynamic_cachekv_quant) { + cachekv_quant_mode = 2; + } else { + cachekv_quant_mode = 1; + } + } + VLOG(1) << "cachekv_quant_mode " << cachekv_quant_mode; blha(dev_ctx, qkv_out_decoder, nullptr, // qkv_bias @@ -273,7 +546,7 @@ void BlockMultiheadAttentionKernel( rope_emb ? &rope_emb.get() : nullptr, // rope_emb key_cache_out, value_cache_out, - fmha_out, + &fmha_buf, bsz, max_block_per_seq, block_size, @@ -285,10 +558,261 @@ void BlockMultiheadAttentionKernel( rope_emb ? 1 : 0, 1. / sqrt(dim_head), /*compute_bias*/ false, - use_neox_style); + use_neox_style, + quant_round_type, + quant_max_bound, + quant_min_bound, + cache_k_quant_scales ? cache_k_quant_scales.get_ptr() : nullptr, + cache_v_quant_scales ? cache_v_quant_scales.get_ptr() : nullptr, + cache_k_dequant_scales ? cache_k_dequant_scales.get_ptr() : nullptr, + cache_v_dequant_scales ? cache_v_dequant_scales.get_ptr() : nullptr, + nullptr, // dequant_qkv_scales + nullptr, // shift + nullptr, // smooth + -1, // quant_fmha_out_scale + cachekv_quant_mode); VLOG(3) << "blha end"; } - VLOG(3) << "decoder done"; + if (out_scale > 0) { + int m = fmha_out->dims()[0]; + int n = fmha_out->dims()[1]; + dim3 grid((n >> 2 + 31) / 32, (m + 31) / 32); + dim3 block(32, 32); + if (out_shift && out_smooth) { + QuantKernel<<>>( + fmha_buf.data(), + out_shift.get_ptr()->data(), + out_smooth.get_ptr()->data(), + reinterpret_cast(fmha_out->data()), + out_scale, + m, + n, + quant_round_type, + quant_max_bound, + quant_min_bound); + } else { + QuantKernel<<>>( + fmha_buf.data(), + reinterpret_cast(fmha_out->data()), + out_scale, + m, + n, + quant_round_type, + quant_max_bound, + quant_min_bound); + } + VLOG(3) << "decoder done"; + } +} + +template +void BlockMultiheadAttentionKernel( + const Context& dev_ctx, + const DenseTensor& qkv, + const DenseTensor& key_cache, + const DenseTensor& value_cache, + const DenseTensor& seq_lens_encoder, + const DenseTensor& seq_lens_decoder, + const DenseTensor& seq_lens_this_time, + const DenseTensor& padding_offsets, + const DenseTensor& cum_offsets, + const DenseTensor& cu_seqlens_q, + const DenseTensor& cu_seqlens_k, + const DenseTensor& block_tables, + const paddle::optional& pre_key_cache, + const paddle::optional& pre_value_cache, + const paddle::optional& rope_emb, + const paddle::optional& mask, + const paddle::optional& tgt_mask, + const paddle::optional& cache_k_quant_scales, + const paddle::optional& cache_v_quant_scales, + const paddle::optional& cache_k_dequant_scales, + const paddle::optional& cache_v_dequant_scales, + const paddle::optional& qkv_out_scale, + const paddle::optional& qkv_bias, + const paddle::optional& out_shift, + const paddle::optional& out_smooth, + int max_seq_len, + int block_size, + bool use_neox_style, + const bool dynamic_cachekv_quant, + const int quant_round_type, + const float quant_max_bound, + const float quant_min_bound, + const float out_scale, + const std::string& compute_dtype, + DenseTensor* fmha_out, + DenseTensor* qkv_out, + DenseTensor* key_cache_out, + DenseTensor* value_cache_out) { + if (qkv.dtype() == phi::DataType::INT32) { + VLOG(1) << "qkv.dtype() int32"; + if (compute_dtype == "fp16") { + VLOG(1) << "compute_dtype fp16"; + DispatchWithDtype(dev_ctx, + qkv, + key_cache, + value_cache, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + padding_offsets, + cum_offsets, + cu_seqlens_q, + cu_seqlens_k, + block_tables, + pre_key_cache, + pre_value_cache, + rope_emb, + mask, + tgt_mask, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_dequant_scales, + cache_v_dequant_scales, + qkv_out_scale, + qkv_bias, + out_shift, + out_smooth, + max_seq_len, + block_size, + use_neox_style, + dynamic_cachekv_quant, + quant_round_type, + quant_max_bound, + quant_min_bound, + out_scale, + compute_dtype, + fmha_out, + qkv_out, + key_cache_out, + value_cache_out); + } else if (compute_dtype == "bf16") { +#if CUDA_VERSION >= 11000 && defined(CUDA_BFLOAT16_AVALIABLE) + DispatchWithDtype(dev_ctx, + qkv, + key_cache, + value_cache, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + padding_offsets, + cum_offsets, + cu_seqlens_q, + cu_seqlens_k, + block_tables, + pre_key_cache, + pre_value_cache, + rope_emb, + mask, + tgt_mask, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_dequant_scales, + cache_v_dequant_scales, + qkv_out_scale, + qkv_bias, + out_shift, + out_smooth, + max_seq_len, + block_size, + use_neox_style, + dynamic_cachekv_quant, + quant_round_type, + quant_max_bound, + quant_min_bound, + out_scale, + compute_dtype, + fmha_out, + qkv_out, + key_cache_out, + value_cache_out); +#endif + } + } else { + VLOG(1) << "qkv.dtype() NOT int32"; + if (std::is_same::value) { + DispatchWithDtype(dev_ctx, + qkv, + key_cache, + value_cache, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + padding_offsets, + cum_offsets, + cu_seqlens_q, + cu_seqlens_k, + block_tables, + pre_key_cache, + pre_value_cache, + rope_emb, + mask, + tgt_mask, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_dequant_scales, + cache_v_dequant_scales, + qkv_out_scale, + qkv_bias, + out_shift, + out_smooth, + max_seq_len, + block_size, + use_neox_style, + dynamic_cachekv_quant, + quant_round_type, + quant_max_bound, + quant_min_bound, + out_scale, + compute_dtype, + fmha_out, + qkv_out, + key_cache_out, + value_cache_out); + } else if (std::is_same::value) { +#if CUDA_VERSION >= 11000 && defined(CUDA_BFLOAT16_AVALIABLE) + DispatchWithDtype(dev_ctx, + qkv, + key_cache, + value_cache, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + padding_offsets, + cum_offsets, + cu_seqlens_q, + cu_seqlens_k, + block_tables, + pre_key_cache, + pre_value_cache, + rope_emb, + mask, + tgt_mask, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_dequant_scales, + cache_v_dequant_scales, + qkv_out_scale, + qkv_bias, + out_shift, + out_smooth, + max_seq_len, + block_size, + use_neox_style, + dynamic_cachekv_quant, + quant_round_type, + quant_max_bound, + quant_min_bound, + out_scale, + compute_dtype, + fmha_out, + qkv_out, + key_cache_out, + value_cache_out); +#endif + } + } } } // namespace fusion @@ -300,11 +824,13 @@ PD_REGISTER_KERNEL(block_multihead_attention, ALL_LAYOUT, phi::fusion::BlockMultiheadAttentionKernel, phi::dtype::bfloat16, - phi::dtype::float16) {} + phi::dtype::float16, + int32_t) {} #else PD_REGISTER_KERNEL(block_multihead_attention, GPU, ALL_LAYOUT, phi::fusion::BlockMultiheadAttentionKernel, - phi::dtype::float16) {} + phi::dtype::float16, + int32_t) {} #endif diff --git a/paddle/phi/kernels/fusion/gpu/mmha_util.cu.h b/paddle/phi/kernels/fusion/gpu/mmha_util.cu.h index 44c4e327d20c8..7d4e0c81198b1 100644 --- a/paddle/phi/kernels/fusion/gpu/mmha_util.cu.h +++ b/paddle/phi/kernels/fusion/gpu/mmha_util.cu.h @@ -95,6 +95,80 @@ struct bf16_8_t { }; #endif +//------------------------------------ +template +struct num_elems; +template <> +struct num_elems { + static constexpr int value = 1; +}; +template <> +struct num_elems { + static constexpr int value = 2; +}; +template <> +struct num_elems { + static constexpr int value = 4; +}; +template <> +struct num_elems { + static constexpr int value = 2; +}; +template <> +struct num_elems { + static constexpr int value = 4; +}; +template <> +struct num_elems { + static constexpr int value = 8; +}; +#ifdef ENABLE_BF16 +template <> +struct num_elems<__nv_bfloat162> { + static constexpr int value = 2; +}; +template <> +struct num_elems { + static constexpr int value = 4; +}; +template <> +struct num_elems { + static constexpr int value = 8; +}; +#endif + +//------------------------------------ +template +struct packed_type; +template +struct packed_type { + using type = T; +}; +template <> +struct packed_type { + using type = uint16_t; +}; +template <> +struct packed_type { + using type = uint32_t; +}; +template <> +struct packed_type { + using type = uint64_t; +}; +template <> +struct packed_type { + using type = float2; +}; +template <> +struct packed_type { + using type = float4; +}; +template <> +struct packed_type { + using type = Float8_; +}; + template struct Qk_vec_ {}; template <> @@ -244,6 +318,42 @@ struct K_vec_ { }; #endif // ENABLE_BF16 +//------------------------------------ +template +struct K_vec_I_ { + using Type = uint8_t; +}; + +template <> +struct K_vec_I_ { + using Type = uint16_t; +}; + +template <> +struct K_vec_I_ { + using Type = uint32_t; +}; + +template <> +struct K_vec_I_ { + using Type = uint64_t; +}; + +#ifdef ENABLE_BF16 +template <> +struct K_vec_I_ { + using Type = uint16_t; +}; +template <> +struct K_vec_I_ { + using Type = uint32_t; +}; +template <> +struct K_vec_I_ { + using Type = uint64_t; +}; +#endif // ENABLE_BF16 + template struct V_vec_ {}; template <> @@ -515,6 +625,183 @@ inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) { } #endif // ENABLE_BF16 +template +inline __device__ void mul_pointer_v2(T* c, float a, IntT* b); + +template <> +inline __device__ void mul_pointer_v2(float4* c, float a, uint32_t* b) { + uint8_t* b_tmp = reinterpret_cast(b); + c->x = a * (static_cast(b_tmp[0]) - 128.0); + c->y = a * (static_cast(b_tmp[1]) - 128.0); + c->z = a * (static_cast(b_tmp[2]) - 128.0); + c->w = a * (static_cast(b_tmp[3]) - 128.0); +} + +template <> +inline __device__ void mul_pointer_v2(float* c, float a, uint8_t* b) { + c[0] = a * (static_cast(b[0]) - 128.0); +} + +inline __device__ void convert_(float16* result, uint32_t const& source) { +#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) + uint32_t* h = reinterpret_cast(result); + uint32_t const i8s = reinterpret_cast(source); + + static constexpr uint32_t mask_for_elt_01 = 0x5150; + static constexpr uint32_t mask_for_elt_23 = 0x5352; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(h[0]) + : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_01)); + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(h[1]) + : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_23)); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(h[0]) + : "r"(h[0]), "r"(I8s_TO_F16s_MAGIC_NUM)); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(h[1]) + : "r"(h[1]), "r"(I8s_TO_F16s_MAGIC_NUM)); +#endif +} + +// float16 * 2 <- uint8_t * 2 +template <> +inline __device__ void mul_pointer_v2(uint32_t* c, float a, uint16_t* b) { +#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) + uint32_t tmp_uint32 = 0; + uint32_t* h = &tmp_uint32; + uint16_t tmp_b = *b; + uint32_t i8s = *reinterpret_cast(&tmp_b); + + static constexpr uint32_t mask_for_elt_01 = 0x5150; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(h[0]) + : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_01)); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(h[0]) + : "r"(h[0]), "r"(I8s_TO_F16s_MAGIC_NUM)); + + half2 tmp_half2 = *reinterpret_cast(h); + tmp_half2.x *= static_cast(a); + tmp_half2.y *= static_cast(a); + + c[0] = *reinterpret_cast(&tmp_half2); + +#endif +} + +// float16 * 4 <- uint8_t * 4 +template <> +inline __device__ void mul_pointer_v2(uint2* c, float a, uint32_t* b) { + float16* c_prime = reinterpret_cast(c); + float16 a_prime = static_cast(a); + convert_(c_prime, *b); +#pragma unroll + for (int i = 0; i < 4; ++i) { + c_prime[i] *= a_prime; + } +} +// float16 * 8 <- uint8_t * 8 +template <> +inline __device__ void mul_pointer_v2(uint4* c, float a, uint64_t* b) { + uint2* tmp_c = reinterpret_cast(c); + uint32_t* tmp_b = reinterpret_cast(b); +#pragma unroll + for (int i = 0; i < 2; ++i) { + mul_pointer_v2(tmp_c + i, a, tmp_b + i); + } +} + +#ifdef ENABLE_BF16 +inline __device__ static void convert_(__nv_bfloat16* result, + uint32_t const& source) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + + uint32_t* bf16_result_ptr = reinterpret_cast(result); + uint32_t const i8s = reinterpret_cast(source); + + static constexpr uint32_t fp32_base = 0x4B000000; + float fp32_intermediates[4]; + + uint32_t* fp32_intermediates_casted = + reinterpret_cast(fp32_intermediates); + fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7651); + fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7652); + fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653); + +#pragma unroll + for (int ii = 0; ii < 4; ++ii) { + fp32_intermediates[ii] -= (8388608.f + 128.f); + } + +#pragma unroll + for (int ii = 0; ii < 2; ++ii) { + bf16_result_ptr[ii] = __byte_perm(fp32_intermediates_casted[2 * ii + 0], + fp32_intermediates_casted[2 * ii + 1], + 0x7632); + } +#endif +} + +template <> +inline __device__ void mul_pointer_v2(__nv_bfloat162* c, float a, uint16_t* b) { + using Packed_Int8_t = typename packed_type::type; + Packed_Int8_t int8_vec_4_val = *reinterpret_cast(b); + uint8_t* int8_vec_pointer = reinterpret_cast(&int8_vec_4_val); + + uint32_t* bf16_result_ptr = reinterpret_cast(c); + uint32_t const i8s = int8_vec_4_val; + + static constexpr uint32_t fp32_base = 0x4B000000; + float fp32_intermediates[2]; + + uint32_t* fp32_intermediates_casted = + reinterpret_cast(fp32_intermediates); + fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7651); + +#pragma unroll + for (int ii = 0; ii < 2; ++ii) { + fp32_intermediates[ii] -= (8388608.f + 128.f); + } + + bf16_result_ptr[0] = __byte_perm( + fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632); + __nv_bfloat16 scale = static_cast<__nv_bfloat16>(a); + c->x = c->x * scale; + c->y = c->y * scale; +} + +template <> +inline __device__ void mul_pointer_v2(bf16_4_t* c, float a, uint32_t* b) { + __nv_bfloat16 a_prime = static_cast<__nv_bfloat16>(a); + __nv_bfloat16* c_prime = reinterpret_cast<__nv_bfloat16*>(c); + convert_(c_prime, *b); +#pragma unroll + for (int i = 0; i < 4; ++i) { + c_prime[i] = c_prime[i] * a_prime; + } +} + +template <> +inline __device__ void mul_pointer_v2(bf16_8_t* c, float a, uint64_t* b) { + bf16_4_t* tmp_c = reinterpret_cast(c); + uint64_t bb = *b; + uint32_t* tmp_b = reinterpret_cast(&bb); +#pragma unroll + for (int i = 0; i < 2; ++i) { + mul_pointer_v2(tmp_c + i, a, tmp_b + i); + } +} +#endif // ENABLE_BF16 + template inline __device__ Acc mul(A a, B b); @@ -523,6 +810,123 @@ inline __device__ float mul(float a, float b) { return a * b; } +#ifdef ENABLE_BF16 +template <> +inline __device__ __nv_bfloat162 mul(float a, __nv_bfloat162 b) { + __nv_bfloat162 ret; + ret.x = static_cast<__nv_bfloat16>(a) * b.x; + ret.y = static_cast<__nv_bfloat16>(a) * b.y; + return ret; +} + +template <> +inline __device__ bf16_4_t mul(float a, bf16_4_t b) { + bf16_4_t ret; + ret.x = mul<__nv_bfloat162, float, __nv_bfloat162>(a, b.x); + ret.y = mul<__nv_bfloat162, float, __nv_bfloat162>(a, b.y); + return ret; +} + +template <> +inline __device__ bf16_8_t mul(float a, bf16_8_t b) { + bf16_8_t ret; + ret.x = mul<__nv_bfloat162, float, __nv_bfloat162>(a, b.x); + ret.y = mul<__nv_bfloat162, float, __nv_bfloat162>(a, b.y); + ret.z = mul<__nv_bfloat162, float, __nv_bfloat162>(a, b.z); + ret.w = mul<__nv_bfloat162, float, __nv_bfloat162>(a, b.w); + return ret; +} +#endif // ENABLE_BF16 + +template <> +inline __device__ uint32_t mul(float a, uint32_t b) { + union { + float16 out[2]; + uint32_t t_out; + }; + + union { + float16 in[2]; + uint32_t t_in; + }; + t_in = b; +#pragma unroll + for (int i = 0; i < 2; ++i) { + out[i] = static_cast(a) * in[i]; + } + return t_out; +} + +template <> +inline __device__ float16 mul(float a, float16 b) { + return static_cast(a) * b; +} + +template <> +inline __device__ uint2 mul(float a, uint2 b) { + union { + uint2 tmp_in; + float16 tmp_in_fp16[4]; + }; + tmp_in = b; + union { + uint2 ret; + float16 tmp_out_fp16[4]; + }; + +#pragma unroll + for (int i = 0; i < 4; ++i) { + tmp_out_fp16[i] = mul(a, tmp_in_fp16[i]); + } + return ret; +} + +template <> +inline __device__ uint4 mul(float a, uint4 b) { + union { + uint4 tmp_in; + float16 tmp_in_fp16[8]; + }; + tmp_in = b; + union { + uint4 ret; + float16 tmp_out_fp16[8]; + }; +#pragma unroll + for (int i = 0; i < 8; ++i) { + tmp_out_fp16[i] = mul(a, tmp_in_fp16[i]); + } + return ret; +} + +template <> +inline __device__ float2 mul(float a, float2 b) { + float2 c; + c.x = a * b.x; + c.y = a * b.y; + return c; +} + +template <> +inline __device__ float4 mul(float a, float4 b) { + float4 c; + c.x = a * b.x; + c.y = a * b.y; + c.z = a * b.z; + c.w = a * b.w; + return c; +} + +template <> +inline __device__ Float8_ mul(float a, Float8_ b) { + Float8_ c; + c.x = mul(a, b.x); + c.y = mul(a, b.y); + c.z = mul(a, b.z); + c.w = mul(a, b.w); + return c; +} + template <> inline __device__ float2 mul(float2 a, float2 b) { float2 c; @@ -1196,6 +1600,171 @@ inline __device__ Float8_ cast_to_float(bf16_8_t u) { } #endif // ENABLE_BF16 +template +inline __device__ T roundWithTiesToEven(T x) { + T xLower = floor(x); + T xUpper = ceil(x); + // x is in interval [xl,xu]. Choose closest of two bounds, breaking ties to + // even. + T dLower = x - xLower; + T dUpper = xUpper - x; + return static_cast( + (dLower == dUpper ? fmod(xLower, 2.0F) == 0.0F : dLower < dUpper) + ? xLower + : xUpper); +} + +template +inline __device__ T round_tmp(D val); + +template <> +inline __device__ uint8_t round_tmp(float val) { + float quant_value = roundWithTiesToEven(val); + quant_value = quant_value > 127.0f ? 127.0f : quant_value; + quant_value = quant_value < -127.0f ? -127.0f : quant_value; + return static_cast(quant_value + 128.0); +} + +template <> +inline __device__ uint8_t round_tmp(float16 val) { + float quant_value = roundWithTiesToEven(static_cast(val)); + quant_value = quant_value > 127.0f ? 127.0f : quant_value; + quant_value = quant_value < -127.0f ? -127.0f : quant_value; + return static_cast(quant_value + 128.0); +} + +#ifdef ENABLE_BF16 +template <> +inline __device__ uint8_t round_tmp(__nv_bfloat16 val) { + float quant_value = + static_cast(roundWithTiesToEven(static_cast(val))); + quant_value = quant_value > 127.0f ? 127.0f : quant_value; + quant_value = quant_value < -127.0f ? -127.0f : quant_value; + return static_cast(quant_value + 128.0); +} +#endif + +template <> +inline __device__ uint16_t round_tmp(float2 val) { + union { + uint16_t ret; + uint8_t tmp[2]; + }; + tmp[0] = round_tmp(val.x); + tmp[1] = round_tmp(val.y); + return ret; +} + +template <> +inline __device__ uint32_t round_tmp(float4 val) { + union { + uint32_t ret; + uint8_t tmp[4]; + }; + tmp[0] = round_tmp(val.x); + tmp[1] = round_tmp(val.y); + tmp[2] = round_tmp(val.z); + tmp[3] = round_tmp(val.w); + return ret; +} + +template <> +inline __device__ uint16_t round_tmp(uint32_t val) { + union { + uint8_t int8[2]; + uint16_t ret; + }; + union { + float16 fp16[2]; + uint32_t tmp; + }; + tmp = val; + +#pragma unroll + for (int i = 0; i < 2; ++i) { + int8[i] = round_tmp(fp16[i]); + } + + return ret; +} + +template <> +inline __device__ uint32_t round_tmp(uint2 val) { + union { + uint8_t int8[4]; + uint32_t ret; + }; + + union { + uint2 ui2; + float16 tmp_fp16[4]; + }; + ui2 = val; + +#pragma unroll + for (int i = 0; i < 4; ++i) { + int8[i] = round_tmp(tmp_fp16[i]); + } + return ret; +} + +template <> +inline __device__ uint64_t round_tmp(uint4 val) { + union { + uint8_t int8[8]; + uint64_t ret; + }; + + union { + uint4 ui4; + float16 tmp_fp16[8]; + }; + ui4 = val; + +#pragma unroll + for (int i = 0; i < 8; ++i) { + int8[i] = round_tmp(tmp_fp16[i]); + } + return ret; +} + +#ifdef ENABLE_BF16 +template <> +inline __device__ uint16_t round_tmp(__nv_bfloat162 val) { + union { + uint8_t tmp[2]; + uint16_t ret; + }; + tmp[0] = round_tmp(val.x); + tmp[1] = round_tmp(val.y); + return ret; +} + +template <> +inline __device__ uint32_t round_tmp(bf16_4_t val) { + union { + uint16_t tmp[2]; + uint32_t ret; + }; + tmp[0] = round_tmp(val.x); + tmp[1] = round_tmp(val.y); + return ret; +} + +template <> +inline __device__ uint64_t round_tmp(bf16_8_t val) { + union { + uint16_t int16[4]; + uint64_t int64; + }; + int16[0] = round_tmp(val.x); + int16[1] = round_tmp(val.y); + int16[2] = round_tmp(val.z); + int16[3] = round_tmp(val.w); + return int64; +} +#endif + inline __device__ float2 rotary_embedding_coefficient(const int zid, const int rot_embed_dim, const float t_step) { @@ -2178,6 +2747,122 @@ struct Qk_dot { } }; +constexpr int32_t WARP_SIZE = 32; +constexpr int32_t HALF_WARP = 16; +constexpr float QUANT_MAX_BOUND = 127.0; +constexpr float QUANT_MIN_BOUND = -127.0; + +template +struct QuantFunc { + __host__ __device__ uint8_t operator()(T x, float quant_scale) { + float tmp = static_cast(x) * quant_scale; + tmp = round(tmp); + if (tmp > QUANT_MAX_BOUND) + tmp = QUANT_MAX_BOUND; + else if (tmp < QUANT_MIN_BOUND) + tmp = QUANT_MIN_BOUND; + return static_cast(tmp + 128.0f); + } +}; + +template +struct MaxFunc { + __device__ T operator()(T a, T b) { return max(a, b); } +}; + +template <> +struct MaxFunc { + __device__ half operator()(half a, half b) { +#if __CUDA_ARCH__ >= 800 + return __hmax(a, b); +#else + return max(static_cast(a), static_cast(b)); +#endif + } +}; + +#if CUDA_VERSION >= 11000 && defined(ENABLE_BF16) +template <> +struct MaxFunc<__nv_bfloat16> { + __device__ __nv_bfloat16 operator()(__nv_bfloat16 a, __nv_bfloat16 b) { +#if __CUDA_ARCH__ >= 800 + return __hmax(a, b); +#else + return max(static_cast(a), static_cast(b)); +#endif + } +}; +#endif + +template +struct AbsFunc { + __device__ T operator()(T x) { return abs(x); } +}; + +template <> +struct AbsFunc { + __device__ half operator()(half x) { +#if __CUDA_ARCH__ >= 800 + return __habs(x); +#else + return abs(static_cast(x)); +#endif + } +}; + +#if CUDA_VERSION >= 11000 && defined(ENABLE_BF16) +template <> +struct AbsFunc<__nv_bfloat16> { + __device__ __nv_bfloat16 operator()(__nv_bfloat16 x) { +#if __CUDA_ARCH__ >= 800 + return __habs(x); +#else + return abs(static_cast(x)); +#endif + } +}; +#endif + +template +__inline__ __device__ T LocalReduceMax(Vec& vec) { // NOLINT + T local_max = static_cast(0.0); +#pragma unroll + for (int i = 0; i < VecSize; ++i) { + local_max = vec[i] > local_max ? vec[i] : local_max; + } + return local_max; +} + +template +__inline__ __device__ T WarpReduceAbsMax(T val, unsigned lane_mask) { +#pragma unroll + for (int mask = HALF_WARP; mask > 0; mask >>= 1) { + val = MaxFunc()(val, __shfl_xor_sync(lane_mask, val, mask, WARP_SIZE)); + } + return val; +} + +template +__inline__ __device__ T BlockReduceAbsMax(T val, unsigned mask) { + static __shared__ T smem[WARP_SIZE]; + int32_t lane_id = threadIdx.x % WARP_SIZE; + int32_t warp_id = threadIdx.x / WARP_SIZE; + + val = WarpReduceAbsMax(val, mask); + + if (lane_id == 0) { + smem[warp_id] = val; + } + + __syncthreads(); + + T abs_max_val = (threadIdx.x < (blockDim.x / WARP_SIZE)) + ? smem[threadIdx.x] + : static_cast(0.0f); + abs_max_val = WarpReduceAbsMax(abs_max_val, mask); + return abs_max_val; +} + } // namespace fusion } // namespace phi diff --git a/python/paddle/incubate/nn/functional/block_multihead_attention.py b/python/paddle/incubate/nn/functional/block_multihead_attention.py index 1573b2d0fec99..4bffb8f2e94b9 100644 --- a/python/paddle/incubate/nn/functional/block_multihead_attention.py +++ b/python/paddle/incubate/nn/functional/block_multihead_attention.py @@ -30,12 +30,26 @@ def block_multihead_attention( block_tables, pre_key_cache=None, pre_value_cache=None, + cache_k_quant_scales=None, + cache_v_quant_scales=None, + cache_k_dequant_scales=None, + cache_v_dequant_scales=None, + qkv_out_scale=None, + qkv_bias=None, + out_shift=None, + out_smooth=None, rope_emb=None, mask=None, tgt_mask=None, max_seq_len=-1, block_size=64, use_neox_style=False, + use_dynamic_cachekv_quant=False, + quant_round_type=1, + quant_max_bound=127.0, + quant_min_bound=-127.0, + out_scale=-1, + compute_dtype="default", ): """ Block Multi-head attention for text summarization. @@ -257,9 +271,23 @@ def block_multihead_attention( rope_emb, mask, tgt_mask, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_dequant_scales, + cache_v_dequant_scales, + qkv_out_scale, + qkv_bias, + out_shift, + out_smooth, max_seq_len, block_size, use_neox_style, + use_dynamic_cachekv_quant, + quant_round_type, + quant_max_bound, + quant_min_bound, + out_scale, + compute_dtype, ) helper = LayerHelper('block_multihead_attention', **locals()) @@ -287,6 +315,22 @@ def block_multihead_attention( inputs['mask'] = mask if tgt_mask is not None: inputs['tgt_mask'] = tgt_mask + if cache_k_quant_scales is not None: + inputs["cache_k_quant_scales"] = cache_k_quant_scales + if cache_v_quant_scales is not None: + inputs["cache_v_quant_scales"] = cache_v_quant_scales + if cache_k_dequant_scales is not None: + inputs["cache_k_dequant_scales"] = cache_k_dequant_scales + if cache_v_dequant_scales is not None: + inputs["cache_v_dequant_scales"] = cache_v_dequant_scales + if qkv_out_scale is not None: + inputs["qkv_out_scale"] = qkv_out_scale + if qkv_bias is not None: + inputs["qkv_bias"] = qkv_bias + if out_shift is not None: + inputs["out_shift"] = out_shift + if out_smooth is not None: + inputs["out_smooth"] = out_smooth outputs = { 'fmha_out': out, @@ -302,6 +346,12 @@ def block_multihead_attention( 'max_seq_len': max_seq_len, 'block_size': block_size, 'use_neox_style': use_neox_style, + 'dynamic_cachekv_quant': use_dynamic_cachekv_quant, + 'quant_round_type': quant_round_type, + 'quant_max_bound': quant_max_bound, + 'quant_min_bound': quant_min_bound, + 'out_scale': out_scale, + 'compute_dtype': compute_dtype, }, ) return out, qkv, key_cache, value_cache diff --git a/test/legacy_test/test_block_multihead_attention.py b/test/legacy_test/test_block_multihead_attention.py index a5c884351e180..04919ca3d8240 100644 --- a/test/legacy_test/test_block_multihead_attention.py +++ b/test/legacy_test/test_block_multihead_attention.py @@ -25,6 +25,7 @@ from paddle.static import Program, program_guard paddle.seed(2023) +np.random.seed(2023) is_sm8x = ( @@ -85,6 +86,9 @@ def naive_attention_impl( pre_cache_v=None, mask=None, scale=1.0, + cache_k_dequant_scales=None, + cache_v_dequant_scales=None, + use_cachekv_int8="None", ): batch = query.shape[0] heads = query.shape[1] @@ -95,10 +99,22 @@ def naive_attention_impl( key = key.reshape([batch, kv_head, 1, seq_len, head_dim]) key = paddle.tile(key, [1, 1, heads // kv_head, 1, 1]) key = key.reshape([batch, heads, seq_len, head_dim]) + + if use_cachekv_int8 == "dynamic": + unsqueeze_shape = [2, 3] + elif use_cachekv_int8 == "static": + unsqueeze_shape = [0, 2, 3] if pre_cache_k is not None: key = paddle.concat([pre_cache_k, key], axis=2) if cache_k is not None: - key = paddle.concat([cache_k, key], axis=2) + if cache_k_dequant_scales is not None: + dequant_cache_k = ( + (cache_k.astype('float32') - 128.0) + * cache_k_dequant_scales.unsqueeze(unsqueeze_shape) + ).astype(key.dtype) + key = paddle.concat([dequant_cache_k, key], axis=2) + else: + key = paddle.concat([cache_k, key], axis=2) value = value.reshape([batch, kv_head, 1, seq_len, head_dim]) value = paddle.tile(value, [1, 1, heads // kv_head, 1, 1]) @@ -106,7 +122,14 @@ def naive_attention_impl( if pre_cache_v is not None: value = paddle.concat([pre_cache_v, value], axis=2) if cache_v is not None: - value = paddle.concat([cache_v, value], axis=2) + if cache_v_dequant_scales is not None: + dequant_cache_v = ( + (cache_v.astype('float32') - 128.0) + * cache_v_dequant_scales.unsqueeze(unsqueeze_shape) + ).astype(value.dtype) + value = paddle.concat([dequant_cache_v, value], axis=2) + else: + value = paddle.concat([cache_v, value], axis=2) qk_res = paddle.matmul(query, key, transpose_y=True) attention = qk_res * scale @@ -355,12 +378,20 @@ def test_all(self): self.block_tables, None, # pre_key_cache None, # pre_value_cache + None, # cache_k_quant_scales + None, # cache_v_quant_scales + None, # cache_k_dequant_scales + None, # cache_v_dequant_scales + None, # qkv_out_scale + None, # qkv_bias + None, # out_shift + None, # out_smooth None, # rotary_embs None, # attn_mask None, # tgt_mask self.seq_len, self.blocksize, - False, # use_neox_rotary_style + False, # use_neox_rotary_style, )[0] np.testing.assert_allclose( @@ -451,6 +482,14 @@ def test_all(self): self.block_tables, None, # pre_key_cache None, # pre_value_cache + None, # cache_k_quant_scales + None, # cache_v_quant_scales + None, # cache_k_dequant_scales + None, # cache_v_dequant_scales + None, # qkv_out_scale + None, # qkv_bias + None, # out_shift + None, # out_smooth None, # rotary_embs None, # attn_mask None, # tgt_mask @@ -637,6 +676,14 @@ def test_all(self): self.block_tables, None, # pre_key_cache None, # pre_value_cache + None, # cache_k_quant_scales + None, # cache_v_quant_scales + None, # cache_k_dequant_scales + None, # cache_v_dequant_scales + None, # qkv_out_scale + None, # qkv_bias + None, # out_shift + None, # out_smooth self.rope_emb, # rotary_embs None, # attn_mask None, # tgt_mask @@ -739,6 +786,14 @@ def test_all(self): self.block_tables, None, # pre_key_cache None, # pre_value_cache + None, # cache_k_quant_scales + None, # cache_v_quant_scales + None, # cache_k_dequant_scales + None, # cache_v_dequant_scales + None, # qkv_out_scale + None, # qkv_bias + None, # out_shift + None, # out_smooth self.rope_emb, # rotary_embs None, # attn_mask None, # tgt_mask @@ -762,7 +817,7 @@ def test_all(self): "core is not compiled with CUDA and cuda version need larger than or equal to 11.4" "and device's compute capability must be 8.x or 90", ) -class TestBlockMultiHeadAttnPreCacbe(unittest.TestCase): +class TestBlockMultiHeadAttnPreCache(unittest.TestCase): def setUp(self): paddle.disable_static() self.name = "TestBlockMultiHeadAttnPreCacbe" @@ -910,6 +965,14 @@ def test_all(self): self.block_tables, self.pre_cache_k, # pre_key_cache self.pre_cache_v, # pre_value_cache + None, # cache_k_quant_scales + None, # cache_v_quant_scales + None, # cache_k_dequant_scales + None, # cache_v_dequant_scales + None, # qkv_out_scale + None, # qkv_bias + None, # out_shift + None, # out_smooth None, # rotary_embs self.attention_mask, # attn_mask None, # tgt_mask @@ -1006,6 +1069,14 @@ def test_all(self): self.block_tables, self.pre_cache_k, # pre_key_cache self.pre_cache_v, # pre_value_cache + None, # cache_k_quant_scales + None, # cache_v_quant_scales + None, # cache_k_dequant_scales + None, # cache_v_dequant_scales + None, # qkv_out_scale + None, # qkv_bias + None, # out_shift + None, # out_smooth None, # rotary_embs self.attention_mask, # attn_mask None, # tgt_mask @@ -1200,6 +1271,14 @@ def test_all(self): block_tables, None, # pre_key_cache None, # pre_value_cache + None, # cache_k_quant_scales + None, # cache_v_quant_scales + None, # cache_k_dequant_scales + None, # cache_v_dequant_scales + None, # qkv_out_scale + None, # qkv_bias + None, # out_shift + None, # out_smooth None, # rotary_embs None, # attn_mask None, # tgt_mask @@ -1233,5 +1312,1575 @@ def test_all(self): ) +@unittest.skipIf( + not core.is_compiled_with_cuda() + or get_cuda_version() < 11040 + or not is_sm_supported, + "core is not compiled with CUDA and cuda version need larger than or equal to 11.4" + "and device's compute capability must be 8.x or 90", +) +class TestBlockMultiHeadAttnEncDecPTQDequant(unittest.TestCase): + def setUp(self): + paddle.disable_static() + self.name = "TestBlockMultiHeadAttnEncDec" + self.place = paddle.CUDAPlace(0) + self.batch_size = 2 + self.num_head = 8 + self.seq_len = 64 + self.max_dec_len = 64 + self.dim_head = 64 + self.hid_dim = self.num_head * self.dim_head + self.blocksize = 64 + self.block_num_per_seq = ( + self.seq_len + self.max_dec_len + self.blocksize - 1 + ) // self.blocksize + self.max_block_num = self.block_num_per_seq * self.batch_size + self.free_list = list(range(self.max_block_num - 1, -1, -1)) + self.seq_lens_encoder = paddle.to_tensor( + [ + self.seq_len, + ] + * self.batch_size, + "int32", + ) + self.seq_lens_decoder = paddle.to_tensor( + [ + 0, + ] + * self.batch_size, + "int32", + ) + self.seq_lens_this_time = self.seq_lens_encoder + self.shape = ( + self.batch_size, + self.num_head, + self.seq_len, + self.dim_head, + ) + self.cache_shape = ( + self.max_block_num, + self.num_head, + self.blocksize, + self.dim_head, + ) + self.dtype = 'float16' + self.attention_mask = create_attn_mask( + self.dtype, + self.batch_size, + [ + self.seq_len, + ] + * self.batch_size, + ) + self.scale = 1.0 / np.sqrt(self.shape[-1]) + self.cache_k = paddle.zeros(shape=self.cache_shape, dtype=self.dtype) + self.cache_v = paddle.zeros(shape=self.cache_shape, dtype=self.dtype) + self.block_tables = paddle.zeros( + shape=(self.batch_size, self.block_num_per_seq), dtype="int32" + ) + for i in range(self.batch_size): + need_block_num = ( + self.seq_len + self.max_dec_len + self.blocksize - 1 + ) // self.blocksize + for j in range(need_block_num): + self.block_tables[i, j] = self.free_list.pop() + ( + self.padding_offset, + self.cum_offset, + self.cu_seqlens_q, + self.cu_seqlens_k, + ) = get_padding_offset( + self.batch_size, self.seq_len, self.seq_lens_this_time + ) + self.token_num = self.padding_offset.shape[0] + + def test_all(self): + paddle.disable_static() + # encoder + query = np.random.randint(-65535, 65535, self.shape, 'int32') + q = paddle.to_tensor( + query, place=self.place, dtype='int32', stop_gradient=False + ) + key = np.random.randint(-65535, 65535, self.shape, 'int32') + k = paddle.to_tensor( + key, place=self.place, dtype='int32', stop_gradient=False + ) + value = np.random.randint(-65535, 65535, self.shape, 'int32') + v = paddle.to_tensor( + value, place=self.place, dtype='int32', stop_gradient=False + ) + + qkv = paddle.stack( + [ + q.transpose([0, 2, 1, 3]).reshape( + [self.token_num, self.hid_dim] + ), + k.transpose([0, 2, 1, 3]).reshape( + [self.token_num, self.hid_dim] + ), + v.transpose([0, 2, 1, 3]).reshape( + [self.token_num, self.hid_dim] + ), + ], + axis=1, + ).reshape([self.token_num, -1]) + + q = q.transpose([0, 2, 1, 3]).reshape([self.token_num, self.hid_dim]) + k = k.transpose([0, 2, 1, 3]).reshape([self.token_num, self.hid_dim]) + v = v.transpose([0, 2, 1, 3]).reshape([self.token_num, self.hid_dim]) + + q_out_scale = 10.0 / paddle.max(q, axis=0).astype('float32') + k_out_scale = 10.0 / paddle.max(k, axis=0).astype('float32') + v_out_scale = 10.0 / paddle.max(v, axis=0).astype('float32') + + qkv_out_scale = paddle.concat( + [q_out_scale, k_out_scale, v_out_scale], axis=0 + ) + + q_bias = paddle.ones([self.hid_dim], dtype=self.dtype) + k_bias = paddle.ones([self.hid_dim], dtype=self.dtype) + v_bias = paddle.ones([self.hid_dim], dtype=self.dtype) + + qkv_bias = paddle.concat([q_bias, k_bias, v_bias], axis=-1) + + # dequant + q = (q.astype('float32') * q_out_scale).astype(self.dtype) + k = (k.astype('float32') * k_out_scale).astype(self.dtype) + v = (v.astype('float32') * v_out_scale).astype(self.dtype) + + # add bias + q = q + q_bias + k = k + k_bias + v = v + v_bias + + # transpose to origin + q = q.reshape( + [self.batch_size, self.seq_len, self.num_head, self.dim_head] + ).transpose([0, 2, 1, 3]) + k = k.reshape( + [self.batch_size, self.seq_len, self.num_head, self.dim_head] + ).transpose([0, 2, 1, 3]) + v = v.reshape( + [self.batch_size, self.seq_len, self.num_head, self.dim_head] + ).transpose([0, 2, 1, 3]) + + out_ = naive_attention_impl( + q, k, v, None, None, None, None, self.attention_mask, self.scale + ) + out_ = remove_padding( + self.seq_lens_this_time, self.cu_seqlens_q, out_, self.token_num + ) + out = block_multihead_attention( + qkv, + self.cache_k, + self.cache_v, + self.seq_lens_encoder, + self.seq_lens_decoder, + self.seq_lens_this_time, + self.padding_offset, + self.cum_offset, + self.cu_seqlens_q, + self.cu_seqlens_k, + self.block_tables, + None, # pre_key_cache + None, # pre_value_cache + None, # cache_k_quant_scales + None, # cache_v_quant_scales + None, # cache_k_dequant_scales + None, # cache_v_dequant_scales + qkv_out_scale, # qkv_out_scale + qkv_bias, # qkv_bias + None, # out_shift + None, # out_smooth + None, # rotary_embs + None, # attn_mask + None, # tgt_mask + self.seq_len, + self.blocksize, + False, # use_neox_rotary_style, + compute_dtype="fp16", + )[0] + + np.testing.assert_allclose( + out.numpy(), + out_.numpy(), + rtol=100, + atol=1, + ) + + # decoder + naive_cache_k, naive_cache_v = block_cache_to_naive_cache( + self.cache_k, + self.cache_v, + self.batch_size, + self.block_tables, + self.seq_len, + ) + + self.seq_lens_decoder[:] = self.seq_lens_encoder + self.seq_lens_encoder[:] = 0 + self.seq_lens_this_time[:] = 1 + self.shape = ( + self.batch_size, + self.num_head, + 1, + self.dim_head, + ) + query = np.random.randint(-65535, 65535, self.shape, 'int32') + q = paddle.to_tensor( + query, place=self.place, dtype='int32', stop_gradient=False + ) + key = np.random.randint(-65535, 65535, self.shape, 'int32') + k = paddle.to_tensor( + key, place=self.place, dtype='int32', stop_gradient=False + ) + value = np.random.randint(-65535, 65535, self.shape, 'int32') + v = paddle.to_tensor( + value, place=self.place, dtype='int32', stop_gradient=False + ) + + qkv = paddle.stack( + [ + q.transpose([0, 2, 1, 3]).reshape( + [self.batch_size, self.hid_dim] + ), + k.transpose([0, 2, 1, 3]).reshape( + [self.batch_size, self.hid_dim] + ), + v.transpose([0, 2, 1, 3]).reshape( + [self.batch_size, self.hid_dim] + ), + ], + axis=1, + ).reshape([self.batch_size, -1]) + + q = q.transpose([0, 2, 1, 3]).reshape([self.batch_size, self.hid_dim]) + k = k.transpose([0, 2, 1, 3]).reshape([self.batch_size, self.hid_dim]) + v = v.transpose([0, 2, 1, 3]).reshape([self.batch_size, self.hid_dim]) + + q_out_scale = 1.0 / paddle.max(q, axis=0).astype('float32') + k_out_scale = 1.0 / paddle.max(k, axis=0).astype('float32') + v_out_scale = 1.0 / paddle.max(v, axis=0).astype('float32') + + qkv_out_scale = paddle.concat( + [q_out_scale, k_out_scale, v_out_scale], axis=0 + ) + + q_bias = paddle.ones([self.hid_dim], dtype=self.dtype) * 0.1 + k_bias = paddle.ones([self.hid_dim], dtype=self.dtype) * 0.1 + v_bias = paddle.ones([self.hid_dim], dtype=self.dtype) * 0.1 + + qkv_bias = paddle.concat([q_bias, k_bias, v_bias], axis=-1) + + # dequant + q = (q.astype('float32') * q_out_scale).astype(self.dtype) + k = (k.astype('float32') * k_out_scale).astype(self.dtype) + v = (v.astype('float32') * v_out_scale).astype(self.dtype) + + # add bias + q = q + q_bias + k = k + k_bias + v = v + v_bias + + # transpose to origin + q = q.reshape( + [self.batch_size, 1, self.num_head, self.dim_head] + ).transpose([0, 2, 1, 3]) + k = k.reshape( + [self.batch_size, 1, self.num_head, self.dim_head] + ).transpose([0, 2, 1, 3]) + v = v.reshape( + [self.batch_size, 1, self.num_head, self.dim_head] + ).transpose([0, 2, 1, 3]) + + ( + self.padding_offset, + self.cum_offset, + self.cu_seqlens_q, + self.cu_seqlens_k, + ) = get_padding_offset(self.batch_size, 1, self.seq_lens_this_time) + + out_ = ( + naive_attention_impl( + q, + k, + v, + naive_cache_k, + naive_cache_v, + None, + None, + None, + self.scale, + ) + .transpose([0, 2, 1, 3]) + .reshape([self.batch_size, -1]) + ) + out = block_multihead_attention( + qkv, + self.cache_k, + self.cache_v, + self.seq_lens_encoder, + self.seq_lens_decoder, + self.seq_lens_this_time, + self.padding_offset, + self.cum_offset, + self.cu_seqlens_q, + self.cu_seqlens_k, + self.block_tables, + None, # pre_key_cache + None, # pre_value_cache + None, # cache_k_quant_scales + None, # cache_v_quant_scales + None, # cache_k_dequant_scales + None, # cache_v_dequant_scales + qkv_out_scale, # qkv_out_scale + qkv_bias, # qkv_bias + None, # out_shift + None, # out_smooth + None, # rotary_embs + None, # attn_mask + None, # tgt_mask + 1, # seq_len, + self.blocksize, + False, # use_neox_rotary_style + compute_dtype="fp16", + )[0] + # NOTE: The diff of decoder is a little big + np.testing.assert_allclose( + out.numpy(), + out_.numpy(), + rtol=100, + atol=1, + ) + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or get_cuda_version() < 11040 + or not is_sm_supported, + "core is not compiled with CUDA and cuda version need larger than or equal to 11.4" + "and device's compute capability must be 8.x or 90", +) +class TestBlockMultiHeadAttnEncDecPTQDequantQuantShiftSmooth(unittest.TestCase): + def setUp(self): + paddle.disable_static() + self.name = "TestBlockMultiHeadAttnEncDec" + self.place = paddle.CUDAPlace(0) + self.batch_size = 2 + self.num_head = 8 + self.seq_len = 64 + self.max_dec_len = 64 + self.dim_head = 64 + self.hid_dim = self.num_head * self.dim_head + self.blocksize = 64 + self.block_num_per_seq = ( + self.seq_len + self.max_dec_len + self.blocksize - 1 + ) // self.blocksize + self.max_block_num = self.block_num_per_seq * self.batch_size + self.free_list = list(range(self.max_block_num - 1, -1, -1)) + self.seq_lens_encoder = paddle.to_tensor( + [ + self.seq_len, + ] + * self.batch_size, + "int32", + ) + self.seq_lens_decoder = paddle.to_tensor( + [ + 0, + ] + * self.batch_size, + "int32", + ) + self.seq_lens_this_time = self.seq_lens_encoder + self.shape = ( + self.batch_size, + self.num_head, + self.seq_len, + self.dim_head, + ) + self.cache_shape = ( + self.max_block_num, + self.num_head, + self.blocksize, + self.dim_head, + ) + self.dtype = 'float16' + self.attention_mask = create_attn_mask( + self.dtype, + self.batch_size, + [ + self.seq_len, + ] + * self.batch_size, + ) + self.scale = 1.0 / np.sqrt(self.shape[-1]) + self.cache_k = paddle.zeros(shape=self.cache_shape, dtype=self.dtype) + self.cache_v = paddle.zeros(shape=self.cache_shape, dtype=self.dtype) + self.block_tables = paddle.zeros( + shape=(self.batch_size, self.block_num_per_seq), dtype="int32" + ) + for i in range(self.batch_size): + need_block_num = ( + self.seq_len + self.max_dec_len + self.blocksize - 1 + ) // self.blocksize + for j in range(need_block_num): + self.block_tables[i, j] = self.free_list.pop() + ( + self.padding_offset, + self.cum_offset, + self.cu_seqlens_q, + self.cu_seqlens_k, + ) = get_padding_offset( + self.batch_size, self.seq_len, self.seq_lens_this_time + ) + self.token_num = self.padding_offset.shape[0] + + def test_all(self): + paddle.disable_static() + # encoder + query = np.random.randint(-65535, 65535, self.shape, 'int32') + q = paddle.to_tensor( + query, place=self.place, dtype='int32', stop_gradient=False + ) + key = np.random.randint(-65535, 65535, self.shape, 'int32') + k = paddle.to_tensor( + key, place=self.place, dtype='int32', stop_gradient=False + ) + value = np.random.randint(-65535, 65535, self.shape, 'int32') + v = paddle.to_tensor( + value, place=self.place, dtype='int32', stop_gradient=False + ) + + qkv = paddle.stack( + [ + q.transpose([0, 2, 1, 3]).reshape( + [self.token_num, self.hid_dim] + ), + k.transpose([0, 2, 1, 3]).reshape( + [self.token_num, self.hid_dim] + ), + v.transpose([0, 2, 1, 3]).reshape( + [self.token_num, self.hid_dim] + ), + ], + axis=1, + ).reshape([self.token_num, -1]) + + q = q.transpose([0, 2, 1, 3]).reshape([self.token_num, self.hid_dim]) + k = k.transpose([0, 2, 1, 3]).reshape([self.token_num, self.hid_dim]) + v = v.transpose([0, 2, 1, 3]).reshape([self.token_num, self.hid_dim]) + + q_out_scale = 1.0 / paddle.max(q, axis=0).astype('float32') + k_out_scale = 1.0 / paddle.max(k, axis=0).astype('float32') + v_out_scale = 1.0 / paddle.max(v, axis=0).astype('float32') + + qkv_out_scale = paddle.concat( + [q_out_scale, k_out_scale, v_out_scale], axis=0 + ) + + q_bias = paddle.ones([self.hid_dim], dtype=self.dtype) + k_bias = paddle.ones([self.hid_dim], dtype=self.dtype) + v_bias = paddle.ones([self.hid_dim], dtype=self.dtype) + + qkv_bias = paddle.concat([q_bias, k_bias, v_bias], axis=-1) + + # dequant + q = (q.astype('float32') * q_out_scale).astype(self.dtype) + k = (k.astype('float32') * k_out_scale).astype(self.dtype) + v = (v.astype('float32') * v_out_scale).astype(self.dtype) + + # add bias + q = q + q_bias + k = k + k_bias + v = v + v_bias + + # transpose to origin + q = q.reshape( + [self.batch_size, self.seq_len, self.num_head, self.dim_head] + ).transpose([0, 2, 1, 3]) + k = k.reshape( + [self.batch_size, self.seq_len, self.num_head, self.dim_head] + ).transpose([0, 2, 1, 3]) + v = v.reshape( + [self.batch_size, self.seq_len, self.num_head, self.dim_head] + ).transpose([0, 2, 1, 3]) + + out_ = naive_attention_impl( + q, k, v, None, None, None, None, self.attention_mask, self.scale + ) + + out_ = remove_padding( + self.seq_lens_this_time, self.cu_seqlens_q, out_, self.token_num + ) + + # shift smooth + shift = np.random.random([self.num_head * self.dim_head]) + shift = paddle.to_tensor(shift, dtype=self.dtype, place=self.place) + + smooth = np.random.random([self.num_head * self.dim_head]) + smooth = paddle.to_tensor(smooth, dtype=self.dtype, place=self.place) + + out_ = (out_ + shift) * smooth + + # quant + out_ *= 127.0 + + out_ = paddle.where(out_ <= -127, paddle.full_like(out_, -127), out_) + out_ = paddle.where(out_ >= 127, paddle.full_like(out_, 127), out_) + out_ = paddle.round(out_).astype('int8') + + out = block_multihead_attention( + qkv, + self.cache_k, + self.cache_v, + self.seq_lens_encoder, + self.seq_lens_decoder, + self.seq_lens_this_time, + self.padding_offset, + self.cum_offset, + self.cu_seqlens_q, + self.cu_seqlens_k, + self.block_tables, + None, # pre_key_cache + None, # pre_value_cache + None, # cache_k_quant_scales + None, # cache_v_quant_scales + None, # cache_k_dequant_scales + None, # cache_v_dequant_scales + qkv_out_scale, # qkv_out_scale + qkv_bias, # qkv_bias + shift, # out_shift + smooth, # out_smooth + None, # rotary_embs + None, # attn_mask + None, # tgt_mask + self.seq_len, + self.blocksize, + False, # use_neox_rotary_style, + compute_dtype="fp16", + out_scale=1.0, + )[0] + + np.testing.assert_allclose( + out.numpy(), + out_.numpy(), + rtol=1, + atol=1, + ) + + # decoder + naive_cache_k, naive_cache_v = block_cache_to_naive_cache( + self.cache_k, + self.cache_v, + self.batch_size, + self.block_tables, + self.seq_len, + ) + + self.seq_lens_decoder[:] = self.seq_lens_encoder + self.seq_lens_encoder[:] = 0 + self.seq_lens_this_time[:] = 1 + self.shape = ( + self.batch_size, + self.num_head, + 1, + self.dim_head, + ) + query = np.random.randint(-65535, 65535, self.shape, 'int32') + q = paddle.to_tensor( + query, place=self.place, dtype='int32', stop_gradient=False + ) + key = np.random.randint(-65535, 65535, self.shape, 'int32') + k = paddle.to_tensor( + key, place=self.place, dtype='int32', stop_gradient=False + ) + value = np.random.randint(-65535, 65535, self.shape, 'int32') + v = paddle.to_tensor( + value, place=self.place, dtype='int32', stop_gradient=False + ) + + qkv = paddle.stack( + [ + q.transpose([0, 2, 1, 3]).reshape( + [self.batch_size, self.hid_dim] + ), + k.transpose([0, 2, 1, 3]).reshape( + [self.batch_size, self.hid_dim] + ), + v.transpose([0, 2, 1, 3]).reshape( + [self.batch_size, self.hid_dim] + ), + ], + axis=1, + ).reshape([self.batch_size, -1]) + + q = q.transpose([0, 2, 1, 3]).reshape([self.batch_size, self.hid_dim]) + k = k.transpose([0, 2, 1, 3]).reshape([self.batch_size, self.hid_dim]) + v = v.transpose([0, 2, 1, 3]).reshape([self.batch_size, self.hid_dim]) + + q_out_scale = 1.0 / paddle.max(q, axis=0).astype('float32') + k_out_scale = 1.0 / paddle.max(k, axis=0).astype('float32') + v_out_scale = 1.0 / paddle.max(v, axis=0).astype('float32') + + qkv_out_scale = paddle.concat( + [q_out_scale, k_out_scale, v_out_scale], axis=0 + ) + + q_bias = paddle.ones([self.hid_dim], dtype=self.dtype) * 0.1 + k_bias = paddle.ones([self.hid_dim], dtype=self.dtype) * 0.1 + v_bias = paddle.ones([self.hid_dim], dtype=self.dtype) * 0.1 + + qkv_bias = paddle.concat([q_bias, k_bias, v_bias], axis=-1) + + # dequant + q = (q.astype('float32') * q_out_scale).astype(self.dtype) + k = (k.astype('float32') * k_out_scale).astype(self.dtype) + v = (v.astype('float32') * v_out_scale).astype(self.dtype) + + # add bias + q = q + q_bias + k = k + k_bias + v = v + v_bias + + # transpose to origin + q = q.reshape( + [self.batch_size, 1, self.num_head, self.dim_head] + ).transpose([0, 2, 1, 3]) + k = k.reshape( + [self.batch_size, 1, self.num_head, self.dim_head] + ).transpose([0, 2, 1, 3]) + v = v.reshape( + [self.batch_size, 1, self.num_head, self.dim_head] + ).transpose([0, 2, 1, 3]) + + ( + self.padding_offset, + self.cum_offset, + self.cu_seqlens_q, + self.cu_seqlens_k, + ) = get_padding_offset(self.batch_size, 1, self.seq_lens_this_time) + + out_ = ( + naive_attention_impl( + q, + k, + v, + naive_cache_k, + naive_cache_v, + None, + None, + None, + self.scale, + ) + .transpose([0, 2, 1, 3]) + .reshape([self.batch_size, -1]) + ) + + # shift smooth + shift = np.random.random([self.num_head * self.dim_head]) + shift = paddle.to_tensor(shift, dtype=self.dtype, place=self.place) + + smooth = np.random.random([self.num_head * self.dim_head]) + smooth = paddle.to_tensor(smooth, dtype=self.dtype, place=self.place) + + out_ = (out_ + shift) * smooth + + # quant + out_ *= 127.0 + + out_ = paddle.where(out_ <= -127, paddle.full_like(out_, -127), out_) + out_ = paddle.where(out_ >= 127, paddle.full_like(out_, 127), out_) + out_ = paddle.round(out_).astype('int8') + + out = block_multihead_attention( + qkv, + self.cache_k, + self.cache_v, + self.seq_lens_encoder, + self.seq_lens_decoder, + self.seq_lens_this_time, + self.padding_offset, + self.cum_offset, + self.cu_seqlens_q, + self.cu_seqlens_k, + self.block_tables, + None, # pre_key_cache + None, # pre_value_cache + None, # cache_k_quant_scales + None, # cache_v_quant_scales + None, # cache_k_dequant_scales + None, # cache_v_dequant_scales + qkv_out_scale, # qkv_out_scale + qkv_bias, # qkv_bias + shift, # out_shift + smooth, # out_smooth + None, # rotary_embs + None, # attn_mask + None, # tgt_mask + 1, # seq_len, + self.blocksize, + False, # use_neox_rotary_style + compute_dtype="fp16", + out_scale=1.0, + )[0] + # NOTE: The diff of decoder is a little big + np.testing.assert_allclose( + out.numpy(), + out_.numpy(), + rtol=20, + atol=57, + ) + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or get_cuda_version() < 11040 + or not is_sm_supported, + "core is not compiled with CUDA and cuda version need larger than or equal to 11.4" + "and device's compute capability must be 8.x or 90", +) +class TestBlockMultiHeadAttnEncDecQuant(unittest.TestCase): + def setUp(self): + paddle.disable_static() + self.name = "TestBlockMultiHeadAttnEncDec" + self.place = paddle.CUDAPlace(0) + self.batch_size = 2 + self.num_head = 8 + self.seq_len = 64 + self.max_dec_len = 64 + self.dim_head = 64 + self.hid_dim = self.num_head * self.dim_head + self.blocksize = 64 + self.block_num_per_seq = ( + self.seq_len + self.max_dec_len + self.blocksize - 1 + ) // self.blocksize + self.max_block_num = self.block_num_per_seq * self.batch_size + self.free_list = list(range(self.max_block_num - 1, -1, -1)) + self.seq_lens_encoder = paddle.to_tensor( + [ + self.seq_len, + ] + * self.batch_size, + "int32", + ) + self.seq_lens_decoder = paddle.to_tensor( + [ + 0, + ] + * self.batch_size, + "int32", + ) + self.seq_lens_this_time = self.seq_lens_encoder + self.shape = ( + self.batch_size, + self.num_head, + self.seq_len, + self.dim_head, + ) + self.cache_shape = ( + self.max_block_num, + self.num_head, + self.blocksize, + self.dim_head, + ) + self.dtype = 'float16' + self.attention_mask = create_attn_mask( + self.dtype, + self.batch_size, + [ + self.seq_len, + ] + * self.batch_size, + ) + self.scale = 1.0 / np.sqrt(self.shape[-1]) + self.cache_k = paddle.zeros(shape=self.cache_shape, dtype=self.dtype) + self.cache_v = paddle.zeros(shape=self.cache_shape, dtype=self.dtype) + self.block_tables = paddle.zeros( + shape=(self.batch_size, self.block_num_per_seq), dtype="int32" + ) + for i in range(self.batch_size): + need_block_num = ( + self.seq_len + self.max_dec_len + self.blocksize - 1 + ) // self.blocksize + for j in range(need_block_num): + self.block_tables[i, j] = self.free_list.pop() + ( + self.padding_offset, + self.cum_offset, + self.cu_seqlens_q, + self.cu_seqlens_k, + ) = get_padding_offset( + self.batch_size, self.seq_len, self.seq_lens_this_time + ) + self.token_num = self.padding_offset.shape[0] + + def test_all(self): + paddle.disable_static() + # encoder + query = np.random.random(self.shape) + q = paddle.to_tensor( + query, place=self.place, dtype=self.dtype, stop_gradient=False + ) + key = np.random.random(self.shape) + k = paddle.to_tensor( + key, place=self.place, dtype=self.dtype, stop_gradient=False + ) + value = np.random.random(self.shape) + v = paddle.to_tensor( + value, place=self.place, dtype=self.dtype, stop_gradient=False + ) + + qkv = paddle.stack( + [ + q.transpose([0, 2, 1, 3]).reshape( + [self.token_num, self.hid_dim] + ), + k.transpose([0, 2, 1, 3]).reshape( + [self.token_num, self.hid_dim] + ), + v.transpose([0, 2, 1, 3]).reshape( + [self.token_num, self.hid_dim] + ), + ], + axis=1, + ).reshape([self.token_num, -1]) + out_ = naive_attention_impl( + q, k, v, None, None, None, None, self.attention_mask, self.scale + ) + # quant + out_ *= 127.0 + + out_ = paddle.where(out_ <= -127, paddle.full_like(out_, -127), out_) + out_ = paddle.where(out_ >= 127, paddle.full_like(out_, 127), out_) + out_ = paddle.round(out_).astype('int8') + + out_ = remove_padding( + self.seq_lens_this_time, self.cu_seqlens_q, out_, self.token_num + ) + out = block_multihead_attention( + qkv, + self.cache_k, + self.cache_v, + self.seq_lens_encoder, + self.seq_lens_decoder, + self.seq_lens_this_time, + self.padding_offset, + self.cum_offset, + self.cu_seqlens_q, + self.cu_seqlens_k, + self.block_tables, + None, # pre_key_cache + None, # pre_value_cache + None, # cache_k_quant_scales + None, # cache_v_quant_scales + None, # cache_k_dequant_scales + None, # cache_v_dequant_scales + None, # qkv_out_scale + None, # qkv_bias + None, # out_shift + None, # out_smooth + None, # rotary_embs + None, # attn_mask + None, # tgt_mask + self.seq_len, + self.blocksize, + False, # use_neox_rotary_style, + out_scale=1.0, + )[0] + + np.testing.assert_allclose( + out.numpy(), + out_.numpy(), + rtol=0.1, + atol=1, + ) + + # decoder + naive_cache_k, naive_cache_v = block_cache_to_naive_cache( + self.cache_k, + self.cache_v, + self.batch_size, + self.block_tables, + self.seq_len, + ) + + self.seq_lens_decoder[:] = self.seq_lens_encoder + self.seq_lens_encoder[:] = 0 + self.seq_lens_this_time[:] = 1 + self.shape = ( + self.batch_size, + self.num_head, + 1, + self.dim_head, + ) + query = np.random.random(self.shape) + q = paddle.to_tensor( + query, place=self.place, dtype=self.dtype, stop_gradient=False + ) + key = np.random.random(self.shape) + k = paddle.to_tensor( + key, place=self.place, dtype=self.dtype, stop_gradient=False + ) + value = np.random.random(self.shape) + v = paddle.to_tensor( + value, place=self.place, dtype=self.dtype, stop_gradient=False + ) + + qkv = paddle.stack( + [ + q.transpose([0, 2, 1, 3]).reshape( + [self.batch_size, self.hid_dim] + ), + k.transpose([0, 2, 1, 3]).reshape( + [self.batch_size, self.hid_dim] + ), + v.transpose([0, 2, 1, 3]).reshape( + [self.batch_size, self.hid_dim] + ), + ], + axis=1, + ).reshape([self.batch_size, -1]) + ( + self.padding_offset, + self.cum_offset, + self.cu_seqlens_q, + self.cu_seqlens_k, + ) = get_padding_offset(self.batch_size, 1, self.seq_lens_this_time) + + out_ = ( + naive_attention_impl( + q, + k, + v, + naive_cache_k, + naive_cache_v, + None, + None, + None, + self.scale, + ) + .transpose([0, 2, 1, 3]) + .reshape([self.batch_size, -1]) + ) + # quant + out_ *= 127.0 + + out_ = paddle.where(out_ <= -127, paddle.full_like(out_, -127), out_) + out_ = paddle.where(out_ >= 127, paddle.full_like(out_, 127), out_) + out_ = paddle.round(out_).astype('int8') + + out = block_multihead_attention( + qkv, + self.cache_k, + self.cache_v, + self.seq_lens_encoder, + self.seq_lens_decoder, + self.seq_lens_this_time, + self.padding_offset, + self.cum_offset, + self.cu_seqlens_q, + self.cu_seqlens_k, + self.block_tables, + None, # pre_key_cache + None, # pre_value_cache + None, # cache_k_quant_scales + None, # cache_v_quant_scales + None, # cache_k_dequant_scales + None, # cache_v_dequant_scales + None, # qkv_out_scale + None, # qkv_bias + None, # out_shift + None, # out_smooth + None, # rotary_embs + None, # attn_mask + None, # tgt_mask + 1, # seq_len, + self.blocksize, + False, # use_neox_rotary_style + out_scale=1.0, + )[0] + # NOTE: The diff of decoder is a little big + np.testing.assert_allclose( + out.numpy(), + out_.numpy(), + rtol=0.1, + atol=1, + ) + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or get_cuda_version() < 11040 + or not is_sm_supported, + "core is not compiled with CUDA and cuda version need larger than or equal to 11.4" + "and device's compute capability must be 8.x or 90", +) +class TestBlockMultiHeadAttnEncDecCacheKVDynamicQuant(unittest.TestCase): + def setUp(self): + paddle.disable_static() + self.name = "TestBlockMultiHeadAttnEncDec" + self.place = paddle.CUDAPlace(0) + self.batch_size = 2 + self.num_head = 8 + self.seq_len = 64 + self.max_dec_len = 64 + self.dim_head = 64 + self.hid_dim = self.num_head * self.dim_head + self.blocksize = 64 + self.block_num_per_seq = ( + self.seq_len + self.max_dec_len + self.blocksize - 1 + ) // self.blocksize + self.max_block_num = self.block_num_per_seq * self.batch_size + self.free_list = list(range(self.max_block_num - 1, -1, -1)) + self.seq_lens_encoder = paddle.to_tensor( + [ + self.seq_len, + ] + * self.batch_size, + "int32", + ) + self.seq_lens_decoder = paddle.to_tensor( + [ + 0, + ] + * self.batch_size, + "int32", + ) + self.seq_lens_this_time = self.seq_lens_encoder + self.shape = ( + self.batch_size, + self.num_head, + self.seq_len, + self.dim_head, + ) + self.cache_shape = ( + self.max_block_num, + self.num_head, + self.blocksize, + self.dim_head, + ) + self.dtype = 'float16' + self.attention_mask = create_attn_mask( + self.dtype, + self.batch_size, + [ + self.seq_len, + ] + * self.batch_size, + ) + self.scale = 1.0 / np.sqrt(self.shape[-1]) + self.cache_k = paddle.zeros(shape=self.cache_shape, dtype='uint8') + self.cache_v = paddle.zeros(shape=self.cache_shape, dtype='uint8') + self.cache_k_quant_scales = paddle.zeros( + shape=[self.batch_size, self.num_head], dtype='float32' + ) + self.cache_v_quant_scales = paddle.zeros( + shape=[self.batch_size, self.num_head], dtype='float32' + ) + self.cache_k_dequant_scales = paddle.zeros( + shape=[self.batch_size, self.num_head], dtype='float32' + ) + self.cache_v_dequant_scales = paddle.zeros( + shape=[self.batch_size, self.num_head], dtype='float32' + ) + + self.block_tables = paddle.zeros( + shape=(self.batch_size, self.block_num_per_seq), dtype="int32" + ) + for i in range(self.batch_size): + need_block_num = ( + self.seq_len + self.max_dec_len + self.blocksize - 1 + ) // self.blocksize + for j in range(need_block_num): + self.block_tables[i, j] = self.free_list.pop() + ( + self.padding_offset, + self.cum_offset, + self.cu_seqlens_q, + self.cu_seqlens_k, + ) = get_padding_offset( + self.batch_size, self.seq_len, self.seq_lens_this_time + ) + self.token_num = self.padding_offset.shape[0] + + def test_all(self): + paddle.disable_static() + # encoder + query = np.random.random(self.shape) + q = paddle.to_tensor( + query, place=self.place, dtype=self.dtype, stop_gradient=False + ) + key = np.random.random(self.shape) + k = paddle.to_tensor( + key, place=self.place, dtype=self.dtype, stop_gradient=False + ) + value = np.random.random(self.shape) + v = paddle.to_tensor( + value, place=self.place, dtype=self.dtype, stop_gradient=False + ) + + qkv = paddle.stack( + [ + q.transpose([0, 2, 1, 3]).reshape( + [self.token_num, self.hid_dim] + ), + k.transpose([0, 2, 1, 3]).reshape( + [self.token_num, self.hid_dim] + ), + v.transpose([0, 2, 1, 3]).reshape( + [self.token_num, self.hid_dim] + ), + ], + axis=1, + ).reshape([self.token_num, -1]) + out_ = naive_attention_impl( + q, k, v, None, None, None, None, self.attention_mask, self.scale + ) + + out_ = remove_padding( + self.seq_lens_this_time, self.cu_seqlens_q, out_, self.token_num + ) + out = block_multihead_attention( + qkv, + self.cache_k, + self.cache_v, + self.seq_lens_encoder, + self.seq_lens_decoder, + self.seq_lens_this_time, + self.padding_offset, + self.cum_offset, + self.cu_seqlens_q, + self.cu_seqlens_k, + self.block_tables, + None, # pre_key_cache + None, # pre_value_cache + self.cache_k_quant_scales, # cache_k_quant_scales + self.cache_v_quant_scales, # cache_v_quant_scales + self.cache_k_dequant_scales, # cache_k_dequant_scales + self.cache_v_dequant_scales, # cache_v_dequant_scales + None, # qkv_out_scale + None, # qkv_bias + None, # out_shift + None, # out_smooth + None, # rotary_embs + None, # attn_mask + None, # tgt_mask + self.seq_len, + self.blocksize, + False, # use_neox_rotary_style, + use_dynamic_cachekv_quant=True, + )[0] + + np.testing.assert_allclose( + out.numpy(), + out_.numpy(), + rtol=0.1, + atol=1, + ) + + # decoder + naive_cache_k, naive_cache_v = block_cache_to_naive_cache( + self.cache_k, + self.cache_v, + self.batch_size, + self.block_tables, + self.seq_len, + ) + + self.seq_lens_decoder[:] = self.seq_lens_encoder + self.seq_lens_encoder[:] = 0 + self.seq_lens_this_time[:] = 1 + self.shape = ( + self.batch_size, + self.num_head, + 1, + self.dim_head, + ) + query = np.random.random(self.shape) + q = paddle.to_tensor( + query, place=self.place, dtype=self.dtype, stop_gradient=False + ) + key = np.random.random(self.shape) + k = paddle.to_tensor( + key, place=self.place, dtype=self.dtype, stop_gradient=False + ) + value = np.random.random(self.shape) + v = paddle.to_tensor( + value, place=self.place, dtype=self.dtype, stop_gradient=False + ) + + qkv = paddle.stack( + [ + q.transpose([0, 2, 1, 3]).reshape( + [self.batch_size, self.hid_dim] + ), + k.transpose([0, 2, 1, 3]).reshape( + [self.batch_size, self.hid_dim] + ), + v.transpose([0, 2, 1, 3]).reshape( + [self.batch_size, self.hid_dim] + ), + ], + axis=1, + ).reshape([self.batch_size, -1]) + ( + self.padding_offset, + self.cum_offset, + self.cu_seqlens_q, + self.cu_seqlens_k, + ) = get_padding_offset(self.batch_size, 1, self.seq_lens_this_time) + + out_ = ( + naive_attention_impl( + q, + k, + v, + naive_cache_k, + naive_cache_v, + None, + None, + None, + self.scale, + cache_k_dequant_scales=self.cache_k_dequant_scales, + cache_v_dequant_scales=self.cache_v_dequant_scales, + use_cachekv_int8="dynamic", + ) + .transpose([0, 2, 1, 3]) + .reshape([self.batch_size, -1]) + ) + # quant + + out = block_multihead_attention( + qkv, + self.cache_k, + self.cache_v, + self.seq_lens_encoder, + self.seq_lens_decoder, + self.seq_lens_this_time, + self.padding_offset, + self.cum_offset, + self.cu_seqlens_q, + self.cu_seqlens_k, + self.block_tables, + None, # pre_key_cache + None, # pre_value_cache + self.cache_k_quant_scales, # cache_k_quant_scales + self.cache_v_quant_scales, # cache_v_quant_scales + self.cache_k_dequant_scales, # cache_k_dequant_scales + self.cache_v_dequant_scales, # cache_v_dequant_scales + None, # qkv_out_scale + None, # qkv_bias + None, # out_shift + None, # out_smooth + None, # rotary_embs + None, # attn_mask + None, # tgt_mask + 1, # seq_len, + self.blocksize, + False, # use_neox_rotary_style + use_dynamic_cachekv_quant=True, + )[0] + # NOTE: The diff of decoder is a little big + np.testing.assert_allclose( + out.numpy(), + out_.numpy(), + rtol=0.1, + atol=1, + ) + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or get_cuda_version() < 11040 + or not is_sm_supported, + "core is not compiled with CUDA and cuda version need larger than or equal to 11.4" + "and device's compute capability must be 8.x or 90", +) +class TestBlockMultiHeadAttnEncDecCacheKVStaticQuant(unittest.TestCase): + def setUp(self): + paddle.disable_static() + self.name = "TestBlockMultiHeadAttnEncDec" + self.place = paddle.CUDAPlace(0) + self.batch_size = 2 + self.num_head = 8 + self.seq_len = 64 + self.max_dec_len = 64 + self.dim_head = 64 + self.hid_dim = self.num_head * self.dim_head + self.blocksize = 64 + self.block_num_per_seq = ( + self.seq_len + self.max_dec_len + self.blocksize - 1 + ) // self.blocksize + self.max_block_num = self.block_num_per_seq * self.batch_size + self.free_list = list(range(self.max_block_num - 1, -1, -1)) + self.seq_lens_encoder = paddle.to_tensor( + [ + self.seq_len, + ] + * self.batch_size, + "int32", + ) + self.seq_lens_decoder = paddle.to_tensor( + [ + 0, + ] + * self.batch_size, + "int32", + ) + self.seq_lens_this_time = self.seq_lens_encoder + self.shape = ( + self.batch_size, + self.num_head, + self.seq_len, + self.dim_head, + ) + self.cache_shape = ( + self.max_block_num, + self.num_head, + self.blocksize, + self.dim_head, + ) + self.dtype = 'float16' + self.attention_mask = create_attn_mask( + self.dtype, + self.batch_size, + [ + self.seq_len, + ] + * self.batch_size, + ) + self.scale = 1.0 / np.sqrt(self.shape[-1]) + self.cache_k = paddle.zeros(shape=self.cache_shape, dtype='uint8') + self.cache_v = paddle.zeros(shape=self.cache_shape, dtype='uint8') + self.cache_k_quant_scales = paddle.zeros( + shape=[self.num_head], dtype='float32' + ) + self.cache_v_quant_scales = paddle.zeros( + shape=[self.num_head], dtype='float32' + ) + self.cache_k_dequant_scales = paddle.zeros( + shape=[self.num_head], dtype='float32' + ) + self.cache_v_dequant_scales = paddle.zeros( + shape=[self.num_head], dtype='float32' + ) + + self.block_tables = paddle.zeros( + shape=(self.batch_size, self.block_num_per_seq), dtype="int32" + ) + for i in range(self.batch_size): + need_block_num = ( + self.seq_len + self.max_dec_len + self.blocksize - 1 + ) // self.blocksize + for j in range(need_block_num): + self.block_tables[i, j] = self.free_list.pop() + ( + self.padding_offset, + self.cum_offset, + self.cu_seqlens_q, + self.cu_seqlens_k, + ) = get_padding_offset( + self.batch_size, self.seq_len, self.seq_lens_this_time + ) + self.token_num = self.padding_offset.shape[0] + + def test_all(self): + paddle.disable_static() + # encoder + query = np.random.random(self.shape) + q = paddle.to_tensor( + query, place=self.place, dtype=self.dtype, stop_gradient=False + ) + key = np.random.random(self.shape) + k = paddle.to_tensor( + key, place=self.place, dtype=self.dtype, stop_gradient=False + ) + value = np.random.random(self.shape) + v = paddle.to_tensor( + value, place=self.place, dtype=self.dtype, stop_gradient=False + ) + + qkv = paddle.stack( + [ + q.transpose([0, 2, 1, 3]).reshape( + [self.token_num, self.hid_dim] + ), + k.transpose([0, 2, 1, 3]).reshape( + [self.token_num, self.hid_dim] + ), + v.transpose([0, 2, 1, 3]).reshape( + [self.token_num, self.hid_dim] + ), + ], + axis=1, + ).reshape([self.token_num, -1]) + + out_ = naive_attention_impl( + q, k, v, None, None, None, None, self.attention_mask, self.scale + ) + + out_ = remove_padding( + self.seq_lens_this_time, self.cu_seqlens_q, out_, self.token_num + ) + + self.cache_k_quant_scales = ( + 127.0 / paddle.max(k, axis=[0, 2, 3]) + ).astype("float32") + self.cache_v_quant_scales = ( + 127.0 / paddle.max(k, axis=[0, 2, 3]) + ).astype("float32") + + self.cache_k_dequant_scales = 1.0 / self.cache_k_quant_scales + self.cache_v_dequant_scales = 1.0 / self.cache_v_quant_scales + + out = block_multihead_attention( + qkv, + self.cache_k, + self.cache_v, + self.seq_lens_encoder, + self.seq_lens_decoder, + self.seq_lens_this_time, + self.padding_offset, + self.cum_offset, + self.cu_seqlens_q, + self.cu_seqlens_k, + self.block_tables, + None, # pre_key_cache + None, # pre_value_cache + self.cache_k_quant_scales, # cache_k_quant_scales + self.cache_v_quant_scales, # cache_v_quant_scales + self.cache_k_dequant_scales, # cache_k_dequant_scales + self.cache_v_dequant_scales, # cache_v_dequant_scales + None, # qkv_out_scale + None, # qkv_bias + None, # out_shift + None, # out_smooth + None, # rotary_embs + None, # attn_mask + None, # tgt_mask + self.seq_len, + self.blocksize, + False, # use_neox_rotary_style, + use_dynamic_cachekv_quant=False, + )[0] + + np.testing.assert_allclose( + out.numpy(), + out_.numpy(), + rtol=0.1, + atol=1, + ) + + # decoder + naive_cache_k, naive_cache_v = block_cache_to_naive_cache( + self.cache_k, + self.cache_v, + self.batch_size, + self.block_tables, + self.seq_len, + ) + + self.seq_lens_decoder[:] = self.seq_lens_encoder + self.seq_lens_encoder[:] = 0 + self.seq_lens_this_time[:] = 1 + self.shape = ( + self.batch_size, + self.num_head, + 1, + self.dim_head, + ) + query = np.random.random(self.shape) + q = paddle.to_tensor( + query, place=self.place, dtype=self.dtype, stop_gradient=False + ) + key = np.random.random(self.shape) + k = paddle.to_tensor( + key, place=self.place, dtype=self.dtype, stop_gradient=False + ) + value = np.random.random(self.shape) + v = paddle.to_tensor( + value, place=self.place, dtype=self.dtype, stop_gradient=False + ) + + qkv = paddle.stack( + [ + q.transpose([0, 2, 1, 3]).reshape( + [self.batch_size, self.hid_dim] + ), + k.transpose([0, 2, 1, 3]).reshape( + [self.batch_size, self.hid_dim] + ), + v.transpose([0, 2, 1, 3]).reshape( + [self.batch_size, self.hid_dim] + ), + ], + axis=1, + ).reshape([self.batch_size, -1]) + ( + self.padding_offset, + self.cum_offset, + self.cu_seqlens_q, + self.cu_seqlens_k, + ) = get_padding_offset(self.batch_size, 1, self.seq_lens_this_time) + + out_ = ( + naive_attention_impl( + q, + k, + v, + naive_cache_k, + naive_cache_v, + None, + None, + None, + self.scale, + cache_k_dequant_scales=self.cache_k_dequant_scales, + cache_v_dequant_scales=self.cache_v_dequant_scales, + use_cachekv_int8="static", + ) + .transpose([0, 2, 1, 3]) + .reshape([self.batch_size, -1]) + ) + + out = block_multihead_attention( + qkv, + self.cache_k, + self.cache_v, + self.seq_lens_encoder, + self.seq_lens_decoder, + self.seq_lens_this_time, + self.padding_offset, + self.cum_offset, + self.cu_seqlens_q, + self.cu_seqlens_k, + self.block_tables, + None, # pre_key_cache + None, # pre_value_cache + self.cache_k_quant_scales, # cache_k_quant_scales + self.cache_v_quant_scales, # cache_v_quant_scales + self.cache_k_dequant_scales, # cache_k_dequant_scales + self.cache_v_dequant_scales, # cache_v_dequant_scales + None, # qkv_out_scale + None, # qkv_bias + None, # out_shift + None, # out_smooth + None, # rotary_embs + None, # attn_mask + None, # tgt_mask + 1, # seq_len, + self.blocksize, + False, # use_neox_rotary_style + use_dynamic_cachekv_quant=False, + )[0] + # NOTE: The diff of decoder is a little big + np.testing.assert_allclose( + out.numpy(), + out_.numpy(), + rtol=0.1, + atol=1, + ) + + if __name__ == '__main__': unittest.main()