From d3b786b94bc784df9535371e738cba13b6b8c212 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Fri, 17 May 2024 14:07:19 +0800 Subject: [PATCH] [CINN/Fusion] horizontal support dynamic shape and enhance fusion ability (#63913) * [CINN] Support horizontal fusion * Change data type * Support horizontal fusion * Fix compile error * add topo sort in backend fusion * horizontal support dynamic shape and enhance fusion ability * fix * xx * fix some bugs * fix * xxxx * fix * horizontal operator fusion enhance * fix * fix * fix * fix * fix by code review * fix --------- Co-authored-by: jiahongyu --- .../operator/transforms/add_cinn_pass.cc | 2 +- .../transforms/cinn_group_cluster_pass.cc | 5 +- .../hlir/framework/pir/op_lowering_impl.cc | 5 +- .../hlir/framework/pir/trivial_op_impl.cc | 67 +++-- .../cinn/hlir/framework/pir/trivial_op_impl.h | 7 +- .../hlir/framework/pir/trivial_op_util.cc | 77 +++++- .../cinn/hlir/framework/pir/trivial_op_util.h | 13 +- paddle/cinn/operator_fusion/backend/pattern.h | 22 +- .../operator_fusion/backend/pattern_fuser.cc | 252 ++++++++++++++---- .../operator_fusion/backend/pattern_fuser.h | 5 - .../cinn/operator_fusion/frontend/pattern.h | 24 +- .../operator_fusion/frontend/pattern_fuser.cc | 14 +- .../operator_fusion/frontend/pattern_fuser.h | 5 - paddle/cinn/operator_fusion/group_cluster.h | 5 +- paddle/cinn/operator_fusion/pattern.h | 29 +- paddle/cinn/operator_fusion/pattern_fuser.h | 158 +++++++++++ paddle/cinn/operator_fusion/pattern_graph.cc | 8 +- paddle/cinn/operator_fusion/pattern_graph.h | 59 ++-- .../operator_fusion/policy/dim_relation.h | 6 +- paddle/cinn/operator_fusion/utils.h | 134 ++++++++++ test/cpp/pir/cinn/pir_all_path_test.cc | 4 +- .../cinn/inference/test_llama_postprocess.py | 4 +- .../pir/cinn/test_fusion_softmax_subgraph.py | 125 +++++++++ test/ir/pir/cinn/test_horizontal_fusion.py | 63 +++++ 24 files changed, 923 insertions(+), 170 deletions(-) create mode 100644 test/ir/pir/cinn/test_fusion_softmax_subgraph.py create mode 100644 test/ir/pir/cinn/test_horizontal_fusion.py diff --git a/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc index 088291de41511..28131f53b4c5a 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc @@ -154,7 +154,7 @@ void ApplyDivideGroupOpToFusionOpPass( std::shared_ptr pass_manager = CreatePassManager(); if (FLAGS_group_schedule_tiling_first) { pass_manager->AddPass(cinn::dialect::ir::CreateCinnGroupClusterPass()); - pass_manager->AddPass(cinn::dialect::ir::CreateAddStoreInFusionOpPass()); + // pass_manager->AddPass(cinn::dialect::ir::CreateAddStoreInFusionOpPass()); } else { pass_manager->AddPass( cinn::dialect::ir::CreateDivideGroupOpToFusionOpPass()); diff --git a/paddle/cinn/hlir/dialect/operator/transforms/cinn_group_cluster_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/cinn_group_cluster_pass.cc index c3bf60c601b7d..26487e055c1ff 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/cinn_group_cluster_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/cinn_group_cluster_pass.cc @@ -246,7 +246,7 @@ std::vector GroupSplit(cinn::dialect::GroupOp group_op) { std::function func = [](pir::Operation* op) { return cinn::fusion::FrontendContent(op); }; const auto& contents = cinn::fusion::MapVector(group_op.GetOperators(), func); - auto cluster_result = cinn::fusion::ClusterOps(contents); + auto cluster_result = cinn::fusion::ClusterOps(contents, {}); std::vector> result; std::transform( cluster_result.begin(), @@ -390,6 +390,9 @@ class CinnGroupClusterPass : public pir::PatternRewritePass { } bool CanApplyOn(pir::Operation* op) const override { + if (op->isa()) { + return false; + } return op->num_regions() > 0; } }; diff --git a/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc b/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc index a87bbf953de1b..b108cd8a4c727 100644 --- a/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc +++ b/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc @@ -203,7 +203,8 @@ BucketLoweredFuncsWrapper OpLowererImpl::BucketLower( // =========== OpFusion ============ - func_bodies = OperationFusion(ops, func_bodies); + // VLOG(4) << "Bucket Lower output values is : " << group->output_values(); + func_bodies = OperationFusion(ops, func_bodies, group->output_values()); const auto& fusion_group_info = GetFusionGroupInfo(func_bodies); // =========== CodeGen And Optimizer ================ @@ -728,7 +729,7 @@ std::vector OpLowererImpl::PostProcess( group->mut_output_names().clear(); // collect all output tensor. - for (auto op_result : group->GetGroupOutputValues()) { + for (auto op_result : group->output_values()) { if (tensor_map.count(op_result) == 0) { continue; } diff --git a/paddle/cinn/hlir/framework/pir/trivial_op_impl.cc b/paddle/cinn/hlir/framework/pir/trivial_op_impl.cc index d4ab3e8ebc084..b68f0f5f8ebe0 100644 --- a/paddle/cinn/hlir/framework/pir/trivial_op_impl.cc +++ b/paddle/cinn/hlir/framework/pir/trivial_op_impl.cc @@ -187,17 +187,21 @@ std::vector GetOutputIters(const FusibleOp& op) { return AppendBound(std::visit(Visitor(), op), _GetRootExpr(op)); } +std::vector GetAllIterVars(const ir::Expr& expr) { + ir::Expr compute_schedule_block_realize = + (ExprSetFinderUtils::ChildScheduleBlockRealizes * + ExprSetFinderUtils::ScheduleBlockRealizeIsNotInit) + .GetSingle(expr); + + const std::vector& all_iter_expr = + compute_schedule_block_realize.As() + ->iter_values; + return ComposeUtils::ExprVec2VarVec(all_iter_expr); +} + std::vector GetReduceIters(const ReduceOp& op) { auto GetUnorderedAllIterVars = [](const ReduceOp& op) { - ir::Expr compute_schedule_block_realize = - (ExprSetFinderUtils::ChildScheduleBlockRealizes * - ExprSetFinderUtils::ScheduleBlockRealizeIsNotInit) - .GetSingle(_GetRootExpr(op)); - - const std::vector& all_iter_expr = - compute_schedule_block_realize.As() - ->iter_values; - return ComposeUtils::ExprVec2VarVec(all_iter_expr); + return GetAllIterVars(_GetRootExpr(op)); }; // Iter Vars not appearing in outer_iter_vars are pushed into @@ -560,16 +564,39 @@ std::pair SplitReduceOp(const ReduceOp& reduce_op) { return std::make_pair(result_trivial, result_reduce); } +std::vector GetAllForIters(const ir::Expr& expr) { + using cinn::hlir::framework::pir::trivial_fusion_detail::ExprSetFinderUtils:: + ChildFors; + using cinn::hlir::framework::pir::trivial_fusion_detail::ExprSetFinderUtils:: + ChildScheduleBlockRealizes; + using cinn::hlir::framework::pir::trivial_fusion_detail::ExprSetFinderUtils:: + FindFather; + using cinn::hlir::framework::pir::trivial_fusion_detail::ExprSetFinderUtils:: + IsFor; + using cinn::hlir::framework::pir::trivial_fusion_detail::ExprSetFinderUtils:: + ScheduleBlockRealizeIsNotInit; + const auto& all_father_fors = + (ChildScheduleBlockRealizes * ScheduleBlockRealizeIsNotInit * + FindFather(expr) * IsFor)(expr); + std::vector vars; + for (const auto& for_expr : all_father_fors) { + vars.push_back(for_expr.As()->loop_var); + } + VLOG(4) << "GetAllForIters : " << expr + << "\n var is : " << utils::Join(vars, ","); + return vars; +} + } // namespace trivial_fusion_detail std::vector OperationFusion( const std::vector<::pir::Operation*>& original_ops, - const std::vector& op_compute_bodies) { - PADDLE_ENFORCE_EQ(FLAGS_group_schedule_tiling_first, - true, - ::common::errors::PreconditionNotMet( - "TrivialFusion must be used with tiling first, set " - "FLAGS_group_schedule_tiling_first=1")); + const std::vector& op_compute_bodies, + const std::vector<::pir::Value>& outputs) { + PADDLE_ENFORCE(FLAGS_group_schedule_tiling_first, + ::common::errors::PreconditionNotMet( + "TrivialFusion must be used with tiling first, set " + "FLAGS_group_schedule_tiling_first=1")); const auto& ops = trivial_fusion_detail::FilterVector( original_ops, [](const ::pir::Operation* op) { if (op->name() == "cinn_op.generate_shape") { @@ -581,10 +608,9 @@ std::vector OperationFusion( std::vector contents; for (int i = 0; i < ops.size(); i++) { contents.emplace_back(ops[i], op_compute_bodies[i]); - // contents.emplace_back(ops[i]); } const auto& fusion_nodes = - cinn::fusion::ClusterOps(contents); + cinn::fusion::ClusterOps(contents, outputs); PADDLE_ENFORCE_EQ(fusion_nodes.size(), 1, @@ -601,6 +627,8 @@ std::vector OperationFusion( FusionGroupInfo GetFusionGroupInfo( const std::vector& op_compute_bodies) { + using trivial_fusion_detail::AppendBound; + using trivial_fusion_detail::GetAllForIters; using trivial_fusion_detail::ReduceOp; using trivial_fusion_detail::ComposeUtils::ConcatVector; using trivial_fusion_detail::ExprSetFinderUtils::ChildScheduleBlockRealizes; @@ -618,7 +646,7 @@ FusionGroupInfo GetFusionGroupInfo( ReduceOp op = ReduceOp(body); if (group_info.reduce_var_name.empty()) { std::vector all_iters = - ConcatVector(GetOutputIters(op), GetReduceIters(op)); + AppendBound(GetAllForIters(body), body); std::transform(all_iters.begin(), all_iters.end(), std::back_inserter(group_info.loop_ranges), @@ -631,7 +659,8 @@ FusionGroupInfo GetFusionGroupInfo( return (int64_t)-1; } }); - std::vector reduce_iters = GetReduceIters(op); + std::vector reduce_iters = fusion::FilterVector( + all_iters, [](const ir::Var& var) { return var->is_reduce_axis; }); for (int64_t i = all_iters.size() - reduce_iters.size(); i < all_iters.size(); i++) { diff --git a/paddle/cinn/hlir/framework/pir/trivial_op_impl.h b/paddle/cinn/hlir/framework/pir/trivial_op_impl.h index 1abe14a48266c..5559cde6796ca 100644 --- a/paddle/cinn/hlir/framework/pir/trivial_op_impl.h +++ b/paddle/cinn/hlir/framework/pir/trivial_op_impl.h @@ -159,6 +159,10 @@ FusibleOp SinkTrivialLoopAlign(TrivialOp trivial_op, ReduceOp reduce_op, std::vector fake_reduce_iter_idx); +std::vector GetAllIterVars(const ir::Expr& expr); + +std::vector GetAllForIters(const ir::Expr& expr); + } // namespace trivial_fusion_detail struct FusionGroupInfo { @@ -178,7 +182,8 @@ FusionGroupInfo GetFusionGroupInfo( std::vector OperationFusion( const std::vector<::pir::Operation*>& ops, - const std::vector& op_compute_bodies); + const std::vector& op_compute_bodies, + const std::vector<::pir::Value>& outputs); } // namespace pir } // namespace framework diff --git a/paddle/cinn/hlir/framework/pir/trivial_op_util.cc b/paddle/cinn/hlir/framework/pir/trivial_op_util.cc index bc29c83dfbf35..faddb313a7943 100644 --- a/paddle/cinn/hlir/framework/pir/trivial_op_util.cc +++ b/paddle/cinn/hlir/framework/pir/trivial_op_util.cc @@ -56,7 +56,7 @@ std::vector VarVec2ExprVec(const std::vector& in) { std::vector GetEachTensorLoadExpr(const ir::Expr& body, const ir::Tensor& tensor) { VLOG(4) << "GetEachTensorLoadExpr: " << tensor; - std::set load_exprs = cinn::ir::ir_utils::CollectIRNodesWithoutTensor( + std::vector load_exprs = cinn::ir::ir_utils::CollectIRNodesInOrder( body, [&tensor](const Expr* expr) { return expr->As() && expr->As()->is_addr_tensor() && expr->As()->tensor.as_tensor_ref()->name == @@ -65,7 +65,7 @@ std::vector GetEachTensorLoadExpr(const ir::Expr& body, for (auto& t : load_exprs) { VLOG(4) << "GetEachTensorLoadExpr Found: " << t << " " << t.ptr(); } - return std::vector(load_exprs.begin(), load_exprs.end()); + return load_exprs; } MappingTargetExprToDestExprMutator::MappingTargetExprToDestExprMutator( @@ -83,6 +83,16 @@ void MappingTargetExprToDestExprMutator::Visit(const ir::Load* load, Expr* op) { IRMutator::Visit(load, op); } } + +void MappingTargetExprToDestExprMutator::Visit(const ir::For* for_node, + Expr* op) { + if (for_node == source_.ptr()) { + *op = dest_; + } else { + IRMutator::Visit(for_node, op); + } +} + void MappingTargetExprToDestExprMutator::Visit(const ir::Store* store, Expr* op) { if (store == source_.ptr()) { @@ -91,6 +101,7 @@ void MappingTargetExprToDestExprMutator::Visit(const ir::Store* store, IRMutator::Visit(store, op); } } + void MappingTargetExprToDestExprMutator::Visit(const ir::Reduce* reduce, Expr* op) { if (reduce == source_.ptr()) { @@ -196,7 +207,7 @@ ExprSetFinder ExprSetFinder::operator*(ExprSetFinder x) const { std::vector res; for (const auto& r : rs) { const auto& x_res = x.f_(r); - res.insert(res.begin(), x_res.begin(), x_res.end()); + res.insert(res.end(), x_res.begin(), x_res.end()); } return res; }; @@ -246,6 +257,15 @@ ExprSetFinder ScheduleBlockRealizeNotRoot = FilterMaker( }, "ScheduleBlockRealizeNotRoot"); +ExprSetFinder ScheduleBlockRealizeIsRoot = FilterMaker( + [](const ir::Expr& e) -> bool { + return (e.As() && + e.As() + ->schedule_block.As() + ->name.find("root") != std::string::npos); + }, + "ScheduleBlockRealizeIsRoot"); + ExprSetFinder ScheduleBlockRealizeIsNotInit = FilterMaker( [](const ir::Expr& e) -> bool { return (e.As() && @@ -277,6 +297,12 @@ ExprSetFinder ChildScheduleBlockRealizes = "ChildScheduleBlockRealizes") * ScheduleBlockRealizeNotRoot; +ExprSetFinder ChildRootScheduleBlockRealizes = + Collector( + [](const ir::Expr* e) { return e->As(); }, + "ChildScheduleBlockRealizes") * + ScheduleBlockRealizeIsRoot; + ExprSetFinder IsForIterVar(const ir::Var& var) { return FilterMaker( [var = var](const ir::Expr& e) -> bool { @@ -304,7 +330,7 @@ ExprSetFinder ChildTensorLoads = Collector( ExprSetFinder ChildTensorStores = Collector( [](const ir::Expr* e) { - return e->As() && e->As()->is_addr_tensor(); + return e->As() && e->As()->is_addr_tensor(); }, "ChildTensorStores"); @@ -324,8 +350,10 @@ ExprSetFinder FindFather(const ir::Expr& root) { const auto& f = [&](const auto& child) -> ExprSet { ExprSetFinder find_child = Collector([child](const ir::Expr* e) { return *e == child; }); - const auto& father_collector = Collector( - [&](const ir::Expr* current) { return !find_child(*current).empty(); }); + const auto& father_collector = Collector([&](const ir::Expr* current) { + auto res = (*current != child) && !find_child(*current).empty(); + return res; + }); return father_collector(root); }; return ExprSetFinder(f, "FindFather"); @@ -373,6 +401,35 @@ ExprTransformer WrapForsTransformer(const std::vector& vs) { return ExprTransformer(f); } +ExprTransformer UnsqueezeForTransformer( + const ExprSetFinderUtils::ExprSetFinder& followed_finder, + const ir::Var& to_append_var) { + const auto& suqueeze_for_func = [&](const ir::Expr& e) -> ir::Expr { + auto copied_e = ir::ir_utils::IRCopy(e); + ir::Expr followed_expr = followed_finder.GetSingle(copied_e); + // (ExprSetFinderUtils::ChildFors * + // ExprSetFinderUtils::IsForIterVar(following_for_iter_var)).GetSingle(copied_e); + VLOG(6) << "UnsqueezeForTransformer: for insert after " << followed_expr; + if (followed_expr.As()) { + followed_expr.As()->body = ir::Block::Make({WrapForTransformer( + to_append_var)(followed_expr.As()->body)}); + } else if (followed_expr.As()) { + const auto& schedule_block = followed_expr.As() + ->schedule_block.As(); + schedule_block->body = + WrapForTransformer(to_append_var)(schedule_block->body); + } else { + PADDLE_THROW( + "UnsqueezeForTransformer: only support insert after a (For / " + "ScheduleBlockRealizer): %s", + followed_expr); + } + VLOG(6) << "UnsqueezeForTransformer: After changed: " << copied_e; + return copied_e; + }; + return ExprTransformer(suqueeze_for_func); +} + ExprTransformer ChangeTensorLoadTransformer(const ir::Tensor& tensor, const ir::Expr& dst_load) { const auto& f = [&](const ir::Expr& e) -> ir::Expr { @@ -420,6 +477,14 @@ ExprTransformer ChangeVarTransformer(const std::vector& target_vars, return ExprTransformer(f); } +ExprTransformer ReplaceVarTransformer(const std::vector& target_vars, + const std::vector& dest_expr) { + const auto& f = [=](const ir::Expr& e) -> ir::Expr { + return ComposeUtils::CopyedReplaceExpr(e, target_vars, dest_expr); + }; + return ExprTransformer(f); +} + bool IsReduceBool(const ir::Expr& lhs, const ir::Expr& rhs) { return lhs.type().is_bool() || rhs.type().is_bool(); } diff --git a/paddle/cinn/hlir/framework/pir/trivial_op_util.h b/paddle/cinn/hlir/framework/pir/trivial_op_util.h index f9172a4ed1167..6b474c0be9798 100644 --- a/paddle/cinn/hlir/framework/pir/trivial_op_util.h +++ b/paddle/cinn/hlir/framework/pir/trivial_op_util.h @@ -78,6 +78,7 @@ struct MappingTargetExprToDestExprMutator : public ir::IRMutator<> { void Visit(const ir::Load* load, Expr* op) override; void Visit(const ir::Store* store, Expr* op) override; void Visit(const ir::Reduce* reduce, Expr* op) override; + void Visit(const ir::For* for_node, Expr* op) override; private: ir::Expr source_; @@ -132,7 +133,7 @@ template ExprSetFinder Collector(Teller t, std::string name = "") { return ExprSetFinder( [=](const ir::Expr& x) -> ExprSet { - const auto& rs = cinn::ir::ir_utils::CollectIRNodesWithoutTensor(x, t); + const auto& rs = cinn::ir::ir_utils::CollectIRNodesInOrder(x, t); return std::vector(rs.begin(), rs.end()); }, name); @@ -170,6 +171,8 @@ extern ExprSetFinder ChildScheduleBlocks; extern ExprSetFinder ChildScheduleBlockRealizes; +extern ExprSetFinder ChildRootScheduleBlockRealizes; + extern ExprSetFinder For2Min; extern ExprSetFinder For2Max; @@ -230,6 +233,14 @@ std::vector CreateInnerBlockVars( ExprTransformer ChangeVarTransformer(const std::vector& target_vars, const std::vector& dest_vars); +ExprTransformer ReplaceVarTransformer(const std::vector& target_vars, + const std::vector& dest_exprs); + +// insert after followed_finder. only support For and ScheduleBlockRealizer +ExprTransformer UnsqueezeForTransformer( + const ExprSetFinderUtils::ExprSetFinder& followed_finder, + const ir::Var& to_append_var); + ExprTransformer SubstitudeByScheduleBlockRealize(const ir::Expr& realize); ExprTransformer WrapScheduleRealizer(const std::vector& block_vars, diff --git a/paddle/cinn/operator_fusion/backend/pattern.h b/paddle/cinn/operator_fusion/backend/pattern.h index 7006fb02b829e..323a0c5a4673c 100644 --- a/paddle/cinn/operator_fusion/backend/pattern.h +++ b/paddle/cinn/operator_fusion/backend/pattern.h @@ -39,12 +39,15 @@ using FusionOp = std::variant; template <> struct TrivialPattern { explicit TrivialPattern(const std::vector& ops, + pir::Operation* sink_op, const TrivialOp& op) - : ops_(ops), trivial_op(op) {} + : ops_(ops), sink_op(sink_op), trivial_op(op) {} std::vector ops_; + pir::Operation* sink_op; TrivialOp trivial_op; static std::string name() { return "Trivial"; } std::vector ops() const { return ops_; } + pir::Operation* sink() const { return sink_op; } }; template <> @@ -68,21 +71,4 @@ struct UnsupportPattern { static std::string name() { return "Unsupport"; } }; -template <> -struct HorizontalFusionPattern { - explicit HorizontalFusionPattern( - const std::vector>& patterns) - : patterns_(patterns) {} - std::vector> patterns_; - std::vector ops() const { - std::vector result; - for (const auto& pattern : patterns_) { - auto ops = GetOpsInPattern(pattern); - ExtendVector(&result, ops); - } - return result; - } - static std::string name() { return "HorizontalFusionPattern"; } -}; - } // namespace cinn::fusion diff --git a/paddle/cinn/operator_fusion/backend/pattern_fuser.cc b/paddle/cinn/operator_fusion/backend/pattern_fuser.cc index 61e8fd658f94a..567ee35612172 100644 --- a/paddle/cinn/operator_fusion/backend/pattern_fuser.cc +++ b/paddle/cinn/operator_fusion/backend/pattern_fuser.cc @@ -34,8 +34,8 @@ StmtPattern ConvertToStmtPattern( kind == hlir::framework::kBroadcast || kind == hlir::framework::kInjective) { CHECK(content.expr.has_value()); - return TrivialPattern({content.op}, - TrivialOp(content.expr.value())); + return TrivialPattern( + {content.op}, content.op, TrivialOp(content.expr.value())); } else { CHECK(false); return UnsupportPattern({content.op}); @@ -74,17 +74,7 @@ StmtPattern MergePatternImpl( const auto& trivial_op = cinn::hlir::framework::pir::trivial_fusion_detail::TrivalxOther_Fusion( first.trivial_op, second.trivial_op); - return TrivialPattern(ops, trivial_op); -} - -template <> -StmtPattern MergePatternImpl( - const HorizontalFusionPattern& first, - const HorizontalFusionPattern& second) { - const auto& contents = - UniqueConcatVector(GetOpsInPattern(first), - GetOpsInPattern(second)); - return HorizontalFusionPattern({first, second}); + return TrivialPattern(ops, second.sink(), trivial_op); } /// Start: Tmp Transform Operation for ReduceTree @@ -141,71 +131,241 @@ std::vector ReduceTreeTrivialTransformRecursive( } /// End: Tmp Transform Operation for reduce tree +/// +struct FusionOp2Expr { + std::vector operator()(const TrivialOp& op) { + return {op.GetFuncBody()}; + } + std::vector operator()(const ReduceOp& op) { + const auto& t_r = SplitReduceOp(op); + return {t_r.first.GetFuncBody(), t_r.second.GetFuncBody()}; + } +}; -std::vector GetFusionOpFromPattern( +std::vector GetExprFromPattern( const StmtPattern& pattern); -struct FusionOpGetter { - std::vector operator()( +ir::Expr UnSqueezeExpr(const ir::Expr& expr, + const std::vector padding_vec) { + using cinn::hlir::framework::pir::trivial_fusion_detail::AppendBound; + using cinn::hlir::framework::pir::trivial_fusion_detail::GetAllForIters; + using cinn::hlir::framework::pir::trivial_fusion_detail::ExprSetFinderUtils:: + ChildFors; + using cinn::hlir::framework::pir::trivial_fusion_detail::ExprSetFinderUtils:: + ChildRootScheduleBlockRealizes; + using cinn::hlir::framework::pir::trivial_fusion_detail::ExprSetFinderUtils:: + ChildScheduleBlockRealizes; + using cinn::hlir::framework::pir::trivial_fusion_detail::ExprSetFinderUtils:: + IsForIterVar; + using cinn::hlir::framework::pir::trivial_fusion_detail:: + ExprTransformerUtils::ReplaceVarTransformer; + using cinn::hlir::framework::pir::trivial_fusion_detail:: + ExprTransformerUtils::UnsqueezeForTransformer; + VLOG(4) << "UnSqueezeExpr: " << expr + << "\npadding vector: " << utils::Join(padding_vec, ", "); + const auto& vars_in_expr = AppendBound(GetAllForIters(expr), expr); + // get the all vars. + int counter = 0; + auto GenNextName = [&counter]() { + counter += 1; + return "expand_var_" + std::to_string(counter); + }; + std::vector vars; + int pointer = 0; + for (int i = 0; i < vars_in_expr.size() + padding_vec.size(); i++) { + if (std::find(padding_vec.begin(), padding_vec.end(), i) != + padding_vec.end()) { + vars.emplace_back(Expr(0), Expr(1), GenNextName()); + } else { + vars.push_back(vars_in_expr[pointer++]); + } + } + // update the is_reduce of expand_var. + for (int i : padding_vec) { + if (i == 0) { + vars[i]->is_reduce_axis = false; + } else { + vars[i]->is_reduce_axis = vars[i - 1]->is_reduce_axis; + } + } + + // sequencely unsqueeze the ir::Expr. + ir::Expr result = expr; + for (int i : padding_vec) { + if (i > 0) { + result = UnsqueezeForTransformer((ChildFors * IsForIterVar(vars[i - 1])), + vars[i])(result); + } else { + result = UnsqueezeForTransformer(ChildRootScheduleBlockRealizes, + vars[i])(result); + } + } + return result; +} + +struct IrExprGetter { + std::vector operator()( const TrivialPattern& pattern) { - return {pattern.trivial_op}; + return FusionOp2Expr()(pattern.trivial_op); } - std::vector operator()(const ReducePattern& pattern) { - return {pattern.reduce_op}; + std::vector operator()(const ReducePattern& pattern) { + return FusionOp2Expr()(pattern.reduce_op); } - std::vector operator()( + std::vector operator()( const ReduceTreePattern& pattern) { - return ReduceTransformRecursive(pattern.GetRootPattern().reduce_op, - pattern); + const auto& fusion_op = + ReduceTransformRecursive(pattern.GetRootPattern().reduce_op, pattern); + std::function(const FusionOp& f)> func = + [](const FusionOp& op) { return std::visit(FusionOp2Expr(), op); }; + return VectorFlatMap(fusion_op, func); } - std::vector operator()( + std::vector operator()( const ReduceTreePlusTrivialPattern& pattern) { - return ReduceTreeTrivialTransformRecursive(pattern.sink_trivial.trivial_op, - pattern); + std::function(const FusionOp& f)> func = + [](const FusionOp& op) { return std::visit(FusionOp2Expr(), op); }; + const auto& fusion_ops = ReduceTreeTrivialTransformRecursive( + pattern.sink_trivial.trivial_op, pattern); + return VectorFlatMap(fusion_ops, func); } - std::vector operator()( + std::vector operator()( const HorizontalFusionPattern& pattern) { - std::vector result; - for (const auto& sub_pattern : pattern.patterns_) { - result = ConcatVector(result, GetFusionOpFromPattern(sub_pattern)); + std::vector result; + VLOG(4) << "Get Fusion Ops from HorizontalFusionPattern: " + << pattern.padding_patterns_.size(); + for (const auto& sub_pattern : pattern.padding_patterns_) { + std::function func = + [&sub_pattern](const ir::Expr& expr) { + return UnSqueezeExpr(expr, sub_pattern.padding_pos); + }; + result = ConcatVector( + result, MapVector(GetExprFromPattern(sub_pattern.pattern), func)); } return result; } - std::vector operator()( + std::vector operator()( const UnsupportPattern& pattern) { CHECK(false) << "Not Implemented."; } }; // tmp transform for reduce_tree and reduce_tree_trivial. -std::vector GetFusionOpFromPattern( - const StmtPattern& pattern) { - return std::visit(FusionOpGetter(), pattern.variant()); +std::vector GetOutputTensors(const ir::Expr& op_expr) { + using cinn::hlir::framework::pir::trivial_fusion_detail::ExprSetFinderUtils:: + ChildScheduleBlockRealizes; + using cinn::hlir::framework::pir::trivial_fusion_detail::ExprSetFinderUtils:: + ChildTensorStores; + using cinn::hlir::framework::pir::trivial_fusion_detail::ExprSetFinderUtils:: + ScheduleBlockRealizeIsNotInit; + const auto& tensors = + (ChildScheduleBlockRealizes * ScheduleBlockRealizeIsNotInit * + ChildTensorStores)(op_expr); + std::function func = [](const ir::Expr& expr) { + return expr.As()->tensor.as_tensor_ref(); + }; + return MapVector(tensors, func); } -struct FusionOp2Expr { - std::vector operator()(const TrivialOp& op) { - return {op.GetFuncBody()}; +std::vector GetInputTensors(const ir::Expr& op_expr) { + using cinn::hlir::framework::pir::trivial_fusion_detail::ExprSetFinderUtils:: + ChildScheduleBlockRealizes; + using cinn::hlir::framework::pir::trivial_fusion_detail::ExprSetFinderUtils:: + ChildTensorLoads; + using cinn::hlir::framework::pir::trivial_fusion_detail::ExprSetFinderUtils:: + ScheduleBlockRealizeIsNotInit; + const auto& exprs = + (ChildScheduleBlockRealizes * ScheduleBlockRealizeIsNotInit * + ChildTensorLoads)(op_expr); + std::function func = [](const ir::Expr& expr) { + return expr.As()->tensor.as_tensor_ref(); + }; + const auto& inputs = MapVector(exprs, func); + const auto& outputs = GetOutputTensors(op_expr); + return FilterVector(inputs, [&outputs](const ir::Tensor& tensor) { + return std::find(outputs.begin(), outputs.end(), tensor) == outputs.end(); + }); +} + +std::vector TopoSort(const std::vector& op_exprs) { + // Topo Sort is important for CINN GroupSchedule. + std::map> tensor2defining_op; + std::map> tensor2used_op; + for (const auto& op : op_exprs) { + auto inputs = GetInputTensors(op); + auto outputs = GetOutputTensors(op); + if (VLOG_IS_ON(4)) { + VLOG(4) << "Ir::Expr is: \n" << op; + VLOG(4) << "Inputs: "; + for (const auto& input : inputs) { + VLOG(4) << input->name; + } + VLOG(4) << "Outputs: "; + for (const auto& output : outputs) { + VLOG(4) << output->name; + } + } + for (const auto& input : inputs) { + tensor2used_op[input].push_back(&op); + } + for (const auto& output : outputs) { + tensor2defining_op[output].push_back(&op); + } } - std::vector operator()(const ReduceOp& op) { - const auto& t_r = SplitReduceOp(op); - return {t_r.first.GetFuncBody(), t_r.second.GetFuncBody()}; + + // Collect Downstreams + std::map> op2downstreams; + std::map degrees; + for (const auto& op : op_exprs) { + degrees[&op] = 0; } -}; + for (const auto& op : op_exprs) { + auto outputs = GetOutputTensors(op); + std::vector downstreams; + for (const auto& output : outputs) { + downstreams = ConcatVector(downstreams, tensor2used_op[output]); + } + for (const auto& downstream : downstreams) { + degrees[downstream]++; + } + op2downstreams[&op] = downstreams; + } + + // Topo Sort + std::vector result; + std::queue q; + for (const auto& op : op_exprs) { + if (degrees[&op] == 0) { + q.push(&op); + } + } + while (!q.empty()) { + auto* cur = q.front(); + VLOG(4) << "Topo Sort Visit Order is:" << GetOutputTensors(*cur)[0]->name; + q.pop(); + result.push_back(cur); + for (const auto& downstream : op2downstreams[cur]) { + degrees[downstream]--; + if (degrees[downstream] == 0) { + q.push(downstream); + } + } + } + CHECK_EQ(result.size(), op_exprs.size()); + std::vector sorted_result; + for (const auto& op : result) { + sorted_result.push_back(*op); + } + return sorted_result; +} std::vector GetExprFromPattern( const StmtPattern& pattern) { - const auto& fusion_ops = GetFusionOpFromPattern(pattern); - std::vector results; - for (const auto& op : fusion_ops) { - results = ConcatVector(results, std::visit(FusionOp2Expr(), op)); - } - return results; + const auto& results = std::visit(IrExprGetter(), pattern.variant()); + return TopoSort(results); } } // namespace cinn::fusion diff --git a/paddle/cinn/operator_fusion/backend/pattern_fuser.h b/paddle/cinn/operator_fusion/backend/pattern_fuser.h index a460d67f8e02f..345972f36cab2 100644 --- a/paddle/cinn/operator_fusion/backend/pattern_fuser.h +++ b/paddle/cinn/operator_fusion/backend/pattern_fuser.h @@ -41,11 +41,6 @@ StmtPattern MergePatternImpl( const TrivialPattern& first, const TrivialPattern& second); -template <> -StmtPattern MergePatternImpl( - const HorizontalFusionPattern& first, - const HorizontalFusionPattern& second); - std::vector GetExprFromPattern( const StmtPattern& pattern); diff --git a/paddle/cinn/operator_fusion/frontend/pattern.h b/paddle/cinn/operator_fusion/frontend/pattern.h index e267483e56586..8ba6a8cccdd83 100644 --- a/paddle/cinn/operator_fusion/frontend/pattern.h +++ b/paddle/cinn/operator_fusion/frontend/pattern.h @@ -47,11 +47,14 @@ struct hash { namespace cinn::fusion { template <> struct TrivialPattern { - explicit TrivialPattern(const std::vector& ops) - : ops_(ops) {} + explicit TrivialPattern(const std::vector& ops, + pir::Operation* sink_op) + : ops_(ops), sink_op(sink_op) {} std::vector ops_; + pir::Operation* sink_op; static std::string name() { return "Trivial"; } std::vector ops() const { return ops_; } + pir::Operation* sink() const { return sink_op; } }; template <> @@ -72,21 +75,4 @@ struct UnsupportPattern { static std::string name() { return "Unsupport"; } }; -template <> -struct HorizontalFusionPattern { - explicit HorizontalFusionPattern( - const std::vector>& patterns) - : patterns_(patterns) {} - std::vector> patterns_; - std::vector ops() const { - std::vector result; - for (const auto& pattern : patterns_) { - auto ops = GetOpsInPattern(pattern); - ExtendVector(&result, ops); - } - return result; - } - static std::string name() { return "HorizontalFusionPattern"; } -}; - } // namespace cinn::fusion diff --git a/paddle/cinn/operator_fusion/frontend/pattern_fuser.cc b/paddle/cinn/operator_fusion/frontend/pattern_fuser.cc index 332230f409979..f7dca31623879 100644 --- a/paddle/cinn/operator_fusion/frontend/pattern_fuser.cc +++ b/paddle/cinn/operator_fusion/frontend/pattern_fuser.cc @@ -28,7 +28,7 @@ StmtPattern ConvertToStmtPattern( } else if (kind == hlir::framework::kElementWise || kind == hlir::framework::kBroadcast || kind == hlir::framework::kInjective) { - return TrivialPattern({content.op}); + return TrivialPattern({content.op}, content.op); } else { return UnsupportPattern({content.op}); } @@ -58,17 +58,7 @@ StmtPattern MergePatternImpl( const auto& contents = UniqueConcatVector(GetOpsInPattern(first), GetOpsInPattern(second)); - return TrivialPattern(contents); -} - -template <> -StmtPattern MergePatternImpl( - const HorizontalFusionPattern& first, - const HorizontalFusionPattern& second) { - const auto& contents = - UniqueConcatVector(GetOpsInPattern(first), - GetOpsInPattern(second)); - return HorizontalFusionPattern({first, second}); + return TrivialPattern(contents, second.sink()); } } // namespace cinn::fusion diff --git a/paddle/cinn/operator_fusion/frontend/pattern_fuser.h b/paddle/cinn/operator_fusion/frontend/pattern_fuser.h index d92b8429ad16b..cd6dfbcf8fed4 100644 --- a/paddle/cinn/operator_fusion/frontend/pattern_fuser.h +++ b/paddle/cinn/operator_fusion/frontend/pattern_fuser.h @@ -41,9 +41,4 @@ StmtPattern MergePatternImpl( const TrivialPattern& first, const TrivialPattern& second); -template <> -StmtPattern MergePatternImpl( - const HorizontalFusionPattern& first, - const HorizontalFusionPattern& second); - } // namespace cinn::fusion diff --git a/paddle/cinn/operator_fusion/group_cluster.h b/paddle/cinn/operator_fusion/group_cluster.h index 2c6f2072ad528..cf9523dc5dd6c 100644 --- a/paddle/cinn/operator_fusion/group_cluster.h +++ b/paddle/cinn/operator_fusion/group_cluster.h @@ -27,7 +27,8 @@ namespace cinn::fusion { template inline std::vector> ClusterOps( - const std::vector>& contents) { + const std::vector>& contents, + const std::vector& output_values) { std::function)> func = [](const fusion::PatternContent& content) { return content.op; }; const auto& origin_ops = fusion::MapVector(contents, func); @@ -36,7 +37,7 @@ inline std::vector> ClusterOps( VLOG(4) << "Input Group with size " << origin_ops.size() << " :\n" << fusion::OpsDebugStr(origin_ops); - std::vector outputs; + std::vector outputs = output_values; const auto& ops = [&] { std::vector ops; for (const auto& content : contents) { diff --git a/paddle/cinn/operator_fusion/pattern.h b/paddle/cinn/operator_fusion/pattern.h index 908b4a4348bfc..fd31fb81f06fb 100644 --- a/paddle/cinn/operator_fusion/pattern.h +++ b/paddle/cinn/operator_fusion/pattern.h @@ -79,9 +79,34 @@ struct ReduceTreePlusTrivialPattern { }; template -class UnsupportPattern {}; +struct StmtPattern; + +template +struct UnsupportPattern {}; + template -class HorizontalFusionPattern {}; +struct HorizontalFusionPattern { + struct PaddingStmtPattern { + StmtPattern pattern; + std::vector padding_pos; + PaddingStmtPattern(const StmtPattern& pattern, + const std::vector& padding_pos) + : pattern(pattern), padding_pos(padding_pos) {} + }; + explicit HorizontalFusionPattern( + const std::vector& patterns) + : padding_patterns_(patterns) {} + std::vector padding_patterns_; + std::vector ops() const { + std::vector result; + for (const auto& pattern : padding_patterns_) { + auto ops = GetOpsInPattern(pattern.pattern); + ExtendVector(&result, ops); + } + return result; + } + static std::string name() { return "HorizontalFusionPattern"; } +}; template using StmtPatternBase = std::variant, diff --git a/paddle/cinn/operator_fusion/pattern_fuser.h b/paddle/cinn/operator_fusion/pattern_fuser.h index 802031b6b2304..d981abf651045 100644 --- a/paddle/cinn/operator_fusion/pattern_fuser.h +++ b/paddle/cinn/operator_fusion/pattern_fuser.h @@ -55,6 +55,113 @@ std::vector GetOpsInPattern(const StmtPattern& pattern) { pattern.variant()); } +using LoopFramework = std::vector; + +// std::optional({}) means not sure. +// std::optional will cause SegmentFault, TODO: fix this. +using MaybeLoopFramework = LoopFramework; + +template +MaybeLoopFramework GetLoopFramework(const StmtPattern& pattern); + +template +struct LoopFrameworkVisitor { + MaybeLoopFramework operator()(const ReducePattern& pattern) { + pir::Operation* reduce_op = pattern.GetReduceOp(); + const auto& flatten_loops = GetDimExprsFromValue(reduce_op->result(0)); + const auto& reduce_axes = GetReduceAxisIdx(reduce_op); + const auto& reduce_loops = GatherVector( + GetDimExprsFromValue(reduce_op->operand(0).source()), reduce_axes); + return ConcatVector(flatten_loops, reduce_loops); + } + + MaybeLoopFramework operator()(const ReduceTreePattern& pattern) { + return GetLoopFramework(StmtPattern(pattern.GetRootPattern())); + } + + MaybeLoopFramework operator()(const TrivialPattern& pattern) { + pir::Operation* t_op = pattern.sink(); + const auto& dims = GetDimExprsFromValue(t_op->result(0)); + const auto& exprs = GetDimExprsFromValue(t_op->result(0)); + return exprs; + } + + MaybeLoopFramework operator()(const HorizontalFusionPattern& pattern) { + // Horizontal Fusion must have the same loop framework. + VLOG(4) << "Get horizontal fusion pattern for loop framework."; + const auto& base_exprs = GetLoopFramework( + StmtPattern(pattern.padding_patterns_.back().pattern)); + const auto& padding_vector = pattern.padding_patterns_.back().padding_pos; + std::vector exprs( + base_exprs.size() + padding_vector.size(), 1); + int pointer = 0; + for (int i = 0; i < exprs.size(); i++) { + if (std::find(padding_vector.begin(), padding_vector.end(), i) == + padding_vector.end()) { + exprs[i] = base_exprs[pointer++]; + } + } + return exprs; + } + + MaybeLoopFramework operator()( + const ReduceTreePlusTrivialPattern& pattern) { + const auto& sink_trivial = pattern.sink_trivial; + const auto& trivial_loop = + GetLoopFramework(StmtPattern(pattern.sink_trivial)); + if (pattern.fake_reduce_iter_idx.empty()) { + // we add reduce loop to the end; + int reduce_axes_len = + GetReduceAxisIdx(pattern.tree.GetRootPattern().GetReduceOp()).size(); + const auto& reduce_loop = + GetLoopFramework(StmtPattern(pattern.tree.GetRootPattern())); + return ConcatVector( + trivial_loop, + SliceVector(reduce_loop, -reduce_axes_len, reduce_loop.size())); + } else { + // we always put fake into the end to make the loop framework consistent. + const auto& non_fake = GatherVector( + trivial_loop, + ExcludeIndex(trivial_loop.size(), pattern.fake_reduce_iter_idx)); + const auto& fake = + GatherVector(trivial_loop, pattern.fake_reduce_iter_idx); + return ConcatVector(non_fake, fake); + } + } + + MaybeLoopFramework operator()(const UnsupportPattern& pattern) { + PADDLE_ENFORCE(false, "Not support GetLoopRange."); + } +}; + +template +MaybeLoopFramework GetLoopFramework(const StmtPattern& pattern) { + return std::visit(LoopFrameworkVisitor(), pattern.variant()); +} + +static MaybeLoopFramework SqueezeLoopFramework( + const MaybeLoopFramework& loop_framework) { + MaybeLoopFramework result; + for (int i = 0; i < loop_framework.size(); i++) { + if (loop_framework[i] == 1) { + continue; // skip 1 + } else { + result.push_back(loop_framework[i]); + } + } + return result; +} + +template +bool IsLoopFrameworkEqual(const StmtPattern& lhs, + const StmtPattern& rhs) { + auto lhs_loop = GetLoopFramework(lhs); + auto rhs_loop = GetLoopFramework(rhs); + VLOG(4) << "lhs loop range is:" << utils::Join(lhs_loop, ","); + VLOG(4) << "rhs loop range is:" << utils::Join(rhs_loop, ","); + return SqueezeLoopFramework(lhs_loop) == SqueezeLoopFramework(rhs_loop); +} + template bool IsReducePattern(const StmtPattern& pattern) { return std::holds_alternative>(pattern); @@ -159,6 +266,57 @@ StmtPattern MergePatternImpl(const ReduceTreePattern& upstream, return result; } +inline auto GetPaddingVector(const MaybeLoopFramework& first, + const MaybeLoopFramework& second) { + // two pointer to get the padding body. + std::vector padding_f; + std::vector padding_s; + VLOG(4) << "GetPaddingVector for: " << utils::Join(first, ",") << " vs " + << utils::Join(second, ","); + std::function RecursivePadding = + [&first, &second, &padding_f, &padding_s, &RecursivePadding]( + int pf, int ps, int padding_size) { + if (pf == first.size() && ps == second.size()) { + return; + } else if (pf == first.size()) { + PADDLE_ENFORCE(second[ps] == 1, "second[ps] must be '1' to padding."); + padding_f.push_back(padding_size); + RecursivePadding(pf, ps + 1, padding_size + 1); + } else if (ps == second.size()) { + PADDLE_ENFORCE(first[pf] == 1, "second[ps] must be '1' to padding."); + padding_s.push_back(padding_size); + RecursivePadding(pf + 1, ps, padding_size + 1); + } else if (second[ps] == first[pf]) { + RecursivePadding(pf + 1, ps + 1, padding_size + 1); + } else if (second[ps] == 1) { + padding_f.push_back(padding_size); + RecursivePadding(pf, ps + 1, padding_size + 1); + } else if (first[ps] == 1) { + padding_s.push_back(padding_size); + RecursivePadding(pf + 1, ps, padding_size + 1); + } else { + PADDLE_THROW("Padding Error."); + } + }; + RecursivePadding(0, 0, 0); + VLOG(4) << "GetPaddingVector result: " << utils::Join(padding_f, ",") + << " vs " << utils::Join(padding_s, ","); + return std::tuple(padding_f, padding_s); +} + +template +StmtPattern MergePatternImpl(const HorizontalFusionPattern& first, + const HorizontalFusionPattern& second) { + const auto& [f, s] = + GetPaddingVector(GetLoopFramework(StmtPattern(first)), + GetLoopFramework(StmtPattern(second))); + typename HorizontalFusionPattern::PaddingStmtPattern pad_first = {first, + f}; + typename HorizontalFusionPattern::PaddingStmtPattern pad_second = {second, + s}; + return HorizontalFusionPattern({pad_first, pad_second}); +} + template StmtPattern MergePatternImpl(const ReduceTreePattern& first, const TrivialPattern& second); diff --git a/paddle/cinn/operator_fusion/pattern_graph.cc b/paddle/cinn/operator_fusion/pattern_graph.cc index a8ab68cf809b3..d3a2a92f6e940 100644 --- a/paddle/cinn/operator_fusion/pattern_graph.cc +++ b/paddle/cinn/operator_fusion/pattern_graph.cc @@ -101,12 +101,15 @@ template void PatternGraph::HorizontalFusion() { GraphTransformer>, + Or>, + StmtPatternGraphMatcher>, + StmtPatternGraphMatcher>, + StmtPatternGraphMatcher>>, LiftToHorizontalFusionPatternOperation>(this); GraphTransformer, HorizontalFusionOperation>(this); } @@ -213,6 +216,7 @@ std::string PatternGraph::GraphInfo() const { for (const auto& v : all_pattern_nodes_) { ss << "\n" << v->DebugStr(); ss << "\n IsOutput: " << IsOutputNodeMatcher()(*this, v); + ss << "\n Loop Framework is: " << GetLoopFramework(v->stmt_pattern()); } ss << "\n==============================="; return ss.str(); diff --git a/paddle/cinn/operator_fusion/pattern_graph.h b/paddle/cinn/operator_fusion/pattern_graph.h index e6ba134262349..e113af683b53d 100644 --- a/paddle/cinn/operator_fusion/pattern_graph.h +++ b/paddle/cinn/operator_fusion/pattern_graph.h @@ -241,8 +241,9 @@ struct MergeTrivialPatternOperation { struct LiftToHorizontalFusionPatternOperation { template void operator()(PatternGraph* graph, PatternNodePtr node) { - node->set_stmt_pattern( - HorizontalFusionPattern({node->stmt_pattern()})); + node->set_stmt_pattern(HorizontalFusionPattern( + {typename HorizontalFusionPattern::PaddingStmtPattern( + node->stmt_pattern(), {})})); } }; @@ -251,6 +252,7 @@ struct HorizontalFusionOperation { void operator()(PatternGraph* graph, const PatternNodePtr& i, const PatternNodePtr& j) { + VLOG(4) << "Start HorizontalFusionOperation"; PADDLE_ENFORCE_EQ( GetPatternName(i->stmt_pattern()), HorizontalFusionPattern::name(), @@ -265,9 +267,13 @@ struct HorizontalFusionOperation { "The pattern of the second node should be HorizontalFusionPattern, " "but got %s.", GetPatternName(j->stmt_pattern()))); - graph->MergeNode(i, j, MergePattern); + auto merged_node = graph->MergeNode(i, j, MergePattern); + VLOG(4) << "MergeHorizontalPattern: \ni " << i->DebugStr() << "\nj " + << j->DebugStr() << "\nmerged " << merged_node->DebugStr(); graph->RemoveNode(i); graph->RemoveNode(j); + VLOG(4) << "After HorizontalFusionOperation, Graph is" + << graph->GraphInfo(); } }; @@ -322,29 +328,25 @@ struct CanFuseReduceTreeAndTrivialMatcher { } }; +template struct HorizontalFusionConstrain { - template bool operator()(const PatternGraph& graph, - const PatternNodePtr& first, - const PatternNodePtr& second) { - if (!StmtPatternGraphMatcher>()(graph, first)) { + const PatternNodePtr& lhs, + const PatternNodePtr& rhs) { + if (!StmtPatternGraphMatcher>()(graph, lhs)) { return false; } - if (!StmtPatternGraphMatcher>()(graph, second)) { + if (!StmtPatternGraphMatcher>()(graph, rhs)) { return false; } - const auto& first_dim = first->sink_op() - ->result(0) - .type() - .template dyn_cast() - .dims(); - const auto& second_dim = second->sink_op() - ->result(0) - .type() - .template dyn_cast() - .dims(); - return graph.topo_manager().CanFuse(first, second) && - first_dim == second_dim; + const auto& lhs_pattern = + std::get>(lhs->stmt_pattern()); + const auto& rhs_pattern = + std::get>(rhs->stmt_pattern()); + + return graph.topo_manager().CanFuse(lhs, rhs) && + IsLoopFrameworkEqual(lhs_pattern.padding_patterns_.back().pattern, + rhs_pattern.padding_patterns_.back().pattern); } }; @@ -387,11 +389,22 @@ struct And { } }; -template -struct Or { +template +struct Or {}; + +template +struct Or { + template + bool operator()(const PatternGraph& graph, const PatternNodePtr& node) { + return A()(graph, node); + } +}; + +template +struct Or { template bool operator()(const PatternGraph& graph, const PatternNodePtr& node) { - return A()(graph, node) || B()(graph, node); + return A()(graph, node) || Or()(graph, node); } }; diff --git a/paddle/cinn/operator_fusion/policy/dim_relation.h b/paddle/cinn/operator_fusion/policy/dim_relation.h index 02a7036109b1c..730e9296f913b 100644 --- a/paddle/cinn/operator_fusion/policy/dim_relation.h +++ b/paddle/cinn/operator_fusion/policy/dim_relation.h @@ -73,7 +73,11 @@ struct DimUsage { oss << ", Index: " << idx_; oss << ", UsageIdx: " << usage_idx_; oss << ", "; - v_.defining_op()->Print(oss); + if (v_.defining_op()) { + v_.defining_op()->Print(oss); + } else { + oss << "No defining op."; + } return oss.str(); } }; diff --git a/paddle/cinn/operator_fusion/utils.h b/paddle/cinn/operator_fusion/utils.h index 5f364f94c1ca1..cc19499b4f17a 100644 --- a/paddle/cinn/operator_fusion/utils.h +++ b/paddle/cinn/operator_fusion/utils.h @@ -199,6 +199,115 @@ std::vector UniqueConcatVector(const std::vector& first, return result; } +struct ValueDim { + pir::Value v_; + size_t idx_; + std::weak_ptr shape_analysis_; + ValueDim(pir::Value v, size_t idx) : v_(v), idx_(idx) { + // Just get a related op to get the shape analysis. It can be value's + // upstream op (defining op) or downstream op (user op). + const auto GetRelatedOpFromValue = + [](const pir::Value& v) -> pir::Operation* { + if (v.defining_op() != nullptr) { + return v.defining_op(); + } + // For inputs of the program, the defining_op is nullptr, we use it's user + // as the related op. + PADDLE_ENFORCE_EQ(v.use_empty(), + false, + phi::errors::PreconditionNotMet( + "Value is an input value, it should have a use.")); + return v.first_use().owner(); + }; + shape_analysis_ = pir::ShapeAnalysisManager::Instance() + .Get(GetRelatedOpFromValue(v)->GetParentProgram()) + .shared_from_this(); + } + ValueDim() = default; + ValueDim(const ValueDim& v) = default; + bool operator==(const ValueDim& v) const { + return (idx_ == v.idx_) && (v_ == v.v_); + } + + symbol::DimExpr GetSymbolicDim() const { + PADDLE_ENFORCE_NOT_NULL(v_.impl(), "Empty value is not expected."); + return shape_analysis().GetProductDimExpr(v_, {static_cast(idx_)}); + } + + bool SymbolicEqualTo(const ValueDim& other) const { + return shape_analysis().IsEqual(GetSymbolicDim(), other.GetSymbolicDim()); + } + + std::string DebugStr() const { + std::ostringstream oss; + oss << "ValueDim: "; + oss << "Index: " << idx_; + oss << ", "; + v_.defining_op()->Print(oss); + return oss.str(); + } + + pir::ShapeConstraintIRAnalysis& shape_analysis() const { + auto shape_analysis_ptr = shape_analysis_.lock(); + PADDLE_ENFORCE_NOT_NULL( + shape_analysis_ptr, + ::common::errors::PreconditionNotMet("shape_analysis_ptr is nullptr.")); + return *shape_analysis_ptr; + } +}; + +static std::vector GetAllValueDimFromValue(const pir::Value& v) { + std::vector value_dims; + size_t rank = GetCompitableRank(v); + for (size_t i = 0; i < rank; ++i) { + value_dims.emplace_back(v, i); + } + return value_dims; +} + +struct ValueDimHash { + std::size_t operator()(const ValueDim& p) const { + auto h1 = std::hash{}(p.idx_); + auto h2 = std::hash{}(p.v_); + // Mainly for demonstration purposes, i.e. works but is overly simple + // In the real world, use sth. like boost.hash_combine + return h1 ^ (h2 << 1); + } +}; + +static std::vector GetDimExprsFromValue(pir::Value value) { + const auto& value_dims = GetAllValueDimFromValue(value); + VLOG(4) << "Start Print:"; + std::function func = + [](const ValueDim& value_dim) { + const auto& symbolic_dim = value_dim.GetSymbolicDim(); + VLOG(4) << symbolic_dim; + return symbolic_dim; + }; + return MapVector(value_dims, func); +} + +template +std::vector GatherVector(const std::vector& inp, + std::vector gathers) { + std::vector result; + for (auto i : gathers) { + result.push_back(inp[i]); + } + return result; +} + +template +std::vector ExcludeIndex(int n, std::vector excludes) { + std::vector result; + for (int i = 0; i < n; ++i) { + if (std::find(excludes.begin(), excludes.end(), i) == excludes.end()) { + result.push_back(i); + } + } + return result; +} + template std::vector GatherVectorExcept(const std::vector& source, const std::vector& idx) { @@ -211,4 +320,29 @@ std::vector GatherVectorExcept(const std::vector& source, return result; } +template +std::vector SliceVector(const std::vector& inp, int start, int end) { + if (start < 0) { + start = inp.size() + start; + } + if (end < 0) { + end = inp.size() + end; + } + std::vector result; + for (int i = start; i < end; ++i) { + result.push_back(inp[i]); + } + return result; +} + +template +std::vector VectorFlatMap( + const std::vector& inp, + const std::function(const T&)>& func) { + std::vector result; + for (const auto& i : inp) { + result = ConcatVector(result, func(i)); + } + return result; +} } // namespace cinn::fusion diff --git a/test/cpp/pir/cinn/pir_all_path_test.cc b/test/cpp/pir/cinn/pir_all_path_test.cc index 7568855e1f71c..0131b2694ceaa 100644 --- a/test/cpp/pir/cinn/pir_all_path_test.cc +++ b/test/cpp/pir/cinn/pir_all_path_test.cc @@ -20,7 +20,6 @@ #include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" #include "paddle/cinn/hlir/dialect/operator/ir/op_dialect.h" #include "paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.h" -#include "paddle/cinn/hlir/dialect/operator/transforms/add_store_in_fusion_op_pass.h" #include "paddle/cinn/hlir/dialect/operator/transforms/cinn_group_cluster_pass.h" #include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/divide_group_op_to_fusion_op_pass.h" #include "paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/lower_cinn_fusion_op_pass.h" @@ -75,7 +74,8 @@ static void RunAndCheckResult(::pir::Program* program, pir::PassManager stage_2_pm(ctx); stage_2_pm.AddPass(cinn::dialect::ir::CreateCinnGroupClusterPass()); - stage_2_pm.AddPass(cinn::dialect::ir::CreateAddStoreInFusionOpPass()); + // stage_2_pm.AddPass(cinn::dialect::ir::CreateAddStoreInFusionOpPass()); // + // (@xiongkun) we remove yield store in new fusion strategy. stage_2_pm.AddPass(pir::CreateDeadCodeEliminationPass()); stage_2_pm.AddPass(cinn::dialect::ir::CreateLowerCinnFusionOpPass()); stage_2_pm.EnableIRPrinting(); diff --git a/test/ir/pir/cinn/inference/test_llama_postprocess.py b/test/ir/pir/cinn/inference/test_llama_postprocess.py index cfff921719f95..1600a3a794409 100644 --- a/test/ir/pir/cinn/inference/test_llama_postprocess.py +++ b/test/ir/pir/cinn/inference/test_llama_postprocess.py @@ -90,8 +90,8 @@ def prepare_data(self): self.input_ids = paddle.randint(0, 512, [1, 32], dtype="int64") def check_jit_kernel_info(self, static_fn): - utils.check_jit_kernel_number(static_fn, 7) - utils.check_jit_kernel_structure(static_fn, {utils.JIT_KERNEL_NAME: 7}) + utils.check_jit_kernel_number(static_fn, 4) + utils.check_jit_kernel_structure(static_fn, {utils.JIT_KERNEL_NAME: 4}) def eval(self, use_cinn): paddle.seed(2024) diff --git a/test/ir/pir/cinn/test_fusion_softmax_subgraph.py b/test/ir/pir/cinn/test_fusion_softmax_subgraph.py new file mode 100644 index 0000000000000..a73eca5f04458 --- /dev/null +++ b/test/ir/pir/cinn/test_fusion_softmax_subgraph.py @@ -0,0 +1,125 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest + +import numpy + +os.environ['FLAGS_cinn_new_group_scheduler'] = '1' +os.environ['FLAGS_group_schedule_tiling_first'] = '1' +os.environ['FLAGS_prim_all'] = 'true' +os.environ['FLAGS_prim_enable_dynamic'] = 'true' +os.environ['FLAGS_print_ir'] = '1' +os.environ['FLAGS_enable_pir_api'] = '1' +os.environ['FLAGS_use_cinn'] = '1' +os.environ['FLAGS_cinn_bucket_compile'] = '1' + +from utils import check_jit_kernel_number + +import paddle + +build_strategy = paddle.static.BuildStrategy() +build_strategy.build_cinn_pass = True + + +class TestFusion(unittest.TestCase): + def setUp(self): + pass + + def tearDown(self): + pass + + def compare_result( + self, dy_compute, data_init, input_spec, expect_fusion_num + ): + static_compute = paddle.jit.to_static( + full_graph=True, + build_strategy=build_strategy, + input_spec=input_spec(), + )(dy_compute) + inputs = data_init() + dy_out = dy_compute(*inputs) + st_out = static_compute(*inputs) + if isinstance(dy_out, paddle.Tensor): + numpy.testing.assert_allclose(dy_out, st_out, atol=1e-5, rtol=1e-6) + return + for d, s in zip(dy_out, st_out): + numpy.testing.assert_allclose(d, s, atol=1e-5, rtol=1e-6) + check_jit_kernel_number(static_compute, expect_fusion_num) + + def test_softmax(self): + def func(var_40, var_106): + var_109 = paddle.cast(var_40, dtype='float32') + var_110 = var_106 + var_109 + var_111 = paddle.cast(var_110, dtype='float32') + var_112 = paddle.max(var_111, keepdim=True, axis=[-1]) + var_113 = var_111 - var_112 + var_114 = paddle.exp(var_113) + var_115 = paddle.sum(var_114, keepdim=True, axis=[-1]) + var_116 = var_114 / var_115 + var_117 = paddle.cast(var_116, dtype='float32') + return var_116, var_117 + + def init(): + var_40 = paddle.rand([1, 1, 17, 17]) + var_40 = paddle.cast(var_40, 'float64') + var_106 = paddle.rand([1, 32, 17, 17]) + return var_40, var_106 + + def input_spec(): + return [ + paddle.static.InputSpec( + shape=[1, 1, 17, 17], dtype='float64', name='var_40' + ), + paddle.static.InputSpec( + shape=[1, 32, 17, 17], dtype='float32', name='var_106' + ), + ] + + self.compare_result(func, init, input_spec, 1) + + def test_horizontal_1(self): + def func(x): + ret1 = x * 2 * 2 + ret2 = paddle.reshape(ret1, [1, 1, 17, 17]) + return ret1, ret2 + + def init(): + x = paddle.rand([17, 17]) + return (x,) + + def input_spec(): + return None + + self.compare_result(func, init, input_spec, 1) + + def test_horizontal_2(self): + def func(x): + ret1 = x * 2 * 2 + ret2 = paddle.reshape(ret1, [1, 17, 1, 1, 17, 1, 1]) + return ret1, ret2 + + def init(): + x = paddle.rand([17, 17]) + return (x,) + + def input_spec(): + return None + + self.compare_result(func, init, input_spec, 1) + + +if __name__ == "__main__": + pass diff --git a/test/ir/pir/cinn/test_horizontal_fusion.py b/test/ir/pir/cinn/test_horizontal_fusion.py new file mode 100644 index 0000000000000..38a073e55323a --- /dev/null +++ b/test/ir/pir/cinn/test_horizontal_fusion.py @@ -0,0 +1,63 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import numpy as np +import utils + +import paddle + + +class HorizontalSubGraph(paddle.nn.Layer): + def __init__(self): + super().__init__() + + def forward(self, x): + tmp1 = paddle.sum(x, axis=-1) + tmp2 = paddle.sum(x * x, axis=-1) + return tmp1 * tmp2 + + +class TestHorizontalGraph(unittest.TestCase): + def setUp(self): + paddle.seed(2024) + self.prepare_data() + + def prepare_data(self): + self.x = paddle.randn([256, 128], dtype="float32") + self.x.stop_gradient = True + + def check_jit_kernel_info(self, static_fn): + utils.check_jit_kernel_number(static_fn, 1) + utils.check_jit_kernel_structure(static_fn, {utils.JIT_KERNEL_NAME: 1}) + + def eval(self, use_cinn): + net = HorizontalSubGraph() + net.eval() + net = utils.apply_to_static(net, use_cinn) + out = net(self.x) + if use_cinn: + self.check_jit_kernel_info(net.forward) + return out + + def test_eval(self): + cinn_out = self.eval(use_cinn=True) + dy_out = self.eval(use_cinn=False) + np.testing.assert_allclose( + cinn_out.numpy(), dy_out.numpy(), atol=1e-5, rtol=1e-5 + ) + + +if __name__ == '__main__': + unittest.main()