diff --git a/paddle/fluid/pir/dialect/operator/ir/ops.yaml b/paddle/fluid/pir/dialect/operator/ir/ops.yaml index b55f9f0f619987..161594179566fa 100644 --- a/paddle/fluid/pir/dialect/operator/ir/ops.yaml +++ b/paddle/fluid/pir/dialect/operator/ir/ops.yaml @@ -676,6 +676,17 @@ view : (mean -> mean_out), (variance -> variance_out) backward : fused_bn_add_activation_grad +- op : fused_multi_transformer + args : (Tensor x, Tensor[] ln_scales, Tensor[] ln_biases, Tensor[] qkv_weights, Tensor[] qkv_biases, Tensor[] cache_kvs, Tensor[] pre_caches, Tensor rotary_tensor, Tensor time_step, Tensor seq_lengths, Tensor src_mask, Tensor[] out_linear_weights, Tensor[] out_linear_biases, Tensor[] ffn_ln_scales, Tensor[] ffn_ln_biases, Tensor[] ffn1_weights, Tensor[] ffn1_biases, Tensor[] ffn2_weights, Tensor[] ffn2_biases, bool pre_layer_norm = true, float epsilon = 1e-5, float dropout_rate = .5f, int rotary_emb_dims = 0, bool is_test = false, str dropout_implementation = "downgrade_in_infer", str act_method = "gelu", bool trans_qkvw =true, int ring_id = -1) + optional : qkv_biases, cache_kvs, pre_caches, rotary_tensor, time_step, seq_lengths, src_mask, out_linear_biases, ffn1_biases, ffn2_biases, cache_kv_outs + output : Tensor[](cache_kv_outs){out_linear_weights.size()}, Tensor(out) + infer_meta : + func : FusedMultiTransformerInferMeta + kernel : + func : fused_multi_transformer + data_type : x + support_dygraph_mode : true + - op : fused_softmax_mask args : (Tensor x, Tensor mask) output : Tensor(out) diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index 7975a627b34303..24b7396f5502d5 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -592,17 +592,6 @@ backward: fused_gemm_epilogue_grad optional: reserve_space -- op : fused_multi_transformer - args : (Tensor x, Tensor[] ln_scales, Tensor[] ln_biases, Tensor[] qkv_weights, Tensor[] qkv_biases, Tensor[] cache_kvs, Tensor[] pre_caches, Tensor rotary_tensor, Tensor time_step, Tensor seq_lengths, Tensor src_mask, Tensor[] out_linear_weights, Tensor[] out_linear_biases, Tensor[] ffn_ln_scales, Tensor[] ffn_ln_biases, Tensor[] ffn1_weights, Tensor[] ffn1_biases, Tensor[] ffn2_weights, Tensor[] ffn2_biases, bool pre_layer_norm = true, float epsilon = 1e-5, float dropout_rate = .5f, int rotary_emb_dims = 0, bool is_test = false, str dropout_implementation = "downgrade_in_infer", str act_method = "gelu", bool trans_qkvw =true, int ring_id = -1) - optional : qkv_biases, cache_kvs, pre_caches, rotary_tensor, time_step, seq_lengths, src_mask, out_linear_biases, ffn1_biases, ffn2_biases, cache_kv_outs - output : Tensor[](cache_kv_outs){out_linear_weights.size()}, Tensor(out) - infer_meta : - func : FusedMultiTransformerInferMeta - kernel : - func : fused_multi_transformer - data_type : x - support_dygraph_mode : true - - op : fused_multi_transformer_int8 args : (Tensor x, Tensor[] ln_scales, Tensor[] ln_biases, Tensor[] qkv_weights, Tensor[] qkv_biases, Tensor[] cache_kvs, Tensor time_step, Tensor src_mask, Tensor[] out_linear_weights, Tensor[] out_linear_biases, Tensor[] ffn_ln_scales, Tensor[] ffn_ln_biases, Tensor[] ffn1_weights, Tensor[] ffn1_biases, Tensor[] ffn2_weights, Tensor[] ffn2_biases, Tensor[] qkv_out_scales, Tensor[] out_linear_out_scales, Tensor[] ffn1_out_scales, Tensor[] ffn2_out_scales, bool pre_layer_norm, float epsilon, float dropout_rate, bool is_test, str dropout_implementation, str act_method, bool trans_qkvw, int ring_id, int num_head, int dim_head, int dim_ffn, float[] qkv_in_scale, float[] out_linear_in_scale, float[] ffn1_in_scale, float[] ffn2_in_scale, int quant_round_type, float quant_max_bound, float quant_min_bound) optional : time_step, src_mask diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 53e0cea953b879..99129fa6530c60 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -3572,6 +3572,31 @@ outputs : {out : Out, intermediate_out : IntermediateOut} +- op: fused_multi_transformer + inputs: + x: X + ln_scales: LnScale + ln_biases: LnBias + qkv_weights: QKVW + qkv_biases: QKVBias + cache_kvs: CacheKV + pre_caches: PreCaches + rotary_tensor: RotaryPosEmb + time_step: TimeStep + seq_lengths: SeqLengths + src_mask: SrcMask + out_linear_weights: OutLinearW + out_linear_biases: OutLinearBias + ffn_ln_scales: FFNLnScale + ffn_ln_biases: FFNLnBias + ffn1_weights: FFN1Weight + ffn1_biases: FFN1Bias + ffn2_weights: FFN2Weight + ffn2_biases: FFN2Bias + outputs: + out: Out + cache_kv_outs: CacheKVOut + - op: fusion_squared_mat_sub inputs : x : X diff --git a/python/paddle/incubate/nn/functional/fused_transformer.py b/python/paddle/incubate/nn/functional/fused_transformer.py index 48416ff899769f..9e40a6aa2afdbc 100644 --- a/python/paddle/incubate/nn/functional/fused_transformer.py +++ b/python/paddle/incubate/nn/functional/fused_transformer.py @@ -1167,36 +1167,78 @@ def fused_multi_transformer( ) # semantic transfer if in_dynamic_or_pir_mode(): - cache_kv_out, final_out = _C_ops.fused_multi_transformer( - x, - ln_scales, - ln_biases, - qkv_weights, - qkv_biases, - cache_kvs, - pre_caches, - rotary_embs, - time_step, - seq_lens, - attn_mask, - linear_weights, - linear_biases, - ffn_ln_scales, - ffn_ln_biases, - ffn1_weights, - ffn1_biases, - ffn2_weights, - ffn2_biases, - pre_layer_norm, - epsilon, - dropout_rate, - rotary_emb_dims, - not training, - mode, - activation, - trans_qkvw, - ring_id, - ) + if in_dynamic_mode(): + cache_kv_out, final_out = _legacy_C_ops.fused_multi_transformer( + x, + ln_scales, + ln_biases, + qkv_weights, + qkv_biases, + cache_kvs, + pre_caches, + rotary_embs, + time_step, + seq_lens, + attn_mask, + linear_weights, + linear_biases, + ffn_ln_scales, + ffn_ln_biases, + ffn1_weights, + ffn1_biases, + ffn2_weights, + ffn2_biases, + cache_kvs, + 'pre_layer_norm', + pre_layer_norm, + 'epsilon', + epsilon, + 'dropout_rate', + dropout_rate, + 'rotary_emb_dims', + rotary_emb_dims, + 'is_test', + not training, + 'dropout_implementation', + mode, + 'act_method', + activation, + 'trans_qkvw', + trans_qkvw, + 'ring_id', + ring_id, + ) + else: + cache_kv_out, final_out = _C_ops.fused_multi_transformer( + x, + ln_scales, + ln_biases, + qkv_weights, + qkv_biases, + cache_kvs, + pre_caches, + rotary_embs, + time_step, + seq_lens, + attn_mask, + linear_weights, + linear_biases, + ffn_ln_scales, + ffn_ln_biases, + ffn1_weights, + ffn1_biases, + ffn2_weights, + ffn2_biases, + pre_layer_norm, + epsilon, + dropout_rate, + rotary_emb_dims, + not training, + mode, + activation, + trans_qkvw, + ring_id, + ) if cache_kvs is not None: return final_out, cache_kv_out return final_out diff --git a/test/legacy_test/test_fused_multi_transformer_op.py b/test/legacy_test/test_fused_multi_transformer_op.py index 63921b64e93f76..b7fec52341be61 100644 --- a/test/legacy_test/test_fused_multi_transformer_op.py +++ b/test/legacy_test/test_fused_multi_transformer_op.py @@ -27,6 +27,7 @@ from paddle.nn.layer.common import Dropout, Linear from paddle.nn.layer.norm import LayerNorm from paddle.nn.layer.transformer import _convert_attention_mask +from paddle.pir_utils import test_with_pir_api seed = 42 @@ -999,19 +1000,20 @@ def GetFusedMultiTransformerOutStatic(self): } if self.has_pre_cache: out = exe.run( - paddle.base.default_main_program(), + paddle.static.default_main_program(), feed=feed_data, - fetch_list=[final_out[0].name], + fetch_list=[final_out[0]], ) else: out = exe.run( - paddle.base.default_main_program(), + paddle.static.default_main_program(), feed=feed_data, - fetch_list=[final_out.name], + fetch_list=[final_out], ) paddle.disable_static() return out + @test_with_pir_api def test_fused_multi_transformer_op(self): if self.has_cache_kv and not self.gen_cache_kv and self.remove_padding: final_out_ref = self.GetVariableDecoderBaselineOut() @@ -1393,6 +1395,7 @@ def config(self): initializer=paddle.nn.initializer.Constant(0.0) ) + @test_with_pir_api def test_fused_multi_transformer_op(self): self.has_pre_cache = True self.remove_padding = False