Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor][std::string --> String] IRModule is updated with String #5523

Merged
merged 4 commits into from
May 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 17 additions & 17 deletions include/tvm/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
#include <tvm/ir/expr.h>
#include <tvm/ir/function.h>
#include <tvm/ir/adt.h>

#include <tvm/node/container.h>
#include <string>
#include <vector>
#include <unordered_map>
Expand Down Expand Up @@ -131,21 +131,21 @@ class IRModuleNode : public Object {
* \param name The variable name.
* \returns true if contains, otherise false.
*/
TVM_DLL bool ContainGlobalVar(const std::string& name) const;
TVM_DLL bool ContainGlobalVar(const String& name) const;

/*!
* \brief Check if the global_type_var_map_ contains a global type variable.
* \param name The variable name.
* \returns true if contains, otherise false.
*/
TVM_DLL bool ContainGlobalTypeVar(const std::string& name) const;
TVM_DLL bool ContainGlobalTypeVar(const String& name) const;

/*!
* \brief Lookup a global function by its variable.
* \param str The unique string specifying the global variable.
* \returns The global variable.
*/
TVM_DLL GlobalVar GetGlobalVar(const std::string& str) const;
TVM_DLL GlobalVar GetGlobalVar(const String& str) const;

/*!
* \brief Collect all global vars defined in this module.
Expand All @@ -158,7 +158,7 @@ class IRModuleNode : public Object {
* \param str The unique string specifying the global variable.
* \returns The global variable.
*/
TVM_DLL GlobalTypeVar GetGlobalTypeVar(const std::string& str) const;
TVM_DLL GlobalTypeVar GetGlobalTypeVar(const String& str) const;

/*!
* \brief Collect all global type vars defined in this module.
Expand All @@ -172,7 +172,7 @@ class IRModuleNode : public Object {
* \param cons name of the constructor
* \returns Constructor of ADT, error if not found
*/
TVM_DLL Constructor GetConstructor(const std::string& adt, const std::string& cons) const;
TVM_DLL Constructor GetConstructor(const String& adt, const String& cons) const;

/*!
* \brief Look up a global function by its variable.
Expand All @@ -186,7 +186,7 @@ class IRModuleNode : public Object {
* \param name The name of the function.
* \returns The function named by the argument.
*/
TVM_DLL BaseFunc Lookup(const std::string& name) const;
TVM_DLL BaseFunc Lookup(const String& name) const;

/*!
* \brief Look up a global type definition by its variable.
Expand All @@ -200,7 +200,7 @@ class IRModuleNode : public Object {
* \param var The name of the global type definition.
* \return The type definition.
*/
TVM_DLL TypeData LookupTypeDef(const std::string& var) const;
TVM_DLL TypeData LookupTypeDef(const String& var) const;

/*!
* \brief Look up a constructor by its tag.
Expand All @@ -225,18 +225,18 @@ class IRModuleNode : public Object {
* relative it will be resovled against the current
* working directory.
*/
TVM_DLL void Import(const std::string& path);
TVM_DLL void Import(const String& path);

/*!
* \brief Import Relay code from the file at path, relative to the standard library.
* \param path The path of the Relay code to import.
*/
TVM_DLL void ImportFromStd(const std::string& path);
TVM_DLL void ImportFromStd(const String& path);

/*!
* \brief The set of imported files.
*/
TVM_DLL std::unordered_set<std::string> Imports() const;
TVM_DLL std::unordered_set<String> Imports() const;

static constexpr const char* _type_key = "IRModule";
static constexpr const bool _type_has_method_sequal_reduce = true;
Expand Down Expand Up @@ -265,7 +265,7 @@ class IRModuleNode : public Object {
/*! \brief The files previously imported, required to ensure
importing is idempotent for each module.
*/
std::unordered_set<std::string> import_set_;
std::unordered_set<String> import_set_;
friend class IRModule;
};

Expand All @@ -283,7 +283,7 @@ class IRModule : public ObjectRef {
*/
TVM_DLL explicit IRModule(Map<GlobalVar, BaseFunc> functions,
Map<GlobalTypeVar, TypeData> type_definitions = {},
std::unordered_set<std::string> import_set = {});
std::unordered_set<String> import_set = {});
/*! \brief default constructor */
IRModule() {}
/*!
Expand Down Expand Up @@ -329,7 +329,7 @@ class IRModule : public ObjectRef {
* \param source_path The path to the source file.
* \return A Relay module.
*/
TVM_DLL static IRModule FromText(const std::string& text, const std::string& source_path);
TVM_DLL static IRModule FromText(const String& text, const String& source_path);

/*! \brief Declare the container type. */
using ContainerType = IRModuleNode;
Expand All @@ -346,7 +346,7 @@ class IRModule : public ObjectRef {
* Use AsText if you want to store the text.
* \sa AsText.
*/
TVM_DLL std::string PrettyPrint(const ObjectRef& node);
TVM_DLL String PrettyPrint(const ObjectRef& node);

/*!
* \brief Render the node as a string in the text format.
Expand All @@ -362,8 +362,8 @@ TVM_DLL std::string PrettyPrint(const ObjectRef& node);
* \sa PrettyPrint.
* \return The text representation.
*/
TVM_DLL std::string AsText(const ObjectRef& node,
TVM_DLL String AsText(const ObjectRef& node,
bool show_meta_data = true,
runtime::TypedPackedFunc<std::string(ObjectRef)> annotate = nullptr);
runtime::TypedPackedFunc<String(ObjectRef)> annotate = nullptr);
} // namespace tvm
#endif // TVM_IR_MODULE_H_
34 changes: 17 additions & 17 deletions src/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ namespace tvm {

IRModule::IRModule(tvm::Map<GlobalVar, BaseFunc> functions,
tvm::Map<GlobalTypeVar, TypeData> type_definitions,
std::unordered_set<std::string> import_set) {
std::unordered_set<String> import_set) {
auto n = make_object<IRModuleNode>();
n->functions = std::move(functions);
n->type_definitions = std::move(type_definitions);
Expand Down Expand Up @@ -111,15 +111,15 @@ void IRModuleNode::SHashReduce(SHashReducer hash_reduce) const {
reduce_temp();
}

bool IRModuleNode::ContainGlobalVar(const std::string& name) const {
bool IRModuleNode::ContainGlobalVar(const String& name) const {
return global_var_map_.find(name) != global_var_map_.end();
}

bool IRModuleNode::ContainGlobalTypeVar(const std::string& name) const {
bool IRModuleNode::ContainGlobalTypeVar(const String& name) const {
return global_type_var_map_.find(name) != global_type_var_map_.end();
}

GlobalVar IRModuleNode::GetGlobalVar(const std::string& name) const {
GlobalVar IRModuleNode::GetGlobalVar(const String& name) const {
auto it = global_var_map_.find(name);
if (it == global_var_map_.end()) {
std::ostringstream msg;
Expand All @@ -146,15 +146,15 @@ tvm::Array<GlobalVar> IRModuleNode::GetGlobalVars() const {
return tvm::Array<GlobalVar>(global_vars);
}

GlobalTypeVar IRModuleNode::GetGlobalTypeVar(const std::string& name) const {
GlobalTypeVar IRModuleNode::GetGlobalTypeVar(const String& name) const {
CHECK(global_type_var_map_.defined());
auto it = global_type_var_map_.find(name);
CHECK(it != global_type_var_map_.end())
<< "Cannot find global type var " << name << " in the Module";
return (*it).second;
}

Constructor IRModuleNode::GetConstructor(const std::string& adt, const std::string& cons) const {
Constructor IRModuleNode::GetConstructor(const String& adt, const String& cons) const {
TypeData typeDef = this->LookupTypeDef(adt);
for (Constructor c : typeDef->constructors) {
if (cons.compare(c->name_hint) == 0) {
Expand Down Expand Up @@ -315,7 +315,7 @@ BaseFunc IRModuleNode::Lookup(const GlobalVar& var) const {
return (*it).second;
}

BaseFunc IRModuleNode::Lookup(const std::string& name) const {
BaseFunc IRModuleNode::Lookup(const String& name) const {
GlobalVar id = this->GetGlobalVar(name);
return this->Lookup(id);
}
Expand All @@ -327,7 +327,7 @@ TypeData IRModuleNode::LookupTypeDef(const GlobalTypeVar& var) const {
return (*it).second;
}

TypeData IRModuleNode::LookupTypeDef(const std::string& name) const {
TypeData IRModuleNode::LookupTypeDef(const String& name) const {
GlobalTypeVar id = this->GetGlobalTypeVar(name);
return this->LookupTypeDef(id);
}
Expand Down Expand Up @@ -379,7 +379,7 @@ IRModule IRModule::FromExpr(
return mod;
}

void IRModuleNode::Import(const std::string& path) {
void IRModuleNode::Import(const String& path) {
if (this->import_set_.count(path) == 0) {
this->import_set_.insert(path);
DLOG(INFO) << "Importing: " << path;
Expand All @@ -392,18 +392,18 @@ void IRModuleNode::Import(const std::string& path) {
}
}

void IRModuleNode::ImportFromStd(const std::string& path) {
void IRModuleNode::ImportFromStd(const String& path) {
auto* f = tvm::runtime::Registry::Get("tvm.relay.std_path");
CHECK(f != nullptr) << "The Relay std_path is not set, please register tvm.relay.std_path.";
std::string std_path = (*f)();
return this->Import(std_path + "/" + path);
this->Import(std_path + "/" + path.operator std::string());
}

std::unordered_set<std::string> IRModuleNode::Imports() const {
std::unordered_set<String> IRModuleNode::Imports() const {
return this->import_set_;
}

IRModule IRModule::FromText(const std::string& text, const std::string& source_path) {
IRModule IRModule::FromText(const String& text, const String& source_path) {
auto* f = tvm::runtime::Registry::Get("relay.fromtext");
CHECK(f != nullptr) << "The Relay std_path is not set, please register tvm.relay.std_path.";
IRModule mod = (*f)(text, source_path);
Expand Down Expand Up @@ -467,7 +467,7 @@ TVM_REGISTER_GLOBAL("ir.Module_Lookup")
});

TVM_REGISTER_GLOBAL("ir.Module_Lookup_str")
.set_body_typed([](IRModule mod, std::string var) {
.set_body_typed([](IRModule mod, String var) {
return mod->Lookup(var);
});

Expand All @@ -477,7 +477,7 @@ TVM_REGISTER_GLOBAL("ir.Module_LookupDef")
});

TVM_REGISTER_GLOBAL("ir.Module_LookupDef_str")
.set_body_typed([](IRModule mod, std::string var) {
.set_body_typed([](IRModule mod, String var) {
return mod->LookupTypeDef(var);
});

Expand All @@ -499,12 +499,12 @@ TVM_REGISTER_GLOBAL("ir.Module_Update")
});

TVM_REGISTER_GLOBAL("ir.Module_Import")
.set_body_typed([](IRModule mod, std::string path) {
.set_body_typed([](IRModule mod, String path) {
mod->Import(path);
});

TVM_REGISTER_GLOBAL("ir.Module_ImportFromStd")
.set_body_typed([](IRModule mod, std::string path) {
.set_body_typed([](IRModule mod, String path) {
mod->ImportFromStd(path);
});;

Expand Down
16 changes: 11 additions & 5 deletions src/printer/relay_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -918,22 +918,28 @@ static const char* kSemVer = "v0.0.4";
// - Implements AsText
// - relay_text_printer.cc (specific printing logics for relay)
// - tir_text_printer.cc (specific printing logics for TIR)
std::string PrettyPrint(const ObjectRef& node) {
String PrettyPrint(const ObjectRef& node) {
Doc doc;
doc << relay::RelayTextPrinter(false, nullptr).PrintFinal(node);
return doc.str();
}

std::string AsText(const ObjectRef& node,
String AsText(const ObjectRef& node,
bool show_meta_data,
runtime::TypedPackedFunc<std::string(ObjectRef)> annotate) {
runtime::TypedPackedFunc<String(ObjectRef)> annotate) {
Doc doc;
doc << kSemVer << Doc::NewLine();
doc << relay::RelayTextPrinter(show_meta_data, annotate).PrintFinal(node);
runtime::TypedPackedFunc<std::string(ObjectRef)> ftyped = nullptr;
if (annotate != nullptr) {
ftyped = runtime::TypedPackedFunc<std::string(ObjectRef)>(
[&annotate](const ObjectRef& expr) -> std::string {
return annotate(expr);
});
}
doc << relay::RelayTextPrinter(show_meta_data, ftyped).PrintFinal(node);
return doc.str();
}


TVM_REGISTER_GLOBAL("ir.PrettyPrint")
.set_body_typed(PrettyPrint);

Expand Down