diff --git a/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc index 9a9057e993be7..439bba580a6c0 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc @@ -139,9 +139,8 @@ bool ProcessOp(paddle::dialect::ExpandOp op, pir::PatternRewriter* rewriter) { pir::ShapeConstraintIRAnalysis& shape_analysis = pir::ShapeAnalysisManager::Instance().Get( op.x().defining_op()->GetParentProgram()); - CHECK(shape_analysis.value_id_to_shapeordata_.find(GetValueId(&value)) != - shape_analysis.value_id_to_shapeordata_.end()); - return shape_analysis.value_id_to_shapeordata_.at(GetValueId(&value)); + + return shape_analysis.GetShapeOrDataForValue(value); }; std::optional opt_generated_shape = GetOutOfRewritedGenerateShapeOp( diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc index cd324b5f05c69..5d4cc10b205ba 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc @@ -16,6 +16,7 @@ #include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" #include "paddle/pir/core/builtin_attribute.h" #include "paddle/pir/core/builtin_type.h" +#include "paddle/pir/core/builtin_type_interfaces.h" #include "paddle/pir/dialect/shape/ir/shape_attribute.h" namespace paddle::dialect { @@ -33,27 +34,25 @@ bool SameOperandsAndResultShape( pir::Value operand_source = op->operand_source(0); symbol::ShapeOrDataDimExprs operand_shape_or_data = - shape_analysis->value_to_shape_or_data_[operand_source]; + shape_analysis->GetShapeOrDataForValue(operand_source); op->set_attribute("symbolic_shape", pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), operand_shape_or_data)); pir::OpResult res = op->result(0); - shape_analysis->value_to_shape_or_data_[res] = operand_shape_or_data; + shape_analysis->SetShapeOrDataForValue(res, operand_shape_or_data); return true; } bool InferSymbolicShapeElementWiseBinary( pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { pir::Value operand_source_0 = op->operand_source(0); - std::string operand_source_0_id = pir::GetValueId(&operand_source_0); std::vector shape_0{ - shape_analysis->value_id_to_shapeordata_[operand_source_0_id].shape()}; + shape_analysis->GetShapeOrDataForValue(operand_source_0).shape()}; pir::Value operand_source_1 = op->operand_source(1); - std::string operand_source_1_id = pir::GetValueId(&operand_source_1); std::vector shape_1{ - shape_analysis->value_id_to_shapeordata_[operand_source_1_id].shape()}; + shape_analysis->GetShapeOrDataForValue(operand_source_1).shape()}; if (shape_0.size() > shape_1.size()) { for (size_t i = 0; i < shape_0.size() - shape_1.size(); i++) { @@ -75,9 +74,11 @@ bool InferSymbolicShapeElementWiseBinary( std::vector data; pir::OpResult res = op->result(0); - std::string res_id = pir::GetValueId(&res); symbol::ShapeOrDataDimExprs shape_data{shapes, data}; - shape_analysis->value_id_to_shapeordata_[res_id] = shape_data; + shape_analysis->SetShapeOrDataForValue(res, shape_data); + op->set_attribute( + "symbolic_shape", + pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data)); return true; } @@ -104,7 +105,7 @@ bool DataOpInferSymbolicShape(pir::Operation *op, std::vector sym_dims; for (auto dim : dims) { symbol::DimExpr dim_expr; - if (dim == -1) { + if (dim == pir::ShapedTypeInterface::kDynamic) { symbol::DimExpr symbolic_dim_expr(shape_analysis->GetNextSymName()); dim_expr = symbolic_dim_expr; } else { @@ -120,7 +121,7 @@ bool DataOpInferSymbolicShape(pir::Operation *op, pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data)); pir::OpResult res = op->result(0); - shape_analysis->value_to_shape_or_data_[res] = shape_data; + shape_analysis->SetShapeOrDataForValue(res, shape_data); return true; } @@ -171,13 +172,13 @@ bool ShapeOpInferSymbolicShape(pir::Operation *op, pir::OpResult res = op->result(0); symbol::ShapeOrDataDimExprs operand_shape_or_data = - shape_analysis->value_to_shape_or_data_[operand_source]; + shape_analysis->GetShapeOrDataForValue(operand_source); symbol::ShapeOrDataDimExprs extend_shape_or_data = symbol::ShapeOrDataDimExprs::MakeConsistentShapeOrData( operand_shape_or_data); - shape_analysis->value_to_shape_or_data_[res] = extend_shape_or_data; + shape_analysis->SetShapeOrDataForValue(res, extend_shape_or_data); op->set_attribute("symbolic_shape", pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), extend_shape_or_data)); @@ -193,7 +194,7 @@ bool StackOpInferSymbolicShape(pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { pir::Value operand_source = op->operand_source(0); symbol::ShapeOrDataDimExprs operand_shape_or_data = - shape_analysis->value_to_shape_or_data_[operand_source]; + shape_analysis->GetShapeOrDataForValue(operand_source); std::vector out_dims; if (operand_shape_or_data.data().has_value()) { @@ -213,7 +214,7 @@ bool StackOpInferSymbolicShape(pir::Operation *op, "symbolic_shape", pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data)); pir::OpResult res = op->result(0); - shape_analysis->value_to_shape_or_data_[res] = shape_data; + shape_analysis->SetShapeOrDataForValue(res, shape_data); return true; } @@ -222,7 +223,7 @@ bool ReshapeOpInferSymbolicShape( pir::Value operand_source_shape = op->operand_source(1); symbol::ShapeOrDataDimExprs operand_shape_or_data = - shape_analysis->value_to_shape_or_data_[operand_source_shape]; + shape_analysis->GetShapeOrDataForValue(operand_source_shape); std::vector out_dims; if (operand_shape_or_data.data().has_value()) { @@ -236,9 +237,9 @@ bool ReshapeOpInferSymbolicShape( pir::OpResult res0 = op->result(0); pir::OpResult res1 = op->result(1); - shape_analysis->value_to_shape_or_data_[res0] = shape_data; - shape_analysis->value_to_shape_or_data_[res1] = - shape_analysis->value_to_shape_or_data_[operand_source_shape]; + shape_analysis->SetShapeOrDataForValue(res0, shape_data); + shape_analysis->SetShapeOrDataForValue( + res1, shape_analysis->GetShapeOrDataForValue(operand_source_shape)); return true; } @@ -267,7 +268,7 @@ bool FullIntArrayOpInferSymbolicShape( pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data)); pir::OpResult res = op->result(0); - shape_analysis->value_to_shape_or_data_[res] = shape_data; + shape_analysis->SetShapeOrDataForValue(res, shape_data); return true; } @@ -286,7 +287,7 @@ bool SliceOpInferSymbolicShape(pir::Operation *op, // dialect. pir::Value operand_source = op->operand_source(0); symbol::ShapeOrDataDimExprs operand_shape_or_data = - shape_analysis->value_to_shape_or_data_[operand_source]; + shape_analysis->GetShapeOrDataForValue(operand_source); pir::AttributeMap attributes = op->attributes(); std::vector attr_starts = @@ -309,7 +310,7 @@ bool SliceOpInferSymbolicShape(pir::Operation *op, pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data)); pir::OpResult res = op->result(0); - shape_analysis->value_to_shape_or_data_[res] = shape_data; + shape_analysis->SetShapeOrDataForValue(res, shape_data); return true; } diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc index cc03b61d354d7..8ecaf902d39a7 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc @@ -3005,15 +3005,9 @@ bool ShapeBroadcastOp::InferSymbolicShape( pir::ShapeConstraintIRAnalysis *shape_analysis) { pir::Value x = operand_source(0); pir::Value y = operand_source(1); - std::string x_id = pir::GetValueId(&x); - std::string y_id = pir::GetValueId(&y); - - IR_ENFORCE(shape_analysis->value_id_to_shapeordata_.count(x_id) > 0, - "x_id does not exist."); - IR_ENFORCE(shape_analysis->value_id_to_shapeordata_.count(y_id) > 0, - "y_id does not exist."); - const auto &x_data_shape = shape_analysis->value_id_to_shapeordata_.at(x_id); - const auto &y_data_shape = shape_analysis->value_id_to_shapeordata_.at(y_id); + + const auto &x_data_shape = shape_analysis->GetShapeOrDataForValue(x); + const auto &y_data_shape = shape_analysis->GetShapeOrDataForValue(y); IR_ENFORCE(x_data_shape.data().has_value(), "Value x comes from ShapeOp, it must have data"); IR_ENFORCE(y_data_shape.data().has_value(), @@ -3028,10 +3022,9 @@ bool ShapeBroadcastOp::InferSymbolicShape( } pir::OpResult res = result(0); - std::string res_id = pir::GetValueId(&res); symbol::ShapeOrDataDimExprs output_data_shape = symbol::ShapeOrDataDimExprs::MakeConsistentShapeOrData(output_data); - shape_analysis->value_id_to_shapeordata_[res_id] = output_data_shape; + shape_analysis->SetShapeOrDataForValue(res, output_data_shape); return true; } diff --git a/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc index 3c56533c2d7c1..44774a59c2be8 100644 --- a/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc +++ b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc @@ -62,11 +62,11 @@ struct CombineOpInferSymbolicShapeInterfaceModel } auto operand_source_1st_data = - shape_analysis->value_to_shape_or_data_[op->operand_source(0)].data(); + shape_analysis->GetShapeOrDataForValue(op->operand_source(0)).data(); if (operand_source_1st_data.has_value()) { for (auto operand_source : op->operands_source()) { auto source_data = - shape_analysis->value_to_shape_or_data_[operand_source] + shape_analysis->GetShapeOrDataForValue(operand_source) .data() .value(); out_dims.push_back(source_data[0]); @@ -83,7 +83,7 @@ struct CombineOpInferSymbolicShapeInterfaceModel pir::shape::SymbolAttribute::get( pir::IrContext::Instance(), shape_data)); auto res = op->result(0); - shape_analysis->value_to_shape_or_data_[res] = shape_data; + shape_analysis->SetShapeOrDataForValue(res, shape_data); return true; } diff --git a/paddle/fluid/pir/transforms/shape_optimization_pass.cc b/paddle/fluid/pir/transforms/shape_optimization_pass.cc index 33ee31b66b4d6..485bb8c15f8ba 100644 --- a/paddle/fluid/pir/transforms/shape_optimization_pass.cc +++ b/paddle/fluid/pir/transforms/shape_optimization_pass.cc @@ -51,7 +51,7 @@ void DebugPrintOpInfo( << "ShapeOrData: "; if (shape_analysis != nullptr) { - auto shape_data = shape_analysis->value_to_shape_or_data_[res]; + auto shape_data = shape_analysis->GetShapeOrDataForValue(res); print_stream << "shape: ["; for (size_t i = 0; i < shape_data.shape().size(); ++i) { @@ -94,7 +94,9 @@ void InferSymExprForAllValues(ModuleOp module_op) { if (infer_symbolic_shape_interface) { VLOG(3) << op.name() << " has InferSymbolicShapeInterface."; PADDLE_ENFORCE(infer_symbolic_shape_interface.InferSymbolicShape( - &shape_analysis)); + &shape_analysis), + "InferSymbolicShape for %s failed.", + op.name()); } DebugPrintOpInfo(&op, &shape_analysis); } diff --git a/paddle/pir/core/builtin_type_interfaces.cc b/paddle/pir/core/builtin_type_interfaces.cc index 1325069bf79f3..e5e5c199339f3 100644 --- a/paddle/pir/core/builtin_type_interfaces.cc +++ b/paddle/pir/core/builtin_type_interfaces.cc @@ -21,16 +21,8 @@ Type ShapedTypeInterface::GetElementType() const { return impl_->get_element_type(*this); } -std::vector ShapedTypeInterface::GetDyShape() const { - if (dy_shape_.size() == 0) { - auto ddim_vec = common::vectorize(impl_->get_shape(*this)); - dy_shape_ = ddim_vec; - std::replace(dy_shape_.begin(), - dy_shape_.end(), - (int64_t)-1, - ShapedTypeInterface::kDynamic); - } - return dy_shape_; +pir::DDim ShapedTypeInterface::GetShape() const { + return impl_->get_shape(*this); } } // namespace pir diff --git a/paddle/pir/core/builtin_type_interfaces.h b/paddle/pir/core/builtin_type_interfaces.h index 3f2357eb41fa0..b9c476f21e472 100644 --- a/paddle/pir/core/builtin_type_interfaces.h +++ b/paddle/pir/core/builtin_type_interfaces.h @@ -56,7 +56,7 @@ class IR_API ShapedTypeInterface /// /// \brief kDynamic /// - static constexpr int64_t kDynamic = std::numeric_limits::min(); + static constexpr int64_t kDynamic = std::int64_t(-1); ShapedTypeInterface(Type type, Concept *impl) : TypeInterfaceBase(type), impl_(impl) {} @@ -69,7 +69,7 @@ class IR_API ShapedTypeInterface /// /// \brief Get the shape of this type. /// - std::vector GetDyShape() const; + pir::DDim GetShape() const; /// /// \brief Check whether this type is ranked, currently return true. @@ -81,7 +81,7 @@ class IR_API ShapedTypeInterface /// int64_t GetRank() const { IR_ENFORCE((*this).HasRank(), "Cannot query rank of unranked shaped type."); - return (*this).GetDyShape().size(); + return (*this).GetShape().size(); } /// @@ -94,11 +94,10 @@ class IR_API ShapedTypeInterface /// dimension. /// bool IsDynamicShape() const { - auto size_vec = (*this).GetDyShape(); - return std::any_of( - size_vec.begin(), size_vec.end(), [](int64_t size_value) { - return IsDynamic(size_value); - }); + auto size_vec = common::vectorize(impl_->get_shape(*this)); + return std::any_of(size_vec.begin(), size_vec.end(), [](int64_t size_val) { + return IsDynamic(size_val); + }); } /// @@ -112,7 +111,7 @@ class IR_API ShapedTypeInterface /// bool IsDynamicDim(unsigned idx) const { IR_ENFORCE(idx < GetRank(), "Invalid index for shaped type."); - return ShapedTypeInterface::IsDynamic((*this).GetDyShape()[idx]); + return ShapedTypeInterface::IsDynamic((*this).GetShape()[idx]); } /// @@ -120,7 +119,7 @@ class IR_API ShapedTypeInterface /// Aborts for unranked types. /// int64_t GetNumDynamicDims() const { - auto shape_vec = (*this).GetDyShape(); + auto shape_vec = vectorize((*this).GetShape()); return std::count_if( shape_vec.begin(), shape_vec.end(), ShapedTypeInterface::IsDynamic); } @@ -131,12 +130,11 @@ class IR_API ShapedTypeInterface /// int64_t GetDimSize(unsigned idx) const { IR_ENFORCE(idx < GetRank(), "Invalid index for shaped type."); - return (*this).GetDyShape()[idx]; + return (*this).GetShape()[idx]; } private: Concept *impl_; - mutable std::vector dy_shape_; }; } // namespace pir diff --git a/paddle/pir/core/type_util.cc b/paddle/pir/core/type_util.cc index e681221afd0bd..cf95883eecf71 100644 --- a/paddle/pir/core/type_util.cc +++ b/paddle/pir/core/type_util.cc @@ -23,12 +23,12 @@ Type GetElementTypeOrSelf(Type type) { return type; } -bool VerifyCompatibleShape(const std::vector &lhs_shape, - const std::vector &rhs_shape) { +bool VerifyCompatibleShape(const pir::DDim &lhs_shape, + const pir::DDim &rhs_shape) { if (lhs_shape.size() != rhs_shape.size()) return false; - for (auto dim1 : lhs_shape) { - for (auto dim2 : rhs_shape) { + for (auto dim1 : common::vectorize(lhs_shape)) { + for (auto dim2 : common::vectorize(rhs_shape)) { if (!ShapedTypeInterface::IsDynamic(dim1) && !ShapedTypeInterface::IsDynamic(dim2) && dim1 != dim2) return false; @@ -47,8 +47,8 @@ bool VerifyCompatibleShape(Type lhs_type, Type rhs_type) { if (!lhs_shaped_type.HasRank() || !rhs_shaped_type.HasRank()) return true; - return VerifyCompatibleShape(lhs_shaped_type.GetDyShape(), - rhs_shaped_type.GetDyShape()); + return VerifyCompatibleShape(lhs_shaped_type.GetShape(), + rhs_shaped_type.GetShape()); } bool VerifyCompatibleDims(const std::vector &dims) { diff --git a/paddle/pir/dialect/shape/utils/shape_optimization_utils.cc b/paddle/pir/dialect/shape/utils/shape_optimization_utils.cc index f0f2539e352fb..cb7bf64ebcbe4 100644 --- a/paddle/pir/dialect/shape/utils/shape_optimization_utils.cc +++ b/paddle/pir/dialect/shape/utils/shape_optimization_utils.cc @@ -201,10 +201,9 @@ std::vector SymbolicDimMgr::CreateSymbolicDimsForRankedValue( std::vector symbols; auto dims = value.type().dyn_cast().dims(); for (int idx = 0; idx < dims.size(); ++idx) { - symbols.push_back( - (dims[idx] == ShapedTypeInterface::kDynamic || dims[idx] == -1) - ? NewSymbolicDim() - : NewConstantSymbolicDim(dims[idx])); + symbols.push_back(dims[idx] == ShapedTypeInterface::kDynamic + ? NewSymbolicDim() + : NewConstantSymbolicDim(dims[idx])); } return symbols; } diff --git a/paddle/pir/dialect/shape/utils/shape_utils.cc b/paddle/pir/dialect/shape/utils/shape_utils.cc index 06619f9ae2535..8cacb4d0de70d 100644 --- a/paddle/pir/dialect/shape/utils/shape_utils.cc +++ b/paddle/pir/dialect/shape/utils/shape_utils.cc @@ -60,7 +60,7 @@ bool ShapeConstraintIRAnalysis::IsShapeEqual(Value lhs, Value rhs) { return false; if (lhs_type.HasStaticShape() && rhs_type.HasStaticShape()) { - return lhs_type.GetDyShape() == rhs_type.GetDyShape(); + return vectorize(lhs_type.GetShape()) == vectorize(rhs_type.GetShape()); } auto lhs_it = value_to_sym_dims_.find(lhs); @@ -95,13 +95,13 @@ bool ShapeConstraintIRAnalysis::IsProductEqual(Value lhs, auto it = value_to_sym_dims_.find(value); if (!type || !type.HasRank()) return false; for (int idx : dim_idxs) { - if (type.GetDyShape()[idx] == ShapedTypeInterface::kDynamic) { + if (type.GetShape()[idx] == ShapedTypeInterface::kDynamic) { if (it == value_to_sym_dims_.end() || static_cast(it->second.size()) <= idx) return false; prod.symbols.push_back(it->second[idx]); } else { - prod.factor *= type.GetDyShape()[idx]; + prod.factor *= type.GetShape()[idx]; } } return true; @@ -148,23 +148,15 @@ ShapeConstraintIRAnalysis& ShapeAnalysisManager::Get(pir::Program* program) { return it->second; } -std::string GetValueId(Value* val) { - auto op_id = val->defining_op()->id(); - auto val_idx = val->dyn_cast().index(); - - return "op_" + std::to_string(op_id) + "_rst_" + std::to_string(val_idx); -} - const symbol::ShapeOrDataDimExprs& -ShapeConstraintIRAnalysis::GetShapeOrDataForValue(Value* val) { - auto val_id = GetValueId(val); - return value_id_to_shapeordata_[val_id]; +ShapeConstraintIRAnalysis::GetShapeOrDataForValue(Value val) { + CHECK(value_to_shape_or_data_.find(val) != value_to_shape_or_data_.end()); + return value_to_shape_or_data_[val]; } void ShapeConstraintIRAnalysis::SetShapeOrDataForValue( - Value* val, const symbol::ShapeOrDataDimExprs& shape_or_data) { - auto val_id = GetValueId(val); - value_id_to_shapeordata_[val_id] = shape_or_data; + Value val, const symbol::ShapeOrDataDimExprs& shape_or_data) { + value_to_shape_or_data_[val] = shape_or_data; } } // namespace pir diff --git a/paddle/pir/dialect/shape/utils/shape_utils.h b/paddle/pir/dialect/shape/utils/shape_utils.h index 09a2aba1d15f2..47ec6d58637dd 100644 --- a/paddle/pir/dialect/shape/utils/shape_utils.h +++ b/paddle/pir/dialect/shape/utils/shape_utils.h @@ -80,22 +80,16 @@ class IR_API ShapeConstraintIRAnalysis : public ShapeAnalysis { return "S" + std::to_string(next_sym_idx_++); } - const symbol::ShapeOrDataDimExprs& GetShapeOrDataForValue(Value* val); + const symbol::ShapeOrDataDimExprs& GetShapeOrDataForValue(Value val); - void SetShapeOrDataForValue(Value* val, + void SetShapeOrDataForValue(Value val, const symbol::ShapeOrDataDimExprs& shape_or_data); - // const symbol::ShapeOrData& GetShapeOrDataForValue() const; - symbol::DimExprBuilder CreateDimExprBuilder() override; - std::unordered_map - value_id_to_shapeordata_; - + private: std::unordered_map value_to_shape_or_data_; - - private: // The operation this analysis runs on. ModuleOp m_; // The `SymbolicDimMgr` this analysis holds. diff --git a/test/cpp/pir/core/type_interface_test.cc b/test/cpp/pir/core/type_interface_test.cc index 7a7af415823ee..2695a01a914fe 100644 --- a/test/cpp/pir/core/type_interface_test.cc +++ b/test/cpp/pir/core/type_interface_test.cc @@ -51,9 +51,8 @@ TEST(shapedtype_test, shapedtype_test) { EXPECT_EQ( dense_tensor_type_interface.GetElementType().isa(), true); - EXPECT_EQ(dense_tensor_type_interface.GetDyShape(), common::vectorize(dims)); - EXPECT_EQ(dense_tensor_type_interface.kDynamic, - std::numeric_limits::min()); + EXPECT_EQ(dense_tensor_type_interface.GetShape(), dims); + EXPECT_EQ(dense_tensor_type_interface.kDynamic, std::int64_t(-1)); EXPECT_EQ(dense_tensor_type_interface.GetRank(), 2); EXPECT_EQ(dense_tensor_type_interface.IsDynamic(2), false); EXPECT_EQ(dense_tensor_type_interface.IsDynamicShape(), false);