diff --git a/python/paddle/incubate/nn/functional/block_multihead_attention.py b/python/paddle/incubate/nn/functional/block_multihead_attention.py index 4bffb8f2e94b9f..9ee18feaad3c83 100644 --- a/python/paddle/incubate/nn/functional/block_multihead_attention.py +++ b/python/paddle/incubate/nn/functional/block_multihead_attention.py @@ -67,13 +67,27 @@ def block_multihead_attention( cu_seqlens_k (Tensor): The cum sequence lengths of key. Its shape is [batchsize + 1, 1]. block_tables (Tensor): The block tables, used to index the cache. Its shape is [batchsize, block_num_per_seq]. pre_key_cache (Tensor): The pre caches of key. Its shape is [batchsize, num_head, pre_cache_length, head_size]. - pre_key_value (Tensor): The pre caches of value. Its shape is [batchsize, num_head, pre_cache_length, head_size]. + pre_value_cache (Tensor): The pre caches of value. Its shape is [batchsize, num_head, pre_cache_length, head_size]. + cache_k_quant_scales (Tensor): The quant scales of cache key. Its shape depends on quant mode (dynamic or static). If dynamic quantization is enabled, its shape is [batchsize, num_head], otherwise its shape is [num_head]. + cache_v_quant_scales (Tensor): The quant scales of cache value. Its shape depends on quant mode (dynamic or static). If dynamic quantization is enabled, its shape is [batchsize, num_head], otherwise its shape is [num_head]. + cache_k_dequant_scales (Tensor): The dequant scales of cache key. Its shape depends on quant mode (dynamic or static). If dynamic quantization is enabled, its shape is [batchsize, num_head], otherwise its shape is [num_head]. + cache_v_dequant_scales (Tensor): The dequant scales of cache value. Its shape depends on quant mode (dynamic or static). If dynamic quantization is enabled, its shape is [batchsize, num_head], otherwise its shape is [num_head]. + qkv_out_scale (Tensor): The dequant scale of qkv, which is the input of BLHA. If the dtype of qkv is `int32`, this input will be applied. Its shape is [3 * num_head * head_size], and its dtype should be `float32`. + qkv_bias (Tensor): The bias of qkv. Its shape is [3 * num_head * head_size]. + out_shift (Tensor): Shift bias of fmha_out, which is the 1st return value. Its shape is [num_head * head_size]. + out_smooth (Tensor): Smooth weight of fmha_out. Its shape is [num_head * head_size]. rope_emb (Tensor): The RoPE embedding. Its shape is [2, batchsize, max_seq_len, 1, head_size // 2]. mask (Tensor): The mask of qk_matmul in encoder. Its shape is [batchsize, 1, max_seq_len, max_seq_len]. tgt_mask (Tensor): The mask of qk_matmul in decoder. Its shape is [batchsize, 1, 1, max_seq_len]. max_seq_len (Int): The max length of the input. Default is -1. block_size (Int): The block_size of cache. Default is 64. use_neox_style (Bool): Whether neox_style RoPE is used or not. Default is False. + use_dynamic_cachekv_quant (Bool): Whether dynamic cache kv quantization is applied or not. Default is False. + quant_round_type (Int): The quant rount type in cache kv quantization and fmha_out quantization. If 0 is set, value will be rounding to nearest ties to even. If 1 is set, value will be rounding to nearest ties away from zero. + quant_max_bound (Float32): The max bound of float type to int type. + quant_min_bound (Float32): The min bound of float type to int type. + out_scale (Float32): The quant scale of fmha_out. Default is -1, which means do not apply quantization for fmha_out. + compute_dtype (Str): A compute dtype, is used to represent the input data type. Default is "default", which means compute dtype is determined by input dtype. However, if the dtype of input is Int32, this value should be set to actual dtype of the model. Returns: Tensor|(output, qkv_out, cache_k_out, cache_v_out), which output is the output of block_multihead_attention layers, qkv_out is inplace with input `qkv`, cache_k_out and cache_v_out are inplace with input `cache_k` and `cache_v`. @@ -229,6 +243,14 @@ def block_multihead_attention( ... block_tables, ... None, # pre_key_cache ... None, # pre_value_cache + ... None, # cache_k_quant_scales + ... None, # cache_v_quant_scales + ... None, # cache_k_dequant_scales + ... None, # cache_v_dequant_scales + ... None, # qkv_out_scale + ... None, # qkv_bias + ... None, # out_shift + ... None, # out_smooth ... None, # rotary_embs ... None, # attn_mask ... None, # tgt_mask