diff --git a/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc index 66c6a7ddf8c59..60caf1e4b4e94 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc @@ -138,8 +138,7 @@ class ScaleOpPattern : public pir::OpRewritePattern { bool MatchAndRewrite(paddle::dialect::ScaleOp op, pir::PatternRewriter &rewriter) const override { - auto scale_factor_gen_op = - op->operand_source(1).dyn_cast().owner(); + auto scale_factor_gen_op = op->operand_source(1).defining_op(); if (auto full_op = scale_factor_gen_op->dyn_cast()) { @@ -190,8 +189,7 @@ class ReshapeOpPattern bool MatchAndRewrite(paddle::dialect::ReshapeOp op, pir::PatternRewriter &rewriter) const override { - auto scale_factor_gen_op = - op->operand_source(1).dyn_cast().owner(); + auto scale_factor_gen_op = op->operand_source(1).defining_op(); if (auto full_op = scale_factor_gen_op->dyn_cast()) { @@ -232,8 +230,7 @@ class Pool2dOpPattern bool MatchAndRewrite(paddle::dialect::Pool2dOp op, pir::PatternRewriter &rewriter) const override { - auto kernel_size_gen_op = - op->operand_source(1).dyn_cast().owner(); + auto kernel_size_gen_op = op->operand_source(1).defining_op(); if (auto full_op = kernel_size_gen_op->dyn_cast()) { @@ -279,13 +276,11 @@ class IsCloseOpPattern bool MatchAndRewrite(paddle::dialect::IscloseOp op, pir::PatternRewriter &rewriter) const override { auto rtol_op = op->operand_source(2) - .dyn_cast() - .owner() + .defining_op() ->dyn_cast(); auto atol_op = op->operand_source(3) - .dyn_cast() - .owner() + .defining_op() ->dyn_cast(); if (rtol_op && atol_op) { @@ -318,13 +313,11 @@ class SliceOpPattern : public pir::OpRewritePattern { bool MatchAndRewrite(paddle::dialect::SliceOp op, pir::PatternRewriter &rewriter) const override { auto start_gen_op = op->operand_source(1) - .dyn_cast() - .owner() + .defining_op() ->dyn_cast(); auto end_gen_op = op->operand_source(2) - .dyn_cast() - .owner() + .defining_op() ->dyn_cast(); if (start_gen_op && end_gen_op) { @@ -360,16 +353,13 @@ class ConcatOpPattern bool MatchAndRewrite(paddle::dialect::ConcatOp op, pir::PatternRewriter &rewriter) const override { - auto axis_gen_op = op->operand_source(1).dyn_cast().owner(); + auto axis_gen_op = op->operand_source(1).defining_op(); if (auto full_op = axis_gen_op->dyn_cast()) { - int axis = phi::Scalar(full_op.attribute("value") - .dyn_cast<::pir::FloatAttribute>() - .data()) - .to(); + int axis = static_cast( + full_op.attribute("value").dyn_cast<::pir::FloatAttribute>().data()); auto input_ops = op->operand_source(0) - .dyn_cast() - .owner() + .defining_op() ->dyn_cast() .inputs(); @@ -413,12 +403,10 @@ class SplitOpPattern : public pir::OpRewritePattern { bool MatchAndRewrite(paddle::dialect::SplitOp op, pir::PatternRewriter &rewriter) const override { auto sections_gen_op = op->operand_source(1) - .dyn_cast() - .owner() + .defining_op() ->dyn_cast(); auto axis_gen_op = op->operand_source(2) - .dyn_cast() - .owner() + .defining_op() ->dyn_cast(); if (sections_gen_op && axis_gen_op) { auto section_attr = sections_gen_op.attribute("value") @@ -432,11 +420,9 @@ class SplitOpPattern : public pir::OpRewritePattern { section_attr[i].dyn_cast<::pir::Int64Attribute>().data()); } } - - int axis = phi::Scalar(axis_gen_op.attribute("value") - .dyn_cast<::pir::FloatAttribute>() - .data()) - .to(); + int axis = static_cast(axis_gen_op.attribute("value") + .dyn_cast<::pir::FloatAttribute>() + .data()); auto input_ele = op->operand_source(0) .type() @@ -448,15 +434,77 @@ class SplitOpPattern : public pir::OpRewritePattern { auto cinn_split = rewriter.Build( op->operand_source(0), vec_sections, axis); - auto build_split = - op->result(0).first_use().owner()->dyn_cast<::pir::SplitOp>(); + auto orig_out = op.result(0); + for (auto it = orig_out.use_begin(); it != orig_out.use_end();) { + auto slice_op = (it++)->owner(); + CHECK(slice_op->isa<::pir::SliceOp>()) + << "Currently only support pir::slice as downstream op"; + int index = slice_op->dyn_cast<::pir::SliceOp>() + .attribute("index") + .dyn_cast<::pir::Int32Attribute>() + .data(); + rewriter.ReplaceAllUsesWith(slice_op->result(0), + cinn_split.result(index)); + rewriter.EraseOp(slice_op); + } + rewriter.EraseOp(op); + + return true; + } + return false; + } +}; + +class SplitWithNumOpPattern + : public pir::OpRewritePattern { + public: + using pir::OpRewritePattern< + paddle::dialect::SplitWithNumOp>::OpRewritePattern; + + bool MatchAndRewrite(paddle::dialect::SplitWithNumOp op, + pir::PatternRewriter &rewriter) const override { + auto axis_gen_op = op->operand_source(1).defining_op(); + if (auto full_op = axis_gen_op->dyn_cast()) { + int axis = static_cast( + full_op.attribute("value").dyn_cast<::pir::FloatAttribute>().data()); - for (size_t i = 0; i < build_split->num_results(); ++i) { - rewriter.ReplaceAllUsesWith(build_split->result(i), - cinn_split.result(i)); + auto input_ele = op->operand_source(0) + .type() + .dyn_cast(); + if (axis < 0) { + axis += input_ele.dims().size(); } + std::vector sections; + + auto split_dim = input_ele.dims()[axis]; + + auto split_num = + op->attribute("num").dyn_cast<::pir::Int32Attribute>().data(); + auto part_ele = (split_dim + split_num - 1) / split_num; - rewriter.EraseOp(build_split); + int total_split_num = 0; + for (int i = 0; i < split_num - 1; ++i) { + sections.push_back(part_ele); + total_split_num += part_ele; + } + + sections.push_back(split_dim - total_split_num); + + auto cinn_split = rewriter.Build( + op->operand_source(0), sections, axis); + + auto orig_out = op.result(0); + for (auto it = orig_out.use_begin(); it != orig_out.use_end();) { + auto slice_op = (it++)->owner(); + CHECK(slice_op->isa<::pir::SliceOp>()); + int index = slice_op->dyn_cast<::pir::SliceOp>() + .attribute("index") + .dyn_cast<::pir::Int32Attribute>() + .data(); + rewriter.ReplaceAllUsesWith(slice_op->result(0), + cinn_split.result(index)); + rewriter.EraseOp(slice_op); + } rewriter.EraseOp(op); @@ -472,10 +520,8 @@ class AddNOpPattern : public pir::OpRewritePattern { bool MatchAndRewrite(paddle::dialect::AddNOp op, pir::PatternRewriter &rewriter) const override { - auto combine_op = op->operand_source(0) - .dyn_cast() - .owner() - ->dyn_cast(); + auto combine_op = + op->operand_source(0).defining_op()->dyn_cast(); auto input_ops = combine_op.inputs(); auto tmp = input_ops[0]; @@ -501,8 +547,7 @@ class ExpandOpPattern bool MatchAndRewrite(paddle::dialect::ExpandOp op, pir::PatternRewriter &rewriter) const override { auto out_shape_gen_op = op->operand_source(1) - .dyn_cast() - .owner() + .defining_op() ->dyn_cast(); if (out_shape_gen_op) { @@ -541,63 +586,6 @@ class ExpandOpPattern } }; -class SplitWithNumOpPattern - : public pir::OpRewritePattern { - public: - using pir::OpRewritePattern< - paddle::dialect::SplitWithNumOp>::OpRewritePattern; - - bool MatchAndRewrite(paddle::dialect::SplitWithNumOp op, - pir::PatternRewriter &rewriter) const override { - auto axis_gen_op = op->operand_source(1).dyn_cast().owner(); - if (auto full_op = axis_gen_op->dyn_cast()) { - int axis = phi::Scalar(full_op.attribute("value") - .dyn_cast<::pir::FloatAttribute>() - .data()) - .to(); - - auto input_ele = op->operand_source(0) - .type() - .dyn_cast(); - if (axis < 0) { - axis += input_ele.dims().size(); - } - std::vector sections; - - auto split_dim = input_ele.dims()[axis]; - - auto split_num = - op->attribute("num").dyn_cast<::pir::Int32Attribute>().data(); - auto part_ele = (split_dim + split_num - 1) / split_num; - - int total_split_num = 0; - for (int i = 0; i < split_num - 1; ++i) { - sections.push_back(part_ele); - total_split_num += part_ele; - } - - sections.push_back(split_dim - total_split_num); - - auto cinn_split = rewriter.Build( - op->operand_source(0), sections, axis); - - int index = 0; - auto orig_out = op.result(0); - for (auto it = orig_out.use_begin(); it != orig_out.use_end();) { - auto split_op = (it++)->owner(); - rewriter.ReplaceAllUsesWith(split_op->result(0), - cinn_split.result(index++)); - rewriter.EraseOp(split_op); - } - - rewriter.EraseOp(op); - - return true; - } - return false; - } -}; - class UniformOpPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { diff --git a/test/cpp/pir/cinn/pir_all_path_test.cc b/test/cpp/pir/cinn/pir_all_path_test.cc index eabd8e490f0d4..4ef73aadc7f4a 100644 --- a/test/cpp/pir/cinn/pir_all_path_test.cc +++ b/test/cpp/pir/cinn/pir_all_path_test.cc @@ -51,6 +51,44 @@ std::vector<::pir::Type> CreateDenseTensorTypes(const phi::DDim& dims) { return op_output_types; } +static void RunAndCheckResult(::pir::Program* program, + const bool check_result = true, + const float gt_val = 2.0) { + ::pir::IrContext* ctx = ::pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + + cinn::dialect::ir::PdOp2CinnOpConverter(program); + + pir::PassManager pm(ctx); + pm.AddPass( + std::make_unique()); + + pm.AddPass(pir::CreateDeadCodeEliminationPass()); + pm.AddPass(pir::CreateBuildCinnPass()); + pm.AddPass(cinn::dialect::ir::CreateCinnGroupLoweringPass()); + CHECK_EQ(pm.Run(program), true); + + paddle::platform::Place place = paddle::platform::CUDAPlace(0); + + auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(program, place); + + paddle::framework::Scope exe_scope; + + paddle::framework::InterpreterCore executor( + place, {"out@fetch"}, kernel_program->block(), &exe_scope); + + executor.Run({}, true); + + auto out_tensor = + executor.local_scope()->FindVar("out@fetch")->Get(); + + if (check_result) { + bool res0 = simple_cmp(out_tensor.data()[0], gt_val); + EXPECT_EQ(res0, true); + } +} + std::shared_ptr<::pir::Program> BuildGroupProgram() { ::pir::IrContext* ctx = ::pir::IrContext::Instance(); ctx->GetOrRegisterDialect(); @@ -86,35 +124,8 @@ TEST(GroupOp, TestBuild) { // Step 1: Construct pir::Program ::pir::IrContext* ctx = ::pir::IrContext::Instance(); std::shared_ptr<::pir::Program> program = BuildGroupProgram(); - ctx->GetOrRegisterDialect(); - ctx->GetOrRegisterDialect(); - - cinn::dialect::ir::PdOp2CinnOpConverter(program.get()); - - pir::PassManager pm(ctx); - pm.AddPass( - std::make_unique()); - pm.AddPass(pir::CreateBuildCinnPass()); - pm.AddPass(cinn::dialect::ir::CreateCinnGroupLoweringPass()); - CHECK_EQ(pm.Run(program.get()), true); - - paddle::platform::Place place = paddle::platform::CUDAPlace(0); - auto kernel_program = - paddle::dialect::PdOpLowerToKernelPass(program.get(), place); - - paddle::framework::Scope exe_scope; - - paddle::framework::InterpreterCore executor( - place, {"out@fetch"}, kernel_program->block(), &exe_scope); - - executor.Run({}, true); - - auto out_tensor = - executor.local_scope()->FindVar("out@fetch")->Get(); - - bool res0 = simple_cmp(out_tensor.data()[0], 1.0 / 768); - EXPECT_EQ(res0, true); + RunAndCheckResult(program.get(), true, 1.0 / 768); } std::shared_ptr<::pir::Program> BuildLayerNormProgram() { @@ -193,35 +204,8 @@ TEST(GroupOp, TestBuildLayerNorm) { // Step 1: Construct pir::Program ::pir::IrContext* ctx = ::pir::IrContext::Instance(); std::shared_ptr<::pir::Program> program = BuildLayerNormProgram(); - ctx->GetOrRegisterDialect(); - ctx->GetOrRegisterDialect(); - - cinn::dialect::ir::PdOp2CinnOpConverter(program.get()); - - pir::PassManager pm(ctx); - pm.AddPass( - std::make_unique()); - pm.AddPass(pir::CreateBuildCinnPass()); - pm.AddPass(cinn::dialect::ir::CreateCinnGroupLoweringPass()); - CHECK_EQ(pm.Run(program.get()), true); - - paddle::platform::Place place = paddle::platform::CUDAPlace(0); - - auto kernel_program = - paddle::dialect::PdOpLowerToKernelPass(program.get(), place); - - paddle::framework::Scope exe_scope; - paddle::framework::InterpreterCore executor( - place, {"out@fetch"}, kernel_program->block(), &exe_scope); - - // TODO(phlrain): fix exec error - executor.Run({}, true); - - auto out_tensor = - executor.local_scope()->FindVar("out@fetch")->Get(); - // auto out_tensor = - // executor.local_scope()->FindVar("out@fetch")->Get(); + RunAndCheckResult(program.get(), false); } std::shared_ptr<::pir::Program> BuildDropOutProgram() { @@ -278,32 +262,8 @@ TEST(GroupOp, TestBuildDropout) { // Step 1: Construct pir::Program ::pir::IrContext* ctx = ::pir::IrContext::Instance(); std::shared_ptr<::pir::Program> program = BuildDropOutProgram(); - ctx->GetOrRegisterDialect(); - ctx->GetOrRegisterDialect(); - cinn::dialect::ir::PdOp2CinnOpConverter(program.get()); - - pir::PassManager pm(ctx); - pm.AddPass( - std::make_unique()); - pm.AddPass(pir::CreateBuildCinnPass()); - pm.AddPass(cinn::dialect::ir::CreateCinnGroupLoweringPass()); - CHECK_EQ(pm.Run(program.get()), true); - - paddle::platform::Place place = paddle::platform::CUDAPlace(0); - - auto kernel_program = - paddle::dialect::PdOpLowerToKernelPass(program.get(), place); - - paddle::framework::Scope exe_scope; - - paddle::framework::InterpreterCore executor( - place, {"out@fetch"}, kernel_program->block(), &exe_scope); - - executor.Run({}, true); - - auto out_tensor = - executor.local_scope()->FindVar("out@fetch")->Get(); + RunAndCheckResult(program.get(), false); } std::shared_ptr<::pir::Program> BuildScaleGroupProgram() { @@ -332,35 +292,8 @@ TEST(GroupOp, TestBuildScale) { // Step 1: Construct pir::Program ::pir::IrContext* ctx = ::pir::IrContext::Instance(); std::shared_ptr<::pir::Program> program = BuildScaleGroupProgram(); - ctx->GetOrRegisterDialect(); - ctx->GetOrRegisterDialect(); - - cinn::dialect::ir::PdOp2CinnOpConverter(program.get()); - - pir::PassManager pm(ctx); - pm.AddPass( - std::make_unique()); - pm.AddPass(pir::CreateBuildCinnPass()); - pm.AddPass(cinn::dialect::ir::CreateCinnGroupLoweringPass()); - CHECK_EQ(pm.Run(program.get()), true); - - paddle::platform::Place place = paddle::platform::CUDAPlace(0); - - auto kernel_program = - paddle::dialect::PdOpLowerToKernelPass(program.get(), place); - - paddle::framework::Scope exe_scope; - paddle::framework::InterpreterCore executor( - place, {"out@fetch"}, kernel_program->block(), &exe_scope); - - executor.Run({}, true); - - auto out_tensor = - executor.local_scope()->FindVar("out@fetch")->Get(); - - bool res0 = simple_cmp(out_tensor.data()[0], 0.5); - EXPECT_EQ(res0, true); + RunAndCheckResult(program.get(), true, 0.5); } std::shared_ptr<::pir::Program> BuildScaleTensorGroupProgram() { @@ -395,35 +328,8 @@ TEST(GroupOp, TestBuildScaleTensor) { // Step 1: Construct pir::Program ::pir::IrContext* ctx = ::pir::IrContext::Instance(); std::shared_ptr<::pir::Program> program = BuildScaleTensorGroupProgram(); - ctx->GetOrRegisterDialect(); - ctx->GetOrRegisterDialect(); - - cinn::dialect::ir::PdOp2CinnOpConverter(program.get()); - - pir::PassManager pm(ctx); - pm.AddPass( - std::make_unique()); - pm.AddPass(pir::CreateBuildCinnPass()); - pm.AddPass(cinn::dialect::ir::CreateCinnGroupLoweringPass()); - CHECK_EQ(pm.Run(program.get()), true); - - paddle::platform::Place place = paddle::platform::CUDAPlace(0); - - auto kernel_program = - paddle::dialect::PdOpLowerToKernelPass(program.get(), place); - - paddle::framework::Scope exe_scope; - - paddle::framework::InterpreterCore executor( - place, {"out@fetch"}, kernel_program->block(), &exe_scope); - - executor.Run({}, true); - - auto out_tensor = - executor.local_scope()->FindVar("out@fetch")->Get(); - bool res0 = simple_cmp(out_tensor.data()[0], 0.5); - EXPECT_EQ(res0, true); + RunAndCheckResult(program.get(), true, 0.5); } std::shared_ptr<::pir::Program> BuildPowerProgram() { @@ -464,35 +370,8 @@ TEST(GroupOp, TestBuildPower) { // Step 1: Construct pir::Program ::pir::IrContext* ctx = ::pir::IrContext::Instance(); std::shared_ptr<::pir::Program> program = BuildPowerProgram(); - ctx->GetOrRegisterDialect(); - ctx->GetOrRegisterDialect(); - - cinn::dialect::ir::PdOp2CinnOpConverter(program.get()); - - pir::PassManager pm(ctx); - pm.AddPass( - std::make_unique()); - pm.AddPass(pir::CreateBuildCinnPass()); - pm.AddPass(cinn::dialect::ir::CreateCinnGroupLoweringPass()); - CHECK_EQ(pm.Run(program.get()), true); - - paddle::platform::Place place = paddle::platform::CUDAPlace(0); - - auto kernel_program = - paddle::dialect::PdOpLowerToKernelPass(program.get(), place); - - paddle::framework::Scope exe_scope; - - paddle::framework::InterpreterCore executor( - place, {"out@fetch"}, kernel_program->block(), &exe_scope); - executor.Run({}, true); - - auto out_tensor = - executor.local_scope()->FindVar("out@fetch")->Get(); - - bool res0 = simple_cmp(out_tensor.data()[0], 16.0); - EXPECT_EQ(res0, true); + RunAndCheckResult(program.get(), true, 16.0); } std::shared_ptr<::pir::Program> BuildLayerNorm2Program() { @@ -587,33 +466,8 @@ TEST(GroupOp, TestBuildLayerNorm2) { // Step 1: Construct pir::Program ::pir::IrContext* ctx = ::pir::IrContext::Instance(); std::shared_ptr<::pir::Program> program = BuildLayerNorm2Program(); - ctx->GetOrRegisterDialect(); - ctx->GetOrRegisterDialect(); - - cinn::dialect::ir::PdOp2CinnOpConverter(program.get()); - - pir::PassManager pm(ctx); - pm.AddPass( - std::make_unique()); - pm.AddPass(pir::CreateBuildCinnPass()); - pm.AddPass(cinn::dialect::ir::CreateCinnGroupLoweringPass()); - CHECK_EQ(pm.Run(program.get()), true); - - paddle::platform::Place place = paddle::platform::CUDAPlace(0); - - auto kernel_program = - paddle::dialect::PdOpLowerToKernelPass(program.get(), place); - - paddle::framework::Scope exe_scope; - paddle::framework::InterpreterCore executor( - place, {"out@fetch"}, kernel_program->block(), &exe_scope); - - // TODO(phlrain): fix exec error - executor.Run({}, true); - - // auto out_tensor = - // executor.local_scope()->FindVar("out@fetch")->Get(); + RunAndCheckResult(program.get(), false); } std::shared_ptr<::pir::Program> BuildSum2GroupProgram() { @@ -649,41 +503,8 @@ TEST(GroupOp, TestBuildSum2Group) { // Step 1: Construct pir::Program ::pir::IrContext* ctx = ::pir::IrContext::Instance(); std::shared_ptr<::pir::Program> program = BuildSum2GroupProgram(); - ctx->GetOrRegisterDialect(); - ctx->GetOrRegisterDialect(); - - cinn::dialect::ir::PdOp2CinnOpConverter(program.get()); - - pir::PassManager pm(ctx); - pm.AddPass( - std::make_unique()); - pm.AddPass(pir::CreateBuildCinnPass()); - pm.AddPass(cinn::dialect::ir::CreateCinnGroupLoweringPass()); - CHECK_EQ(pm.Run(program.get()), true); - - paddle::platform::Place place = paddle::platform::CUDAPlace(0); - - auto kernel_program = - paddle::dialect::PdOpLowerToKernelPass(program.get(), place); - - paddle::framework::Scope exe_scope; - paddle::framework::InterpreterCore executor( - place, {"out@fetch"}, kernel_program->block(), &exe_scope); - - executor.Run({}, true); - - auto out_tensor = - executor.local_scope()->FindVar("out@fetch")->Get(); - - auto out_tensor2 = - executor.local_scope()->FindVar("out2@fetch")->Get(); - - bool res0 = simple_cmp(out_tensor.data()[0], 1.0); - EXPECT_EQ(res0, true); - - bool res1 = (out_tensor2.data()[0] == 0.0); - EXPECT_EQ(res1, true); + RunAndCheckResult(program.get(), true, 1.0); } std::shared_ptr<::pir::Program> BuildConcatProgram() { @@ -719,37 +540,8 @@ TEST(GroupOp, TestBuildConcat) { // Step 1: Construct pir::Program ::pir::IrContext* ctx = ::pir::IrContext::Instance(); std::shared_ptr<::pir::Program> program = BuildConcatProgram(); - ctx->GetOrRegisterDialect(); - ctx->GetOrRegisterDialect(); - - cinn::dialect::ir::PdOp2CinnOpConverter(program.get()); - - pir::PassManager pm(ctx); - pm.AddPass( - std::make_unique()); - - pm.AddPass(pir::CreateDeadCodeEliminationPass()); - pm.AddPass(pir::CreateBuildCinnPass()); - pm.AddPass(cinn::dialect::ir::CreateCinnGroupLoweringPass()); - CHECK_EQ(pm.Run(program.get()), true); - - paddle::platform::Place place = paddle::platform::CUDAPlace(0); - - auto kernel_program = - paddle::dialect::PdOpLowerToKernelPass(program.get(), place); - - paddle::framework::Scope exe_scope; - - paddle::framework::InterpreterCore executor( - place, {"out@fetch"}, kernel_program->block(), &exe_scope); - executor.Run({}, true); - - auto out_tensor = - executor.local_scope()->FindVar("out@fetch")->Get(); - - bool res0 = simple_cmp(out_tensor.data()[0], 2.0); - EXPECT_EQ(res0, true); + RunAndCheckResult(program.get(), true, 2.0); } std::shared_ptr<::pir::Program> BuildSliceProgram() { @@ -778,43 +570,13 @@ std::shared_ptr<::pir::Program> BuildSliceProgram() { return program; } -// TEST(GroupOp, TestBuildSlice) { -// // Step 1: Construct pir::Program -// ::pir::IrContext* ctx = ::pir::IrContext::Instance(); -// std::shared_ptr<::pir::Program> program = BuildSliceProgram(); -// ctx->GetOrRegisterDialect(); -// ctx->GetOrRegisterDialect(); - -// cinn::dialect::ir::PdOp2CinnOpConverter(program.get()); - -// pir::PassManager pm(ctx); -// pm.AddPass( -// std::make_unique()); - -// pm.AddPass(pir::CreateDeadCodeEliminationPass()); -// pm.AddPass(pir::CreateBuildCinnPass()); -// CHECK_EQ(pm.Run(program.get()), true); - -// auto res = cinn::dialect::ir::CINNGroupLoweringPass(program.get()); - -// paddle::platform::Place place = paddle::platform::CUDAPlace(0); - -// auto kernel_program = -// paddle::dialect::PdOpLowerToKernelPass(res.get(), place); - -// paddle::framework::Scope exe_scope; - -// paddle::framework::InterpreterCore executor( -// place, {"out@fetch"}, kernel_program->block(), &exe_scope); - -// executor.Run({}, true); - -// // auto out_tensor = -// // executor.local_scope()->FindVar("out@fetch")->Get(); +TEST(GroupOp, TestBuildSlice) { + // Step 1: Construct pir::Program + ::pir::IrContext* ctx = ::pir::IrContext::Instance(); + std::shared_ptr<::pir::Program> program = BuildSliceProgram(); -// // bool res0 = simple_cmp(out_tensor.data()[0], 2.0); -// // EXPECT_EQ(res0, true); -// } + RunAndCheckResult(program.get(), true, 2.0); +} std::shared_ptr<::pir::Program> BuildSplitProgram() { ::pir::IrContext* ctx = ::pir::IrContext::Instance(); @@ -840,39 +602,8 @@ TEST(GroupOp, TestBuildSplit) { // Step 1: Construct pir::Program ::pir::IrContext* ctx = ::pir::IrContext::Instance(); std::shared_ptr<::pir::Program> program = BuildSplitProgram(); - ctx->GetOrRegisterDialect(); - ctx->GetOrRegisterDialect(); - - cinn::dialect::ir::PdOp2CinnOpConverter(program.get()); - - pir::PassManager pm(ctx); - pm.AddPass( - std::make_unique()); - - pm.AddPass(pir::CreateDeadCodeEliminationPass()); - pm.AddPass(pir::CreateBuildCinnPass()); - CHECK_EQ(pm.Run(program.get()), true); - // TODO(phlrain): codengen will failed in split op - // auto res = cinn::dialect::ir::CINNGroupLoweringPass(program.get()); - - // paddle::platform::Place place = paddle::platform::CUDAPlace(0); - - // auto kernel_program = - // paddle::dialect::PdOpLowerToKernelPass(res.get(), place); - - // paddle::framework::Scope exe_scope; - - // paddle::framework::InterpreterCore executor( - // place, {"out@fetch"}, kernel_program->block(), &exe_scope); - - // executor.Run({}, true); - - // auto out_tensor = - // executor.local_scope()->FindVar("out@fetch")->Get(); - - // bool res0 = simple_cmp(out_tensor.data()[0], 2.0); - // EXPECT_EQ(res0, true); + RunAndCheckResult(program.get(), true, 2.0); } std::shared_ptr<::pir::Program> BuildAddNProgram() { @@ -915,38 +646,8 @@ TEST(GroupOp, TestBuildAddN) { // Step 1: Construct pir::Program ::pir::IrContext* ctx = ::pir::IrContext::Instance(); std::shared_ptr<::pir::Program> program = BuildAddNProgram(); - ctx->GetOrRegisterDialect(); - ctx->GetOrRegisterDialect(); - - cinn::dialect::ir::PdOp2CinnOpConverter(program.get()); - - pir::PassManager pm(ctx); - pm.AddPass( - std::make_unique()); - - pm.AddPass(pir::CreateDeadCodeEliminationPass()); - pm.AddPass(pir::CreateBuildCinnPass()); - pm.AddPass(cinn::dialect::ir::CreateCinnGroupLoweringPass()); - - CHECK_EQ(pm.Run(program.get()), true); - paddle::platform::Place place = paddle::platform::CUDAPlace(0); - - auto kernel_program = - paddle::dialect::PdOpLowerToKernelPass(program.get(), place); - - paddle::framework::Scope exe_scope; - - paddle::framework::InterpreterCore executor( - place, {"out@fetch"}, kernel_program->block(), &exe_scope); - - executor.Run({}, true); - - auto out_tensor = - executor.local_scope()->FindVar("out@fetch")->Get(); - - bool res0 = simple_cmp(out_tensor.data()[0], 6.0); - EXPECT_EQ(res0, true); + RunAndCheckResult(program.get(), true, 6.0); } std::shared_ptr<::pir::Program> BuildSplitSectionProgram() { @@ -966,8 +667,8 @@ std::shared_ptr<::pir::Program> BuildSplitSectionProgram() { .Build( x, std::vector({3, 5, 8}), -1) .out(); - auto out_list = builder.Build(split_arr).outputs(); - builder.Build(out_list[0], "out", 0); + auto out = builder.Build(split_arr, 0).result(0); + builder.Build(out, "out", 0); return program; } @@ -975,35 +676,6 @@ TEST(GroupOp, TestBuildSplitSection) { // Step 1: Construct pir::Program ::pir::IrContext* ctx = ::pir::IrContext::Instance(); std::shared_ptr<::pir::Program> program = BuildSplitSectionProgram(); - ctx->GetOrRegisterDialect(); - ctx->GetOrRegisterDialect(); - - cinn::dialect::ir::PdOp2CinnOpConverter(program.get()); - - pir::PassManager pm(ctx); - pm.AddPass( - std::make_unique()); - - pm.AddPass(pir::CreateDeadCodeEliminationPass()); - pm.AddPass(pir::CreateBuildCinnPass()); - pm.AddPass(cinn::dialect::ir::CreateCinnGroupLoweringPass()); - CHECK_EQ(pm.Run(program.get()), true); - - paddle::platform::Place place = paddle::platform::CUDAPlace(0); - - auto kernel_program = - paddle::dialect::PdOpLowerToKernelPass(program.get(), place); - - paddle::framework::Scope exe_scope; - - paddle::framework::InterpreterCore executor( - place, {"out@fetch"}, kernel_program->block(), &exe_scope); - - executor.Run({}, true); - - auto out_tensor = - executor.local_scope()->FindVar("out@fetch")->Get(); - bool res0 = simple_cmp(out_tensor.data()[0], 2.0); - EXPECT_EQ(res0, true); + RunAndCheckResult(program.get(), 2.0); }