diff --git a/test/cpp/ir/pattern_rewrite/drr_attention_fuse_test.cc b/test/cpp/ir/pattern_rewrite/drr_attention_fuse_test.cc index e4b04ed0c6aaf..cee0f1c3963df 100644 --- a/test/cpp/ir/pattern_rewrite/drr_attention_fuse_test.cc +++ b/test/cpp/ir/pattern_rewrite/drr_attention_fuse_test.cc @@ -214,6 +214,147 @@ class AttentionFusePass : public ir::Pass { ir::FrozenRewritePatternSet patterns_; }; +void BuildProgram(ir::Builder &builder) { + paddle::dialect::FullOp full_input_op = + builder.Build(std::vector{1, 300, 256}, + 0.9, + phi::DataType::FLOAT32, + phi::CPUPlace()); + + // left + paddle::dialect::FullOp full_mat1_y_op = + builder.Build(std::vector{256, 256}, + 1.1, + phi::DataType::FLOAT32, + phi::CPUPlace()); + paddle::dialect::MatmulOp matmul_op1 = + builder.Build( + full_input_op.out(), full_mat1_y_op.out(), false, false); + + paddle::dialect::FullOp full_eleadd1_y_op = + builder.Build(std::vector{256}, + 1.5, + phi::DataType::FLOAT32, + phi::CPUPlace()); + + paddle::dialect::AddOp add_op1 = builder.Build( + matmul_op1.out(), full_eleadd1_y_op.out()); + + paddle::dialect::ReshapeOp reshape_op1 = + builder.Build( + add_op1.out(), std::vector{1, 300, 8, 32}); + + paddle::dialect::TransposeOp transpose_op1 = + builder.Build(reshape_op1.out(), + std::vector{0, 2, 1, 3}); + + // middle + paddle::dialect::FullOp full_mat2_y_op = + builder.Build(std::vector{256, 256}, + 1.1, + phi::DataType::FLOAT32, + phi::CPUPlace()); + + paddle::dialect::MatmulOp matmul_op2 = + builder.Build( + full_input_op.out(), full_mat2_y_op.out(), false, false); + + paddle::dialect::FullOp full_eleadd2_y_op = + builder.Build(std::vector{256}, + 1.5, + phi::DataType::FLOAT32, + phi::CPUPlace()); + paddle::dialect::AddOp add_op2 = builder.Build( + matmul_op2.out(), full_eleadd2_y_op.out()); + + paddle::dialect::ReshapeOp reshape_op2 = + builder.Build( + add_op2.out(), std::vector{1, 300, 8, 32}); + + paddle::dialect::TransposeOp transpose_op2 = + builder.Build(reshape_op2.out(), + std::vector{0, 2, 1, 3}); + + // right + paddle::dialect::FullOp full_mat3_y_op = + builder.Build(std::vector{256, 256}, + 1.1, + phi::DataType::FLOAT32, + phi::CPUPlace()); + + paddle::dialect::MatmulOp matmul_op3 = + builder.Build( + full_input_op.out(), full_mat3_y_op.out(), false, false); + + paddle::dialect::FullOp full_eleadd3_y_op = + builder.Build(std::vector{256}, + 1.5, + phi::DataType::FLOAT32, + phi::CPUPlace()); + + paddle::dialect::AddOp add_op3 = builder.Build( + matmul_op3.out(), full_eleadd3_y_op.out()); + + paddle::dialect::ReshapeOp reshape_op3 = + builder.Build( + add_op3.out(), std::vector{1, 300, 8, 32}); + + paddle::dialect::TransposeOp transpose_op3 = + builder.Build(reshape_op3.out(), + std::vector{0, 2, 1, 3}); + + paddle::dialect::ScaleOp scale_op1 = builder.Build( + transpose_op3.out(), 0.1767766922712326, 0.0, true); + + paddle::dialect::MatmulOp matmul_op4 = + builder.Build( + scale_op1.out(), transpose_op2.out(), false, true); + + paddle::dialect::SoftmaxOp softmax_op1 = + builder.Build(matmul_op4.out(), -1); + + // tail + paddle::dialect::MatmulOp matmul_op5 = + builder.Build( + softmax_op1.out(), transpose_op1.out(), false, false); + + paddle::dialect::TransposeOp transpose_op4 = + builder.Build(matmul_op5.out(), + std::vector{0, 2, 1, 3}); + + paddle::dialect::ReshapeOp reshape_op4 = + builder.Build(transpose_op4.out(), + std::vector{1, 300, 256}); + + paddle::dialect::FullOp full_mat4_y_op = + builder.Build(std::vector{256, 256}, + 1.1, + phi::DataType::FLOAT32, + phi::CPUPlace()); + paddle::dialect::MatmulOp matmul_op6 = + builder.Build( + reshape_op4.out(), full_mat4_y_op.out(), false, false); + + paddle::dialect::FullOp full_eleadd4_y_op = + builder.Build(std::vector{256}, + 1.5, + phi::DataType::FLOAT32, + phi::CPUPlace()); + paddle::dialect::AddOp add_op4 = builder.Build( + matmul_op6.out(), full_eleadd4_y_op.out()); + + paddle::dialect::FullOp full_slice_axes_op = + builder.Build( + std::vector{64}, 2, phi::DataType::INT64, phi::CPUPlace()); + paddle::dialect::FullOp full_slice_starts_op = + builder.Build( + std::vector{64}, 2, phi::DataType::INT64, phi::CPUPlace()); + paddle::dialect::SliceOp slice_op1 = builder.Build( + add_op4.out(), full_slice_axes_op.out(), full_slice_starts_op); + + builder.Build(slice_op1.out(), "out", 0); +} + /* TEST(DrrTest, AttentionFuse) { ir::IrContext *ctx = ir::IrContext::Instance();