Skip to content

Commit

Permalink
[Refactor][std::string --> String] IR is updated with String (#5547)
Browse files Browse the repository at this point in the history
* [std::string --> String] GlobalTypeVar is updated with String

* [std::string --> String] GlobalVar is updated with String

* [std::string --> String][IR] ADT is updated with String

* [std::string --> String][IR] OP is updated with String

* [std::string --> String][IR] Attrs is updated with String input

* [std::string --> String][IR] GlobalVar is updated with String

* [std::string --> String][Test] Pyconverter is updated with String change
  • Loading branch information
ANSHUMAN TRIPATHY authored May 11, 2020
1 parent ad2ee97 commit 1a0f44d
Show file tree
Hide file tree
Showing 17 changed files with 51 additions and 39 deletions.
2 changes: 1 addition & 1 deletion include/tvm/ir/adt.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class Constructor : public RelayExpr {
* \param inputs The input types.
* \param belong_to The data type var the constructor will construct.
*/
TVM_DLL Constructor(std::string name_hint, Array<Type> inputs, GlobalTypeVar belong_to);
TVM_DLL Constructor(String name_hint, Array<Type> inputs, GlobalTypeVar belong_to);

TVM_DEFINE_OBJECT_REF_METHODS(Constructor, RelayExpr, ConstructorNode);
};
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/ir/env_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class EnvFunc : public ObjectRef {
* \return The created global function.
* \note The function can be unique
*/
TVM_DLL static EnvFunc Get(const std::string& name);
TVM_DLL static EnvFunc Get(const String& name);
/*! \brief specify container node */
using ContainerType = EnvFuncNode;
};
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ class GlobalVar;
class GlobalVarNode : public RelayExprNode {
public:
/*! \brief The name of the variable, this only acts as a hint. */
std::string name_hint;
String name_hint;

void VisitAttrs(AttrVisitor* v) {
v->Visit("name_hint", &name_hint);
Expand Down Expand Up @@ -216,7 +216,7 @@ class GlobalVarNode : public RelayExprNode {
*/
class GlobalVar : public RelayExpr {
public:
TVM_DLL explicit GlobalVar(std::string name_hint);
TVM_DLL explicit GlobalVar(String name_hint);

TVM_DEFINE_OBJECT_REF_METHODS(GlobalVar, RelayExpr, GlobalVarNode);
};
Expand Down
9 changes: 5 additions & 4 deletions include/tvm/ir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ class Op : public RelayExpr {
* \param op_name Name of the operator.
* \return Pointer to a Op, valid throughout program lifetime.
*/
TVM_DLL static const Op& Get(const std::string& op_name);
TVM_DLL static const Op& Get(const String& op_name);

/*! \brief specify container node */
using ContainerType = OpNode;
Expand All @@ -196,13 +196,13 @@ class Op : public RelayExpr {
* \param key The attribute key
* \return reference to GenericOpMap
*/
TVM_DLL static const GenericOpMap& GetGenericAttr(const std::string& key);
TVM_DLL static const GenericOpMap& GetGenericAttr(const String& key);
/*!
* \brief Checks if the key is present in the registry
* \param key The attribute key
* \return bool True if the key is present
*/
TVM_DLL static bool HasGenericAttr(const std::string& key);
TVM_DLL static bool HasGenericAttr(const String& key);
};

/*!
Expand Down Expand Up @@ -303,7 +303,8 @@ class OpRegistry {
// return internal pointer to op.
inline OpNode* get();
// update the attribute OpMap
TVM_DLL void UpdateAttr(const std::string& key, runtime::TVMRetValue value, int plevel);

TVM_DLL void UpdateAttr(const String& key, runtime::TVMRetValue value, int plevel);
};

/*!
Expand Down
8 changes: 4 additions & 4 deletions include/tvm/ir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ class PassInfo : public ObjectRef {
* \param name Name of the pass.
* \param required The passes that are required to perform the current pass.
*/
TVM_DLL PassInfo(int opt_level, std::string name, Array<runtime::String> required);
TVM_DLL PassInfo(int opt_level, String name, Array<runtime::String> required);

TVM_DEFINE_OBJECT_REF_METHODS(PassInfo, ObjectRef, PassInfoNode);
};
Expand Down Expand Up @@ -327,7 +327,7 @@ class Sequential : public Pass {
* This allows users to only provide a list of passes and execute them
* under a given context.
*/
TVM_DLL Sequential(Array<Pass> passes, std::string name = "sequential");
TVM_DLL Sequential(Array<Pass> passes, String name = "sequential");

Sequential() = default;
explicit Sequential(ObjectPtr<Object> n) : Pass(n) {}
Expand All @@ -348,15 +348,15 @@ class Sequential : public Pass {
*/
TVM_DLL Pass
CreateModulePass(const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func,
int opt_level, const std::string& name, const Array<runtime::String>& required);
int opt_level, const String& name, const Array<runtime::String>& required);

/*!
* \brief A special trace pass that prints the header and IR to LOG(INFO).
* \param header The header to be attached to the output.
* \param show_meta_data Whether should we show meta data.
* \return The pass.
*/
TVM_DLL Pass PrintIR(std::string header = "", bool show_meta_data = false);
TVM_DLL Pass PrintIR(String header = "", bool show_meta_data = false);

} // namespace transform
} // namespace tvm
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/ir/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ class GlobalTypeVarNode : public TypeNode {
* this only acts as a hint to the user,
* and is not used for equality.
*/
std::string name_hint;
String name_hint;
/*! \brief The kind of type parameter */
TypeKind kind;

Expand Down Expand Up @@ -301,7 +301,7 @@ class GlobalTypeVar : public Type {
* \param name_hint The name of the type var.
* \param kind The kind of the type var.
*/
TVM_DLL GlobalTypeVar(std::string name_hint, TypeKind kind);
TVM_DLL GlobalTypeVar(String name_hint, TypeKind kind);

TVM_DEFINE_OBJECT_REF_METHODS(GlobalTypeVar, Type, GlobalTypeVarNode);
};
Expand Down
9 changes: 9 additions & 0 deletions include/tvm/runtime/container.h
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,15 @@ inline String String::operator=(std::string other) {
return Downcast<String>(*this);
}

inline String operator+(const std::string lhs, const String& rhs) {
return lhs + rhs.operator std::string();
}

inline std::ostream& operator<<(std::ostream& out, const String& input) {
out.write(input.data(), input.size());
return out;
}

inline int String::memncmp(const char* lhs, const char* rhs, size_t lhs_count, size_t rhs_count) {
if (lhs == rhs && lhs_count == rhs_count) return 0;

Expand Down
4 changes: 2 additions & 2 deletions python/tvm/ir/json_compact.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def _convert(item, nodes):
"EnvFunc": _update_global_key,
"relay.Op": _update_global_key,
"relay.TypeVar": [_ftype_var, _update_from_std_str("name_hint")],
"relay.GlobalTypeVar": _ftype_var,
"relay.GlobalTypeVar": [_ftype_var, _update_from_std_str("name_hint")],
"relay.Type": _rename("Type"),
"relay.TupleType": _rename("TupleType"),
"relay.TypeConstraint": _rename("TypeConstraint"),
Expand All @@ -122,7 +122,7 @@ def _convert(item, nodes):
"relay.Module": _rename("IRModule"),
"relay.SourceName": _rename("SourceName"),
"relay.Span": _rename("Span"),
"relay.GlobalVar": _rename("GlobalVar"),
"relay.GlobalVar": [_rename("GlobalVar"), _update_from_std_str("name_hint")],
"relay.Pass": _rename("transform.Pass"),
"relay.PassInfo": _rename("transform.PassInfo"),
"relay.PassContext": _rename("transform.PassContext"),
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/testing/py_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def convert_func_node(self, func: Function, name_var=None):
if name_var is None:
func_name = self.generate_function_name('_anon_func')
if isinstance(name_var, GlobalVar):
func_name = name_var.name_hint
func_name = str(name_var.name_hint)
if isinstance(name_var, Var):
func_name = self.get_var_name(name_var)

Expand Down Expand Up @@ -411,7 +411,7 @@ def visit_var(self, var: Expr):
def visit_global_var(self, gvar: Expr):
# we don't need to add numbers to global var names because
# the *names* are checked for uniqueness in the mod
return (Name(gvar.name_hint, Load()), [])
return (Name(str(gvar.name_hint), Load()), [])


def visit_let(self, letexp: Expr):
Expand Down
4 changes: 2 additions & 2 deletions src/ir/adt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

namespace tvm {

Constructor::Constructor(std::string name_hint, tvm::Array<Type> inputs, GlobalTypeVar belong_to) {
Constructor::Constructor(String name_hint, tvm::Array<Type> inputs, GlobalTypeVar belong_to) {
ObjectPtr<ConstructorNode> n = make_object<ConstructorNode>();
n->name_hint = std::move(name_hint);
n->inputs = std::move(inputs);
Expand All @@ -37,7 +37,7 @@ Constructor::Constructor(std::string name_hint, tvm::Array<Type> inputs, GlobalT
TVM_REGISTER_NODE_TYPE(ConstructorNode);

TVM_REGISTER_GLOBAL("ir.Constructor")
.set_body_typed([](std::string name_hint, tvm::Array<Type> inputs, GlobalTypeVar belong_to) {
.set_body_typed([](String name_hint, tvm::Array<Type> inputs, GlobalTypeVar belong_to) {
return Constructor(name_hint, inputs, belong_to);
});

Expand Down
2 changes: 1 addition & 1 deletion src/ir/env_func.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ ObjectPtr<Object> CreateEnvNode(const std::string& name) {
return n;
}

EnvFunc EnvFunc::Get(const std::string& name) { return EnvFunc(CreateEnvNode(name)); }
EnvFunc EnvFunc::Get(const String& name) { return EnvFunc(CreateEnvNode(name)); }

TVM_REGISTER_GLOBAL("ir.EnvFuncGet").set_body_typed(EnvFunc::Get);

Expand Down
4 changes: 2 additions & 2 deletions src/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,15 +137,15 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')';
});

GlobalVar::GlobalVar(std::string name_hint) {
GlobalVar::GlobalVar(String name_hint) {
ObjectPtr<GlobalVarNode> n = make_object<GlobalVarNode>();
n->name_hint = std::move(name_hint);
data_ = std::move(n);
}

TVM_REGISTER_NODE_TYPE(GlobalVarNode);

TVM_REGISTER_GLOBAL("ir.GlobalVar").set_body_typed([](std::string name) {
TVM_REGISTER_GLOBAL("ir.GlobalVar").set_body_typed([](String name) {
return GlobalVar(name);
});

Expand Down
2 changes: 1 addition & 1 deletion src/ir/function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ TVM_REGISTER_GLOBAL("ir.BaseFunc_Attrs").set_body_typed([](BaseFunc func) { retu
TVM_REGISTER_GLOBAL("ir.BaseFuncCopy").set_body_typed([](BaseFunc func) { return func; });

TVM_REGISTER_GLOBAL("ir.BaseFuncWithAttr")
.set_body_typed([](BaseFunc func, std::string key, ObjectRef value) -> BaseFunc {
.set_body_typed([](BaseFunc func, String key, ObjectRef value) -> BaseFunc {
if (func->IsInstance<tir::PrimFuncNode>()) {
return WithAttr(Downcast<tir::PrimFunc>(std::move(func)), key, value);
} else if (func->IsInstance<relay::FunctionNode>()) {
Expand Down
10 changes: 5 additions & 5 deletions src/ir/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ struct OpManager {
};

// find operator by name
const Op& Op::Get(const std::string& name) {
const Op& Op::Get(const String& name) {
const OpRegistry* reg = dmlc::Registry<OpRegistry>::Find(name);
CHECK(reg != nullptr) << "Operator " << name << " is not registered";
return reg->op();
Expand All @@ -75,7 +75,7 @@ OpRegistry::OpRegistry() {
}

// Get attribute map by key
const GenericOpMap& Op::GetGenericAttr(const std::string& key) {
const GenericOpMap& Op::GetGenericAttr(const String& key) {
OpManager* mgr = OpManager::Global();
std::lock_guard<std::mutex> lock(mgr->mutex);
auto it = mgr->attr.find(key);
Expand All @@ -86,7 +86,7 @@ const GenericOpMap& Op::GetGenericAttr(const std::string& key) {
}

// Check if a key is present in the registry.
bool Op::HasGenericAttr(const std::string& key) {
bool Op::HasGenericAttr(const String& key) {
OpManager* mgr = OpManager::Global();
std::lock_guard<std::mutex> lock(mgr->mutex);
auto it = mgr->attr.find(key);
Expand All @@ -110,7 +110,7 @@ void OpRegistry::reset_attr(const std::string& key) {
}
}

void OpRegistry::UpdateAttr(const std::string& key, TVMRetValue value, int plevel) {
void OpRegistry::UpdateAttr(const String& key, TVMRetValue value, int plevel) {
OpManager* mgr = OpManager::Global();
std::lock_guard<std::mutex> lock(mgr->mutex);
std::unique_ptr<GenericOpMap>& op_map = mgr->attr[key];
Expand Down Expand Up @@ -141,7 +141,7 @@ TVM_REGISTER_GLOBAL("relay.op._ListOpNames").set_body_typed([]() {
return ret;
});

TVM_REGISTER_GLOBAL("relay.op._GetOp").set_body_typed([](std::string name) -> Op {
TVM_REGISTER_GLOBAL("relay.op._GetOp").set_body_typed([](String name) -> Op {
return Op::Get(name);
});

Expand Down
14 changes: 7 additions & 7 deletions src/ir/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ class SequentialNode : public PassNode {
TVM_DECLARE_FINAL_OBJECT_INFO(SequentialNode, PassNode);
};

PassInfo::PassInfo(int opt_level, std::string name, tvm::Array<runtime::String> required) {
PassInfo::PassInfo(int opt_level, String name, tvm::Array<runtime::String> required) {
auto pass_info = make_object<PassInfoNode>();
pass_info->opt_level = opt_level;
pass_info->name = std::move(name);
Expand Down Expand Up @@ -238,7 +238,7 @@ Sequential::Sequential(tvm::Array<Pass> passes, PassInfo pass_info) {
data_ = std::move(n);
}

Sequential::Sequential(tvm::Array<Pass> passes, std::string name) {
Sequential::Sequential(tvm::Array<Pass> passes, String name) {
auto n = make_object<SequentialNode>();
n->passes = std::move(passes);
PassInfo pass_info = PassInfo(2, std::move(name), {});
Expand Down Expand Up @@ -282,10 +282,10 @@ bool SequentialNode::PassEnabled(const PassInfo& info) const {
return ctx->opt_level >= info->opt_level;
}

Pass GetPass(const std::string& pass_name) {
Pass GetPass(const String& pass_name) {
using tvm::runtime::Registry;
const runtime::PackedFunc* f = nullptr;
if (pass_name.find("transform.") != std::string::npos) {
if (pass_name.operator std::string().find("transform.") != std::string::npos) {
f = Registry::Get(pass_name);
} else if ((f = Registry::Get("transform." + pass_name))) {
// pass
Expand Down Expand Up @@ -313,7 +313,7 @@ IRModule SequentialNode::operator()(IRModule mod, const PassContext& pass_ctx) c
}

Pass CreateModulePass(const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func,
int opt_level, const std::string& name,
int opt_level, const String& name,
const tvm::Array<runtime::String>& required) {
PassInfo pass_info = PassInfo(opt_level, name, required);
return ModulePass(pass_func, pass_info);
Expand All @@ -322,7 +322,7 @@ Pass CreateModulePass(const runtime::TypedPackedFunc<IRModule(IRModule, PassCont
TVM_REGISTER_NODE_TYPE(PassInfoNode);

TVM_REGISTER_GLOBAL("transform.PassInfo")
.set_body_typed([](int opt_level, std::string name, tvm::Array<runtime::String> required) {
.set_body_typed([](int opt_level, String name, tvm::Array<runtime::String> required) {
return PassInfo(opt_level, name, required);
});

Expand Down Expand Up @@ -439,7 +439,7 @@ TVM_REGISTER_GLOBAL("transform.EnterPassContext").set_body_typed(PassContext::In

TVM_REGISTER_GLOBAL("transform.ExitPassContext").set_body_typed(PassContext::Internal::ExitScope);

Pass PrintIR(std::string header, bool show_meta_data) {
Pass PrintIR(String header, bool show_meta_data) {
auto pass_func = [header, show_meta_data](IRModule mod, const PassContext& ctx) {
LOG(INFO) << "PrintIR(" << header << "):\n" << AsText(mod, show_meta_data);
return mod;
Expand Down
4 changes: 2 additions & 2 deletions src/ir/type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << "TypeVar(" << node->name_hint << ", " << node->kind << ")";
});

GlobalTypeVar::GlobalTypeVar(std::string name, TypeKind kind) {
GlobalTypeVar::GlobalTypeVar(String name, TypeKind kind) {
ObjectPtr<GlobalTypeVarNode> n = make_object<GlobalTypeVarNode>();
n->name_hint = std::move(name);
n->kind = std::move(kind);
Expand All @@ -90,7 +90,7 @@ GlobalTypeVar::GlobalTypeVar(std::string name, TypeKind kind) {

TVM_REGISTER_NODE_TYPE(GlobalTypeVarNode);

TVM_REGISTER_GLOBAL("ir.GlobalTypeVar").set_body_typed([](std::string name, int kind) {
TVM_REGISTER_GLOBAL("ir.GlobalTypeVar").set_body_typed([](String name, int kind) {
return GlobalTypeVar(name, static_cast<TypeKind>(kind));
});

Expand Down
4 changes: 3 additions & 1 deletion src/printer/relay_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,9 @@ Doc RelayTextPrinter::VisitExpr_(const FunctionNode* op) {
return PrintFunc(Doc::Text("fn "), GetRef<Function>(op));
}

Doc RelayTextPrinter::VisitExpr_(const GlobalVarNode* op) { return Doc::Text('@' + op->name_hint); }
Doc RelayTextPrinter::VisitExpr_(const GlobalVarNode* op) {
return Doc::Text('@' + op->name_hint.operator std::string());
}

Doc RelayTextPrinter::VisitExpr_(const OpNode* op) { return Doc::Text(op->name); }

Expand Down

0 comments on commit 1a0f44d

Please sign in to comment.