Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
zrr1999 committed Feb 28, 2024
1 parent 05cbbbe commit 98bb536
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 45 deletions.
11 changes: 11 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,17 @@
view : (mean -> mean_out), (variance -> variance_out)
backward : fused_bn_add_activation_grad

- op : fused_multi_transformer
args : (Tensor x, Tensor[] ln_scales, Tensor[] ln_biases, Tensor[] qkv_weights, Tensor[] qkv_biases, Tensor[] cache_kvs, Tensor[] pre_caches, Tensor rotary_tensor, Tensor time_step, Tensor seq_lengths, Tensor src_mask, Tensor[] out_linear_weights, Tensor[] out_linear_biases, Tensor[] ffn_ln_scales, Tensor[] ffn_ln_biases, Tensor[] ffn1_weights, Tensor[] ffn1_biases, Tensor[] ffn2_weights, Tensor[] ffn2_biases, bool pre_layer_norm = true, float epsilon = 1e-5, float dropout_rate = .5f, int rotary_emb_dims = 0, bool is_test = false, str dropout_implementation = "downgrade_in_infer", str act_method = "gelu", bool trans_qkvw =true, int ring_id = -1)
optional : qkv_biases, cache_kvs, pre_caches, rotary_tensor, time_step, seq_lengths, src_mask, out_linear_biases, ffn1_biases, ffn2_biases, cache_kv_outs
output : Tensor[](cache_kv_outs){out_linear_weights.size()}, Tensor(out)
infer_meta :
func : FusedMultiTransformerInferMeta
kernel :
func : fused_multi_transformer
data_type : x
support_dygraph_mode : true

- op : fused_softmax_mask
args : (Tensor x, Tensor mask)
output : Tensor(out)
Expand Down
11 changes: 0 additions & 11 deletions paddle/phi/api/yaml/legacy_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -592,17 +592,6 @@
backward: fused_gemm_epilogue_grad
optional: reserve_space

- op : fused_multi_transformer
args : (Tensor x, Tensor[] ln_scales, Tensor[] ln_biases, Tensor[] qkv_weights, Tensor[] qkv_biases, Tensor[] cache_kvs, Tensor[] pre_caches, Tensor rotary_tensor, Tensor time_step, Tensor seq_lengths, Tensor src_mask, Tensor[] out_linear_weights, Tensor[] out_linear_biases, Tensor[] ffn_ln_scales, Tensor[] ffn_ln_biases, Tensor[] ffn1_weights, Tensor[] ffn1_biases, Tensor[] ffn2_weights, Tensor[] ffn2_biases, bool pre_layer_norm = true, float epsilon = 1e-5, float dropout_rate = .5f, int rotary_emb_dims = 0, bool is_test = false, str dropout_implementation = "downgrade_in_infer", str act_method = "gelu", bool trans_qkvw =true, int ring_id = -1)
optional : qkv_biases, cache_kvs, pre_caches, rotary_tensor, time_step, seq_lengths, src_mask, out_linear_biases, ffn1_biases, ffn2_biases, cache_kv_outs
output : Tensor[](cache_kv_outs){out_linear_weights.size()}, Tensor(out)
infer_meta :
func : FusedMultiTransformerInferMeta
kernel :
func : fused_multi_transformer
data_type : x
support_dygraph_mode : true

- op : fused_multi_transformer_int8
args : (Tensor x, Tensor[] ln_scales, Tensor[] ln_biases, Tensor[] qkv_weights, Tensor[] qkv_biases, Tensor[] cache_kvs, Tensor time_step, Tensor src_mask, Tensor[] out_linear_weights, Tensor[] out_linear_biases, Tensor[] ffn_ln_scales, Tensor[] ffn_ln_biases, Tensor[] ffn1_weights, Tensor[] ffn1_biases, Tensor[] ffn2_weights, Tensor[] ffn2_biases, Tensor[] qkv_out_scales, Tensor[] out_linear_out_scales, Tensor[] ffn1_out_scales, Tensor[] ffn2_out_scales, bool pre_layer_norm, float epsilon, float dropout_rate, bool is_test, str dropout_implementation, str act_method, bool trans_qkvw, int ring_id, int num_head, int dim_head, int dim_ffn, float[] qkv_in_scale, float[] out_linear_in_scale, float[] ffn1_in_scale, float[] ffn2_in_scale, int quant_round_type, float quant_max_bound, float quant_min_bound)
optional : time_step, src_mask
Expand Down
25 changes: 25 additions & 0 deletions paddle/phi/api/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3572,6 +3572,31 @@
outputs :
{out : Out, intermediate_out : IntermediateOut}

