diff --git a/paddle/cinn/common/broadcast_tree.cc b/paddle/cinn/common/broadcast_tree.cc index 4b14b41af3ae4..74ed4aff42798 100644 --- a/paddle/cinn/common/broadcast_tree.cc +++ b/paddle/cinn/common/broadcast_tree.cc @@ -120,6 +120,12 @@ using Pattern2Placement = std::unordered_map; Pattern2Placement ConstructCstrLhsEqRhsReplacement( const symbol::Broadcastable& broadcastable_condition) { auto [lhs, rhs] = *broadcastable_condition; + if (SubstituteDimExpr(rhs, Pattern2Placement{{lhs, rhs}}) != rhs) { + return Pattern2Placement{{rhs, lhs}}; + } + if (SubstituteDimExpr(lhs, Pattern2Placement{{rhs, lhs}}) != lhs) { + return Pattern2Placement{{lhs, rhs}}; + } if (rhs.isa()) return Pattern2Placement{{rhs, lhs}}; if (lhs.isa()) return Pattern2Placement{{lhs, rhs}}; return Pattern2Placement{{lhs, rhs}}; diff --git a/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc b/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc index da8cb090b7a2d..43b9d865e2ab6 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc +++ b/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc @@ -197,6 +197,19 @@ void FusionOp::Print(pir::IrPrinter& printer) { os << printer.indentation() << "}"; } +bool FusionOp::InferSymbolicShape( + ::pir::InferSymbolicShapeContext* infer_context) { + ::pir::InferSymExprForBlock(*block(), infer_context); + + for (uint32_t rst_idx = 0; rst_idx < num_results(); rst_idx++) { + auto inner_yield_value = block()->back().operand_source(rst_idx); + const auto& shape = + infer_context->GetShapeOrDataForValue(inner_yield_value); + infer_context->SetShapeOrDataForValue(result(rst_idx), shape); + } + return true; +} + void YieldStoreOp::Build(pir::Builder& builder, pir::OperationArgument& argument, pir::Value x, diff --git a/paddle/cinn/hlir/dialect/operator/ir/manual_op.h b/paddle/cinn/hlir/dialect/operator/ir/manual_op.h index 396f9929ecb35..ba43fd7e325f4 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/manual_op.h +++ b/paddle/cinn/hlir/dialect/operator/ir/manual_op.h @@ -61,7 +61,8 @@ class IR_API GroupOp // FusionOp represents a subgraphs that can be fused to one kernel. // Every GroupOp can be lowered to at least one FusionOp -class IR_API FusionOp : public pir::Op { +class IR_API FusionOp + : public pir::Op { public: using Op::Op; static const char *name() { return "cinn_op.fusion"; } @@ -81,6 +82,8 @@ class IR_API FusionOp : public pir::Op { std::vector GetOperators() const; + bool InferSymbolicShape(pir::InferSymbolicShapeContext *infer_context); + void VerifySig(); void Print(pir::IrPrinter &printer); // NOLINT }; diff --git a/paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.cc index 10cc7ae94f80b..49edddd8518ff 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.cc @@ -130,8 +130,6 @@ bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) { .dyn_cast() .data()); op->operand(0).set_source(new_full->result(0)); - shape_analysis.SetShapeOrDataForValue( - new_full.result(0), symbol::TensorShapeOrDataDimExprs(out_dim)); } else { auto new_transpose_op = rewriter->Build( op->operand_source(0), @@ -139,9 +137,6 @@ bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) { output_shape); op->operand(0).set_source(new_transpose_op->result(0)); - shape_analysis.SetShapeOrDataForValue( - new_transpose_op.result(0), - symbol::TensorShapeOrDataDimExprs(out_dim)); } } @@ -160,8 +155,6 @@ bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) { .data()); op->operand(1).set_source(new_full->result(0)); - shape_analysis.SetShapeOrDataForValue( - new_full.result(0), symbol::TensorShapeOrDataDimExprs(out_dim)); } else { auto new_transpose_op = rewriter->Build( op->operand_source(1), @@ -169,9 +162,6 @@ bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) { output_shape); op->operand(1).set_source(new_transpose_op->result(0)); - shape_analysis.SetShapeOrDataForValue( - new_transpose_op.result(0), - symbol::TensorShapeOrDataDimExprs(out_dim)); } } diff --git a/paddle/cinn/hlir/dialect/operator/transforms/add_store_in_fusion_op_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/add_store_in_fusion_op_pass.cc index 7e4bf74065fbb..a0a59b27c3cea 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/add_store_in_fusion_op_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/add_store_in_fusion_op_pass.cc @@ -44,10 +44,6 @@ class AddYieldStoreInFusionOpPattern op->operand_source(i), op->operand_source(i).type()); auto orignal_base = op->operand_source(i); op->operand(i).set_source(store_op.result(0)); - - shape_analysis.SetShapeOrDataForValue( - store_op.result(0), - shape_analysis.GetShapeOrDataForValue(orignal_base)); } return true; diff --git a/paddle/cinn/hlir/dialect/operator/transforms/cinn_group_cluster_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/cinn_group_cluster_pass.cc index c3bf60c601b7d..83bb4912d0121 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/cinn_group_cluster_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/cinn_group_cluster_pass.cc @@ -178,6 +178,8 @@ ::pir::GroupOpsVec CloneOps( auto& alignment_schedule_info = node.alignment_schedule_info; for (auto op : group_ops) { auto new_op = op->Clone(*ir_mapping, clone_options); + // TODO(Hongqing-work): delete this after fix bug of + // cinn_dynamic_reshape_op_pass auto& shape_analysis = pir::ShapeAnalysisManager::Instance().Get(op->GetParentProgram()); @@ -350,6 +352,8 @@ class CinnGroupClusterPattern auto new_group_op = ReplaceWithGroupOp( &rewriter, uniq_ops, node, output_values, &ir_mapping); + // TODO(Hongqing-work): delete this after fix bug of + // cinn_dynamic_reshape_op_pass auto& shape_analysis = pir::ShapeAnalysisManager::Instance().Get( group_op->GetParentProgram()); // update ir mapping diff --git a/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc index 63cc4cf04b68c..effa5573891be 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc @@ -357,8 +357,6 @@ bool ReplaceShapeOpsToGenerateShape( GetOutOfRewrittenGenerateShapeOp( shape_operand, rewriter, ShapeOrDataDimExprs4Value); if (!opt_generated_shape.has_value()) return false; - shape_analysis->SetShapeOrDataForValue( - opt_generated_shape.value(), ShapeOrDataDimExprs4Value(shape_operand)); rewriter->ReplaceAllUsesWith(shape_operand, opt_generated_shape.value()); return true; } diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/divide_group_op_to_fusion_op_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/divide_group_op_to_fusion_op_pass.cc index 8f64980baf1c8..21f5c9721f68d 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/divide_group_op_to_fusion_op_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/divide_group_op_to_fusion_op_pass.cc @@ -163,6 +163,8 @@ class GroupOpPattern : public pir::OpRewritePattern { << "fusion_yield_values already has key!"; const auto& shape_expr = shape_analysis.GetShapeOrDataForValue(vec_outs[i]); + // TODO(Hongqing-work): delete this after fix bug of + // cinn_dynamic_reshape_op_pass shape_analysis.SetShapeOrDataForValue(fusion_op.result(i), shape_expr); auto find_it = output_value2id.find(vec_outs[i]); if (find_it != output_value2id.end()) { diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/generate_shape_util.cc b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/generate_shape_util.cc index 662b7b36c37bb..30b470d42ca2a 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/generate_shape_util.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/generate_shape_util.cc @@ -90,12 +90,6 @@ std::optional InsertGenerateShapeOpToRunFirst( return std::nullopt; } -void CloneDimExprInfo(pir::Value from, - pir::Value to, - const ShapeOrDataDimExprsAccessor& ctx) { - ctx.SetShapeOrDataDimExprs(to, ctx.GetShapeOrDataDimExprs(from)); -} - void ReplaceAllUses(pir::Value from, pir::Value to) { from.ReplaceAllUsesWith(to); } @@ -119,7 +113,6 @@ bool RewriteOneGenerateShapeOpToRunFirst( std::optional new_shape = InsertGenerateShapeOpToRunFirst( &builder, block_args, op.out(), dim_exprs_accessor); if (!new_shape.has_value()) continue; - CloneDimExprInfo(op.out(), new_shape.value(), dim_exprs_accessor); ReplaceAllUses(op.out(), new_shape.value()); EraseGenerateShapeOp(op_iter, block); return true; diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/generate_shape_util.h b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/generate_shape_util.h index 93edec2c4c72d..d7f99d03493a1 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/generate_shape_util.h +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/generate_shape_util.h @@ -31,8 +31,6 @@ namespace cinn::dialect { struct ShapeOrDataDimExprsAccessor { std::function GetShapeOrDataDimExprs; - std::function - SetShapeOrDataDimExprs; }; // Returns true if at least one GenerateShapeOp rewrote. diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/move_generate_shape_ops_to_prologue_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/move_generate_shape_ops_to_prologue_pass.cc index f395a1fb3e28b..6e47c1ac9a4ec 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/move_generate_shape_ops_to_prologue_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/move_generate_shape_ops_to_prologue_pass.cc @@ -56,12 +56,7 @@ class GroupOpGenerateShapeOpsPattern .GetShapeOrDataDimExprs = [&](pir::Value value) -> const symbol::ShapeOrDataDimExprs& { return shape_analysis.GetShapeOrDataForValue(value); - }, - .SetShapeOrDataDimExprs = - [&](pir::Value value, - const symbol::ShapeOrDataDimExprs& dim_exprs) { - shape_analysis.SetShapeOrDataForValue(value, dim_exprs); - }}; + }}; return MoveGenerateShapeOpsToPrologue( ctx, group_op.block(), dim_exprs_accessor); } @@ -82,12 +77,7 @@ class MoveGenerateShapeOpsToProloguePass : public pir::Pass { .GetShapeOrDataDimExprs = [&](pir::Value value) -> const symbol::ShapeOrDataDimExprs& { return shape_analysis.GetShapeOrDataForValue(value); - }, - .SetShapeOrDataDimExprs = - [&](pir::Value value, - const symbol::ShapeOrDataDimExprs& dim_exprs) { - shape_analysis.SetShapeOrDataForValue(value, dim_exprs); - }}; + }}; MoveGenerateShapeOpsToPrologue(ctx, group_op.block(), dim_exprs_accessor); } diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/single_op_fallback_to_phi.cc b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/single_op_fallback_to_phi.cc index 85a6a9c0677a0..5e431d3d7fb00 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/single_op_fallback_to_phi.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/single_op_fallback_to_phi.cc @@ -64,9 +64,6 @@ class FusionOpPattern : public pir::OpRewritePattern { for (size_t i = 0; i < fusion_op.num_results(); ++i) { rewriter.ReplaceAllUsesWith(fusion_op.result(i), paddle_op.value()->result(i)); - shape_analysis.SetShapeOrDataForValue( - paddle_op.value()->result(i), - shape_analysis.GetShapeOrDataForValue(fusion_op.result(i))); } rewriter.EraseOp(fusion_op); diff --git a/paddle/cinn/hlir/dialect/operator/transforms/insert_broadcast_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/insert_broadcast_pass.cc index 56a2aa07d7096..13074afc4d761 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/insert_broadcast_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/insert_broadcast_pass.cc @@ -73,13 +73,11 @@ bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) { pir::Value broadcasted_x = rewriter->Build(x, output_dim_tensor).out(); op->operand(0).set_source(broadcasted_x); - shape_analysis.SetShapeOrDataForValue(broadcasted_x, out_shape); } if (y_shape.shape() != out_shape.shape()) { pir::Value broadcasted_y = rewriter->Build(y, output_dim_tensor).out(); op->operand(1).set_source(broadcasted_y); - shape_analysis.SetShapeOrDataForValue(broadcasted_y, out_shape); } return true; } diff --git a/paddle/cinn/hlir/dialect/operator/transforms/replace_dynamic_expand_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/replace_dynamic_expand_pass.cc index 3adf8cc6110ec..340b7d4ee158f 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/replace_dynamic_expand_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/replace_dynamic_expand_pass.cc @@ -69,12 +69,6 @@ class DynamicExpandOpPattern op->operand_source(0), broadcast_axes, out_shape); }(); - auto& shape_analysis = - pir::ShapeAnalysisManager::Instance().Get(op->GetParentProgram()); - shape_analysis.SetShapeOrDataForValue( - broadcast->result(0), - shape_analysis.GetShapeOrDataForValue(op.result(0))); - if (auto pre_full = broadcast->operand_source(0) .defining_op() ->dyn_cast()) { @@ -82,11 +76,6 @@ class DynamicExpandOpPattern .type() .dyn_cast() .dims(); - if (input_dim.size() == 1 && input_dim[0] == 1) { - shape_analysis.SetShapeOrDataForValue( - pre_full->result(0), - shape_analysis.GetShapeOrDataForValue(op.result(0))); - } } rewriter.ReplaceAllUsesWith(op->result(0), broadcast->result(0));