diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index b0776dee661f..d113860ddbce 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -28,7 +28,7 @@ #include #include #include - +#include #include #include #include @@ -131,21 +131,21 @@ class IRModuleNode : public Object { * \param name The variable name. * \returns true if contains, otherise false. */ - TVM_DLL bool ContainGlobalVar(const std::string& name) const; + TVM_DLL bool ContainGlobalVar(const String& name) const; /*! * \brief Check if the global_type_var_map_ contains a global type variable. * \param name The variable name. * \returns true if contains, otherise false. */ - TVM_DLL bool ContainGlobalTypeVar(const std::string& name) const; + TVM_DLL bool ContainGlobalTypeVar(const String& name) const; /*! * \brief Lookup a global function by its variable. * \param str The unique string specifying the global variable. * \returns The global variable. */ - TVM_DLL GlobalVar GetGlobalVar(const std::string& str) const; + TVM_DLL GlobalVar GetGlobalVar(const String& str) const; /*! * \brief Collect all global vars defined in this module. @@ -158,7 +158,7 @@ class IRModuleNode : public Object { * \param str The unique string specifying the global variable. * \returns The global variable. */ - TVM_DLL GlobalTypeVar GetGlobalTypeVar(const std::string& str) const; + TVM_DLL GlobalTypeVar GetGlobalTypeVar(const String& str) const; /*! * \brief Collect all global type vars defined in this module. @@ -172,7 +172,7 @@ class IRModuleNode : public Object { * \param cons name of the constructor * \returns Constructor of ADT, error if not found */ - TVM_DLL Constructor GetConstructor(const std::string& adt, const std::string& cons) const; + TVM_DLL Constructor GetConstructor(const String& adt, const String& cons) const; /*! * \brief Look up a global function by its variable. @@ -186,7 +186,7 @@ class IRModuleNode : public Object { * \param name The name of the function. * \returns The function named by the argument. */ - TVM_DLL BaseFunc Lookup(const std::string& name) const; + TVM_DLL BaseFunc Lookup(const String& name) const; /*! * \brief Look up a global type definition by its variable. @@ -200,7 +200,7 @@ class IRModuleNode : public Object { * \param var The name of the global type definition. * \return The type definition. */ - TVM_DLL TypeData LookupTypeDef(const std::string& var) const; + TVM_DLL TypeData LookupTypeDef(const String& var) const; /*! * \brief Look up a constructor by its tag. @@ -225,18 +225,18 @@ class IRModuleNode : public Object { * relative it will be resovled against the current * working directory. */ - TVM_DLL void Import(const std::string& path); + TVM_DLL void Import(const String& path); /*! * \brief Import Relay code from the file at path, relative to the standard library. * \param path The path of the Relay code to import. */ - TVM_DLL void ImportFromStd(const std::string& path); + TVM_DLL void ImportFromStd(const String& path); /*! * \brief The set of imported files. */ - TVM_DLL std::unordered_set Imports() const; + TVM_DLL std::unordered_set Imports() const; static constexpr const char* _type_key = "IRModule"; static constexpr const bool _type_has_method_sequal_reduce = true; @@ -265,7 +265,7 @@ class IRModuleNode : public Object { /*! \brief The files previously imported, required to ensure importing is idempotent for each module. */ - std::unordered_set import_set_; + std::unordered_set import_set_; friend class IRModule; }; @@ -283,7 +283,7 @@ class IRModule : public ObjectRef { */ TVM_DLL explicit IRModule(Map functions, Map type_definitions = {}, - std::unordered_set import_set = {}); + std::unordered_set import_set = {}); /*! \brief default constructor */ IRModule() {} /*! @@ -329,7 +329,7 @@ class IRModule : public ObjectRef { * \param source_path The path to the source file. * \return A Relay module. */ - TVM_DLL static IRModule FromText(const std::string& text, const std::string& source_path); + TVM_DLL static IRModule FromText(const String& text, const String& source_path); /*! \brief Declare the container type. */ using ContainerType = IRModuleNode; @@ -346,7 +346,7 @@ class IRModule : public ObjectRef { * Use AsText if you want to store the text. * \sa AsText. */ -TVM_DLL std::string PrettyPrint(const ObjectRef& node); +TVM_DLL String PrettyPrint(const ObjectRef& node); /*! * \brief Render the node as a string in the text format. @@ -362,8 +362,8 @@ TVM_DLL std::string PrettyPrint(const ObjectRef& node); * \sa PrettyPrint. * \return The text representation. */ -TVM_DLL std::string AsText(const ObjectRef& node, +TVM_DLL String AsText(const ObjectRef& node, bool show_meta_data = true, - runtime::TypedPackedFunc annotate = nullptr); + runtime::TypedPackedFunc annotate = nullptr); } // namespace tvm #endif // TVM_IR_MODULE_H_ diff --git a/src/ir/module.cc b/src/ir/module.cc index 6262150556c7..1be58f3caded 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -40,7 +40,7 @@ namespace tvm { IRModule::IRModule(tvm::Map functions, tvm::Map type_definitions, - std::unordered_set import_set) { + std::unordered_set import_set) { auto n = make_object(); n->functions = std::move(functions); n->type_definitions = std::move(type_definitions); @@ -111,15 +111,15 @@ void IRModuleNode::SHashReduce(SHashReducer hash_reduce) const { reduce_temp(); } -bool IRModuleNode::ContainGlobalVar(const std::string& name) const { +bool IRModuleNode::ContainGlobalVar(const String& name) const { return global_var_map_.find(name) != global_var_map_.end(); } -bool IRModuleNode::ContainGlobalTypeVar(const std::string& name) const { +bool IRModuleNode::ContainGlobalTypeVar(const String& name) const { return global_type_var_map_.find(name) != global_type_var_map_.end(); } -GlobalVar IRModuleNode::GetGlobalVar(const std::string& name) const { +GlobalVar IRModuleNode::GetGlobalVar(const String& name) const { auto it = global_var_map_.find(name); if (it == global_var_map_.end()) { std::ostringstream msg; @@ -146,7 +146,7 @@ tvm::Array IRModuleNode::GetGlobalVars() const { return tvm::Array(global_vars); } -GlobalTypeVar IRModuleNode::GetGlobalTypeVar(const std::string& name) const { +GlobalTypeVar IRModuleNode::GetGlobalTypeVar(const String& name) const { CHECK(global_type_var_map_.defined()); auto it = global_type_var_map_.find(name); CHECK(it != global_type_var_map_.end()) @@ -154,7 +154,7 @@ GlobalTypeVar IRModuleNode::GetGlobalTypeVar(const std::string& name) const { return (*it).second; } -Constructor IRModuleNode::GetConstructor(const std::string& adt, const std::string& cons) const { +Constructor IRModuleNode::GetConstructor(const String& adt, const String& cons) const { TypeData typeDef = this->LookupTypeDef(adt); for (Constructor c : typeDef->constructors) { if (cons.compare(c->name_hint) == 0) { @@ -315,7 +315,7 @@ BaseFunc IRModuleNode::Lookup(const GlobalVar& var) const { return (*it).second; } -BaseFunc IRModuleNode::Lookup(const std::string& name) const { +BaseFunc IRModuleNode::Lookup(const String& name) const { GlobalVar id = this->GetGlobalVar(name); return this->Lookup(id); } @@ -327,7 +327,7 @@ TypeData IRModuleNode::LookupTypeDef(const GlobalTypeVar& var) const { return (*it).second; } -TypeData IRModuleNode::LookupTypeDef(const std::string& name) const { +TypeData IRModuleNode::LookupTypeDef(const String& name) const { GlobalTypeVar id = this->GetGlobalTypeVar(name); return this->LookupTypeDef(id); } @@ -379,7 +379,7 @@ IRModule IRModule::FromExpr( return mod; } -void IRModuleNode::Import(const std::string& path) { +void IRModuleNode::Import(const String& path) { if (this->import_set_.count(path) == 0) { this->import_set_.insert(path); DLOG(INFO) << "Importing: " << path; @@ -392,18 +392,18 @@ void IRModuleNode::Import(const std::string& path) { } } -void IRModuleNode::ImportFromStd(const std::string& path) { +void IRModuleNode::ImportFromStd(const String& path) { auto* f = tvm::runtime::Registry::Get("tvm.relay.std_path"); CHECK(f != nullptr) << "The Relay std_path is not set, please register tvm.relay.std_path."; std::string std_path = (*f)(); - return this->Import(std_path + "/" + path); + this->Import(std_path + "/" + path.operator std::string()); } -std::unordered_set IRModuleNode::Imports() const { +std::unordered_set IRModuleNode::Imports() const { return this->import_set_; } -IRModule IRModule::FromText(const std::string& text, const std::string& source_path) { +IRModule IRModule::FromText(const String& text, const String& source_path) { auto* f = tvm::runtime::Registry::Get("relay.fromtext"); CHECK(f != nullptr) << "The Relay std_path is not set, please register tvm.relay.std_path."; IRModule mod = (*f)(text, source_path); @@ -467,7 +467,7 @@ TVM_REGISTER_GLOBAL("ir.Module_Lookup") }); TVM_REGISTER_GLOBAL("ir.Module_Lookup_str") -.set_body_typed([](IRModule mod, std::string var) { +.set_body_typed([](IRModule mod, String var) { return mod->Lookup(var); }); @@ -477,7 +477,7 @@ TVM_REGISTER_GLOBAL("ir.Module_LookupDef") }); TVM_REGISTER_GLOBAL("ir.Module_LookupDef_str") -.set_body_typed([](IRModule mod, std::string var) { +.set_body_typed([](IRModule mod, String var) { return mod->LookupTypeDef(var); }); @@ -499,12 +499,12 @@ TVM_REGISTER_GLOBAL("ir.Module_Update") }); TVM_REGISTER_GLOBAL("ir.Module_Import") -.set_body_typed([](IRModule mod, std::string path) { +.set_body_typed([](IRModule mod, String path) { mod->Import(path); }); TVM_REGISTER_GLOBAL("ir.Module_ImportFromStd") -.set_body_typed([](IRModule mod, std::string path) { +.set_body_typed([](IRModule mod, String path) { mod->ImportFromStd(path); });; diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index bda997a59d4d..2e675c8ed8f4 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -918,22 +918,28 @@ static const char* kSemVer = "v0.0.4"; // - Implements AsText // - relay_text_printer.cc (specific printing logics for relay) // - tir_text_printer.cc (specific printing logics for TIR) -std::string PrettyPrint(const ObjectRef& node) { +String PrettyPrint(const ObjectRef& node) { Doc doc; doc << relay::RelayTextPrinter(false, nullptr).PrintFinal(node); return doc.str(); } -std::string AsText(const ObjectRef& node, +String AsText(const ObjectRef& node, bool show_meta_data, - runtime::TypedPackedFunc annotate) { + runtime::TypedPackedFunc annotate) { Doc doc; doc << kSemVer << Doc::NewLine(); - doc << relay::RelayTextPrinter(show_meta_data, annotate).PrintFinal(node); + runtime::TypedPackedFunc ftyped = nullptr; + if (annotate != nullptr) { + ftyped = runtime::TypedPackedFunc( + [&annotate](const ObjectRef& expr) -> std::string { + return annotate(expr); + }); + } + doc << relay::RelayTextPrinter(show_meta_data, ftyped).PrintFinal(node); return doc.str(); } - TVM_REGISTER_GLOBAL("ir.PrettyPrint") .set_body_typed(PrettyPrint);