From 25db66ad4e0cd077945e007595036c728bcedc3d Mon Sep 17 00:00:00 2001 From: Hongqing-work Date: Mon, 20 May 2024 12:49:21 +0000 Subject: [PATCH 1/4] [CINN]delete redundant SetShapeOrDataForValue --- paddle/cinn/common/broadcast_tree.cc | 14 +++++- .../hlir/dialect/operator/ir/manual_op.cc | 18 ++++++++ .../cinn/hlir/dialect/operator/ir/manual_op.h | 5 +- .../add_broadcast_to_elementwise_pass.cc | 10 ---- .../transforms/add_store_in_fusion_op_pass.cc | 4 -- .../transforms/cinn_group_cluster_pass.cc | 4 ++ ...e_shape_ops_into_generate_shape_op_pass.cc | 2 - .../divide_group_op_to_fusion_op_pass.cc | 2 + .../group_merge/generate_shape_util.cc | 7 --- .../group_merge/generate_shape_util.h | 2 - ...ove_generate_shape_ops_to_prologue_pass.cc | 14 +----- .../group_merge/single_op_fallback_to_phi.cc | 3 -- .../transforms/insert_broadcast_pass.cc | 2 - .../transforms/replace_dynamic_expand_pass.cc | 11 ----- .../include/dialect/shape/utils/dim_expr.h | 3 ++ .../pir/src/dialect/shape/utils/dim_expr.cc | 46 +++++++++++++++++++ 16 files changed, 91 insertions(+), 56 deletions(-) diff --git a/paddle/cinn/common/broadcast_tree.cc b/paddle/cinn/common/broadcast_tree.cc index 4b14b41af3ae4..1bfa82927b6e9 100644 --- a/paddle/cinn/common/broadcast_tree.cc +++ b/paddle/cinn/common/broadcast_tree.cc @@ -120,8 +120,18 @@ using Pattern2Placement = std::unordered_map; Pattern2Placement ConstructCstrLhsEqRhsReplacement( const symbol::Broadcastable& broadcastable_condition) { auto [lhs, rhs] = *broadcastable_condition; - if (rhs.isa()) return Pattern2Placement{{rhs, lhs}}; - if (lhs.isa()) return Pattern2Placement{{lhs, rhs}}; + if (rhs.isa()) { + if (GetSymbolSet(lhs).count(ToString(rhs)) != 0) { + return Pattern2Placement{{lhs, rhs}}; + } + return Pattern2Placement{{rhs, lhs}}; + } + if (lhs.isa()) { + if (GetSymbolSet(rhs).count(ToString(lhs)) != 0) { + return Pattern2Placement{{rhs, lhs}}; + } + return Pattern2Placement{{lhs, rhs}}; + } return Pattern2Placement{{lhs, rhs}}; } diff --git a/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc b/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc index da8cb090b7a2d..c1ca06771988d 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc +++ b/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc @@ -197,6 +197,24 @@ void FusionOp::Print(pir::IrPrinter& printer) { os << printer.indentation() << "}"; } +bool FusionOp::InferSymbolicShape( + ::pir::InferSymbolicShapeContext* infer_context) { + ::pir::InferSymExprForBlock(*block(), infer_context); + + for (uint32_t rst_idx = 0; rst_idx < num_results(); rst_idx++) { + auto inner_yield_value = block()->back().operand_source(rst_idx); + const auto& shape = + infer_context->GetShapeOrDataForValue(inner_yield_value); + infer_context->SetShapeOrDataForValue(result(rst_idx), shape); + } + + if (VLOG_IS_ON(4)) { + ::std::cerr << ">>>>>>>>>>>>>>>>>>>> cinn_op.fusion(op_id: op_" + << block()->back().id() << ") END." << ::std::endl; + } + return true; +} + void YieldStoreOp::Build(pir::Builder& builder, pir::OperationArgument& argument, pir::Value x, diff --git a/paddle/cinn/hlir/dialect/operator/ir/manual_op.h b/paddle/cinn/hlir/dialect/operator/ir/manual_op.h index 396f9929ecb35..ba43fd7e325f4 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/manual_op.h +++ b/paddle/cinn/hlir/dialect/operator/ir/manual_op.h @@ -61,7 +61,8 @@ class IR_API GroupOp // FusionOp represents a subgraphs that can be fused to one kernel. // Every GroupOp can be lowered to at least one FusionOp -class IR_API FusionOp : public pir::Op { +class IR_API FusionOp + : public pir::Op { public: using Op::Op; static const char *name() { return "cinn_op.fusion"; } @@ -81,6 +82,8 @@ class IR_API FusionOp : public pir::Op { std::vector GetOperators() const; + bool InferSymbolicShape(pir::InferSymbolicShapeContext *infer_context); + void VerifySig(); void Print(pir::IrPrinter &printer); // NOLINT }; diff --git a/paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.cc index 10cc7ae94f80b..49edddd8518ff 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.cc @@ -130,8 +130,6 @@ bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) { .dyn_cast() .data()); op->operand(0).set_source(new_full->result(0)); - shape_analysis.SetShapeOrDataForValue( - new_full.result(0), symbol::TensorShapeOrDataDimExprs(out_dim)); } else { auto new_transpose_op = rewriter->Build( op->operand_source(0), @@ -139,9 +137,6 @@ bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) { output_shape); op->operand(0).set_source(new_transpose_op->result(0)); - shape_analysis.SetShapeOrDataForValue( - new_transpose_op.result(0), - symbol::TensorShapeOrDataDimExprs(out_dim)); } } @@ -160,8 +155,6 @@ bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) { .data()); op->operand(1).set_source(new_full->result(0)); - shape_analysis.SetShapeOrDataForValue( - new_full.result(0), symbol::TensorShapeOrDataDimExprs(out_dim)); } else { auto new_transpose_op = rewriter->Build( op->operand_source(1), @@ -169,9 +162,6 @@ bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) { output_shape); op->operand(1).set_source(new_transpose_op->result(0)); - shape_analysis.SetShapeOrDataForValue( - new_transpose_op.result(0), - symbol::TensorShapeOrDataDimExprs(out_dim)); } } diff --git a/paddle/cinn/hlir/dialect/operator/transforms/add_store_in_fusion_op_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/add_store_in_fusion_op_pass.cc index 7e4bf74065fbb..a0a59b27c3cea 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/add_store_in_fusion_op_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/add_store_in_fusion_op_pass.cc @@ -44,10 +44,6 @@ class AddYieldStoreInFusionOpPattern op->operand_source(i), op->operand_source(i).type()); auto orignal_base = op->operand_source(i); op->operand(i).set_source(store_op.result(0)); - - shape_analysis.SetShapeOrDataForValue( - store_op.result(0), - shape_analysis.GetShapeOrDataForValue(orignal_base)); } return true; diff --git a/paddle/cinn/hlir/dialect/operator/transforms/cinn_group_cluster_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/cinn_group_cluster_pass.cc index c3bf60c601b7d..83bb4912d0121 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/cinn_group_cluster_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/cinn_group_cluster_pass.cc @@ -178,6 +178,8 @@ ::pir::GroupOpsVec CloneOps( auto& alignment_schedule_info = node.alignment_schedule_info; for (auto op : group_ops) { auto new_op = op->Clone(*ir_mapping, clone_options); + // TODO(Hongqing-work): delete this after fix bug of + // cinn_dynamic_reshape_op_pass auto& shape_analysis = pir::ShapeAnalysisManager::Instance().Get(op->GetParentProgram()); @@ -350,6 +352,8 @@ class CinnGroupClusterPattern auto new_group_op = ReplaceWithGroupOp( &rewriter, uniq_ops, node, output_values, &ir_mapping); + // TODO(Hongqing-work): delete this after fix bug of + // cinn_dynamic_reshape_op_pass auto& shape_analysis = pir::ShapeAnalysisManager::Instance().Get( group_op->GetParentProgram()); // update ir mapping diff --git a/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc index 63cc4cf04b68c..effa5573891be 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc @@ -357,8 +357,6 @@ bool ReplaceShapeOpsToGenerateShape( GetOutOfRewrittenGenerateShapeOp( shape_operand, rewriter, ShapeOrDataDimExprs4Value); if (!opt_generated_shape.has_value()) return false; - shape_analysis->SetShapeOrDataForValue( - opt_generated_shape.value(), ShapeOrDataDimExprs4Value(shape_operand)); rewriter->ReplaceAllUsesWith(shape_operand, opt_generated_shape.value()); return true; } diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/divide_group_op_to_fusion_op_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/divide_group_op_to_fusion_op_pass.cc index 8f64980baf1c8..21f5c9721f68d 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/divide_group_op_to_fusion_op_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/divide_group_op_to_fusion_op_pass.cc @@ -163,6 +163,8 @@ class GroupOpPattern : public pir::OpRewritePattern { << "fusion_yield_values already has key!"; const auto& shape_expr = shape_analysis.GetShapeOrDataForValue(vec_outs[i]); + // TODO(Hongqing-work): delete this after fix bug of + // cinn_dynamic_reshape_op_pass shape_analysis.SetShapeOrDataForValue(fusion_op.result(i), shape_expr); auto find_it = output_value2id.find(vec_outs[i]); if (find_it != output_value2id.end()) { diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/generate_shape_util.cc b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/generate_shape_util.cc index 662b7b36c37bb..30b470d42ca2a 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/generate_shape_util.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/generate_shape_util.cc @@ -90,12 +90,6 @@ std::optional InsertGenerateShapeOpToRunFirst( return std::nullopt; } -void CloneDimExprInfo(pir::Value from, - pir::Value to, - const ShapeOrDataDimExprsAccessor& ctx) { - ctx.SetShapeOrDataDimExprs(to, ctx.GetShapeOrDataDimExprs(from)); -} - void ReplaceAllUses(pir::Value from, pir::Value to) { from.ReplaceAllUsesWith(to); } @@ -119,7 +113,6 @@ bool RewriteOneGenerateShapeOpToRunFirst( std::optional new_shape = InsertGenerateShapeOpToRunFirst( &builder, block_args, op.out(), dim_exprs_accessor); if (!new_shape.has_value()) continue; - CloneDimExprInfo(op.out(), new_shape.value(), dim_exprs_accessor); ReplaceAllUses(op.out(), new_shape.value()); EraseGenerateShapeOp(op_iter, block); return true; diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/generate_shape_util.h b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/generate_shape_util.h index 93edec2c4c72d..d7f99d03493a1 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/generate_shape_util.h +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/generate_shape_util.h @@ -31,8 +31,6 @@ namespace cinn::dialect { struct ShapeOrDataDimExprsAccessor { std::function GetShapeOrDataDimExprs; - std::function - SetShapeOrDataDimExprs; }; // Returns true if at least one GenerateShapeOp rewrote. diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/move_generate_shape_ops_to_prologue_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/move_generate_shape_ops_to_prologue_pass.cc index f395a1fb3e28b..6e47c1ac9a4ec 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/move_generate_shape_ops_to_prologue_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/move_generate_shape_ops_to_prologue_pass.cc @@ -56,12 +56,7 @@ class GroupOpGenerateShapeOpsPattern .GetShapeOrDataDimExprs = [&](pir::Value value) -> const symbol::ShapeOrDataDimExprs& { return shape_analysis.GetShapeOrDataForValue(value); - }, - .SetShapeOrDataDimExprs = - [&](pir::Value value, - const symbol::ShapeOrDataDimExprs& dim_exprs) { - shape_analysis.SetShapeOrDataForValue(value, dim_exprs); - }}; + }}; return MoveGenerateShapeOpsToPrologue( ctx, group_op.block(), dim_exprs_accessor); } @@ -82,12 +77,7 @@ class MoveGenerateShapeOpsToProloguePass : public pir::Pass { .GetShapeOrDataDimExprs = [&](pir::Value value) -> const symbol::ShapeOrDataDimExprs& { return shape_analysis.GetShapeOrDataForValue(value); - }, - .SetShapeOrDataDimExprs = - [&](pir::Value value, - const symbol::ShapeOrDataDimExprs& dim_exprs) { - shape_analysis.SetShapeOrDataForValue(value, dim_exprs); - }}; + }}; MoveGenerateShapeOpsToPrologue(ctx, group_op.block(), dim_exprs_accessor); } diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/single_op_fallback_to_phi.cc b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/single_op_fallback_to_phi.cc index 85a6a9c0677a0..5e431d3d7fb00 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/single_op_fallback_to_phi.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/single_op_fallback_to_phi.cc @@ -64,9 +64,6 @@ class FusionOpPattern : public pir::OpRewritePattern { for (size_t i = 0; i < fusion_op.num_results(); ++i) { rewriter.ReplaceAllUsesWith(fusion_op.result(i), paddle_op.value()->result(i)); - shape_analysis.SetShapeOrDataForValue( - paddle_op.value()->result(i), - shape_analysis.GetShapeOrDataForValue(fusion_op.result(i))); } rewriter.EraseOp(fusion_op); diff --git a/paddle/cinn/hlir/dialect/operator/transforms/insert_broadcast_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/insert_broadcast_pass.cc index 56a2aa07d7096..13074afc4d761 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/insert_broadcast_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/insert_broadcast_pass.cc @@ -73,13 +73,11 @@ bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) { pir::Value broadcasted_x = rewriter->Build(x, output_dim_tensor).out(); op->operand(0).set_source(broadcasted_x); - shape_analysis.SetShapeOrDataForValue(broadcasted_x, out_shape); } if (y_shape.shape() != out_shape.shape()) { pir::Value broadcasted_y = rewriter->Build(y, output_dim_tensor).out(); op->operand(1).set_source(broadcasted_y); - shape_analysis.SetShapeOrDataForValue(broadcasted_y, out_shape); } return true; } diff --git a/paddle/cinn/hlir/dialect/operator/transforms/replace_dynamic_expand_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/replace_dynamic_expand_pass.cc index 3adf8cc6110ec..340b7d4ee158f 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/replace_dynamic_expand_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/replace_dynamic_expand_pass.cc @@ -69,12 +69,6 @@ class DynamicExpandOpPattern op->operand_source(0), broadcast_axes, out_shape); }(); - auto& shape_analysis = - pir::ShapeAnalysisManager::Instance().Get(op->GetParentProgram()); - shape_analysis.SetShapeOrDataForValue( - broadcast->result(0), - shape_analysis.GetShapeOrDataForValue(op.result(0))); - if (auto pre_full = broadcast->operand_source(0) .defining_op() ->dyn_cast()) { @@ -82,11 +76,6 @@ class DynamicExpandOpPattern .type() .dyn_cast() .dims(); - if (input_dim.size() == 1 && input_dim[0] == 1) { - shape_analysis.SetShapeOrDataForValue( - pre_full->result(0), - shape_analysis.GetShapeOrDataForValue(op.result(0))); - } } rewriter.ReplaceAllUsesWith(op->result(0), broadcast->result(0)); diff --git a/paddle/pir/include/dialect/shape/utils/dim_expr.h b/paddle/pir/include/dialect/shape/utils/dim_expr.h index a45ba01538ae7..83bd1cd28e991 100644 --- a/paddle/pir/include/dialect/shape/utils/dim_expr.h +++ b/paddle/pir/include/dialect/shape/utils/dim_expr.h @@ -19,6 +19,7 @@ #include #include #include +#include #include #include @@ -223,6 +224,8 @@ using DimExprConstraint = std::variant, Broadcastable>; IR_API std::string ToString(const DimExpr& dim_expr); +IR_API std::unordered_set GetSymbolSet(const DimExpr& dim_expr); + IR_API std::ostream& operator<<(std::ostream&, const DimExpr& dim_expr); IR_API std::ostream& operator<<(std::ostream&, diff --git a/paddle/pir/src/dialect/shape/utils/dim_expr.cc b/paddle/pir/src/dialect/shape/utils/dim_expr.cc index 0b4e041cd6b47..0218eb5ef0955 100644 --- a/paddle/pir/src/dialect/shape/utils/dim_expr.cc +++ b/paddle/pir/src/dialect/shape/utils/dim_expr.cc @@ -142,8 +142,54 @@ std::string ListDimExprToString(const List& dim_exprs, } return ret; } +std::unordered_set ListDimExprGetSymbolSet( + const List& dim_exprs) { + std::unordered_set vars; + for (std::size_t i = 0; i < dim_exprs->size(); ++i) { + const auto& inner_vars = GetSymbolSet(dim_exprs->at(i)); + vars.insert(inner_vars.begin(), inner_vars.end()); + } + return vars; +} } // namespace +std::unordered_set GetSymbolSet(const DimExpr& dim_expr) { + std::unordered_set vars; + auto lambdas = common::Overloaded{ + [&](std::int64_t dim_expr) { return; }, + [&](const std::string& dim_expr) { vars.insert(dim_expr); }, + [&](const Negative& dim_expr) { + const auto& inner_vars = GetSymbolSet(dim_expr->data); + vars.insert(inner_vars.begin(), inner_vars.end()); + }, + [&](const Reciprocal& dim_expr) { + const auto& inner_vars = GetSymbolSet(dim_expr->data); + vars.insert(inner_vars.begin(), inner_vars.end()); + }, + [&](const Add& dim_expr) { + const auto& inner_vars = ListDimExprGetSymbolSet(dim_expr.operands); + vars.insert(inner_vars.begin(), inner_vars.end()); + }, + [&](const Mul& dim_expr) { + const auto& inner_vars = ListDimExprGetSymbolSet(dim_expr.operands); + vars.insert(inner_vars.begin(), inner_vars.end()); + }, + [&](const Max& dim_expr) { + const auto& inner_vars = ListDimExprGetSymbolSet(dim_expr.operands); + vars.insert(inner_vars.begin(), inner_vars.end()); + }, + [&](const Min& dim_expr) { + const auto& inner_vars = ListDimExprGetSymbolSet(dim_expr.operands); + vars.insert(inner_vars.begin(), inner_vars.end()); + }, + [&](const Broadcast& dim_expr) { + const auto& inner_vars = ListDimExprGetSymbolSet(dim_expr.operands); + vars.insert(inner_vars.begin(), inner_vars.end()); + }}; + std::visit(lambdas, dim_expr.variant()); + return vars; +} + std::string ToString(const DimExpr& dim_expr) { auto lambdas = common::Overloaded{ [](std::int64_t dim_expr) { return std::to_string(dim_expr); }, From e4145f65c17e41f652dbaa8f67d052b3816e245e Mon Sep 17 00:00:00 2001 From: Hongqing-work Date: Tue, 21 May 2024 04:44:28 +0000 Subject: [PATCH 2/4] update broadcast_tree Substitute policy --- paddle/cinn/common/broadcast_tree.cc | 12 ++--- .../hlir/dialect/operator/ir/manual_op.cc | 5 -- .../include/dialect/shape/utils/dim_expr.h | 3 -- .../pir/src/dialect/shape/utils/dim_expr.cc | 46 ------------------- 4 files changed, 4 insertions(+), 62 deletions(-) diff --git a/paddle/cinn/common/broadcast_tree.cc b/paddle/cinn/common/broadcast_tree.cc index 1bfa82927b6e9..001c9bc23ed0a 100644 --- a/paddle/cinn/common/broadcast_tree.cc +++ b/paddle/cinn/common/broadcast_tree.cc @@ -120,18 +120,14 @@ using Pattern2Placement = std::unordered_map; Pattern2Placement ConstructCstrLhsEqRhsReplacement( const symbol::Broadcastable& broadcastable_condition) { auto [lhs, rhs] = *broadcastable_condition; - if (rhs.isa()) { - if (GetSymbolSet(lhs).count(ToString(rhs)) != 0) { - return Pattern2Placement{{lhs, rhs}}; - } + if (SubstituteDimExpr(rhs, {lhs, rhs}) != rhs) { return Pattern2Placement{{rhs, lhs}}; } - if (lhs.isa()) { - if (GetSymbolSet(rhs).count(ToString(lhs)) != 0) { - return Pattern2Placement{{rhs, lhs}}; - } + if (SubstituteDimExpr(lhs, {rhs, lhs}) != lhs) { return Pattern2Placement{{lhs, rhs}}; } + if (rhs.isa()) return Pattern2Placement{{rhs, lhs}}; + if (lhs.isa()) return Pattern2Placement{{lhs, rhs}}; return Pattern2Placement{{lhs, rhs}}; } diff --git a/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc b/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc index c1ca06771988d..43b9d865e2ab6 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc +++ b/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc @@ -207,11 +207,6 @@ bool FusionOp::InferSymbolicShape( infer_context->GetShapeOrDataForValue(inner_yield_value); infer_context->SetShapeOrDataForValue(result(rst_idx), shape); } - - if (VLOG_IS_ON(4)) { - ::std::cerr << ">>>>>>>>>>>>>>>>>>>> cinn_op.fusion(op_id: op_" - << block()->back().id() << ") END." << ::std::endl; - } return true; } diff --git a/paddle/pir/include/dialect/shape/utils/dim_expr.h b/paddle/pir/include/dialect/shape/utils/dim_expr.h index 83bd1cd28e991..a45ba01538ae7 100644 --- a/paddle/pir/include/dialect/shape/utils/dim_expr.h +++ b/paddle/pir/include/dialect/shape/utils/dim_expr.h @@ -19,7 +19,6 @@ #include #include #include -#include #include #include @@ -224,8 +223,6 @@ using DimExprConstraint = std::variant, Broadcastable>; IR_API std::string ToString(const DimExpr& dim_expr); -IR_API std::unordered_set GetSymbolSet(const DimExpr& dim_expr); - IR_API std::ostream& operator<<(std::ostream&, const DimExpr& dim_expr); IR_API std::ostream& operator<<(std::ostream&, diff --git a/paddle/pir/src/dialect/shape/utils/dim_expr.cc b/paddle/pir/src/dialect/shape/utils/dim_expr.cc index 0218eb5ef0955..0b4e041cd6b47 100644 --- a/paddle/pir/src/dialect/shape/utils/dim_expr.cc +++ b/paddle/pir/src/dialect/shape/utils/dim_expr.cc @@ -142,54 +142,8 @@ std::string ListDimExprToString(const List& dim_exprs, } return ret; } -std::unordered_set ListDimExprGetSymbolSet( - const List& dim_exprs) { - std::unordered_set vars; - for (std::size_t i = 0; i < dim_exprs->size(); ++i) { - const auto& inner_vars = GetSymbolSet(dim_exprs->at(i)); - vars.insert(inner_vars.begin(), inner_vars.end()); - } - return vars; -} } // namespace -std::unordered_set GetSymbolSet(const DimExpr& dim_expr) { - std::unordered_set vars; - auto lambdas = common::Overloaded{ - [&](std::int64_t dim_expr) { return; }, - [&](const std::string& dim_expr) { vars.insert(dim_expr); }, - [&](const Negative& dim_expr) { - const auto& inner_vars = GetSymbolSet(dim_expr->data); - vars.insert(inner_vars.begin(), inner_vars.end()); - }, - [&](const Reciprocal& dim_expr) { - const auto& inner_vars = GetSymbolSet(dim_expr->data); - vars.insert(inner_vars.begin(), inner_vars.end()); - }, - [&](const Add& dim_expr) { - const auto& inner_vars = ListDimExprGetSymbolSet(dim_expr.operands); - vars.insert(inner_vars.begin(), inner_vars.end()); - }, - [&](const Mul& dim_expr) { - const auto& inner_vars = ListDimExprGetSymbolSet(dim_expr.operands); - vars.insert(inner_vars.begin(), inner_vars.end()); - }, - [&](const Max& dim_expr) { - const auto& inner_vars = ListDimExprGetSymbolSet(dim_expr.operands); - vars.insert(inner_vars.begin(), inner_vars.end()); - }, - [&](const Min& dim_expr) { - const auto& inner_vars = ListDimExprGetSymbolSet(dim_expr.operands); - vars.insert(inner_vars.begin(), inner_vars.end()); - }, - [&](const Broadcast& dim_expr) { - const auto& inner_vars = ListDimExprGetSymbolSet(dim_expr.operands); - vars.insert(inner_vars.begin(), inner_vars.end()); - }}; - std::visit(lambdas, dim_expr.variant()); - return vars; -} - std::string ToString(const DimExpr& dim_expr) { auto lambdas = common::Overloaded{ [](std::int64_t dim_expr) { return std::to_string(dim_expr); }, From 3669f1e2cae3ca1b3f8a6b770bed9d852a58310d Mon Sep 17 00:00:00 2001 From: Hongqing-work Date: Tue, 21 May 2024 07:26:57 +0000 Subject: [PATCH 3/4] fix --- paddle/cinn/common/broadcast_tree.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/cinn/common/broadcast_tree.cc b/paddle/cinn/common/broadcast_tree.cc index 001c9bc23ed0a..be8af13fc47d3 100644 --- a/paddle/cinn/common/broadcast_tree.cc +++ b/paddle/cinn/common/broadcast_tree.cc @@ -120,10 +120,10 @@ using Pattern2Placement = std::unordered_map; Pattern2Placement ConstructCstrLhsEqRhsReplacement( const symbol::Broadcastable& broadcastable_condition) { auto [lhs, rhs] = *broadcastable_condition; - if (SubstituteDimExpr(rhs, {lhs, rhs}) != rhs) { + if (SubstituteDimExpr(rhs, Pattern2Placement{{rhs, lhs}}) != rhs) { return Pattern2Placement{{rhs, lhs}}; } - if (SubstituteDimExpr(lhs, {rhs, lhs}) != lhs) { + if (SubstituteDimExpr(lhs, Pattern2Placement{{rhs, lhs}}) != lhs) { return Pattern2Placement{{lhs, rhs}}; } if (rhs.isa()) return Pattern2Placement{{rhs, lhs}}; From 494af5a477db4b69fca5723ba07c7ae8da01e8e2 Mon Sep 17 00:00:00 2001 From: Hongqing-work Date: Wed, 22 May 2024 02:37:43 +0000 Subject: [PATCH 4/4] fix --- paddle/cinn/common/broadcast_tree.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/cinn/common/broadcast_tree.cc b/paddle/cinn/common/broadcast_tree.cc index be8af13fc47d3..74ed4aff42798 100644 --- a/paddle/cinn/common/broadcast_tree.cc +++ b/paddle/cinn/common/broadcast_tree.cc @@ -120,7 +120,7 @@ using Pattern2Placement = std::unordered_map; Pattern2Placement ConstructCstrLhsEqRhsReplacement( const symbol::Broadcastable& broadcastable_condition) { auto [lhs, rhs] = *broadcastable_condition; - if (SubstituteDimExpr(rhs, Pattern2Placement{{rhs, lhs}}) != rhs) { + if (SubstituteDimExpr(rhs, Pattern2Placement{{lhs, rhs}}) != rhs) { return Pattern2Placement{{rhs, lhs}}; } if (SubstituteDimExpr(lhs, Pattern2Placement{{rhs, lhs}}) != lhs) {