Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#22 from gongshaotian/drr
Browse files Browse the repository at this point in the history
Add buildProgram in drr_attention_fuse_test.cc
  • Loading branch information
yuanlehome authored Aug 30, 2023
2 parents 5aeec24 + e959e60 commit 5b6717d
Showing 1 changed file with 141 additions and 0 deletions.
141 changes: 141 additions & 0 deletions test/cpp/ir/pattern_rewrite/drr_attention_fuse_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,147 @@ class AttentionFusePass : public ir::Pass {
ir::FrozenRewritePatternSet patterns_;
};

void BuildProgram(ir::Builder &builder) {
paddle::dialect::FullOp full_input_op =
builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{1, 300, 256},
0.9,
phi::DataType::FLOAT32,
phi::CPUPlace());

// left
paddle::dialect::FullOp full_mat1_y_op =
builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{256, 256},
1.1,
phi::DataType::FLOAT32,
phi::CPUPlace());
paddle::dialect::MatmulOp matmul_op1 =
builder.Build<paddle::dialect::MatmulOp>(
full_input_op.out(), full_mat1_y_op.out(), false, false);

paddle::dialect::FullOp full_eleadd1_y_op =
builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{256},
1.5,
phi::DataType::FLOAT32,
phi::CPUPlace());

paddle::dialect::AddOp add_op1 = builder.Build<paddle::dialect::AddOp>(
matmul_op1.out(), full_eleadd1_y_op.out());

paddle::dialect::ReshapeOp reshape_op1 =
builder.Build<paddle::dialect::ReshapeOp>(
add_op1.out(), std::vector<int64_t>{1, 300, 8, 32});

paddle::dialect::TransposeOp transpose_op1 =
builder.Build<paddle::dialect::TransposeOp>(reshape_op1.out(),
std::vector<int>{0, 2, 1, 3});

// middle
paddle::dialect::FullOp full_mat2_y_op =
builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{256, 256},
1.1,
phi::DataType::FLOAT32,
phi::CPUPlace());

paddle::dialect::MatmulOp matmul_op2 =
builder.Build<paddle::dialect::MatmulOp>(
full_input_op.out(), full_mat2_y_op.out(), false, false);

paddle::dialect::FullOp full_eleadd2_y_op =
builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{256},
1.5,
phi::DataType::FLOAT32,
phi::CPUPlace());
paddle::dialect::AddOp add_op2 = builder.Build<paddle::dialect::AddOp>(
matmul_op2.out(), full_eleadd2_y_op.out());

paddle::dialect::ReshapeOp reshape_op2 =
builder.Build<paddle::dialect::ReshapeOp>(
add_op2.out(), std::vector<int>{1, 300, 8, 32});

paddle::dialect::TransposeOp transpose_op2 =
builder.Build<paddle::dialect::TransposeOp>(reshape_op2.out(),
std::vector<int>{0, 2, 1, 3});

// right
paddle::dialect::FullOp full_mat3_y_op =
builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{256, 256},
1.1,
phi::DataType::FLOAT32,
phi::CPUPlace());

paddle::dialect::MatmulOp matmul_op3 =
builder.Build<paddle::dialect::MatmulOp>(
full_input_op.out(), full_mat3_y_op.out(), false, false);

paddle::dialect::FullOp full_eleadd3_y_op =
builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{256},
1.5,
phi::DataType::FLOAT32,
phi::CPUPlace());

paddle::dialect::AddOp add_op3 = builder.Build<paddle::dialect::AddOp>(
matmul_op3.out(), full_eleadd3_y_op.out());

paddle::dialect::ReshapeOp reshape_op3 =
builder.Build<paddle::dialect::ReshapeOp>(
add_op3.out(), std::vector<int>{1, 300, 8, 32});

paddle::dialect::TransposeOp transpose_op3 =
builder.Build<paddle::dialect::TransposeOp>(reshape_op3.out(),
std::vector<int>{0, 2, 1, 3});

paddle::dialect::ScaleOp scale_op1 = builder.Build<paddle::dialect::ScaleOp>(
transpose_op3.out(), 0.1767766922712326, 0.0, true);

paddle::dialect::MatmulOp matmul_op4 =
builder.Build<paddle::dialect::MatmulOp>(
scale_op1.out(), transpose_op2.out(), false, true);

paddle::dialect::SoftmaxOp softmax_op1 =
builder.Build<paddle::dialect::SoftmaxOp>(matmul_op4.out(), -1);

// tail
paddle::dialect::MatmulOp matmul_op5 =
builder.Build<paddle::dialect::MatmulOp>(
softmax_op1.out(), transpose_op1.out(), false, false);

paddle::dialect::TransposeOp transpose_op4 =
builder.Build<paddle::dialect::TransposeOp>(matmul_op5.out(),
std::vector<int>{0, 2, 1, 3});

paddle::dialect::ReshapeOp reshape_op4 =
builder.Build<paddle::dialect::ReshapeOp>(transpose_op4.out(),
std::vector<int>{1, 300, 256});

paddle::dialect::FullOp full_mat4_y_op =
builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{256, 256},
1.1,
phi::DataType::FLOAT32,
phi::CPUPlace());
paddle::dialect::MatmulOp matmul_op6 =
builder.Build<paddle::dialect::MatmulOp>(
reshape_op4.out(), full_mat4_y_op.out(), false, false);

paddle::dialect::FullOp full_eleadd4_y_op =
builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{256},
1.5,
phi::DataType::FLOAT32,
phi::CPUPlace());
paddle::dialect::AddOp add_op4 = builder.Build<paddle::dialect::AddOp>(
matmul_op6.out(), full_eleadd4_y_op.out());

paddle::dialect::FullOp full_slice_axes_op =
builder.Build<paddle::dialect::FullOp>(
std::vector<int>{64}, 2, phi::DataType::INT64, phi::CPUPlace());
paddle::dialect::FullOp full_slice_starts_op =
builder.Build<paddle::dialect::FullOp>(
std::vector<int>{64}, 2, phi::DataType::INT64, phi::CPUPlace());
paddle::dialect::SliceOp slice_op1 = builder.Build<paddle::dialect::SliceOp>(
add_op4.out(), full_slice_axes_op.out(), full_slice_starts_op);

builder.Build<paddle::dialect::FetchOp>(slice_op1.out(), "out", 0);
}

/*
TEST(DrrTest, AttentionFuse) {
ir::IrContext *ctx = ir::IrContext::Instance();
Expand Down

0 comments on commit 5b6717d

Please sign in to comment.