From 505a40b961eaf9fd35b29c9df28380d751ffe3f3 Mon Sep 17 00:00:00 2001 From: tc20042008 <156998525+tc20042008@users.noreply.github.com> Date: Fri, 17 May 2024 10:05:09 +0800 Subject: [PATCH] Dump original pir code to original_programs.py (#64373) * fix typo * refactor hash and operator== for ShapeOrDataDimExprs * dump null Types and null symbols to pir py code --------- Co-authored-by: jiahy0825 --- .../operator/transforms/add_cinn_pass.cc | 12 +- .../transforms/pir_to_py_code_converter.cc | 139 ++++++++++++------ .../transforms/pir_to_py_code_converter.h | 21 ++- .../operator/transforms/type_adt_type_id.cc | 3 + .../operator/transforms/type_adt_type_id.h | 10 +- paddle/common/flags.cc | 5 + 6 files changed, 140 insertions(+), 50 deletions(-) diff --git a/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc index 8199cc3fa740ea..7f8a3a7cf02fc5 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc @@ -219,13 +219,21 @@ int64_t GetOpCount(const ::pir::Operation* op) { void ApplyCinnPass(::pir::Program* program, const std::function()>& CreatePassManager) { + PirToPyCodeConverter(program) + .file_name("original_programs.py") + .dump_symbolic_shape(false) + .SaveIfFlagEnabled(); ApplyPdToCinnPass(program, CreatePassManager); ApplyCinnPreprocessPass(program, CreatePassManager); ApplyBuildGroupOpPass(program, CreatePassManager); - PirToPyCodeConverter().SaveIfFlagEnabled("group_op_programs", *program); + PirToPyCodeConverter(program) + .file_name("group_op_programs.py") + .SaveIfFlagEnabled(); ApplyGroupOpPass(program, CreatePassManager); ApplyDivideGroupOpToFusionOpPass(program, CreatePassManager); - PirToPyCodeConverter().SaveIfFlagEnabled("fusion_op_programs", *program); + PirToPyCodeConverter(program) + .file_name("fusion_op_programs.py") + .SaveIfFlagEnabled(); LOG(INFO) << "FusionOp count before lowering : *****[ " << GetOpCount(program->module_op()) << " ]*****"; diff --git a/paddle/cinn/hlir/dialect/operator/transforms/pir_to_py_code_converter.cc b/paddle/cinn/hlir/dialect/operator/transforms/pir_to_py_code_converter.cc index 6da5b3cc6eebdd..78703dfe2c618a 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/pir_to_py_code_converter.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/pir_to_py_code_converter.cc @@ -37,6 +37,7 @@ #include "paddle/pir/include/dialect/shape/utils/shape_analysis.h" COMMON_DECLARE_string(logging_pir_py_code_dir); +COMMON_DECLARE_bool(logging_trunc_pir_py_code); namespace cinn::dialect::ir { @@ -92,13 +93,20 @@ int64_t GetAutoIncrementalId() { return seq_no++; } +using ShapeAnalysisGetterT = + std::function( + const pir::Program*)>; + } // namespace struct PirToPyCodeConverterHelper { - explicit PirToPyCodeConverterHelper(const pir::Program* program) + explicit PirToPyCodeConverterHelper( + const pir::Program* program, + const ShapeAnalysisGetterT& ShapeAnalysisGetter) : program_(program), indent_size_(kDefaultIndentSize), - seq_no_(GetAutoIncrementalId()) {} + seq_no_(GetAutoIncrementalId()), + ShapeAnalysisGetter_(ShapeAnalysisGetter) {} std::string Convert() { return Convert(*program_); } @@ -106,6 +114,7 @@ struct PirToPyCodeConverterHelper { const pir::Program* program_; const int indent_size_; int64_t seq_no_; + ShapeAnalysisGetterT ShapeAnalysisGetter_; std::string Convert(const pir::Program& program) { auto istrings = ConvertMethodsToPyClass(program.module_op(), [&]() { @@ -147,7 +156,8 @@ struct PirToPyCodeConverterHelper { template void VisitEachEQCstr(const DoEachEQCstrT& DoEachEQCstr) { const auto& constraints_mgr = GetConstraintsMgr(); - for (const auto& [lhs, rhs] : constraints_mgr.equals().GetMap()) { + if (!constraints_mgr.has_value()) return; + for (const auto& [lhs, rhs] : constraints_mgr.value()->equals().GetMap()) { if (lhs == rhs) continue; DoEachEQCstr(lhs, rhs); } @@ -165,7 +175,8 @@ struct PirToPyCodeConverterHelper { template void VisitEachGtOneCstr(const DoEachGtOneCstrT& DoEachGtOneCstr) { const auto& constraints_mgr = GetConstraintsMgr(); - for (const auto& dim_expr : constraints_mgr.gtones()) { + if (!constraints_mgr.has_value()) return; + for (const auto& dim_expr : constraints_mgr.value()->gtones()) { DoEachGtOneCstr(dim_expr); } } @@ -181,7 +192,8 @@ struct PirToPyCodeConverterHelper { void VisitEachBroadcastableCstr( const DoEachBroadcastableCstrT& DoEachBroadcastableCstr) { const auto& constraints_mgr = GetConstraintsMgr(); - const auto& broadcastables = constraints_mgr.broadcastables(); + if (!constraints_mgr.has_value()) return; + const auto& broadcastables = constraints_mgr.value()->broadcastables(); for (const auto& broadcastable : broadcastables) { const auto& [lhs, rhs] = *broadcastable; if (lhs == rhs) continue; @@ -198,9 +210,10 @@ struct PirToPyCodeConverterHelper { return ss.str(); } - const symbol::ConstraintsManager& GetConstraintsMgr() { - auto& shape_analysis = pir::ShapeAnalysisManager::Instance().Get(program_); - return shape_analysis.constraints_manager(); + std::optional GetConstraintsMgr() { + const auto& shape_analysis = ShapeAnalysisGetter_(program_); + if (!shape_analysis.has_value()) return std::nullopt; + return &shape_analysis.value()->constraints_manager(); } IStrings ConvertModuleOp(const pir::ModuleOp& module) { @@ -402,6 +415,12 @@ struct PirToPyCodeConverterHelper { } ss << attr_name << "=" << ConvertAttr(attr); }); + VisitSymbolicAttrs(op, [&](const auto& attr_name, const auto& attrs) { + if (i++ > 0) { + ss << ", "; + } + ss << attr_name << "=" << ConvertSymbolicAttrs(attrs); + }); return ss.str(); } @@ -410,6 +429,25 @@ struct PirToPyCodeConverterHelper { return std::visit(AttrConverter{attr}, adt_type_id.variant()); } + static std::string ConvertSymbolicAttrs( + const std::vector>& attrs) { + std::ostringstream ss; + ss << "self.a_array("; + int i = 0; + for (const auto& attr : attrs) { + if (i++ > 0) { + ss << ", "; + } + if (!attr.has_value()) { + ss << "self.a_symbol(self.s_null())"; + } else { + ss << ConvertAttr(attr.value()); + } + } + ss << ")"; + return ss.str(); + } + static std::string ConvertShapeOrData( const symbol::ShapeOrDataDimExprs& shape_or_data) { return shape_or_data.Match( @@ -781,25 +819,20 @@ struct PirToPyCodeConverterHelper { if (attr_name == "sym_shape_str") continue; DoEachAttr(attr_name, attr); } - DoEachAttr("__operands_symbols_signature__", - GetOpOperandsSymbolsSignature(op)); - DoEachAttr("__results_symbols_signature__", - GetOpResultsSymbolsSignature(op)); - } - - pir::Attribute GetOpOperandsSymbolsSignature(const pir::Operation* op) { - std::vector attrs = GetOpOperandsSymbolDimsAttributes(op); - return pir::ArrayAttribute::get(pir::IrContext::Instance(), attrs); } - pir::Attribute GetOpResultsSymbolsSignature(const pir::Operation* op) { - std::vector attrs = GetOpResultsSymbolDimsAttributes(op); - return pir::ArrayAttribute::get(pir::IrContext::Instance(), attrs); + template + void VisitSymbolicAttrs(const pir::Operation* op, + const DoEachAttrT& DoEachAttr) { + DoEachAttr("__operands_symbols_signature__", + GetOpOperandsSymbolDimsAttributes(op)); + DoEachAttr("__results_symbols_signature__", + GetOpResultsSymbolDimsAttributes(op)); } - std::vector GetOpOperandsSymbolDimsAttributes( + std::vector> GetOpOperandsSymbolDimsAttributes( const pir::Operation* op) { - std::vector attrs; + std::vector> attrs; attrs.reserve(op->num_operands()); for (int i = 0; i < op->num_operands(); ++i) { attrs.push_back(GetValueSymbolDimsAttribute(op->operand_source(i))); @@ -807,9 +840,9 @@ struct PirToPyCodeConverterHelper { return attrs; } - std::vector GetOpResultsSymbolDimsAttributes( + std::vector> GetOpResultsSymbolDimsAttributes( const pir::Operation* op) { - std::vector attrs; + std::vector> attrs; attrs.reserve(op->num_results()); for (int i = 0; i < op->num_results(); ++i) { attrs.push_back(GetValueSymbolDimsAttribute(op->result(i))); @@ -817,19 +850,22 @@ struct PirToPyCodeConverterHelper { return attrs; } - pir::Attribute GetValueSymbolDimsAttribute(pir::Value value) { + std::optional GetValueSymbolDimsAttribute(pir::Value value) { auto* ctx = pir::IrContext::Instance(); using SymbolAttr = pir::shape::SymbolAttribute; if (!value) { - return SymbolAttr::get(ctx, symbol::TensorShapeOrDataDimExprs{}); + return std::nullopt; } - const auto* shape_or_data = GetShapeOrDataDimExprs(value); - return SymbolAttr::get(ctx, *shape_or_data); + const auto& shape_or_data = GetShapeOrDataDimExprs(value); + if (!shape_or_data.has_value()) return std::nullopt; + return SymbolAttr::get(ctx, *shape_or_data.value()); } - const symbol::ShapeOrDataDimExprs* GetShapeOrDataDimExprs(pir::Value value) { - auto& shape_analysis = pir::ShapeAnalysisManager::Instance().Get(program_); - return &shape_analysis.GetShapeOrDataForValue(value); + std::optional GetShapeOrDataDimExprs( + pir::Value value) { + const auto& shape_analysis = ShapeAnalysisGetter_(program_); + if (!shape_analysis.has_value()) return std::nullopt; + return &shape_analysis.value()->GetShapeOrDataForValue(value); } std::string ConvertInputTypes(const pir::Operation* op) { @@ -869,6 +905,10 @@ struct PirToPyCodeConverterHelper { template using AdtTypeId = ::common::AdtTypeId; + std::string operator()(AdtTypeId) { + return "self.t_null"; + } + std::string operator()(AdtTypeId<::pir::VectorType>) { std::stringstream ss; const auto& name = ::pir::VectorType::name(); @@ -1130,22 +1170,39 @@ struct PirToPyCodeConverterHelper { } }; +std::optional GetShapeAnalysisFromManager( + const pir::Program* program) { + return &pir::ShapeAnalysisManager::Instance().Get(program); +} + +std::optional GetNullShapeAnalysis( + const pir::Program* program) { + return std::nullopt; +} + } // namespace -void PirToPyCodeConverter::SaveIfFlagEnabled( - const std::string& tag, const pir::Program& program) const { +void PirToPyCodeConverter::SaveIfFlagEnabled() const { + if (program_ == nullptr) return; + if (file_name_.empty()) return; if (FLAGS_logging_pir_py_code_dir == "") return; const std::string file_path = - FLAGS_logging_pir_py_code_dir + "/" + tag + ".py"; - const std::string content = PirToPyCodeConverterHelper(&program).Convert(); + FLAGS_logging_pir_py_code_dir + "/" + file_name_; + ShapeAnalysisGetterT ShapeAnalysisGetter = + (dump_symbolic_shape_ ? GetShapeAnalysisFromManager + : GetNullShapeAnalysis); + PirToPyCodeConverterHelper converter_helper(program_, ShapeAnalysisGetter); + const std::string content = converter_helper.Convert(); static std::mutex mutex; std::unique_lock lock(mutex); - static std::unordered_map once_flags; - std::call_once(once_flags[file_path], [&] { - std::ofstream ofs; - ofs.open(file_path.c_str(), std::ios::out | std::ios::trunc); - ofs.close(); - }); + if (FLAGS_logging_trunc_pir_py_code) { + static std::unordered_map once_flags; + std::call_once(once_flags[file_path], [&] { + std::ofstream ofs; + ofs.open(file_path.c_str(), std::ios::out | std::ios::trunc); + ofs.close(); + }); + } std::ofstream ofs; ofs.open(file_path.c_str(), std::ios::out | std::ios::app); if (!ofs.is_open()) return; diff --git a/paddle/cinn/hlir/dialect/operator/transforms/pir_to_py_code_converter.h b/paddle/cinn/hlir/dialect/operator/transforms/pir_to_py_code_converter.h index bbb36acd526a6d..b72c4ef56c579a 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/pir_to_py_code_converter.h +++ b/paddle/cinn/hlir/dialect/operator/transforms/pir_to_py_code_converter.h @@ -26,12 +26,27 @@ namespace cinn::dialect::ir { class PirToPyCodeConverter { public: - PirToPyCodeConverter() = default; + explicit PirToPyCodeConverter(pir::Program* program) + : program_(program), file_name_(), dump_symbolic_shape_(true) {} PirToPyCodeConverter(const PirToPyCodeConverter&) = delete; PirToPyCodeConverter(PirToPyCodeConverter&&) = delete; - void SaveIfFlagEnabled(const std::string& tag, - const pir::Program& program) const; + PirToPyCodeConverter& file_name(const std::string& file_name) { + file_name_ = file_name; + return *this; + } + + PirToPyCodeConverter& dump_symbolic_shape(bool val) { + dump_symbolic_shape_ = val; + return *this; + } + + void SaveIfFlagEnabled() const; + + private: + pir::Program* program_; + std::string file_name_; + bool dump_symbolic_shape_; }; } // namespace cinn::dialect::ir diff --git a/paddle/cinn/hlir/dialect/operator/transforms/type_adt_type_id.cc b/paddle/cinn/hlir/dialect/operator/transforms/type_adt_type_id.cc index da8bfba19fcf9a..82ca2cafb7539f 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/type_adt_type_id.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/type_adt_type_id.cc @@ -24,6 +24,9 @@ namespace cinn::dialect::ir { TypeAdtTypeId GetTypeAdtTypeId(const pir::Type& type) { + if (!type) { + return ::common::AdtTypeId{}; + } #define RETURN_TYPE_TYPE_ID_IF_MATCH(cls) \ if (type.isa()) return ::common::AdtTypeId{}; FOR_EACH_PIR_ALTERNATIVE_TYPLE(RETURN_TYPE_TYPE_ID_IF_MATCH) diff --git a/paddle/cinn/hlir/dialect/operator/transforms/type_adt_type_id.h b/paddle/cinn/hlir/dialect/operator/transforms/type_adt_type_id.h index 41b70c97be9aac..f23bda9bef3317 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/type_adt_type_id.h +++ b/paddle/cinn/hlir/dialect/operator/transforms/type_adt_type_id.h @@ -18,7 +18,6 @@ namespace pir { class Type; - class VectorType; class DenseTensorType; class BFloat16Type; @@ -58,13 +57,16 @@ class Complex128Type; namespace cinn::dialect::ir { +class NullType; class UnclassifiedType; -using TypeAdtTypeIdBase = ::common::AdtBaseTypeId< +using TypeAdtTypeIdBase = + ::common::AdtBaseTypeId; + UnclassifiedType>; struct TypeAdtTypeId : public TypeAdtTypeIdBase { using TypeAdtTypeIdBase::TypeAdtTypeIdBase; diff --git a/paddle/common/flags.cc b/paddle/common/flags.cc index 9689d3b75b7bad..c9b3b29115d757 100644 --- a/paddle/common/flags.cc +++ b/paddle/common/flags.cc @@ -1448,6 +1448,11 @@ PHI_DEFINE_EXPORTED_string(logging_pir_py_code_dir, "", "the logging directory to save pir py code"); +PHI_DEFINE_EXPORTED_bool(logging_trunc_pir_py_code, + true, + "whether truncate the logging files under directory " + "FLAGS_logging_pir_py_code_dir"); + /** * Using PIR API in Python * Name: enable_pir_api