From 5200038348fec42680fec06d5ba54c97aecd0bc6 Mon Sep 17 00:00:00 2001 From: jiahy0825 Date: Mon, 6 May 2024 11:11:40 +0000 Subject: [PATCH 1/3] fix typo --- .../dialect/operator/transforms/pir_to_py_code_converter.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/cinn/hlir/dialect/operator/transforms/pir_to_py_code_converter.cc b/paddle/cinn/hlir/dialect/operator/transforms/pir_to_py_code_converter.cc index a8139f9cffc8b..32f6d67d75d46 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/pir_to_py_code_converter.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/pir_to_py_code_converter.cc @@ -583,7 +583,7 @@ struct PirToPyCodeConverterHelper { std::string operator()(AdtTypeId<::pir::VectorType>) { std::stringstream ss; - const auto& name = ::pir::DenseTensorType::name(); + const auto& name = ::pir::VectorType::name(); const auto& vec_type = type.dyn_cast<::pir::VectorType>(); ss << "self." << name << "("; for (int i = 0; i < vec_type.size(); ++i) { From f027356e75be77567e3dd8897831c6eb9af68cfb Mon Sep 17 00:00:00 2001 From: jiahy0825 Date: Thu, 9 May 2024 08:19:46 +0000 Subject: [PATCH 2/3] refactor hash and operator== for ShapeOrDataDimExprs --- .../shape/ir/shape_attribute_storage.h | 20 ++--------- .../include/dialect/shape/utils/dim_expr.h | 13 +++++++ .../dialect/shape/utils/shape_or_data_expr.h | 36 +++++++++++++++++-- 3 files changed, 48 insertions(+), 21 deletions(-) diff --git a/paddle/pir/include/dialect/shape/ir/shape_attribute_storage.h b/paddle/pir/include/dialect/shape/ir/shape_attribute_storage.h index 055a045f292da..42885e12d9b56 100644 --- a/paddle/pir/include/dialect/shape/ir/shape_attribute_storage.h +++ b/paddle/pir/include/dialect/shape/ir/shape_attribute_storage.h @@ -40,26 +40,10 @@ struct SymbolAttributeStorage : public AttributeStorage { } static std::size_t HashValue(const ParamKey &key) { - std::size_t hash_value = 0; - for (size_t i = 0; i < key.shape().size(); ++i) { - hash_value = detail::hash_combine( - hash_value, - std::hash()(symbol::ToString(key.shape()[i]))); - } - if (key.data().has_value()) { - for (size_t i = 0; i < key.data().value().size(); ++i) { - hash_value = detail::hash_combine( - hash_value, - std::hash()(symbol::ToString(key.data().value()[i]))); - } - } - - return hash_value; + return std::hash()(key); } - bool operator==(const ParamKey &key) const { - return data_.shape() == key.shape() && data_.data() == key.data(); - } + bool operator==(const ParamKey &key) const { return data_ == key; } ParamKey data() const { return data_; } diff --git a/paddle/pir/include/dialect/shape/utils/dim_expr.h b/paddle/pir/include/dialect/shape/utils/dim_expr.h index a45ba01538ae7..448fb6b777eab 100644 --- a/paddle/pir/include/dialect/shape/utils/dim_expr.h +++ b/paddle/pir/include/dialect/shape/utils/dim_expr.h @@ -26,6 +26,7 @@ #include "paddle/common/enforce.h" #include "paddle/common/overloaded.h" #include "paddle/pir/include/core/dll_decl.h" +#include "paddle/pir/include/core/utils.h" namespace symbol { @@ -241,4 +242,16 @@ struct hash { } }; +template <> +struct hash> { + std::size_t operator()(const std::vector& dim_exprs) const { + std::size_t hash_value = 0; + const auto hash_func = std::hash(); + for (const auto& dim_expr : dim_exprs) { + hash_value = pir::detail::hash_combine(hash_value, hash_func(dim_expr)); + } + return hash_value; + } +}; + } // namespace std diff --git a/paddle/pir/include/dialect/shape/utils/shape_or_data_expr.h b/paddle/pir/include/dialect/shape/utils/shape_or_data_expr.h index a9d393342d25d..1089caab8ed45 100644 --- a/paddle/pir/include/dialect/shape/utils/shape_or_data_expr.h +++ b/paddle/pir/include/dialect/shape/utils/shape_or_data_expr.h @@ -193,13 +193,43 @@ IR_API std::ostream& operator<<(std::ostream&, const ShapeOrDataDimExprs& dim_expr); } // namespace symbol + namespace std { + +template <> +struct hash { + std::size_t operator()(const symbol::TensorShapeOrDataDimExprs& obj) const { + const auto hash_func = std::hash>(); + std::size_t ret = hash_func(obj.shape()); + ret = pir::detail::hash_combine(ret, obj.data().has_value()); + if (obj.data().has_value()) { + ret = pir::detail::hash_combine(ret, hash_func(obj.data().value())); + } + return ret; + } +}; + +template <> +struct hash { + std::size_t operator()( + const symbol::TensorListShapeOrDataDimExprs& obj) const { + const auto hash_func = std::hash(); + std::size_t ret = 0; + for (const auto& shape_or_data : obj) { + ret = pir::detail::hash_combine(ret, hash_func(shape_or_data)); + } + return ret; + } +}; + template <> struct hash { std::size_t operator()(const symbol::ShapeOrDataDimExprs& obj) const { - std::ostringstream os; - os << obj; - return std::hash()(os.str()); + return obj.Match([](const auto& impl) { + using T = std::decay_t; + return std::hash()(impl); + }); } }; + } // namespace std From 99a73957f1bd6db3e6cbde28ede8862b8f7a0778 Mon Sep 17 00:00:00 2001 From: jiahy0825 Date: Thu, 16 May 2024 08:35:10 +0000 Subject: [PATCH 3/3] dump null Types and null symbols to pir py code --- .../operator/transforms/add_cinn_pass.cc | 12 +- .../transforms/pir_to_py_code_converter.cc | 139 ++++++++++++------ .../transforms/pir_to_py_code_converter.h | 21 ++- .../operator/transforms/type_adt_type_id.cc | 3 + .../operator/transforms/type_adt_type_id.h | 10 +- paddle/common/flags.cc | 5 + 6 files changed, 140 insertions(+), 50 deletions(-) diff --git a/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc index 8199cc3fa740e..7f8a3a7cf02fc 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc @@ -219,13 +219,21 @@ int64_t GetOpCount(const ::pir::Operation* op) { void ApplyCinnPass(::pir::Program* program, const std::function()>& CreatePassManager) { + PirToPyCodeConverter(program) + .file_name("original_programs.py") + .dump_symbolic_shape(false) + .SaveIfFlagEnabled(); ApplyPdToCinnPass(program, CreatePassManager); ApplyCinnPreprocessPass(program, CreatePassManager); ApplyBuildGroupOpPass(program, CreatePassManager); - PirToPyCodeConverter().SaveIfFlagEnabled("group_op_programs", *program); + PirToPyCodeConverter(program) + .file_name("group_op_programs.py") + .SaveIfFlagEnabled(); ApplyGroupOpPass(program, CreatePassManager); ApplyDivideGroupOpToFusionOpPass(program, CreatePassManager); - PirToPyCodeConverter().SaveIfFlagEnabled("fusion_op_programs", *program); + PirToPyCodeConverter(program) + .file_name("fusion_op_programs.py") + .SaveIfFlagEnabled(); LOG(INFO) << "FusionOp count before lowering : *****[ " << GetOpCount(program->module_op()) << " ]*****"; diff --git a/paddle/cinn/hlir/dialect/operator/transforms/pir_to_py_code_converter.cc b/paddle/cinn/hlir/dialect/operator/transforms/pir_to_py_code_converter.cc index 6da5b3cc6eebd..78703dfe2c618 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/pir_to_py_code_converter.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/pir_to_py_code_converter.cc @@ -37,6 +37,7 @@ #include "paddle/pir/include/dialect/shape/utils/shape_analysis.h" COMMON_DECLARE_string(logging_pir_py_code_dir); +COMMON_DECLARE_bool(logging_trunc_pir_py_code); namespace cinn::dialect::ir { @@ -92,13 +93,20 @@ int64_t GetAutoIncrementalId() { return seq_no++; } +using ShapeAnalysisGetterT = + std::function( + const pir::Program*)>; + } // namespace struct PirToPyCodeConverterHelper { - explicit PirToPyCodeConverterHelper(const pir::Program* program) + explicit PirToPyCodeConverterHelper( + const pir::Program* program, + const ShapeAnalysisGetterT& ShapeAnalysisGetter) : program_(program), indent_size_(kDefaultIndentSize), - seq_no_(GetAutoIncrementalId()) {} + seq_no_(GetAutoIncrementalId()), + ShapeAnalysisGetter_(ShapeAnalysisGetter) {} std::string Convert() { return Convert(*program_); } @@ -106,6 +114,7 @@ struct PirToPyCodeConverterHelper { const pir::Program* program_; const int indent_size_; int64_t seq_no_; + ShapeAnalysisGetterT ShapeAnalysisGetter_; std::string Convert(const pir::Program& program) { auto istrings = ConvertMethodsToPyClass(program.module_op(), [&]() { @@ -147,7 +156,8 @@ struct PirToPyCodeConverterHelper { template void VisitEachEQCstr(const DoEachEQCstrT& DoEachEQCstr) { const auto& constraints_mgr = GetConstraintsMgr(); - for (const auto& [lhs, rhs] : constraints_mgr.equals().GetMap()) { + if (!constraints_mgr.has_value()) return; + for (const auto& [lhs, rhs] : constraints_mgr.value()->equals().GetMap()) { if (lhs == rhs) continue; DoEachEQCstr(lhs, rhs); } @@ -165,7 +175,8 @@ struct PirToPyCodeConverterHelper { template void VisitEachGtOneCstr(const DoEachGtOneCstrT& DoEachGtOneCstr) { const auto& constraints_mgr = GetConstraintsMgr(); - for (const auto& dim_expr : constraints_mgr.gtones()) { + if (!constraints_mgr.has_value()) return; + for (const auto& dim_expr : constraints_mgr.value()->gtones()) { DoEachGtOneCstr(dim_expr); } } @@ -181,7 +192,8 @@ struct PirToPyCodeConverterHelper { void VisitEachBroadcastableCstr( const DoEachBroadcastableCstrT& DoEachBroadcastableCstr) { const auto& constraints_mgr = GetConstraintsMgr(); - const auto& broadcastables = constraints_mgr.broadcastables(); + if (!constraints_mgr.has_value()) return; + const auto& broadcastables = constraints_mgr.value()->broadcastables(); for (const auto& broadcastable : broadcastables) { const auto& [lhs, rhs] = *broadcastable; if (lhs == rhs) continue; @@ -198,9 +210,10 @@ struct PirToPyCodeConverterHelper { return ss.str(); } - const symbol::ConstraintsManager& GetConstraintsMgr() { - auto& shape_analysis = pir::ShapeAnalysisManager::Instance().Get(program_); - return shape_analysis.constraints_manager(); + std::optional GetConstraintsMgr() { + const auto& shape_analysis = ShapeAnalysisGetter_(program_); + if (!shape_analysis.has_value()) return std::nullopt; + return &shape_analysis.value()->constraints_manager(); } IStrings ConvertModuleOp(const pir::ModuleOp& module) { @@ -402,6 +415,12 @@ struct PirToPyCodeConverterHelper { } ss << attr_name << "=" << ConvertAttr(attr); }); + VisitSymbolicAttrs(op, [&](const auto& attr_name, const auto& attrs) { + if (i++ > 0) { + ss << ", "; + } + ss << attr_name << "=" << ConvertSymbolicAttrs(attrs); + }); return ss.str(); } @@ -410,6 +429,25 @@ struct PirToPyCodeConverterHelper { return std::visit(AttrConverter{attr}, adt_type_id.variant()); } + static std::string ConvertSymbolicAttrs( + const std::vector>& attrs) { + std::ostringstream ss; + ss << "self.a_array("; + int i = 0; + for (const auto& attr : attrs) { + if (i++ > 0) { + ss << ", "; + } + if (!attr.has_value()) { + ss << "self.a_symbol(self.s_null())"; + } else { + ss << ConvertAttr(attr.value()); + } + } + ss << ")"; + return ss.str(); + } + static std::string ConvertShapeOrData( const symbol::ShapeOrDataDimExprs& shape_or_data) { return shape_or_data.Match( @@ -781,25 +819,20 @@ struct PirToPyCodeConverterHelper { if (attr_name == "sym_shape_str") continue; DoEachAttr(attr_name, attr); } - DoEachAttr("__operands_symbols_signature__", - GetOpOperandsSymbolsSignature(op)); - DoEachAttr("__results_symbols_signature__", - GetOpResultsSymbolsSignature(op)); - } - - pir::Attribute GetOpOperandsSymbolsSignature(const pir::Operation* op) { - std::vector attrs = GetOpOperandsSymbolDimsAttributes(op); - return pir::ArrayAttribute::get(pir::IrContext::Instance(), attrs); } - pir::Attribute GetOpResultsSymbolsSignature(const pir::Operation* op) { - std::vector attrs = GetOpResultsSymbolDimsAttributes(op); - return pir::ArrayAttribute::get(pir::IrContext::Instance(), attrs); + template + void VisitSymbolicAttrs(const pir::Operation* op, + const DoEachAttrT& DoEachAttr) { + DoEachAttr("__operands_symbols_signature__", + GetOpOperandsSymbolDimsAttributes(op)); + DoEachAttr("__results_symbols_signature__", + GetOpResultsSymbolDimsAttributes(op)); } - std::vector GetOpOperandsSymbolDimsAttributes( + std::vector> GetOpOperandsSymbolDimsAttributes( const pir::Operation* op) { - std::vector attrs; + std::vector> attrs; attrs.reserve(op->num_operands()); for (int i = 0; i < op->num_operands(); ++i) { attrs.push_back(GetValueSymbolDimsAttribute(op->operand_source(i))); @@ -807,9 +840,9 @@ struct PirToPyCodeConverterHelper { return attrs; } - std::vector GetOpResultsSymbolDimsAttributes( + std::vector> GetOpResultsSymbolDimsAttributes( const pir::Operation* op) { - std::vector attrs; + std::vector> attrs; attrs.reserve(op->num_results()); for (int i = 0; i < op->num_results(); ++i) { attrs.push_back(GetValueSymbolDimsAttribute(op->result(i))); @@ -817,19 +850,22 @@ struct PirToPyCodeConverterHelper { return attrs; } - pir::Attribute GetValueSymbolDimsAttribute(pir::Value value) { + std::optional GetValueSymbolDimsAttribute(pir::Value value) { auto* ctx = pir::IrContext::Instance(); using SymbolAttr = pir::shape::SymbolAttribute; if (!value) { - return SymbolAttr::get(ctx, symbol::TensorShapeOrDataDimExprs{}); + return std::nullopt; } - const auto* shape_or_data = GetShapeOrDataDimExprs(value); - return SymbolAttr::get(ctx, *shape_or_data); + const auto& shape_or_data = GetShapeOrDataDimExprs(value); + if (!shape_or_data.has_value()) return std::nullopt; + return SymbolAttr::get(ctx, *shape_or_data.value()); } - const symbol::ShapeOrDataDimExprs* GetShapeOrDataDimExprs(pir::Value value) { - auto& shape_analysis = pir::ShapeAnalysisManager::Instance().Get(program_); - return &shape_analysis.GetShapeOrDataForValue(value); + std::optional GetShapeOrDataDimExprs( + pir::Value value) { + const auto& shape_analysis = ShapeAnalysisGetter_(program_); + if (!shape_analysis.has_value()) return std::nullopt; + return &shape_analysis.value()->GetShapeOrDataForValue(value); } std::string ConvertInputTypes(const pir::Operation* op) { @@ -869,6 +905,10 @@ struct PirToPyCodeConverterHelper { template using AdtTypeId = ::common::AdtTypeId; + std::string operator()(AdtTypeId) { + return "self.t_null"; + } + std::string operator()(AdtTypeId<::pir::VectorType>) { std::stringstream ss; const auto& name = ::pir::VectorType::name(); @@ -1130,22 +1170,39 @@ struct PirToPyCodeConverterHelper { } }; +std::optional GetShapeAnalysisFromManager( + const pir::Program* program) { + return &pir::ShapeAnalysisManager::Instance().Get(program); +} + +std::optional GetNullShapeAnalysis( + const pir::Program* program) { + return std::nullopt; +} + } // namespace -void PirToPyCodeConverter::SaveIfFlagEnabled( - const std::string& tag, const pir::Program& program) const { +void PirToPyCodeConverter::SaveIfFlagEnabled() const { + if (program_ == nullptr) return; + if (file_name_.empty()) return; if (FLAGS_logging_pir_py_code_dir == "") return; const std::string file_path = - FLAGS_logging_pir_py_code_dir + "/" + tag + ".py"; - const std::string content = PirToPyCodeConverterHelper(&program).Convert(); + FLAGS_logging_pir_py_code_dir + "/" + file_name_; + ShapeAnalysisGetterT ShapeAnalysisGetter = + (dump_symbolic_shape_ ? GetShapeAnalysisFromManager + : GetNullShapeAnalysis); + PirToPyCodeConverterHelper converter_helper(program_, ShapeAnalysisGetter); + const std::string content = converter_helper.Convert(); static std::mutex mutex; std::unique_lock lock(mutex); - static std::unordered_map once_flags; - std::call_once(once_flags[file_path], [&] { - std::ofstream ofs; - ofs.open(file_path.c_str(), std::ios::out | std::ios::trunc); - ofs.close(); - }); + if (FLAGS_logging_trunc_pir_py_code) { + static std::unordered_map once_flags; + std::call_once(once_flags[file_path], [&] { + std::ofstream ofs; + ofs.open(file_path.c_str(), std::ios::out | std::ios::trunc); + ofs.close(); + }); + } std::ofstream ofs; ofs.open(file_path.c_str(), std::ios::out | std::ios::app); if (!ofs.is_open()) return; diff --git a/paddle/cinn/hlir/dialect/operator/transforms/pir_to_py_code_converter.h b/paddle/cinn/hlir/dialect/operator/transforms/pir_to_py_code_converter.h index bbb36acd526a6..b72c4ef56c579 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/pir_to_py_code_converter.h +++ b/paddle/cinn/hlir/dialect/operator/transforms/pir_to_py_code_converter.h @@ -26,12 +26,27 @@ namespace cinn::dialect::ir { class PirToPyCodeConverter { public: - PirToPyCodeConverter() = default; + explicit PirToPyCodeConverter(pir::Program* program) + : program_(program), file_name_(), dump_symbolic_shape_(true) {} PirToPyCodeConverter(const PirToPyCodeConverter&) = delete; PirToPyCodeConverter(PirToPyCodeConverter&&) = delete; - void SaveIfFlagEnabled(const std::string& tag, - const pir::Program& program) const; + PirToPyCodeConverter& file_name(const std::string& file_name) { + file_name_ = file_name; + return *this; + } + + PirToPyCodeConverter& dump_symbolic_shape(bool val) { + dump_symbolic_shape_ = val; + return *this; + } + + void SaveIfFlagEnabled() const; + + private: + pir::Program* program_; + std::string file_name_; + bool dump_symbolic_shape_; }; } // namespace cinn::dialect::ir diff --git a/paddle/cinn/hlir/dialect/operator/transforms/type_adt_type_id.cc b/paddle/cinn/hlir/dialect/operator/transforms/type_adt_type_id.cc index da8bfba19fcf9..82ca2cafb7539 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/type_adt_type_id.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/type_adt_type_id.cc @@ -24,6 +24,9 @@ namespace cinn::dialect::ir { TypeAdtTypeId GetTypeAdtTypeId(const pir::Type& type) { + if (!type) { + return ::common::AdtTypeId{}; + } #define RETURN_TYPE_TYPE_ID_IF_MATCH(cls) \ if (type.isa()) return ::common::AdtTypeId{}; FOR_EACH_PIR_ALTERNATIVE_TYPLE(RETURN_TYPE_TYPE_ID_IF_MATCH) diff --git a/paddle/cinn/hlir/dialect/operator/transforms/type_adt_type_id.h b/paddle/cinn/hlir/dialect/operator/transforms/type_adt_type_id.h index 41b70c97be9aa..f23bda9bef331 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/type_adt_type_id.h +++ b/paddle/cinn/hlir/dialect/operator/transforms/type_adt_type_id.h @@ -18,7 +18,6 @@ namespace pir { class Type; - class VectorType; class DenseTensorType; class BFloat16Type; @@ -58,13 +57,16 @@ class Complex128Type; namespace cinn::dialect::ir { +class NullType; class UnclassifiedType; -using TypeAdtTypeIdBase = ::common::AdtBaseTypeId< +using TypeAdtTypeIdBase = + ::common::AdtBaseTypeId; + UnclassifiedType>; struct TypeAdtTypeId : public TypeAdtTypeIdBase { using TypeAdtTypeIdBase::TypeAdtTypeIdBase; diff --git a/paddle/common/flags.cc b/paddle/common/flags.cc index 2ccd3d7265166..23a3013cf7834 100644 --- a/paddle/common/flags.cc +++ b/paddle/common/flags.cc @@ -1448,6 +1448,11 @@ PHI_DEFINE_EXPORTED_string(logging_pir_py_code_dir, "", "the logging directory to save pir py code"); +PHI_DEFINE_EXPORTED_bool(logging_trunc_pir_py_code, + true, + "whether truncate the logging files under directory " + "FLAGS_logging_pir_py_code_dir"); + /** * Using PIR API in Python * Name: enable_pir_api