- op: fused_multi_transformer
inputs:
x: X
ln_scales: LnScale
ln_biases: LnBias
qkv_weights: QKVW
qkv_biases: QKVBias
cache_kvs: CacheKV
pre_caches: PreCaches
rotary_tensor: RotaryPosEmb
time_step: TimeStep
seq_lengths: SeqLengths
src_mask: SrcMask
out_linear_weights: OutLinearW
out_linear_biases: OutLinearBias
ffn_ln_scales: FFNLnScale
ffn_ln_biases: FFNLnBias
ffn1_weights: FFN1Weight
ffn1_biases: FFN1Bias
ffn2_weights: FFN2Weight
ffn2_biases: FFN2Bias
outputs:
out: Out
cache_kv_outs: CacheKVOut

- op: fusion_squared_mat_sub
inputs :
x : X
Expand Down
102 changes: 72 additions & 30 deletions python/paddle/incubate/nn/functional/fused_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1167,36 +1167,78 @@ def fused_multi_transformer(
) # semantic transfer

if in_dynamic_or_pir_mode():
cache_kv_out, final_out = _C_ops.fused_multi_transformer(
x,
ln_scales,
ln_biases,
qkv_weights,
qkv_biases,
cache_kvs,
pre_caches,
rotary_embs,
time_step,
seq_lens,
attn_mask,
linear_weights,
linear_biases,
ffn_ln_scales,
ffn_ln_biases,
ffn1_weights,
ffn1_biases,
ffn2_weights,
ffn2_biases,
pre_layer_norm,
epsilon,
dropout_rate,
rotary_emb_dims,
not training,
mode,
activation,
trans_qkvw,
ring_id,
)
if in_dynamic_mode():
cache_kv_out, final_out = _legacy_C_ops.fused_multi_transformer(
x,
ln_scales,
ln_biases,
qkv_weights,
qkv_biases,
cache_kvs,
pre_caches,
rotary_embs,
time_step,
seq_lens,
attn_mask,
linear_weights,
linear_biases,
ffn_ln_scales,
ffn_ln_biases,
ffn1_weights,
ffn1_biases,
ffn2_weights,
ffn2_biases,
cache_kvs,
'pre_layer_norm',
pre_layer_norm,
'epsilon',
epsilon,
'dropout_rate',
dropout_rate,
'rotary_emb_dims',
rotary_emb_dims,
'is_test',
not training,
'dropout_implementation',
mode,
'act_method',
activation,
'trans_qkvw',
trans_qkvw,
'ring_id',
ring_id,
)
else:
cache_kv_out, final_out = _C_ops.fused_multi_transformer(
x,
ln_scales,
ln_biases,
qkv_weights,
qkv_biases,
cache_kvs,
pre_caches,
rotary_embs,
time_step,
seq_lens,
attn_mask,
linear_weights,
linear_biases,
ffn_ln_scales,
ffn_ln_biases,
ffn1_weights,
ffn1_biases,
ffn2_weights,
ffn2_biases,
pre_layer_norm,
epsilon,
dropout_rate,
rotary_emb_dims,
not training,
mode,
activation,
trans_qkvw,
ring_id,
)
if cache_kvs is not None:
return final_out, cache_kv_out
return final_out
Expand Down
11 changes: 7 additions & 4 deletions test/legacy_test/test_fused_multi_transformer_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from paddle.nn.layer.common import Dropout, Linear
from paddle.nn.layer.norm import LayerNorm
from paddle.nn.layer.transformer import _convert_attention_mask
from paddle.pir_utils import test_with_pir_api

seed = 42

Expand Down Expand Up @@ -999,19 +1000,20 @@ def GetFusedMultiTransformerOutStatic(self):
}
if self.has_pre_cache:
out = exe.run(
paddle.base.default_main_program(),
paddle.static.default_main_program(),
feed=feed_data,
fetch_list=[final_out[0].name],
fetch_list=[final_out[0]],
)
else:
out = exe.run(
paddle.base.default_main_program(),
paddle.static.default_main_program(),
feed=feed_data,
fetch_list=[final_out.name],
fetch_list=[final_out],
)
paddle.disable_static()
return out

@test_with_pir_api
def test_fused_multi_transformer_op(self):
if self.has_cache_kv and not self.gen_cache_kv and self.remove_padding:
final_out_ref = self.GetVariableDecoderBaselineOut()
Expand Down Expand Up @@ -1393,6 +1395,7 @@ def config(self):
initializer=paddle.nn.initializer.Constant(0.0)
)

@test_with_pir_api
def test_fused_multi_transformer_op(self):
self.has_pre_cache = True
self.remove_padding = False
Expand Down

0 comments on commit 98bb536

Please sign in to comment.