diff --git a/include/tvm/ir/adt.h b/include/tvm/ir/adt.h index 9d45dc10800e..9b45c66dc76b 100644 --- a/include/tvm/ir/adt.h +++ b/include/tvm/ir/adt.h @@ -91,7 +91,7 @@ class Constructor : public RelayExpr { * \param inputs The input types. * \param belong_to The data type var the constructor will construct. */ - TVM_DLL Constructor(std::string name_hint, Array inputs, GlobalTypeVar belong_to); + TVM_DLL Constructor(String name_hint, Array inputs, GlobalTypeVar belong_to); TVM_DEFINE_OBJECT_REF_METHODS(Constructor, RelayExpr, ConstructorNode); }; diff --git a/include/tvm/ir/env_func.h b/include/tvm/ir/env_func.h index 320d6e38e610..2f803672a20b 100644 --- a/include/tvm/ir/env_func.h +++ b/include/tvm/ir/env_func.h @@ -92,7 +92,7 @@ class EnvFunc : public ObjectRef { * \return The created global function. * \note The function can be unique */ - TVM_DLL static EnvFunc Get(const std::string& name); + TVM_DLL static EnvFunc Get(const String& name); /*! \brief specify container node */ using ContainerType = EnvFuncNode; }; diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 717ffb1b4826..6797f16c3829 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -188,7 +188,7 @@ class GlobalVar; class GlobalVarNode : public RelayExprNode { public: /*! \brief The name of the variable, this only acts as a hint. */ - std::string name_hint; + String name_hint; void VisitAttrs(AttrVisitor* v) { v->Visit("name_hint", &name_hint); @@ -216,7 +216,7 @@ class GlobalVarNode : public RelayExprNode { */ class GlobalVar : public RelayExpr { public: - TVM_DLL explicit GlobalVar(std::string name_hint); + TVM_DLL explicit GlobalVar(String name_hint); TVM_DEFINE_OBJECT_REF_METHODS(GlobalVar, RelayExpr, GlobalVarNode); }; diff --git a/include/tvm/ir/op.h b/include/tvm/ir/op.h index 7fafb5a69421..aeda4fa7dd2b 100644 --- a/include/tvm/ir/op.h +++ b/include/tvm/ir/op.h @@ -185,7 +185,7 @@ class Op : public RelayExpr { * \param op_name Name of the operator. * \return Pointer to a Op, valid throughout program lifetime. */ - TVM_DLL static const Op& Get(const std::string& op_name); + TVM_DLL static const Op& Get(const String& op_name); /*! \brief specify container node */ using ContainerType = OpNode; @@ -196,13 +196,13 @@ class Op : public RelayExpr { * \param key The attribute key * \return reference to GenericOpMap */ - TVM_DLL static const GenericOpMap& GetGenericAttr(const std::string& key); + TVM_DLL static const GenericOpMap& GetGenericAttr(const String& key); /*! * \brief Checks if the key is present in the registry * \param key The attribute key * \return bool True if the key is present */ - TVM_DLL static bool HasGenericAttr(const std::string& key); + TVM_DLL static bool HasGenericAttr(const String& key); }; /*! @@ -303,7 +303,8 @@ class OpRegistry { // return internal pointer to op. inline OpNode* get(); // update the attribute OpMap - TVM_DLL void UpdateAttr(const std::string& key, runtime::TVMRetValue value, int plevel); + + TVM_DLL void UpdateAttr(const String& key, runtime::TVMRetValue value, int plevel); }; /*! diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h index 558d2da79361..a825b95294e9 100644 --- a/include/tvm/ir/transform.h +++ b/include/tvm/ir/transform.h @@ -224,7 +224,7 @@ class PassInfo : public ObjectRef { * \param name Name of the pass. * \param required The passes that are required to perform the current pass. */ - TVM_DLL PassInfo(int opt_level, std::string name, Array required); + TVM_DLL PassInfo(int opt_level, String name, Array required); TVM_DEFINE_OBJECT_REF_METHODS(PassInfo, ObjectRef, PassInfoNode); }; @@ -327,7 +327,7 @@ class Sequential : public Pass { * This allows users to only provide a list of passes and execute them * under a given context. */ - TVM_DLL Sequential(Array passes, std::string name = "sequential"); + TVM_DLL Sequential(Array passes, String name = "sequential"); Sequential() = default; explicit Sequential(ObjectPtr n) : Pass(n) {} @@ -348,7 +348,7 @@ class Sequential : public Pass { */ TVM_DLL Pass CreateModulePass(const runtime::TypedPackedFunc& pass_func, - int opt_level, const std::string& name, const Array& required); + int opt_level, const String& name, const Array& required); /*! * \brief A special trace pass that prints the header and IR to LOG(INFO). @@ -356,7 +356,7 @@ CreateModulePass(const runtime::TypedPackedFunc * \param show_meta_data Whether should we show meta data. * \return The pass. */ -TVM_DLL Pass PrintIR(std::string header = "", bool show_meta_data = false); +TVM_DLL Pass PrintIR(String header = "", bool show_meta_data = false); } // namespace transform } // namespace tvm diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h index ed648411266c..65b454f08b52 100644 --- a/include/tvm/ir/type.h +++ b/include/tvm/ir/type.h @@ -267,7 +267,7 @@ class GlobalTypeVarNode : public TypeNode { * this only acts as a hint to the user, * and is not used for equality. */ - std::string name_hint; + String name_hint; /*! \brief The kind of type parameter */ TypeKind kind; @@ -301,7 +301,7 @@ class GlobalTypeVar : public Type { * \param name_hint The name of the type var. * \param kind The kind of the type var. */ - TVM_DLL GlobalTypeVar(std::string name_hint, TypeKind kind); + TVM_DLL GlobalTypeVar(String name_hint, TypeKind kind); TVM_DEFINE_OBJECT_REF_METHODS(GlobalTypeVar, Type, GlobalTypeVarNode); }; diff --git a/include/tvm/runtime/container.h b/include/tvm/runtime/container.h index 49c005e36d7d..e2f2453933a5 100644 --- a/include/tvm/runtime/container.h +++ b/include/tvm/runtime/container.h @@ -564,6 +564,15 @@ inline String String::operator=(std::string other) { return Downcast(*this); } +inline String operator+(const std::string lhs, const String& rhs) { + return lhs + rhs.operator std::string(); +} + +inline std::ostream& operator<<(std::ostream& out, const String& input) { + out.write(input.data(), input.size()); + return out; +} + inline int String::memncmp(const char* lhs, const char* rhs, size_t lhs_count, size_t rhs_count) { if (lhs == rhs && lhs_count == rhs_count) return 0; diff --git a/python/tvm/ir/json_compact.py b/python/tvm/ir/json_compact.py index fcea9d821222..a3ff499ad9d9 100644 --- a/python/tvm/ir/json_compact.py +++ b/python/tvm/ir/json_compact.py @@ -111,7 +111,7 @@ def _convert(item, nodes): "EnvFunc": _update_global_key, "relay.Op": _update_global_key, "relay.TypeVar": [_ftype_var, _update_from_std_str("name_hint")], - "relay.GlobalTypeVar": _ftype_var, + "relay.GlobalTypeVar": [_ftype_var, _update_from_std_str("name_hint")], "relay.Type": _rename("Type"), "relay.TupleType": _rename("TupleType"), "relay.TypeConstraint": _rename("TypeConstraint"), @@ -122,7 +122,7 @@ def _convert(item, nodes): "relay.Module": _rename("IRModule"), "relay.SourceName": _rename("SourceName"), "relay.Span": _rename("Span"), - "relay.GlobalVar": _rename("GlobalVar"), + "relay.GlobalVar": [_rename("GlobalVar"), _update_from_std_str("name_hint")], "relay.Pass": _rename("transform.Pass"), "relay.PassInfo": _rename("transform.PassInfo"), "relay.PassContext": _rename("transform.PassContext"), diff --git a/python/tvm/relay/testing/py_converter.py b/python/tvm/relay/testing/py_converter.py index 61a04ec392dd..89c3393408f5 100644 --- a/python/tvm/relay/testing/py_converter.py +++ b/python/tvm/relay/testing/py_converter.py @@ -190,7 +190,7 @@ def convert_func_node(self, func: Function, name_var=None): if name_var is None: func_name = self.generate_function_name('_anon_func') if isinstance(name_var, GlobalVar): - func_name = name_var.name_hint + func_name = str(name_var.name_hint) if isinstance(name_var, Var): func_name = self.get_var_name(name_var) @@ -411,7 +411,7 @@ def visit_var(self, var: Expr): def visit_global_var(self, gvar: Expr): # we don't need to add numbers to global var names because # the *names* are checked for uniqueness in the mod - return (Name(gvar.name_hint, Load()), []) + return (Name(str(gvar.name_hint), Load()), []) def visit_let(self, letexp: Expr): diff --git a/src/ir/adt.cc b/src/ir/adt.cc index 957905ded3cf..f0ce859f3f87 100644 --- a/src/ir/adt.cc +++ b/src/ir/adt.cc @@ -26,7 +26,7 @@ namespace tvm { -Constructor::Constructor(std::string name_hint, tvm::Array inputs, GlobalTypeVar belong_to) { +Constructor::Constructor(String name_hint, tvm::Array inputs, GlobalTypeVar belong_to) { ObjectPtr n = make_object(); n->name_hint = std::move(name_hint); n->inputs = std::move(inputs); @@ -37,7 +37,7 @@ Constructor::Constructor(std::string name_hint, tvm::Array inputs, GlobalT TVM_REGISTER_NODE_TYPE(ConstructorNode); TVM_REGISTER_GLOBAL("ir.Constructor") - .set_body_typed([](std::string name_hint, tvm::Array inputs, GlobalTypeVar belong_to) { + .set_body_typed([](String name_hint, tvm::Array inputs, GlobalTypeVar belong_to) { return Constructor(name_hint, inputs, belong_to); }); diff --git a/src/ir/env_func.cc b/src/ir/env_func.cc index 7deff903cc1f..7b0d6e6f09c2 100644 --- a/src/ir/env_func.cc +++ b/src/ir/env_func.cc @@ -45,7 +45,7 @@ ObjectPtr CreateEnvNode(const std::string& name) { return n; } -EnvFunc EnvFunc::Get(const std::string& name) { return EnvFunc(CreateEnvNode(name)); } +EnvFunc EnvFunc::Get(const String& name) { return EnvFunc(CreateEnvNode(name)); } TVM_REGISTER_GLOBAL("ir.EnvFuncGet").set_body_typed(EnvFunc::Get); diff --git a/src/ir/expr.cc b/src/ir/expr.cc index 000305b61c26..8b2656b28fc7 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -137,7 +137,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')'; }); -GlobalVar::GlobalVar(std::string name_hint) { +GlobalVar::GlobalVar(String name_hint) { ObjectPtr n = make_object(); n->name_hint = std::move(name_hint); data_ = std::move(n); @@ -145,7 +145,7 @@ GlobalVar::GlobalVar(std::string name_hint) { TVM_REGISTER_NODE_TYPE(GlobalVarNode); -TVM_REGISTER_GLOBAL("ir.GlobalVar").set_body_typed([](std::string name) { +TVM_REGISTER_GLOBAL("ir.GlobalVar").set_body_typed([](String name) { return GlobalVar(name); }); diff --git a/src/ir/function.cc b/src/ir/function.cc index 57d62b4f17b5..c0cda704c424 100644 --- a/src/ir/function.cc +++ b/src/ir/function.cc @@ -38,7 +38,7 @@ TVM_REGISTER_GLOBAL("ir.BaseFunc_Attrs").set_body_typed([](BaseFunc func) { retu TVM_REGISTER_GLOBAL("ir.BaseFuncCopy").set_body_typed([](BaseFunc func) { return func; }); TVM_REGISTER_GLOBAL("ir.BaseFuncWithAttr") - .set_body_typed([](BaseFunc func, std::string key, ObjectRef value) -> BaseFunc { + .set_body_typed([](BaseFunc func, String key, ObjectRef value) -> BaseFunc { if (func->IsInstance()) { return WithAttr(Downcast(std::move(func)), key, value); } else if (func->IsInstance()) { diff --git a/src/ir/op.cc b/src/ir/op.cc index 8f587686d7c4..3a6bcbccb9f9 100644 --- a/src/ir/op.cc +++ b/src/ir/op.cc @@ -61,7 +61,7 @@ struct OpManager { }; // find operator by name -const Op& Op::Get(const std::string& name) { +const Op& Op::Get(const String& name) { const OpRegistry* reg = dmlc::Registry::Find(name); CHECK(reg != nullptr) << "Operator " << name << " is not registered"; return reg->op(); @@ -75,7 +75,7 @@ OpRegistry::OpRegistry() { } // Get attribute map by key -const GenericOpMap& Op::GetGenericAttr(const std::string& key) { +const GenericOpMap& Op::GetGenericAttr(const String& key) { OpManager* mgr = OpManager::Global(); std::lock_guard lock(mgr->mutex); auto it = mgr->attr.find(key); @@ -86,7 +86,7 @@ const GenericOpMap& Op::GetGenericAttr(const std::string& key) { } // Check if a key is present in the registry. -bool Op::HasGenericAttr(const std::string& key) { +bool Op::HasGenericAttr(const String& key) { OpManager* mgr = OpManager::Global(); std::lock_guard lock(mgr->mutex); auto it = mgr->attr.find(key); @@ -110,7 +110,7 @@ void OpRegistry::reset_attr(const std::string& key) { } } -void OpRegistry::UpdateAttr(const std::string& key, TVMRetValue value, int plevel) { +void OpRegistry::UpdateAttr(const String& key, TVMRetValue value, int plevel) { OpManager* mgr = OpManager::Global(); std::lock_guard lock(mgr->mutex); std::unique_ptr& op_map = mgr->attr[key]; @@ -141,7 +141,7 @@ TVM_REGISTER_GLOBAL("relay.op._ListOpNames").set_body_typed([]() { return ret; }); -TVM_REGISTER_GLOBAL("relay.op._GetOp").set_body_typed([](std::string name) -> Op { +TVM_REGISTER_GLOBAL("relay.op._GetOp").set_body_typed([](String name) -> Op { return Op::Get(name); }); diff --git a/src/ir/transform.cc b/src/ir/transform.cc index d7d9b063aa12..59e0c1c85276 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -201,7 +201,7 @@ class SequentialNode : public PassNode { TVM_DECLARE_FINAL_OBJECT_INFO(SequentialNode, PassNode); }; -PassInfo::PassInfo(int opt_level, std::string name, tvm::Array required) { +PassInfo::PassInfo(int opt_level, String name, tvm::Array required) { auto pass_info = make_object(); pass_info->opt_level = opt_level; pass_info->name = std::move(name); @@ -238,7 +238,7 @@ Sequential::Sequential(tvm::Array passes, PassInfo pass_info) { data_ = std::move(n); } -Sequential::Sequential(tvm::Array passes, std::string name) { +Sequential::Sequential(tvm::Array passes, String name) { auto n = make_object(); n->passes = std::move(passes); PassInfo pass_info = PassInfo(2, std::move(name), {}); @@ -282,10 +282,10 @@ bool SequentialNode::PassEnabled(const PassInfo& info) const { return ctx->opt_level >= info->opt_level; } -Pass GetPass(const std::string& pass_name) { +Pass GetPass(const String& pass_name) { using tvm::runtime::Registry; const runtime::PackedFunc* f = nullptr; - if (pass_name.find("transform.") != std::string::npos) { + if (pass_name.operator std::string().find("transform.") != std::string::npos) { f = Registry::Get(pass_name); } else if ((f = Registry::Get("transform." + pass_name))) { // pass @@ -313,7 +313,7 @@ IRModule SequentialNode::operator()(IRModule mod, const PassContext& pass_ctx) c } Pass CreateModulePass(const runtime::TypedPackedFunc& pass_func, - int opt_level, const std::string& name, + int opt_level, const String& name, const tvm::Array& required) { PassInfo pass_info = PassInfo(opt_level, name, required); return ModulePass(pass_func, pass_info); @@ -322,7 +322,7 @@ Pass CreateModulePass(const runtime::TypedPackedFunc required) { + .set_body_typed([](int opt_level, String name, tvm::Array required) { return PassInfo(opt_level, name, required); }); @@ -439,7 +439,7 @@ TVM_REGISTER_GLOBAL("transform.EnterPassContext").set_body_typed(PassContext::In TVM_REGISTER_GLOBAL("transform.ExitPassContext").set_body_typed(PassContext::Internal::ExitScope); -Pass PrintIR(std::string header, bool show_meta_data) { +Pass PrintIR(String header, bool show_meta_data) { auto pass_func = [header, show_meta_data](IRModule mod, const PassContext& ctx) { LOG(INFO) << "PrintIR(" << header << "):\n" << AsText(mod, show_meta_data); return mod; diff --git a/src/ir/type.cc b/src/ir/type.cc index 212a6e5ea1bc..38a6ec3e6805 100644 --- a/src/ir/type.cc +++ b/src/ir/type.cc @@ -81,7 +81,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "TypeVar(" << node->name_hint << ", " << node->kind << ")"; }); -GlobalTypeVar::GlobalTypeVar(std::string name, TypeKind kind) { +GlobalTypeVar::GlobalTypeVar(String name, TypeKind kind) { ObjectPtr n = make_object(); n->name_hint = std::move(name); n->kind = std::move(kind); @@ -90,7 +90,7 @@ GlobalTypeVar::GlobalTypeVar(std::string name, TypeKind kind) { TVM_REGISTER_NODE_TYPE(GlobalTypeVarNode); -TVM_REGISTER_GLOBAL("ir.GlobalTypeVar").set_body_typed([](std::string name, int kind) { +TVM_REGISTER_GLOBAL("ir.GlobalTypeVar").set_body_typed([](String name, int kind) { return GlobalTypeVar(name, static_cast(kind)); }); diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index 3c545ef5488e..5166a489e22f 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -446,7 +446,9 @@ Doc RelayTextPrinter::VisitExpr_(const FunctionNode* op) { return PrintFunc(Doc::Text("fn "), GetRef(op)); } -Doc RelayTextPrinter::VisitExpr_(const GlobalVarNode* op) { return Doc::Text('@' + op->name_hint); } +Doc RelayTextPrinter::VisitExpr_(const GlobalVarNode* op) { + return Doc::Text('@' + op->name_hint.operator std::string()); +} Doc RelayTextPrinter::VisitExpr_(const OpNode* op) { return Doc::Text(op->name); }