Skip to content

Commit

Permalink
[Refactor][std::string --> String] IRModule is updated with String (a…
Browse files Browse the repository at this point in the history
…pache#5523)

* [std::string --> String] IRModule is updated with String

* [1] Packedfunction updated

* [2] Lint error fixed

* [3] Remove std::string variant
  • Loading branch information
ANSHUMAN TRIPATHY authored and trevor-m committed Jun 18, 2020
1 parent 8b6193c commit 4e80b2d
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 39 deletions.
34 changes: 17 additions & 17 deletions include/tvm/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
#include <tvm/ir/expr.h>
#include <tvm/ir/function.h>
#include <tvm/ir/adt.h>

#include <tvm/node/container.h>
#include <string>
#include <vector>
#include <unordered_map>
Expand Down Expand Up @@ -131,21 +131,21 @@ class IRModuleNode : public Object {
* \param name The variable name.
* \returns true if contains, otherise false.
*/
TVM_DLL bool ContainGlobalVar(const std::string& name) const;
TVM_DLL bool ContainGlobalVar(const String& name) const;

/*!
* \brief Check if the global_type_var_map_ contains a global type variable.
* \param name The variable name.
* \returns true if contains, otherise false.
*/
TVM_DLL bool ContainGlobalTypeVar(const std::string& name) const;
TVM_DLL bool ContainGlobalTypeVar(const String& name) const;

/*!
* \brief Lookup a global function by its variable.
* \param str The unique string specifying the global variable.
* \returns The global variable.
*/
TVM_DLL GlobalVar GetGlobalVar(const std::string& str) const;
TVM_DLL GlobalVar GetGlobalVar(const String& str) const;

/*!
* \brief Collect all global vars defined in this module.
Expand All @@ -158,7 +158,7 @@ class IRModuleNode : public Object {
* \param str The unique string specifying the global variable.
* \returns The global variable.
*/
TVM_DLL GlobalTypeVar GetGlobalTypeVar(const std::string& str) const;
TVM_DLL GlobalTypeVar GetGlobalTypeVar(const String& str) const;

/*!
* \brief Collect all global type vars defined in this module.
Expand All @@ -172,7 +172,7 @@ class IRModuleNode : public Object {
* \param cons name of the constructor
* \returns Constructor of ADT, error if not found
*/
TVM_DLL Constructor GetConstructor(const std::string& adt, const std::string& cons) const;
TVM_DLL Constructor GetConstructor(const String& adt, const String& cons) const;

/*!
* \brief Look up a global function by its variable.
Expand All @@ -186,7 +186,7 @@ class IRModuleNode : public Object {
* \param name The name of the function.
* \returns The function named by the argument.
*/
TVM_DLL BaseFunc Lookup(const std::string& name) const;
TVM_DLL BaseFunc Lookup(const String& name) const;

/*!
* \brief Look up a global type definition by its variable.
Expand All @@ -200,7 +200,7 @@ class IRModuleNode : public Object {
* \param var The name of the global type definition.
* \return The type definition.
*/
TVM_DLL TypeData LookupTypeDef(const std::string& var) const;
TVM_DLL TypeData LookupTypeDef(const String& var) const;

/*!
* \brief Look up a constructor by its tag.
Expand All @@ -225,18 +225,18 @@ class IRModuleNode : public Object {
* relative it will be resovled against the current
* working directory.
*/
TVM_DLL void Import(const std::string& path);
TVM_DLL void Import(const String& path);

/*!
* \brief Import Relay code from the file at path, relative to the standard library.
* \param path The path of the Relay code to import.
*/
TVM_DLL void ImportFromStd(const std::string& path);
TVM_DLL void ImportFromStd(const String& path);

/*!
* \brief The set of imported files.
*/
TVM_DLL std::unordered_set<std::string> Imports() const;
TVM_DLL std::unordered_set<String> Imports() const;

static constexpr const char* _type_key = "IRModule";
static constexpr const bool _type_has_method_sequal_reduce = true;
Expand Down Expand Up @@ -265,7 +265,7 @@ class IRModuleNode : public Object {
/*! \brief The files previously imported, required to ensure
importing is idempotent for each module.
*/
std::unordered_set<std::string> import_set_;
std::unordered_set<String> import_set_;
friend class IRModule;
};

Expand All @@ -283,7 +283,7 @@ class IRModule : public ObjectRef {
*/
TVM_DLL explicit IRModule(Map<GlobalVar, BaseFunc> functions,
Map<GlobalTypeVar, TypeData> type_definitions = {},
std::unordered_set<std::string> import_set = {});
std::unordered_set<String> import_set = {});
/*! \brief default constructor */
IRModule() {}
/*!
Expand Down Expand Up @@ -329,7 +329,7 @@ class IRModule : public ObjectRef {
* \param source_path The path to the source file.
* \return A Relay module.
*/
TVM_DLL static IRModule FromText(const std::string& text, const std::string& source_path);
TVM_DLL static IRModule FromText(const String& text, const String& source_path);

/*! \brief Declare the container type. */
using ContainerType = IRModuleNode;
Expand All @@ -346,7 +346,7 @@ class IRModule : public ObjectRef {
* Use AsText if you want to store the text.
* \sa AsText.
*/
TVM_DLL std::string PrettyPrint(const ObjectRef& node);
TVM_DLL String PrettyPrint(const ObjectRef& node);

/*!
* \brief Render the node as a string in the text format.
Expand All @@ -362,8 +362,8 @@ TVM_DLL std::string PrettyPrint(const ObjectRef& node);
* \sa PrettyPrint.
* \return The text representation.
*/
TVM_DLL std::string AsText(const ObjectRef& node,
TVM_DLL String AsText(const ObjectRef& node,
bool show_meta_data = true,
runtime::TypedPackedFunc<std::string(ObjectRef)> annotate = nullptr);
runtime::TypedPackedFunc<String(ObjectRef)> annotate = nullptr);
} // namespace tvm
#endif // TVM_IR_MODULE_H_
34 changes: 17 additions & 17 deletions src/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ namespace tvm {

IRModule::IRModule(tvm::Map<GlobalVar, BaseFunc> functions,
tvm::Map<GlobalTypeVar, TypeData> type_definitions,
std::unordered_set<std::string> import_set) {
std::unordered_set<String> import_set) {
auto n = make_object<IRModuleNode>();
n->functions = std::move(functions);
n->type_definitions = std::move(type_definitions);
Expand Down Expand Up @@ -111,15 +111,15 @@ void IRModuleNode::SHashReduce(SHashReducer hash_reduce) const {
reduce_temp();
}

bool IRModuleNode::ContainGlobalVar(const std::string& name) const {
bool IRModuleNode::ContainGlobalVar(const String& name) const {
return global_var_map_.find(name) != global_var_map_.end();
}

bool IRModuleNode::ContainGlobalTypeVar(const std::string& name) const {
bool IRModuleNode::ContainGlobalTypeVar(const String& name) const {
return global_type_var_map_.find(name) != global_type_var_map_.end();
}

GlobalVar IRModuleNode::GetGlobalVar(const std::string& name) const {
GlobalVar IRModuleNode::GetGlobalVar(const String& name) const {
auto it = global_var_map_.find(name);
if (it == global_var_map_.end()) {
std::ostringstream msg;
Expand All @@ -146,15 +146,15 @@ tvm::Array<GlobalVar> IRModuleNode::GetGlobalVars() const {
return tvm::Array<GlobalVar>(global_vars);
}

GlobalTypeVar IRModuleNode::GetGlobalTypeVar(const std::string& name) const {
GlobalTypeVar IRModuleNode::GetGlobalTypeVar(const String& name) const {
CHECK(global_type_var_map_.defined());
auto it = global_type_var_map_.find(name);
CHECK(it != global_type_var_map_.end())
<< "Cannot find global type var " << name << " in the Module";
return (*it).second;
}

Constructor IRModuleNode::GetConstructor(const std::string& adt, const std::string& cons) const {
Constructor IRModuleNode::GetConstructor(const String& adt, const String& cons) const {
TypeData typeDef = this->LookupTypeDef(adt);
for (Constructor c : typeDef->constructors) {
if (cons.compare(c->name_hint) == 0) {
Expand Down Expand Up @@ -315,7 +315,7 @@ BaseFunc IRModuleNode::Lookup(const GlobalVar& var) const {
return (*it).second;
}

BaseFunc IRModuleNode::Lookup(const std::string& name) const {
BaseFunc IRModuleNode::Lookup(const String& name) const {
GlobalVar id = this->GetGlobalVar(name);
return this->Lookup(id);
}
Expand All @@ -327,7 +327,7 @@ TypeData IRModuleNode::LookupTypeDef(const GlobalTypeVar& var) const {
return (*it).second;
}

TypeData IRModuleNode::LookupTypeDef(const std::string& name) const {
TypeData IRModuleNode::LookupTypeDef(const String& name) const {
GlobalTypeVar id = this->GetGlobalTypeVar(name);
return this->LookupTypeDef(id);
}
Expand Down Expand Up @@ -379,7 +379,7 @@ IRModule IRModule::FromExpr(
return mod;
}

void IRModuleNode::Import(const std::string& path) {
void IRModuleNode::Import(const String& path) {
if (this->import_set_.count(path) == 0) {
this->import_set_.insert(path);
DLOG(INFO) << "Importing: " << path;
Expand All @@ -392,18 +392,18 @@ void IRModuleNode::Import(const std::string& path) {
}
}

void IRModuleNode::ImportFromStd(const std::string& path) {
void IRModuleNode::ImportFromStd(const String& path) {
auto* f = tvm::runtime::Registry::Get("tvm.relay.std_path");
CHECK(f != nullptr) << "The Relay std_path is not set, please register tvm.relay.std_path.";
std::string std_path = (*f)();
return this->Import(std_path + "/" + path);
this->Import(std_path + "/" + path.operator std::string());
}

std::unordered_set<std::string> IRModuleNode::Imports() const {
std::unordered_set<String> IRModuleNode::Imports() const {
return this->import_set_;
}

IRModule IRModule::FromText(const std::string& text, const std::string& source_path) {
IRModule IRModule::FromText(const String& text, const String& source_path) {
auto* f = tvm::runtime::Registry::Get("relay.fromtext");
CHECK(f != nullptr) << "The Relay std_path is not set, please register tvm.relay.std_path.";
IRModule mod = (*f)(text, source_path);
Expand Down Expand Up @@ -467,7 +467,7 @@ TVM_REGISTER_GLOBAL("ir.Module_Lookup")
});

TVM_REGISTER_GLOBAL("ir.Module_Lookup_str")
.set_body_typed([](IRModule mod, std::string var) {
.set_body_typed([](IRModule mod, String var) {
return mod->Lookup(var);
});

Expand All @@ -477,7 +477,7 @@ TVM_REGISTER_GLOBAL("ir.Module_LookupDef")
});

TVM_REGISTER_GLOBAL("ir.Module_LookupDef_str")
.set_body_typed([](IRModule mod, std::string var) {
.set_body_typed([](IRModule mod, String var) {
return mod->LookupTypeDef(var);
});

Expand All @@ -499,12 +499,12 @@ TVM_REGISTER_GLOBAL("ir.Module_Update")
});

TVM_REGISTER_GLOBAL("ir.Module_Import")
.set_body_typed([](IRModule mod, std::string path) {
.set_body_typed([](IRModule mod, String path) {
mod->Import(path);
});

TVM_REGISTER_GLOBAL("ir.Module_ImportFromStd")
.set_body_typed([](IRModule mod, std::string path) {
.set_body_typed([](IRModule mod, String path) {
mod->ImportFromStd(path);
});;

Expand Down
16 changes: 11 additions & 5 deletions src/printer/relay_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -918,22 +918,28 @@ static const char* kSemVer = "v0.0.4";
// - Implements AsText
// - relay_text_printer.cc (specific printing logics for relay)
// - tir_text_printer.cc (specific printing logics for TIR)
std::string PrettyPrint(const ObjectRef& node) {
String PrettyPrint(const ObjectRef& node) {
Doc doc;
doc << relay::RelayTextPrinter(false, nullptr).PrintFinal(node);
return doc.str();
}

std::string AsText(const ObjectRef& node,
String AsText(const ObjectRef& node,
bool show_meta_data,
runtime::TypedPackedFunc<std::string(ObjectRef)> annotate) {
runtime::TypedPackedFunc<String(ObjectRef)> annotate) {
Doc doc;
doc << kSemVer << Doc::NewLine();
doc << relay::RelayTextPrinter(show_meta_data, annotate).PrintFinal(node);
runtime::TypedPackedFunc<std::string(ObjectRef)> ftyped = nullptr;
if (annotate != nullptr) {
ftyped = runtime::TypedPackedFunc<std::string(ObjectRef)>(
[&annotate](const ObjectRef& expr) -> std::string {
return annotate(expr);
});
}
doc << relay::RelayTextPrinter(show_meta_data, ftyped).PrintFinal(node);
return doc.str();
}


TVM_REGISTER_GLOBAL("ir.PrettyPrint")
.set_body_typed(PrettyPrint);

Expand Down

0 comments on commit 4e80b2d

Please sign in to comment.