diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index a683fd66743df..6822159cf119b 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -107,11 +107,12 @@ class PrimExpr : public BaseExpr { * \param value The value to be constructed. */ TVM_DLL PrimExpr(float value); // NOLINT(*) + /*! - * \brief construct from string. - * \param str The value to be constructed. + * \brief construct from runtime String. + * \param value The value to be constructed. */ - TVM_DLL PrimExpr(std::string str); // NOLINT(*) + TVM_DLL PrimExpr(runtime::String value); // NOLINT(*) /*! \return the data type of this expression. */ DataType dtype() const { diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h index ecd234a93f763..3a9913fba33d0 100644 --- a/include/tvm/ir/transform.h +++ b/include/tvm/ir/transform.h @@ -57,6 +57,7 @@ #define TVM_IR_TRANSFORM_H_ #include +#include #include #include #include @@ -95,9 +96,9 @@ class PassContextNode : public Object { int fallback_device{static_cast(kDLCPU)}; /*! \brief The list of required passes. */ - Array required_pass; + Array required_pass; /*! \brief The list of disabled passes. */ - Array disabled_pass; + Array disabled_pass; TraceFunc trace_func; @@ -197,7 +198,7 @@ class PassInfoNode : public Object { std::string name; /*! \brief The passes that are required to perform the current pass. */ - Array required; + Array required; PassInfoNode() = default; @@ -226,7 +227,7 @@ class PassInfo : public ObjectRef { */ TVM_DLL PassInfo(int opt_level, std::string name, - Array required); + Array required); TVM_DEFINE_OBJECT_REF_METHODS(PassInfo, ObjectRef, PassInfoNode); }; @@ -346,7 +347,7 @@ Pass CreateModulePass( const runtime::TypedPackedFunc& pass_func, int opt_level, const std::string& name, - const Array& required); + const Array& required); } // namespace transform } // namespace tvm diff --git a/include/tvm/node/node.h b/include/tvm/node/node.h index 04f477bdd0247..b39e3b4034213 100644 --- a/include/tvm/node/node.h +++ b/include/tvm/node/node.h @@ -35,6 +35,7 @@ #define TVM_NODE_NODE_H_ #include +#include #include #include #include @@ -62,6 +63,7 @@ using runtime::make_object; using runtime::PackedFunc; using runtime::TVMArgs; using runtime::TVMRetValue; +using runtime::String; } // namespace tvm #endif // TVM_NODE_NODE_H_ diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index deb084c65d546..2dcf7f31e2d05 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -24,6 +24,7 @@ #ifndef TVM_RELAY_TRANSFORM_H_ #define TVM_RELAY_TRANSFORM_H_ +#include #include #include #include @@ -59,7 +60,7 @@ TVM_DLL Pass CreateFunctionPass(const runtime::TypedPackedFunc< Function(Function, IRModule, PassContext)>& pass_func, int opt_level, const std::string& name, - const tvm::Array& required); + const tvm::Array& required); /*! \brief Remove expressions which does not effect the program result. * @@ -355,7 +356,7 @@ TVM_DLL Pass Inline(); * * \return The pass. */ -TVM_DLL Pass RemoveUnusedFunctions(Array entry_functions); +TVM_DLL Pass RemoveUnusedFunctions(Array entry_functions); } // namespace transform diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h index 16292094c8893..c2b77b65b92f8 100644 --- a/include/tvm/target/target.h +++ b/include/tvm/target/target.h @@ -52,11 +52,11 @@ class TargetNode : public Object { /*! \brief The warp size that should be used by the LowerThreadAllreduce pass */ int thread_warp_size = 1; /*! \brief Keys for this target */ - Array keys_array; + Array keys_array; /*! \brief Options for this target */ - Array options_array; + Array options_array; /*! \brief Collection of imported libs */ - Array libs_array; + Array libs_array; /*! \return the full device string to pass to codegen::Build */ TVM_DLL const std::string& str() const; diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h index 682402221b3af..ad5c5cd60d318 100644 --- a/include/tvm/tir/stmt_functor.h +++ b/include/tvm/tir/stmt_functor.h @@ -326,7 +326,7 @@ class StmtExprMutator : * won't do further recursion. * \param postorder The function called after recursive mutation. * The recursive mutation result is passed to postorder for further mutation. - * \param only_enable List of StringImm. + * \param only_enable List of runtime::String. * If it is empty, all IRNode will call preorder/postorder * If it is not empty, preorder/postorder will only be called * when the IRNode's type key is in the list. @@ -334,7 +334,7 @@ class StmtExprMutator : TVM_DLL Stmt IRTransform(Stmt node, const runtime::PackedFunc& preorder, const runtime::PackedFunc& postorder, - const Array& only_enable = {}); + const Array& only_enable = {}); /*! * \brief recursively visit the ir in post DFS order node, apply fvisit diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 860014d774a4c..5ad40a30ecd05 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -56,7 +56,7 @@ TVM_DLL Pass CreatePrimFuncPass(const runtime::TypedPackedFunc< PrimFunc(PrimFunc, IRModule, PassContext)>& pass_func, int opt_level, const std::string& name, - const tvm::Array& required); + const tvm::Array& required); /*! * \brief Transform the high-level PrimFunc to a low-level version @@ -100,7 +100,7 @@ TVM_DLL Pass MakePackedAPI(int num_unpacked_args); * * \return The pass. */ -TVM_DLL Pass RemapThreadAxis(Map axis_map); +TVM_DLL Pass RemapThreadAxis(Map axis_map); /*! diff --git a/python/tvm/autotvm/task/task.py b/python/tvm/autotvm/task/task.py index ddee149c479ea..00b667670c650 100644 --- a/python/tvm/autotvm/task/task.py +++ b/python/tvm/autotvm/task/task.py @@ -24,6 +24,7 @@ import numpy as np from tvm import target as _target +from tvm import runtime from tvm.ir import container from tvm.tir import expr from tvm.te import tensor, placeholder @@ -55,6 +56,8 @@ def _encode(x): return x if isinstance(x, (expr.StringImm, expr.IntImm, expr.FloatImm)): return x.value + if isinstance(x, runtime.container.String): + return str(x) if x is None: return None raise RuntimeError('Do not support type "%s" in argument. Consider to use' diff --git a/python/tvm/relay/backend/graph_runtime_codegen.py b/python/tvm/relay/backend/graph_runtime_codegen.py index 3e5f0157b32f5..8210f27732be3 100644 --- a/python/tvm/relay/backend/graph_runtime_codegen.py +++ b/python/tvm/relay/backend/graph_runtime_codegen.py @@ -84,8 +84,7 @@ def codegen(self, func): lowered_func = self._get_irmodule() param_names = self._list_params_name() params = {} - for name in param_names: - key = name.value + for key in param_names: arr = self._get_param_by_name(key) param = empty(arr.shape, dtype=arr.dtype, ctx=arr.ctx) arr.copyto(param) diff --git a/python/tvm/runtime/container.py b/python/tvm/runtime/container.py index dd59011b1eca6..a719dcd4eaf06 100644 --- a/python/tvm/runtime/container.py +++ b/python/tvm/runtime/container.py @@ -16,8 +16,9 @@ # under the License. """Runtime container structures.""" import tvm._ffi - +from tvm._ffi.base import string_types from tvm.runtime import Object, ObjectTypes +from tvm.runtime import _ffi_api def getitem_helper(obj, elem_getter, length, idx): """Helper function to implement a pythonic getitem function. @@ -75,18 +76,19 @@ def __init__(self, tag, fields): for f in fields: assert isinstance(f, ObjectTypes), "Expect object or " \ "tvm NDArray type, but received : {0}".format(type(f)) - self.__init_handle_by_constructor__(_ADT, tag, *fields) + self.__init_handle_by_constructor__(_ffi_api.ADT, tag, + *fields) @property def tag(self): - return _GetADTTag(self) + return _ffi_api.GetADTTag(self) def __getitem__(self, idx): return getitem_helper( - self, _GetADTFields, len(self), idx) + self, _ffi_api.GetADTFields, len(self), idx) def __len__(self): - return _GetADTSize(self) + return _ffi_api.GetADTSize(self) def tuple_object(fields=None): @@ -106,7 +108,7 @@ def tuple_object(fields=None): for f in fields: assert isinstance(f, ObjectTypes), "Expect object or tvm " \ "NDArray type, but received : {0}".format(type(f)) - return _Tuple(*fields) + return _ffi_api.Tuple(*fields) @tvm._ffi.register_object("runtime.String") @@ -115,7 +117,7 @@ class String(Object): Parameters ---------- - string : Str + string : str The string used to construct a runtime String object Returns @@ -124,7 +126,50 @@ class String(Object): The created object. """ def __init__(self, string): - self.__init_handle_by_constructor__(_String, string) + self.__init_handle_by_constructor__(_ffi_api.String, string) + + def __str__(self): + return _ffi_api.GetStdString(self) + + def __len__(self): + return _ffi_api.GetStringSize(self) + + def __hash__(self): + return _ffi_api.StringHash(self) + + def __eq__(self, other): + if isinstance(other, string_types): + return self.__str__() == other + + if not isinstance(other, String): + return False + + return _ffi_api.CompareString(self, other) == 0 + + def __ne__(self, other): + return not self.__eq__(other) + + def __gt__(self, other): + return _ffi_api.CompareString(self, other) > 0 + + def __lt__(self, other): + return _ffi_api.CompareString(self, other) < 0 + + def __getitem__(self, key): + return self.__str__()[key] + + def startswith(self, string): + """Check if the runtime string starts with a given string + Parameters + ---------- + string : str + The provided string -tvm._ffi._init_api("tvm.runtime.container") + Returns + ------- + ret : boolean + Return true if the runtime string starts with the given string, + otherwise, false. + """ + return self.__str__().startswith(string) diff --git a/python/tvm/runtime/object_generic.py b/python/tvm/runtime/object_generic.py index 22354db7737c5..a7716df831894 100644 --- a/python/tvm/runtime/object_generic.py +++ b/python/tvm/runtime/object_generic.py @@ -19,7 +19,7 @@ from numbers import Number, Integral from tvm._ffi.base import string_types -from . import _ffi_node_api +from . import _ffi_node_api, _ffi_api from .object import ObjectBase, _set_class_object_generic from .ndarray import NDArrayBase from .packed_func import PackedFuncBase, convert_to_tvm_func @@ -56,7 +56,7 @@ def convert_to_object(value): if isinstance(value, Number): return const(value) if isinstance(value, string_types): - return _ffi_node_api.String(value) + return _ffi_api.String(value) if isinstance(value, (list, tuple)): value = [convert_to_object(x) for x in value] return _ffi_node_api.Array(*value) diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py index e6046cef1839b..1410672faead5 100644 --- a/python/tvm/target/target.py +++ b/python/tvm/target/target.py @@ -48,26 +48,26 @@ def __new__(cls): @property def keys(self): if not self._keys: - self._keys = [k.value for k in self.keys_array] + self._keys = [str(k) for k in self.keys_array] return self._keys @property def options(self): if not self._options: - self._options = [o.value for o in self.options_array] + self._options = [str(o) for o in self.options_array] return self._options @property def libs(self): if not self._libs: - self._libs = [l.value for l in self.libs_array] + self._libs = [str(l) for l in self.libs_array] return self._libs @property def model(self): for opt in self.options_array: - if opt.value.startswith('-model='): - return opt.value[7:] + if opt.startswith('-model='): + return opt[7:] return 'unknown' @property diff --git a/src/autotvm/touch_extractor.cc b/src/autotvm/touch_extractor.cc index b5bf2ed609648..fbd0829c8a60a 100644 --- a/src/autotvm/touch_extractor.cc +++ b/src/autotvm/touch_extractor.cc @@ -252,9 +252,9 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array > > for (auto var : vars) { Array > feature_row; ItervarFeature &fea = touch_analyzer.itervar_map[var]; - feature_row.push_back(Array{std::string("_itervar_"), var}); + feature_row.push_back(Array{tvm::tir::StringImmNode::make("_itervar_"), var}); - Array attr{std::string("_attr_"), + Array attr{tvm::tir::StringImmNode::make("_attr_"), FloatImm(DataType::Float(32), trans(fea.length)), IntImm(DataType::Int(32), fea.nest_level), FloatImm(DataType::Float(32), trans(fea.topdown_product)), @@ -267,7 +267,7 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array > > feature_row.push_back(attr); // arithmetic - feature_row.push_back(Array{std::string("_arith_"), + feature_row.push_back(Array{tvm::tir::StringImmNode::make("_arith_"), FloatImm(DataType::Float(32), trans(fea.add_ct)), FloatImm(DataType::Float(32), trans(fea.mul_ct)), FloatImm(DataType::Float(32), trans(fea.div_ct)), @@ -282,7 +282,7 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array > > for (auto k : bufs) { TouchPattern &v = fea.touch_feature[k]; feature_row.push_back( - Array{k, + Array{tvm::tir::StringImmNode::make(k), FloatImm(DataType::Float(32), trans(v.stride)), FloatImm(DataType::Float(32), trans(v.mod)), FloatImm(DataType::Float(32), trans(v.count)), diff --git a/src/ir/attrs.cc b/src/ir/attrs.cc index 066b8f99ea7c2..444af7eca2575 100644 --- a/src/ir/attrs.cc +++ b/src/ir/attrs.cc @@ -42,7 +42,7 @@ void DictAttrsNode::InitByPackedArgs( if (val.IsObjectRef()) { dict.Set(key, val.operator ObjectRef()); } else if (val.type_code() == kTVMStr) { - dict.Set(key, PrimExpr(val.operator std::string())); + dict.Set(key, runtime::String(val.operator std::string())); } else { dict.Set(key, val.operator PrimExpr()); } diff --git a/src/ir/expr.cc b/src/ir/expr.cc index b07f04aa69749..1f0337e579f34 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -40,8 +40,8 @@ PrimExpr::PrimExpr(int32_t value) PrimExpr::PrimExpr(float value) : PrimExpr(FloatImm(DataType::Float(32), value)) {} -PrimExpr::PrimExpr(std::string str) - : PrimExpr(tir::StringImmNode::make(str)) {} +PrimExpr::PrimExpr(runtime::String value) + : PrimExpr(tir::StringImmNode::make(value)) {} PrimExpr PrimExpr::FromObject_(ObjectPtr ptr) { using runtime::ObjectTypeChecker; @@ -51,6 +51,9 @@ PrimExpr PrimExpr::FromObject_(ObjectPtr ptr) { if (ptr->IsInstance()) { return te::Tensor(ptr)(); } + if (ptr->IsInstance()) { + return tir::StringImmNode::make(runtime::String(ptr)); + } CHECK(ObjectTypeChecker::Check(ptr.get())) << "Expect type " << ObjectTypeChecker::TypeName() << " but get " << ptr->GetTypeKey(); diff --git a/src/ir/op.cc b/src/ir/op.cc index 54374eb8a526a..e5011fbf43aef 100644 --- a/src/ir/op.cc +++ b/src/ir/op.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include #include @@ -140,10 +141,10 @@ void OpRegistry::UpdateAttr(const std::string& key, // Frontend APIs TVM_REGISTER_GLOBAL("relay.op._ListOpNames") .set_body_typed([]() { - Array ret; + Array ret; for (const std::string& name : dmlc::Registry::ListAllNames()) { - ret.push_back(tvm::PrimExpr(name)); + ret.push_back(runtime::String(name)); } return ret; }); diff --git a/src/ir/transform.cc b/src/ir/transform.cc index 61c1fc240b6f5..6e38aac92ec0c 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -23,6 +23,7 @@ */ #include #include +#include #include #include #include @@ -212,7 +213,7 @@ class SequentialNode : public PassNode { PassInfo::PassInfo(int opt_level, std::string name, - tvm::Array required) { + tvm::Array required) { auto pass_info = make_object(); pass_info->opt_level = opt_level; pass_info->name = std::move(name); @@ -274,12 +275,10 @@ void SequentialNode::ResolveDependency(const IRModule& mod) { } // linearly scan the pass array to match pass_name -inline bool PassArrayContains(const Array& pass_array, +inline bool PassArrayContains(const Array& pass_array, const std::string& pass_name) { for (auto x : pass_array) { - auto* str_name = x.as(); - CHECK(str_name) << "pass name must be str"; - if (str_name->value == pass_name) return true; + if (x == pass_name) return true; } return false; } @@ -324,9 +323,7 @@ IRModule SequentialNode::operator()(const IRModule& module, if (!PassEnabled(pass_info)) continue; // resolve dependencies for (const auto& it : pass_info->required) { - const auto* name = it.as(); - CHECK(name); - mod = GetPass(name->value)(mod, pass_ctx); + mod = GetPass(it)(mod, pass_ctx); } mod = pass(mod, pass_ctx); } @@ -337,7 +334,7 @@ Pass CreateModulePass( const runtime::TypedPackedFunc& pass_func, int opt_level, const std::string& name, - const tvm::Array& required) { + const tvm::Array& required) { PassInfo pass_info = PassInfo(opt_level, name, required); return ModulePass(pass_func, pass_info); } @@ -345,7 +342,7 @@ Pass CreateModulePass( TVM_REGISTER_NODE_TYPE(PassInfoNode); TVM_REGISTER_GLOBAL("transform.PassInfo") -.set_body_typed([](int opt_level, std::string name, tvm::Array required) { +.set_body_typed([](int opt_level, std::string name, tvm::Array required) { return PassInfo(opt_level, name, required); }); @@ -363,8 +360,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "opt_level: " << node->opt_level; p->stream << "required passes: [" << "\n"; for (const auto& it : node->required) { - const auto* str = it.as(); - p->stream << str->value << ", "; + p->stream << it << ", "; } p->stream << "]\n"; }); @@ -401,7 +397,7 @@ TVM_REGISTER_GLOBAL("transform.Sequential") tvm::Array passes = args[0]; int opt_level = args[1]; std::string name = args[2]; - tvm::Array required = args[3]; + tvm::Array required = args[3]; PassInfo pass_info = PassInfo(opt_level, name, required); *ret = Sequential(passes, pass_info); }); @@ -427,8 +423,8 @@ TVM_REGISTER_GLOBAL("transform.PassContext") auto pctx = PassContext::Create(); int opt_level = args[0]; int fallback_device = args[1]; - tvm::Array required = args[2]; - tvm::Array disabled = args[3]; + tvm::Array required = args[2]; + tvm::Array disabled = args[3]; TraceFunc trace_func = args[4]; pctx->opt_level = opt_level; pctx->fallback_device = fallback_device; diff --git a/src/node/container.cc b/src/node/container.cc index 8fff151ce6059..1c33ca2ed10d1 100644 --- a/src/node/container.cc +++ b/src/node/container.cc @@ -48,7 +48,10 @@ struct StringObjTrait { } }; -TVM_REGISTER_REFLECTION_VTABLE(runtime::StringObj, StringObjTrait); +TVM_REGISTER_REFLECTION_VTABLE(runtime::StringObj, StringObjTrait) +.set_creator([](const std::string&) -> ObjectPtr { + return ::tvm::runtime::make_object(); +}); struct ADTObjTrait { static constexpr const std::nullptr_t VisitAttrs = nullptr; diff --git a/src/node/serialization.cc b/src/node/serialization.cc index 11c9e8fc8cb63..c661b34dd5e20 100644 --- a/src/node/serialization.cc +++ b/src/node/serialization.cc @@ -102,7 +102,7 @@ class NodeIndexer : public AttrVisitor { for (const auto& kv : n->data) { MakeIndex(const_cast(kv.second.get())); } - } else { + } else if (!node->IsInstance()) { reflection_->VisitAttrs(node, this); } } @@ -242,6 +242,8 @@ class JSONAttrGetter : public AttrVisitor { node_->data.push_back( node_index_->at(const_cast(kv.second.get()))); } + } else if (node->IsInstance()) { + node_->data.push_back(node_index_->at(node)); } else { // recursively index normal object. reflection_->VisitAttrs(node, this); @@ -335,7 +337,7 @@ class JSONAttrSetter : public AttrVisitor { n->data[node_->keys[i]] = ObjectRef(node_list_->at(node_->data[i])); } - } else { + } else if (!node->IsInstance()) { reflection_->VisitAttrs(node, this); } } diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index bda997a59d4d7..685a8636a37cd 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -141,7 +141,12 @@ class RelayTextPrinter : } else { // default module. std::ostringstream os; - os << node; + if (node->IsInstance()) { + runtime::String str = Downcast(node); + os << "\"" << str.c_str() << "\""; + } else { + os << node; + } return Doc::RawText(os.str()); } } diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index eaf78bc1b0f78..bc0685b5d9053 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -86,9 +86,9 @@ struct GraphCodegen { std::unordered_map GetParams() { std::unordered_map ret; - auto names = CallFunc>("list_params_name", nullptr); + auto names = CallFunc>("list_params_name", nullptr); for (auto expr : names) { - auto key = expr.as()->value; + auto key = expr.operator std::string(); ret[key] = CallFunc("get_param_by_name", key); } return ret; @@ -191,12 +191,12 @@ class RelayBuildModule : public runtime::ModuleNode { /*! * \brief List all paramter names * - * \return Array names of params + * \return Array names of params */ - Array ListParamNames() { - Array ret; + Array ListParamNames() { + Array ret; for (const auto& kv : params_) { - ret.push_back(tir::StringImmNode::make(kv.first)); + ret.push_back(runtime::String(kv.first)); } return ret; } @@ -272,7 +272,7 @@ class RelayBuildModule : public runtime::ModuleNode { } Array pass_seqs; - Array entry_functions{tvm::PrimExpr{"main"}}; + Array entry_functions{runtime::String("main")}; pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions)); // Run all dialect legalization passes. diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index f75da0731242b..6a06f83e5a24e 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -617,17 +617,18 @@ class CompileEngineImpl : public CompileEngineNode { for (const auto& it : cache_) { auto src_func = it.first->source_func; CHECK(src_func.defined()); - if (src_func->GetAttr(attr::kCompiler).defined()) { - auto code_gen = src_func->GetAttr(attr::kCompiler); + if (src_func->GetAttr(attr::kCompiler).defined()) { + auto code_gen = src_func->GetAttr(attr::kCompiler); CHECK(code_gen.defined()) << "No external codegen is set"; - if (ext_mods.find(code_gen->value) == ext_mods.end()) { - ext_mods[code_gen->value] = IRModule({}, {}); + std::string code_gen_name = code_gen.operator std::string(); + if (ext_mods.find(code_gen_name) == ext_mods.end()) { + ext_mods[code_gen_name] = IRModule({}, {}); } auto symbol_name = src_func->GetAttr(tvm::attr::kGlobalSymbol); CHECK(symbol_name.defined()) << "No external symbol is set for:\n" << AsText(src_func, false); auto gv = GlobalVar(std::string(symbol_name)); - ext_mods[code_gen->value]->Add(gv, src_func); + ext_mods[code_gen_name]->Add(gv, src_func); cached_ext_funcs.push_back(it.first); } } @@ -691,7 +692,7 @@ class CompileEngineImpl : public CompileEngineNode { } // No need to lower external functions for now. We will invoke the external // codegen tool once and lower all functions together. - if (key->source_func->GetAttr(attr::kCompiler).defined()) { + if (key->source_func->GetAttr(attr::kCompiler).defined()) { auto cache_node = make_object(); const auto name_node = key->source_func->GetAttr(tvm::attr::kGlobalSymbol); diff --git a/src/relay/backend/graph_runtime_codegen.cc b/src/relay/backend/graph_runtime_codegen.cc index c7f1be82c3710..c126017f982f4 100644 --- a/src/relay/backend/graph_runtime_codegen.cc +++ b/src/relay/backend/graph_runtime_codegen.cc @@ -419,7 +419,7 @@ class GraphRuntimeCodegen auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower"); Target target; // Handle external function - if (func->GetAttr(attr::kCompiler).defined()) { + if (func->GetAttr(attr::kCompiler).defined()) { target = tvm::target::ext_dev(); CCacheKey key = (*pf0)(func, target); CachedFunc ext_func = (*pf1)(compile_engine_, key); @@ -482,7 +482,7 @@ class GraphRuntimeCodegen return {}; } std::vector VisitExpr_(const FunctionNode* op) override { - CHECK(op->GetAttr(attr::kCompiler).defined()) + CHECK(op->GetAttr(attr::kCompiler).defined()) << "Only functions supported by custom codegen"; return {}; } @@ -633,10 +633,9 @@ class GraphRuntimeCodegenModule : public runtime::ModuleNode { }); } else if (name == "list_params_name") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - Array ret; + Array ret; for (const auto &kv : this->output_.params) { - tvm::PrimExpr name = tir::StringImmNode::make(kv.first); - ret.push_back(name); + ret.push_back(runtime::String(kv.first)); } *rv = ret; }); diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 78ebb0fc5383c..1a55b504f29e9 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -475,7 +475,7 @@ class VMFunctionCompiler : ExprFunctor { Target target; - if (func->GetAttr(attr::kCompiler).defined()) { + if (func->GetAttr(attr::kCompiler).defined()) { target = tvm::target::ext_dev(); } else { // Next generate the invoke instruction. @@ -493,7 +493,7 @@ class VMFunctionCompiler : ExprFunctor { auto cfunc = engine_->Lower(key); auto op_index = -1; - if (func->GetAttr(attr::kCompiler).defined()) { + if (func->GetAttr(attr::kCompiler).defined()) { op_index = context_->cached_funcs.size(); context_->cached_funcs.push_back(cfunc); } else { @@ -873,7 +873,7 @@ void VMCompiler::Lower(IRModule mod, IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targets) { Array pass_seqs; - Array entry_functions{tvm::PrimExpr{"main"}}; + Array entry_functions{runtime::String{"main"}}; pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions)); // Run all dialect legalization passes. pass_seqs.push_back(relay::qnn::transform::Legalize()); diff --git a/src/relay/backend/vm/inline_primitives.cc b/src/relay/backend/vm/inline_primitives.cc index 74b2a47634c85..9ecea40f24fd2 100644 --- a/src/relay/backend/vm/inline_primitives.cc +++ b/src/relay/backend/vm/inline_primitives.cc @@ -122,7 +122,7 @@ struct PrimitiveInliner : ExprMutator { auto global = pair.first; auto base_func = pair.second; if (auto* n = base_func.as()) { - if (n->GetAttr(attr::kCompiler).defined()) continue; + if (n->GetAttr(attr::kCompiler).defined()) continue; auto func = GetRef(n); DLOG(INFO) << "Before inlining primitives: " << global diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc index 80745e1a11145..0cd17a97fb03a 100644 --- a/src/relay/backend/vm/lambda_lift.cc +++ b/src/relay/backend/vm/lambda_lift.cc @@ -190,7 +190,7 @@ class LambdaLifter : public ExprMutator { auto glob_funcs = module_->functions; for (auto pair : glob_funcs) { if (auto* n = pair.second.as()) { - if (n->GetAttr(attr::kCompiler).defined()) continue; + if (n->GetAttr(attr::kCompiler).defined()) continue; auto func = GetRef(n); func = Function(func->params, VisitExpr(func->body), diff --git a/src/relay/backend/vm/removed_unused_funcs.cc b/src/relay/backend/vm/removed_unused_funcs.cc index dd11fce5cc42e..c2fe37f15453f 100644 --- a/src/relay/backend/vm/removed_unused_funcs.cc +++ b/src/relay/backend/vm/removed_unused_funcs.cc @@ -87,11 +87,10 @@ struct CallTracer : ExprVisitor { * \return The module with dead functions removed. */ IRModule RemoveUnusedFunctions(const IRModule& module, - Array entry_funcs) { + Array entry_funcs) { std::unordered_set called_funcs{}; for (auto entry : entry_funcs) { - auto* str_name = entry.as(); - auto funcs = CallTracer(module).Trace(str_name->value); + auto funcs = CallTracer(module).Trace(entry); called_funcs.insert(funcs.cbegin(), funcs.cend()); } auto existing_functions = module->functions; @@ -108,7 +107,7 @@ IRModule RemoveUnusedFunctions(const IRModule& module, namespace transform { -Pass RemoveUnusedFunctions(Array entry_functions) { +Pass RemoveUnusedFunctions(Array entry_functions) { runtime::TypedPackedFunc pass_func = [=](IRModule m, PassContext pc) { return relay::vm::RemoveUnusedFunctions(m, entry_functions); diff --git a/src/relay/ir/transform.cc b/src/relay/ir/transform.cc index a4bab36d3fe50..d5f0fbac29c8d 100644 --- a/src/relay/ir/transform.cc +++ b/src/relay/ir/transform.cc @@ -145,14 +145,14 @@ IRModule FunctionPassNode::operator()(const IRModule& mod, bool FunctionPassNode::SkipFunction(const Function& func) const { return func->GetAttr(attr::kSkipOptimization, 0)->value != 0 || - (func->GetAttr(attr::kCompiler).defined()); + (func->GetAttr(attr::kCompiler).defined()); } Pass CreateFunctionPass( const runtime::TypedPackedFunc& pass_func, int opt_level, const std::string& name, - const tvm::Array& required) { + const tvm::Array& required) { PassInfo pass_info = PassInfo(opt_level, name, required); return FunctionPass(pass_func, pass_info); } diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 87b4602095a20..7aa8bf1863a12 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -1177,7 +1177,6 @@ Array ArangeCompute(const Attrs& attrs, te::Tensor start = inputs[0]; te::Tensor stop = inputs[1]; te::Tensor step = inputs[2]; - Array empty = {0}; return { DynamicArange(start, stop, step, param->dtype) }; } diff --git a/src/relay/transforms/alter_op_layout.cc b/src/relay/transforms/alter_op_layout.cc index 63c1cb96886d8..59cf9f98288f9 100644 --- a/src/relay/transforms/alter_op_layout.cc +++ b/src/relay/transforms/alter_op_layout.cc @@ -126,7 +126,7 @@ Pass AlterOpLayout() { return Downcast(relay::alter_op_layout::AlterOpLayout(f)); }; return CreateFunctionPass(pass_func, 3, "AlterOpLayout", - {tir::StringImmNode::make("InferType")}); + {runtime::String("InferType")}); } TVM_REGISTER_GLOBAL("relay._transform.AlterOpLayout") diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index c3d34cb9ab7cd..0863cd33183c4 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -59,11 +59,12 @@ class AnnotateTargetWrapper : public ExprMutator { // handle composite functions Function func = Downcast(call->op); CHECK(func.defined()); - auto comp_name = func->GetAttr(attr::kComposite); + auto comp_name = func->GetAttr(attr::kComposite); if (comp_name.defined()) { - size_t i = comp_name->value.find('.'); + std::string comp_name_str = comp_name; + size_t i = comp_name_str.find('.'); if (i != std::string::npos) { - std::string target = comp_name->value.substr(0, i); + std::string target = comp_name_str.substr(0, i); if (target == target_) return true; } } @@ -147,7 +148,7 @@ class AnnotateTargetWrapper : public ExprMutator { Function func; Expr new_body; // don't step into composite functions - if (fn->GetAttr(attr::kComposite).defined()) { + if (fn->GetAttr(attr::kComposite).defined()) { func = GetRef(fn); new_body = func->body; } else { @@ -225,7 +226,7 @@ Pass AnnotateTarget(const std::string& target) { return Downcast(relay::annotate_target::AnnotateTarget(f, target)); }; auto func_pass = CreateFunctionPass(pass_func, 0, "AnnotateTargetFunc", - {tir::StringImmNode::make("InferType")}); + {runtime::String("InferType")}); return transform::Sequential({func_pass, InferType()}, "AnnotateTarget"); } diff --git a/src/relay/transforms/canonicalize_cast.cc b/src/relay/transforms/canonicalize_cast.cc index 759a4aea741d3..4b35ba219b674 100644 --- a/src/relay/transforms/canonicalize_cast.cc +++ b/src/relay/transforms/canonicalize_cast.cc @@ -134,7 +134,7 @@ Pass CanonicalizeCast() { return Downcast(CanonicalizeCast(f)); }; return CreateFunctionPass(pass_func, 3, "CanonicalizeCast", - {tir::StringImmNode::make("InferType")}); + {runtime::String("InferType")}); } TVM_REGISTER_GLOBAL("relay._transform.CanonicalizeCast") diff --git a/src/relay/transforms/canonicalize_ops.cc b/src/relay/transforms/canonicalize_ops.cc index 97a128db65211..44140a902a2fd 100644 --- a/src/relay/transforms/canonicalize_ops.cc +++ b/src/relay/transforms/canonicalize_ops.cc @@ -75,7 +75,7 @@ Pass CanonicalizeOps() { return Downcast(CanonicalizeOps(f)); }; return CreateFunctionPass(pass_func, 3, "CanonicalizeOps", - {tir::StringImmNode::make("InferType")}); + {runtime::String("InferType")}); } TVM_REGISTER_GLOBAL("relay._transform.CanonicalizeOps") diff --git a/src/relay/transforms/combine_parallel_conv2d.cc b/src/relay/transforms/combine_parallel_conv2d.cc index 3884dacbb22c4..3c8eea04d28f3 100644 --- a/src/relay/transforms/combine_parallel_conv2d.cc +++ b/src/relay/transforms/combine_parallel_conv2d.cc @@ -221,7 +221,7 @@ Pass CombineParallelConv2D(uint64_t min_num_branches) { return Downcast(CombineParallelConv2D(f, min_num_branches)); }; return CreateFunctionPass(pass_func, 4, "CombineParallelConv2d", - {tir::StringImmNode::make("InferType")}); + {runtime::String("InferType")}); } TVM_REGISTER_GLOBAL("relay._transform.CombineParallelConv2D") diff --git a/src/relay/transforms/combine_parallel_dense.cc b/src/relay/transforms/combine_parallel_dense.cc index 612dae5ef00cf..2dc8321e517b2 100644 --- a/src/relay/transforms/combine_parallel_dense.cc +++ b/src/relay/transforms/combine_parallel_dense.cc @@ -81,7 +81,7 @@ Pass CombineParallelDense(uint64_t min_num_branches) { return Downcast(CombineParallelDense(f, min_num_branches)); }; return CreateFunctionPass(pass_func, 4, "CombineParallelDense", - {tir::StringImmNode::make("InferType")}); + {runtime::String("InferType")}); } TVM_REGISTER_GLOBAL("relay._transform.CombineParallelDense") diff --git a/src/relay/transforms/combine_parallel_op_batch.cc b/src/relay/transforms/combine_parallel_op_batch.cc index 55ca3f62bec01..f63f169be4086 100644 --- a/src/relay/transforms/combine_parallel_op_batch.cc +++ b/src/relay/transforms/combine_parallel_op_batch.cc @@ -194,7 +194,7 @@ Pass CombineParallelOpBatch(const std::string& op_name, min_num_branches)); }; return CreateFunctionPass(pass_func, 4, "CombineParallelOpBatch", - {tir::StringImmNode::make("InferType")}); + {runtime::String("InferType")}); } TVM_REGISTER_GLOBAL("relay._transform.CombineParallelOpBatch") diff --git a/src/relay/transforms/convert_layout.cc b/src/relay/transforms/convert_layout.cc index 871969dd1f37c..d43a0851e0997 100644 --- a/src/relay/transforms/convert_layout.cc +++ b/src/relay/transforms/convert_layout.cc @@ -134,8 +134,8 @@ Pass ConvertLayout(const std::string& desired_layout) { }; return CreateFunctionPass( pass_func, 3, "ConvertLayout", - {tir::StringImmNode::make("InferType"), - tir::StringImmNode::make("CanonicalizeOps")}); + {runtime::String("InferType"), + runtime::String("CanonicalizeOps")}); } TVM_REGISTER_GLOBAL("relay._transform.ConvertLayout").set_body_typed(ConvertLayout); diff --git a/src/relay/transforms/device_annotation.cc b/src/relay/transforms/device_annotation.cc index b4d61f1108325..9955ef6ee7d2c 100644 --- a/src/relay/transforms/device_annotation.cc +++ b/src/relay/transforms/device_annotation.cc @@ -574,7 +574,7 @@ Pass RewriteAnnotatedOps(int fallback_device) { return Downcast(relay::RewriteAnnotatedOps(f, fallback_device)); }; return CreateFunctionPass(pass_func, 1, "RewriteAnnotatedOps", - {tir::StringImmNode::make("InferType")}); + {runtime::String("InferType")}); } TVM_REGISTER_GLOBAL("relay._transform.RewriteDeviceAnnotation") diff --git a/src/relay/transforms/eliminate_common_subexpr.cc b/src/relay/transforms/eliminate_common_subexpr.cc index f905ba55719d8..696e83a7db538 100644 --- a/src/relay/transforms/eliminate_common_subexpr.cc +++ b/src/relay/transforms/eliminate_common_subexpr.cc @@ -92,7 +92,7 @@ Pass EliminateCommonSubexpr(PackedFunc fskip) { return Downcast(EliminateCommonSubexpr(f, fskip)); }; return CreateFunctionPass(pass_func, 3, "EliminateCommonSubexpr", - {tir::StringImmNode::make("InferType")}); + {runtime::String("InferType")}); } TVM_REGISTER_GLOBAL("relay._transform.EliminateCommonSubexpr") diff --git a/src/relay/transforms/fast_math.cc b/src/relay/transforms/fast_math.cc index cf00a89fff549..668982e561e88 100644 --- a/src/relay/transforms/fast_math.cc +++ b/src/relay/transforms/fast_math.cc @@ -71,7 +71,7 @@ Pass FastMath() { return Downcast(FastMath(f)); }; return CreateFunctionPass(pass_func, 4, "FastMath", - {tir::StringImmNode::make("InferType")}); + {runtime::String("InferType")}); } TVM_REGISTER_GLOBAL("relay._transform.FastMath") diff --git a/src/relay/transforms/fold_scale_axis.cc b/src/relay/transforms/fold_scale_axis.cc index 49f6e3fd01cd5..11325f6526b84 100644 --- a/src/relay/transforms/fold_scale_axis.cc +++ b/src/relay/transforms/fold_scale_axis.cc @@ -961,7 +961,7 @@ Pass ForwardFoldScaleAxis() { relay::fold_scale_axis::ForwardFoldScaleAxis(f)); }; return CreateFunctionPass(pass_func, 3, "ForwardFoldScaleAxis", - {tir::StringImmNode::make("InferType")}); + {runtime::String("InferType")}); } TVM_REGISTER_GLOBAL("relay._transform.ForwardFoldScaleAxis") @@ -974,7 +974,7 @@ Pass BackwardFoldScaleAxis() { relay::fold_scale_axis::BackwardFoldScaleAxis(f)); }; return CreateFunctionPass(pass_func, 3, "BackwardFoldScaleAxis", - {tir::StringImmNode::make("InferType")}); + {runtime::String("InferType")}); } TVM_REGISTER_GLOBAL("relay._transform.BackwardFoldScaleAxis") diff --git a/src/relay/transforms/fuse_ops.cc b/src/relay/transforms/fuse_ops.cc index 9168898cae360..cdd29394a2047 100644 --- a/src/relay/transforms/fuse_ops.cc +++ b/src/relay/transforms/fuse_ops.cc @@ -981,7 +981,7 @@ Pass FuseOps(int fuse_opt_level) { return Downcast(FuseOps(f, opt_level, m)); }; return CreateFunctionPass(pass_func, 1, "FuseOps", - {tir::StringImmNode::make("InferType")}); + {runtime::String("InferType")}); } TVM_REGISTER_GLOBAL("relay._transform.FuseOps") diff --git a/src/relay/transforms/inline.cc b/src/relay/transforms/inline.cc index ef3c51f86105b..8400b3fb3088e 100644 --- a/src/relay/transforms/inline.cc +++ b/src/relay/transforms/inline.cc @@ -131,7 +131,7 @@ class Inliner : ExprMutator { fn->attrs); // Inline the function body to the caller if this function uses default // compiler, i.e. no external codegen is needed. - if (!func->GetAttr(attr::kCompiler).defined()) { + if (!func->GetAttr(attr::kCompiler).defined()) { CHECK_EQ(func->params.size(), args.size()) << "Mismatch found in the number of parameters and call args"; // Bind the parameters with call args. diff --git a/src/relay/transforms/legalize.cc b/src/relay/transforms/legalize.cc index 250dd69cd62fa..b7a0945951295 100644 --- a/src/relay/transforms/legalize.cc +++ b/src/relay/transforms/legalize.cc @@ -102,7 +102,7 @@ Pass Legalize(const std::string& legalize_map_attr_name) { [=](Function f, IRModule m, PassContext pc) { return Downcast(relay::legalize::Legalize(f, legalize_map_attr_name)); }; - return CreateFunctionPass(pass_func, 1, "Legalize", {tir::StringImmNode::make("InferType")}); + return CreateFunctionPass(pass_func, 1, "Legalize", {runtime::String("InferType")}); } TVM_REGISTER_GLOBAL("relay._transform.Legalize").set_body_typed(Legalize); diff --git a/src/relay/transforms/merge_composite.cc b/src/relay/transforms/merge_composite.cc index 35b93dced90d0..1fb1dea93d83f 100644 --- a/src/relay/transforms/merge_composite.cc +++ b/src/relay/transforms/merge_composite.cc @@ -159,9 +159,9 @@ class MergeCompositeWrapper : public ExprMutator { if (call->op->IsInstance()) { Function func = Downcast(call->op); CHECK(func.defined()); - const auto name_node = func->GetAttr(attr::kComposite); + auto name_node = func->GetAttr(attr::kComposite); // don't step into existing composite functions - if (name_node.defined() && name_node->value != "") { + if (name_node.defined() && name_node != "") { tvm::Array new_args; for (const auto& arg : call->args) { auto new_e = this->Mutate(arg); @@ -185,7 +185,7 @@ class MergeCompositeWrapper : public ExprMutator { auto free_vars = FreeVars(extract); // make the composite function auto f = Function(free_vars, extract, call->checked_type_, {}, DictAttrs()); - f = WithAttr(std::move(f), attr::kComposite, tir::StringImmNode::make(pattern_name_)); + f = WithAttr(std::move(f), attr::kComposite, runtime::String(pattern_name_)); // find the expressions associated with the free vars using the args_map // this tells us which expressions should be given as inputs to the composite function Array args; @@ -207,16 +207,14 @@ class MergeCompositeWrapper : public ExprMutator { PackedFunc check_; }; -Expr MergeComposite(const Expr& expr, const Array& pattern_names, +Expr MergeComposite(const Expr& expr, const Array& pattern_names, const Array& patterns, const std::vector& checks) { CHECK_EQ(pattern_names.size(), patterns.size()); Expr merged_expr = expr; // merge the patterns one-by-one in order for (size_t i = 0; i < patterns.size(); i++) { - std::string pattern_name = pattern_names[i]->value; - Expr pattern = patterns[i]; - PackedFunc check = checks[i]; - merged_expr = MergeCompositeWrapper(pattern_name, pattern, check).Mutate(merged_expr); + merged_expr = + MergeCompositeWrapper(pattern_names[i], patterns[i], checks[i]).Mutate(merged_expr); } return merged_expr; } @@ -225,7 +223,7 @@ Expr MergeComposite(const Expr& expr, const Array& pattern_names namespace transform { -Pass MergeComposite(const tvm::Array& pattern_names, +Pass MergeComposite(const tvm::Array& pattern_names, const tvm::Array& patterns, const std::vector& checks) { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { @@ -236,8 +234,9 @@ Pass MergeComposite(const tvm::Array& pattern_names, return func_pass; } -TVM_REGISTER_GLOBAL("relay._transform.MergeComposite").set_body([](TVMArgs args, TVMRetValue* rv) { - tvm::Array pattern_names = args[0]; +TVM_REGISTER_GLOBAL("relay._transform.MergeComposite") +.set_body([](TVMArgs args, TVMRetValue* rv) { + tvm::Array pattern_names = args[0]; tvm::Array patterns = args[1]; std::vector checks; for (int i = 2; i < args.size(); i++) { diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc index a4e38634bf9d3..8eeac1748a43e 100644 --- a/src/relay/transforms/partition_graph.cc +++ b/src/relay/transforms/partition_graph.cc @@ -245,7 +245,7 @@ class Partitioner : public ExprMutator { global_region_func = WithAttr(std::move(global_region_func), attr::kPrimitive, tvm::Integer(1)); global_region_func = WithAttr(std::move(global_region_func), attr::kCompiler, - tvm::tir::StringImmNode::make(target)); + tvm::runtime::String(target)); global_region_func = WithAttr(std::move(global_region_func), attr::kInline, tvm::Integer(1)); diff --git a/src/relay/transforms/simplify_inference.cc b/src/relay/transforms/simplify_inference.cc index bc7c15e8dea41..b33799a26b430 100644 --- a/src/relay/transforms/simplify_inference.cc +++ b/src/relay/transforms/simplify_inference.cc @@ -205,7 +205,7 @@ Pass SimplifyInference() { return Downcast(SimplifyInference(f)); }; return CreateFunctionPass(pass_func, 0, "SimplifyInference", - {tir::StringImmNode::make("InferType")}); + {runtime::String("InferType")}); } TVM_REGISTER_GLOBAL("relay._transform.SimplifyInference") diff --git a/src/relay/transforms/to_a_normal_form.cc b/src/relay/transforms/to_a_normal_form.cc index 6e35dfbcb1582..2f18be851e286 100644 --- a/src/relay/transforms/to_a_normal_form.cc +++ b/src/relay/transforms/to_a_normal_form.cc @@ -299,7 +299,7 @@ IRModule ToANormalForm(const IRModule& m) { for (const auto& it : funcs) { CHECK_EQ(FreeVars(it.second).size(), 0); if (const auto* n = it.second.as()) { - if (n->GetAttr(attr::kCompiler).defined()) continue; + if (n->GetAttr(attr::kCompiler).defined()) continue; } Expr ret = TransformF([&](const Expr& e) { diff --git a/src/runtime/container.cc b/src/runtime/container.cc index 400f6469615fb..81dfd3d4e2521 100644 --- a/src/runtime/container.cc +++ b/src/runtime/container.cc @@ -32,14 +32,14 @@ namespace runtime { using namespace vm; -TVM_REGISTER_GLOBAL("runtime.container._GetADTTag") +TVM_REGISTER_GLOBAL("runtime.GetADTTag") .set_body([](TVMArgs args, TVMRetValue* rv) { ObjectRef obj = args[0]; const auto& adt = Downcast(obj); *rv = static_cast(adt.tag()); }); -TVM_REGISTER_GLOBAL("runtime.container._GetADTSize") +TVM_REGISTER_GLOBAL("runtime.GetADTSize") .set_body([](TVMArgs args, TVMRetValue* rv) { ObjectRef obj = args[0]; const auto& adt = Downcast(obj); @@ -47,7 +47,7 @@ TVM_REGISTER_GLOBAL("runtime.container._GetADTSize") }); -TVM_REGISTER_GLOBAL("runtime.container._GetADTFields") +TVM_REGISTER_GLOBAL("runtime.GetADTFields") .set_body([](TVMArgs args, TVMRetValue* rv) { ObjectRef obj = args[0]; int idx = args[1]; @@ -56,7 +56,7 @@ TVM_REGISTER_GLOBAL("runtime.container._GetADTFields") *rv = adt[idx]; }); -TVM_REGISTER_GLOBAL("runtime.container._Tuple") +TVM_REGISTER_GLOBAL("runtime.Tuple") .set_body([](TVMArgs args, TVMRetValue* rv) { std::vector fields; for (auto i = 0; i < args.size(); ++i) { @@ -65,7 +65,7 @@ TVM_REGISTER_GLOBAL("runtime.container._Tuple") *rv = ADT::Tuple(fields); }); -TVM_REGISTER_GLOBAL("runtime.container._ADT") +TVM_REGISTER_GLOBAL("runtime.ADT") .set_body([](TVMArgs args, TVMRetValue* rv) { int itag = args[0]; size_t tag = static_cast(itag); @@ -76,11 +76,31 @@ TVM_REGISTER_GLOBAL("runtime.container._ADT") *rv = ADT(tag, fields); }); -TVM_REGISTER_GLOBAL("runtime.container._String") +TVM_REGISTER_GLOBAL("runtime.String") .set_body_typed([](std::string str) { return String(std::move(str)); }); +TVM_REGISTER_GLOBAL("runtime.GetStringSize") +.set_body_typed([](String str) { + return static_cast(str.size()); +}); + +TVM_REGISTER_GLOBAL("runtime.GetStdString") +.set_body_typed([](String str) { + return std::string(str); +}); + +TVM_REGISTER_GLOBAL("runtime.CompareString") +.set_body_typed([](String lhs, String rhs) { + return lhs.compare(rhs); +}); + +TVM_REGISTER_GLOBAL("runtime.StringHash") +.set_body_typed([](String str) { + return static_cast(std::hash()(str)); +}); + TVM_REGISTER_OBJECT_TYPE(ADTObj); TVM_REGISTER_OBJECT_TYPE(StringObj); TVM_REGISTER_OBJECT_TYPE(ClosureObj); diff --git a/src/target/generic_func.cc b/src/target/generic_func.cc index 8eef4b75ff408..8a4b4e20e6400 100644 --- a/src/target/generic_func.cc +++ b/src/target/generic_func.cc @@ -22,6 +22,7 @@ #include #include +#include #include #include #include @@ -150,12 +151,12 @@ TVM_REGISTER_GLOBAL("target.GenericFuncRegisterFunc") GenericFunc generic_func = args[0]; // Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown PackedFunc* func = new PackedFunc(args[1].operator PackedFunc()); - Array tags = args[2]; + Array tags = args[2]; bool allow_override = args[3]; std::vector tags_vector; for (auto& tag : tags) { - tags_vector.push_back(tag.as()->value); + tags_vector.push_back(std::string(tag)); } generic_func diff --git a/src/target/source/codegen_vhls.cc b/src/target/source/codegen_vhls.cc index 6c1c3b9d22fc4..482d5a25a1a3e 100644 --- a/src/target/source/codegen_vhls.cc +++ b/src/target/source/codegen_vhls.cc @@ -147,7 +147,7 @@ runtime::Module BuildSDAccel(IRModule mod, std::string target_str) { std::string whole_code = cg.Finish(); // Generate source code for compilation. - Array > kernel_info; + Array > kernel_info; for (auto kv : mod->functions) { CHECK(kv.second->IsInstance()) @@ -164,8 +164,8 @@ runtime::Module BuildSDAccel(IRModule mod, std::string target_str) { auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); CHECK(global_symbol.defined()) << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; - std::string func_name = global_symbol; - kernel_info.push_back(Array({func_name, code})); + runtime::String func_name(global_symbol); + kernel_info.push_back(Array({func_name, runtime::String(code)})); } std::string xclbin; diff --git a/src/target/target.cc b/src/target/target.cc index ab2077db584cb..cb0d214edbe54 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -62,39 +62,39 @@ Target CreateTarget(const std::string& target_name, std::string device_flag = "-device="; std::string keys_flag = "-keys="; for (auto& item : options) { - t->options_array.push_back(tir::StringImmNode::make(item)); + t->options_array.push_back(runtime::String(item)); if (item.find(libs_flag) == 0) { std::stringstream ss(item.substr(libs_flag.length())); std::string lib_item; while (std::getline(ss, lib_item, ',')) { - t->libs_array.push_back(tir::StringImmNode::make(lib_item)); + t->libs_array.push_back(runtime::String(lib_item)); } } else if (item.find(device_flag) == 0) { t->device_name = item.substr(device_flag.length()); - t->keys_array.push_back(tir::StringImmNode::make(t->device_name)); + t->keys_array.push_back(runtime::String(t->device_name)); } else if (item.find(keys_flag) == 0) { std::stringstream ss(item.substr(keys_flag.length())); std::string key_item; while (std::getline(ss, key_item, ',')) { - t->keys_array.push_back(tir::StringImmNode::make(key_item)); + t->keys_array.push_back(runtime::String(key_item)); } } } if (t->device_name.length() > 0) { - t->keys_array.push_back(tir::StringImmNode::make(t->device_name)); + t->keys_array.push_back(runtime::String(t->device_name)); } t->device_type = kDLCPU; t->thread_warp_size = 1; if (target_name == "c" && t->device_name == "micro_dev") { t->device_type = kDLMicroDev; } else if (target_name == "c" || target_name == "llvm") { - t->keys_array.push_back(tir::StringImmNode::make("cpu")); + t->keys_array.push_back(runtime::String("cpu")); } else if (target_name == "cuda" || target_name == "nvptx") { t->device_type = kDLGPU; - t->keys_array.push_back(tir::StringImmNode::make("cuda")); - t->keys_array.push_back(tir::StringImmNode::make("gpu")); + t->keys_array.push_back(runtime::String("cuda")); + t->keys_array.push_back(runtime::String("gpu")); t->max_num_threads = 1024; t->thread_warp_size = 32; } else if (target_name == "rocm" || target_name == "opencl") { @@ -104,8 +104,8 @@ Target CreateTarget(const std::string& target_name, } else { t->device_type = kDLROCM; } - t->keys_array.push_back(tir::StringImmNode::make(target_name)); - t->keys_array.push_back(tir::StringImmNode::make("gpu")); + t->keys_array.push_back(runtime::String(target_name)); + t->keys_array.push_back(runtime::String("gpu")); t->max_num_threads = 256; if (t->device_name == "intel_graphics") { t->thread_warp_size = 16; @@ -116,20 +116,20 @@ Target CreateTarget(const std::string& target_name, } else { t->device_type = kDLVulkan; } - t->keys_array.push_back(tir::StringImmNode::make(target_name)); - t->keys_array.push_back(tir::StringImmNode::make("gpu")); + t->keys_array.push_back(runtime::String(target_name)); + t->keys_array.push_back(runtime::String("gpu")); t->max_num_threads = 256; } else if (target_name == "sdaccel") { t->device_type = kDLOpenCL; - t->keys_array.push_back(tir::StringImmNode::make("sdaccel")); - t->keys_array.push_back(tir::StringImmNode::make("hls")); + t->keys_array.push_back(runtime::String("sdaccel")); + t->keys_array.push_back(runtime::String("hls")); } else if (target_name == "aocl" || target_name == "aocl_sw_emu") { t->device_type = kDLAOCL; - t->keys_array.push_back(tir::StringImmNode::make("aocl")); - t->keys_array.push_back(tir::StringImmNode::make("hls")); + t->keys_array.push_back(runtime::String("aocl")); + t->keys_array.push_back(runtime::String("hls")); } else if (target_name == "opengl") { t->device_type = kOpenGL; - t->keys_array.push_back(tir::StringImmNode::make("opengl")); + t->keys_array.push_back(runtime::String("opengl")); } else if (target_name == "stackvm") { t->device_type = kDLCPU; } else if (target_name == "ext_dev") { @@ -165,7 +165,7 @@ TVM_REGISTER_GLOBAL("target.TargetFromString") std::vector TargetNode::keys() const { std::vector result; for (auto& expr : keys_array) { - result.push_back(expr.as()->value); + result.push_back(expr); } return result; } @@ -173,7 +173,7 @@ std::vector TargetNode::keys() const { std::vector TargetNode::options() const { std::vector result; for (auto& expr : options_array) { - result.push_back(expr.as()->value); + result.push_back(expr); } return result; } @@ -181,7 +181,7 @@ std::vector TargetNode::options() const { std::unordered_set TargetNode::libs() const { std::unordered_set result; for (auto& expr : libs_array) { - result.insert(expr.as()->value); + result.insert(expr); } return result; } diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 891d13723d9a6..0efa33ad6a65b 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -47,7 +47,6 @@ Var::Var(std::string name_hint, Type type_annotation) { data_ = std::move(n); } - Var Var::copy_with_suffix(const std::string& suffix) const { const VarNode* node = get(); ObjectPtr new_ptr; @@ -826,20 +825,28 @@ TVM_REGISTER_GLOBAL("tir.Load") } }); - - TVM_REGISTER_GLOBAL("tir.Call") .set_body_typed([]( DataType type, std::string name, - Array args, int call_type, + Array args, int call_type, FunctionRef func, int value_index ) { + Array prim_expr_args; + for (const auto& it : args) { + CHECK(it->IsInstance() || + it->IsInstance()); + if (const auto* str = it.as()) { + prim_expr_args.push_back(StringImmNode::make(str->data)); + } else { + prim_expr_args.push_back(Downcast(it)); + } + } return CallNode::make(type, - name, - args, - static_cast(call_type), - func, - value_index); + name, + prim_expr_args, + static_cast(call_type), + func, + value_index); }); } // namespace tir diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index ea199821b2369..96fc4354aa947 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -120,10 +120,10 @@ class IRTransformer final : Stmt IRTransform(Stmt ir_node, const runtime::PackedFunc& f_preorder, const runtime::PackedFunc& f_postorder, - const Array& only_enable) { + const Array& only_enable) { std::unordered_set only_type_index; - for (PrimExpr s : only_enable) { - only_type_index.insert(Object::TypeKey2Index(s.as()->value.c_str())); + for (auto s : only_enable) { + only_type_index.insert(Object::TypeKey2Index(s.c_str())); } IRTransformer transform(f_preorder, f_postorder, only_type_index); return transform(std::move(ir_node)); diff --git a/src/tir/ir/transform.cc b/src/tir/ir/transform.cc index 773c67d79269e..001c7cfc5d05c 100644 --- a/src/tir/ir/transform.cc +++ b/src/tir/ir/transform.cc @@ -124,7 +124,7 @@ Pass CreatePrimFuncPass( const runtime::TypedPackedFunc& pass_func, int opt_level, const std::string& name, - const tvm::Array& required) { + const tvm::Array& required) { PassInfo pass_info = PassInfo(opt_level, name, required); return PrimFuncPass(pass_func, pass_info); } diff --git a/src/tir/pass/arg_binder.cc b/src/tir/pass/arg_binder.cc index 30542eaf6c1c0..c684b9e680387 100644 --- a/src/tir/pass/arg_binder.cc +++ b/src/tir/pass/arg_binder.cc @@ -42,7 +42,8 @@ void BinderAddAssert(PrimExpr cond, if (!is_one(scond)) { std::ostringstream os; os << "Argument " << arg_name << " has an unsatisfied constraint"; - asserts->emplace_back(AssertStmtNode::make(scond, os.str(), EvaluateNode::make(0))); + asserts->emplace_back(AssertStmtNode::make(scond, tvm::tir::StringImmNode::make(os.str()), + EvaluateNode::make(0))); } } @@ -173,7 +174,8 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, ndim_err_msg << arg_name << ".ndim is expected to equal " << buffer->shape.size(); - asserts_.emplace_back(AssertStmtNode::make(a_ndim == v_ndim, ndim_err_msg.str(), nop)); + auto msg = tvm::tir::StringImmNode::make(ndim_err_msg.str()); + asserts_.emplace_back(AssertStmtNode::make(a_ndim == v_ndim, msg, nop)); // type checks DataType dtype = buffer->dtype; std::ostringstream type_err_msg; @@ -187,7 +189,9 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, if (!(dtype == DataType::Int(4) || dtype == DataType::UInt(4) || dtype == DataType::Int(1))) { - asserts_.emplace_back(AssertStmtNode::make(cond, type_err_msg.str(), nop)); + auto type_msg = tvm::tir::StringImmNode::make(type_err_msg.str()); + asserts_.emplace_back(AssertStmtNode::make(a_ndim == v_ndim, msg, nop)); + asserts_.emplace_back(AssertStmtNode::make(cond, type_msg, nop)); } // data field if (Bind_(buffer->data, TVMArrayGet(DataType::Handle(), handle, intrinsic::kArrData), @@ -245,9 +249,10 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, stride_err_msg << arg_name << ".strides:" << " expected to be compact array"; if (conds.size() != 0) { + auto stride_msg = tvm::tir::StringImmNode::make(stride_err_msg.str()); Stmt check = AssertStmtNode::make(arith::ComputeReduce(conds, PrimExpr()), - stride_err_msg.str(), EvaluateNode::make(0)); + stride_msg, EvaluateNode::make(0)); check = IfThenElseNode::make(NotNode::make(is_null), check, Stmt()); asserts_.emplace_back(SeqStmt({check, EvaluateNode::make(0)})); } @@ -269,9 +274,8 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, } else { std::ostringstream stride_null_err_msg; stride_null_err_msg << arg_name << ".strides: expected non-null strides."; - asserts_.emplace_back( - AssertStmtNode::make( - NotNode::make(is_null), stride_null_err_msg.str(), nop)); + asserts_.emplace_back(AssertStmtNode::make( + NotNode::make(is_null), tvm::tir::StringImmNode::make(stride_null_err_msg.str()), nop)); for (size_t k = 0; k < buffer->strides.size(); ++k) { std::ostringstream field_name; diff --git a/src/tir/pass/hoist_if_then_else.cc b/src/tir/pass/hoist_if_then_else.cc index 1fd43ff72ffea..f7c75c0ea3173 100644 --- a/src/tir/pass/hoist_if_then_else.cc +++ b/src/tir/pass/hoist_if_then_else.cc @@ -160,7 +160,7 @@ Stmt update_for(const Stmt& parent_for_stmt, const Stmt& new_if_stmt) { }); return IRTransform(parent_for_stmt, nullptr, replace_target_for, - {PrimExpr("For")}); + {runtime::String("For")}); } // Remove IfThenElse node from a For node. @@ -187,10 +187,10 @@ std::pair RemoveIf(const Stmt& for_stmt, const Stmt& if_stmt) { }); then_for = IRTransform(for_stmt, nullptr, replace_then_case, - {PrimExpr("IfThenElse")}); + {runtime::String("IfThenElse")}); if (if_stmt.as()->else_case.defined()) { else_for = IRTransform(for_stmt, nullptr, replace_else_case, - {PrimExpr("IfThenElse")}); + {runtime::String("IfThenElse")}); } return std::make_pair(then_for, else_for); @@ -411,7 +411,7 @@ Stmt IfThenElseHoist::PostOrderMutate(const Stmt& stmt) { *ret = new_for; } }); - return IRTransform(stmt, nullptr, replace_top_for, {PrimExpr("For")}); + return IRTransform(stmt, nullptr, replace_top_for, {runtime::String("For")}); } Stmt HoistIfThenElse(Stmt stmt) { diff --git a/src/tir/pass/tensor_core.cc b/src/tir/pass/tensor_core.cc index 88f749646d522..dc2df985a8ee7 100644 --- a/src/tir/pass/tensor_core.cc +++ b/src/tir/pass/tensor_core.cc @@ -860,7 +860,7 @@ class TensorCoreIRMutator : public StmtExprMutator { auto it = matrix_abc_.find(simplify_name(node->name)); CHECK(it != matrix_abc_.end()) << "Cannot find matrix info for " << node->name; - auto matrix_abc = "wmma." + it->second; + auto matrix_abc = tvm::tir::StringImmNode::make("wmma." + it->second); Stmt body = this->VisitStmt(op->body); return AttrStmtNode::make(op->node, op->attr_key, diff --git a/src/tir/transforms/bind_device_type.cc b/src/tir/transforms/bind_device_type.cc index 486f21c907f97..952d6635f582a 100644 --- a/src/tir/transforms/bind_device_type.cc +++ b/src/tir/transforms/bind_device_type.cc @@ -47,7 +47,8 @@ class DeviceTypeBinder: public StmtExprMutator { var_ = nullptr; std::ostringstream os; os << "device_type need to be " << device_type_; - return AssertStmtNode::make(op->value == value, os.str(), body); + return AssertStmtNode::make(op->value == value, tvm::tir::StringImmNode::make(os.str()), + body); } } return StmtExprMutator::VisitStmt_(op); diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index c49b04442b2ff..4b75c46452bd9 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -41,7 +41,8 @@ namespace tvm { namespace tir { inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) { - return AssertStmtNode::make(lhs == rhs, msg, EvaluateNode::make(0)); + return AssertStmtNode::make(lhs == rhs, tvm::tir::StringImmNode::make(msg), + EvaluateNode::make(0)); } PrimFunc MakePackedAPI(PrimFunc&& func, @@ -140,17 +141,19 @@ PrimFunc MakePackedAPI(PrimFunc&& func, AssertStmtNode::make(tcode == kTVMOpaqueHandle || tcode == kTVMNDArrayHandle || tcode == kTVMDLTensorHandle || - tcode == kTVMNullptr, msg.str(), nop)); + tcode == kTVMNullptr, + tvm::tir::StringImmNode::make(msg.str()), nop)); } else if (t.is_int() || t.is_uint()) { std::ostringstream msg; msg << name_hint << ": Expect arg[" << i << "] to be int"; - seq_check.emplace_back(AssertStmtNode::make(tcode == kDLInt, msg.str(), nop)); + seq_check.emplace_back( + AssertStmtNode::make(tcode == kDLInt, tvm::tir::StringImmNode::make(msg.str()), nop)); } else { CHECK(t.is_float()); std::ostringstream msg; msg << name_hint << ": Expect arg[" << i << "] to be float"; seq_check.emplace_back( - AssertStmtNode::make(tcode == kDLFloat, msg.str(), nop)); + AssertStmtNode::make(tcode == kDLFloat, tvm::tir::StringImmNode::make(msg.str()), nop)); } } else { args.push_back(v_arg); diff --git a/src/tir/transforms/remap_thread_axis.cc b/src/tir/transforms/remap_thread_axis.cc index f695b3c777aaa..f3663532e56ba 100644 --- a/src/tir/transforms/remap_thread_axis.cc +++ b/src/tir/transforms/remap_thread_axis.cc @@ -76,12 +76,10 @@ class ThreadAxisRewriter : private StmtExprMutator { }; -PrimFunc RemapThreadAxis(PrimFunc&& f, Map thread_map) { +PrimFunc RemapThreadAxis(PrimFunc&& f, Map thread_map) { std::unordered_map tmap; for (const auto& kv : thread_map) { - const StringImmNode* str = kv.first.as(); - CHECK(str != nullptr); - tmap[str->value] = kv.second; + tmap[kv.first] = kv.second; } auto thread_axis = f->GetAttr >(tir::attr::kDeviceThreadAxis); @@ -101,7 +99,7 @@ PrimFunc RemapThreadAxis(PrimFunc&& f, Map thread_map) { namespace transform { -Pass RemapThreadAxis(Map thread_map) { +Pass RemapThreadAxis(Map thread_map) { auto pass_func = [thread_map](PrimFunc f, IRModule m, PassContext ctx) { return RemapThreadAxis(std::move(f), thread_map); }; diff --git a/tests/python/relay/test_annotate_target.py b/tests/python/relay/test_annotate_target.py index 0a2abd73d5eb2..59fef7ced340a 100644 --- a/tests/python/relay/test_annotate_target.py +++ b/tests/python/relay/test_annotate_target.py @@ -230,7 +230,7 @@ def before(): add_node = relay.add(in_1, in_2) relu_node = relay.nn.relu(add_node) add_relu = relay.Function([in_1, in_2], relu_node) - add_relu = add_relu.with_attr("Composite", tvm.tir.StringImm("test.add_relu")) + add_relu = add_relu.with_attr("Composite", "test.add_relu") # merged function r = relay.Call(add_relu, [a, b]) @@ -248,7 +248,7 @@ def after(): add_node = relay.add(in_1, in_2) relu_node = relay.nn.relu(add_node) add_relu = relay.Function([in_1, in_2], relu_node) - add_relu = add_relu.with_attr("Composite", tvm.tir.StringImm("test.add_relu")) + add_relu = add_relu.with_attr("Composite", "test.add_relu") # merged function cb_1 = relay.annotation.compiler_begin(a, "test") diff --git a/tests/python/relay/test_call_graph.py b/tests/python/relay/test_call_graph.py index 0af55d282b8f5..bae077c6a646c 100644 --- a/tests/python/relay/test_call_graph.py +++ b/tests/python/relay/test_call_graph.py @@ -134,7 +134,7 @@ def test_recursive_func(): func = relay.Function([i], sb.get(), ret_type=relay.TensorType([], 'int32')) - func = func.with_attr("Compiler", tvm.tir.StringImm("a")) + func = func.with_attr("Compiler", "a") mod[sum_up] = func iarg = relay.var('i', shape=[], dtype='int32') mod["main"] = relay.Function([iarg], sum_up(iarg)) diff --git a/tests/python/relay/test_external_codegen.py b/tests/python/relay/test_external_codegen.py index 724e81d65af76..b4496bb044bad 100644 --- a/tests/python/relay/test_external_codegen.py +++ b/tests/python/relay/test_external_codegen.py @@ -79,9 +79,8 @@ def check_graph_runtime_result(): def set_external_func_attr(func, compiler, ext_symbol): func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) - func = func.with_attr("Compiler", tvm.tir.StringImm(compiler)) - func = func.with_attr("global_symbol", - runtime.container.String(ext_symbol)) + func = func.with_attr("Compiler", compiler) + func = func.with_attr("global_symbol", ext_symbol) return func diff --git a/tests/python/relay/test_ir_nodes.py b/tests/python/relay/test_ir_nodes.py index dbd5934c38acc..dc73e3516e2da 100644 --- a/tests/python/relay/test_ir_nodes.py +++ b/tests/python/relay/test_ir_nodes.py @@ -96,7 +96,7 @@ def test_function(): body = relay.Tuple(tvm.runtime.convert([])) type_params = tvm.runtime.convert([]) fn = relay.Function(params, body, ret_type, type_params) - fn = fn.with_attr("test_attribute", tvm.tir.StringImm("value")) + fn = fn.with_attr("test_attribute", "value") assert fn.params == params assert fn.body == body assert fn.type_params == type_params diff --git a/tests/python/relay/test_ir_structural_equal_hash.py b/tests/python/relay/test_ir_structural_equal_hash.py index 271960ec4ab8a..e1a0a01772c48 100644 --- a/tests/python/relay/test_ir_structural_equal_hash.py +++ b/tests/python/relay/test_ir_structural_equal_hash.py @@ -356,7 +356,7 @@ def test_function_attr(): p00 = relay.subtract(z00, w01) q00 = relay.multiply(p00, w02) func0 = relay.Function([x0, w00, w01, w02], q00) - func0 = func0.with_attr("FuncName", tvm.runtime.container.String("a")) + func0 = func0.with_attr("FuncName", "a") x1 = relay.var('x1', shape=(10, 10)) w10 = relay.var('w10', shape=(10, 10)) @@ -366,7 +366,7 @@ def test_function_attr(): p10 = relay.subtract(z10, w11) q10 = relay.multiply(p10, w12) func1 = relay.Function([x1, w10, w11, w12], q10) - func1 = func1.with_attr("FuncName", tvm.runtime.container.String("b")) + func1 = func1.with_attr("FuncName", "b") assert not consistent_equal(func0, func1) @@ -698,7 +698,7 @@ def test_fn_attribute(): d = relay.var('d', shape=(10, 10)) add_1 = relay.add(c, d) add_1_fn = relay.Function([c, d], add_1) - add_1_fn = add_1_fn.with_attr("TestAttribute", tvm.runtime.container.String("test")) + add_1_fn = add_1_fn.with_attr("TestAttribute", "test") add_1_fn = run_opt_pass(add_1_fn, relay.transform.InferType()) assert not consistent_equal(add_1_fn, add_fn) diff --git a/tests/python/relay/test_pass_inline.py b/tests/python/relay/test_pass_inline.py index 0f6d539768fee..3b41f079d826b 100644 --- a/tests/python/relay/test_pass_inline.py +++ b/tests/python/relay/test_pass_inline.py @@ -209,7 +209,7 @@ def get_mod(): g11 = relay.GlobalVar("g11") fn11 = relay.Function([x11], x11) fn11 = fn11.with_attr("Inline", tvm.tir.IntImm("int32", 1)) - fn11 = fn11.with_attr("Compiler", tvm.tir.StringImm("a")) + fn11 = fn11.with_attr("Compiler", "a") mod[g11] = fn11 x1 = relay.var("x1", shape=(3, 5)) @@ -244,7 +244,7 @@ def expected(): x11 = relay.var("x11", shape=(3, 5)) fn11 = relay.Function([x11], x11) fn11 = fn11.with_attr("Inline", tvm.tir.IntImm("int32", 1)) - fn11 = fn11.with_attr("Compiler", tvm.tir.StringImm("a")) + fn11 = fn11.with_attr("Compiler", "a") x2 = relay.var("x2", shape=(3, 5)) y2 = relay.var("y2", shape=(3, 5)) @@ -367,7 +367,7 @@ def get_mod(): x1 = relay.var("x1", shape=(2, 2)) fn1 = relay.Function([x1], x1) fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1)) - fn1 = fn1.with_attr("Compiler", tvm.tir.StringImm("a")) + fn1 = fn1.with_attr("Compiler", "a") g1 = relay.GlobalVar("g1") mod[g1] = fn1 mod["main"] = relay.Function([x, y], x + y + g1(x)) @@ -380,7 +380,7 @@ def expected(): x1 = relay.var("x1", shape=(2, 2)) fn1 = relay.Function([x1], x1) fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1)) - fn1 = fn1.with_attr("Compiler", tvm.tir.StringImm("a")) + fn1 = fn1.with_attr("Compiler", "a") mod["main"] = relay.Function([x, y], x + y + fn1(x)) return mod @@ -446,7 +446,7 @@ def get_mod(): sb.ret(x1 + y1) fn1 = relay.Function([x1, y1], sb.get()) fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1)) - fn1 = fn1.with_attr("Compiler", tvm.tir.StringImm("a")) + fn1 = fn1.with_attr("Compiler", "a") g1 = relay.GlobalVar("g1") mod[g1] = fn1 @@ -456,7 +456,7 @@ def get_mod(): sb1.ret(x2 - y2) fn2 = relay.Function([x2, y2], sb1.get()) fn2 = fn2.with_attr("Inline", tvm.tir.IntImm("int32", 1)) - fn2 = fn2.with_attr("Compiler", tvm.tir.StringImm("b")) + fn2 = fn2.with_attr("Compiler", "b") g2 = relay.GlobalVar("g2") mod[g2] = fn2 @@ -478,7 +478,7 @@ def expected(): sb.ret(x1 + y1) fn1 = relay.Function([x1, y1], sb.get()) fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1)) - fn1 = fn1.with_attr("Compiler", tvm.tir.StringImm("a")) + fn1 = fn1.with_attr("Compiler", "a") x2 = relay.var("x2", shape=(3, 5)) y2 = relay.var("y2", shape=(3, 5)) @@ -486,7 +486,7 @@ def expected(): sb1.ret(x2 - y2) fn2 = relay.Function([x2, y2], sb1.get()) fn2 = fn2.with_attr("Inline", tvm.tir.IntImm("int32", 1)) - fn2 = fn2.with_attr("Compiler", tvm.tir.StringImm("b")) + fn2 = fn2.with_attr("Compiler", "b") p0 = relay.var("p0", shape=(3, 5)) p1 = relay.var("p1", shape=(3, 5)) @@ -539,10 +539,10 @@ def get_mod(): mod = tvm.IRModule({}) fn1 = relay.Function([], relay.const(1)) fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1)) - fn1 = fn1.with_attr("Compiler", tvm.tir.StringImm("a")) + fn1 = fn1.with_attr("Compiler", "a") fn2 = relay.Function([], relay.const(2)) fn2 = fn2.with_attr("Inline", tvm.tir.IntImm("int32", 1)) - fn2 = fn2.with_attr("Compiler", tvm.tir.StringImm("b")) + fn2 = fn2.with_attr("Compiler", "b") g1 = relay.GlobalVar('g1') g2 = relay.GlobalVar('g2') mod[g1] = fn1 @@ -555,10 +555,10 @@ def expected(): mod = tvm.IRModule({}) fn1 = relay.Function([], relay.const(1)) fn1 = fn1.with_attr("Inline", tvm.tir.IntImm("int32", 1)) - fn1 = fn1.with_attr("Compiler", tvm.tir.StringImm("a")) + fn1 = fn1.with_attr("Compiler", "a") fn2 = relay.Function([], relay.const(2)) fn2 = fn2.with_attr("Inline", tvm.tir.IntImm("int32", 1)) - fn2 = fn2.with_attr("Compiler", tvm.tir.StringImm("b")) + fn2 = fn2.with_attr("Compiler", "b") p = relay.var('p', 'bool') mod['main'] = relay.Function([p], relay.Call( relay.If(p, fn1, fn2), [])) @@ -787,7 +787,7 @@ def get_mod(): y0 = relay.var("y0", shape=(3, 5)) fn0 = relay.Function([x0, y0], x0 * y0) fn0 = fn0.with_attr("Inline", tvm.tir.IntImm("int32", 1)) - fn0 = fn0.with_attr("Compiler", tvm.tir.StringImm("aa")) + fn0 = fn0.with_attr("Compiler", "aa") g0 = relay.GlobalVar("g0") mod[g0] = fn0 @@ -811,7 +811,7 @@ def expected(): y0 = relay.var("y0", shape=(3, 5)) fn0 = relay.Function([x0, y0], x0 * y0) fn0 = fn0.with_attr("Inline", tvm.tir.IntImm("int32", 1)) - fn0 = fn0.with_attr("Compiler", tvm.tir.StringImm("aa")) + fn0 = fn0.with_attr("Compiler", "aa") x1 = relay.var("x1", shape=(3, 5)) y1 = relay.var("y1", shape=(3, 5)) diff --git a/tests/python/relay/test_pass_merge_composite.py b/tests/python/relay/test_pass_merge_composite.py index 110d855216e4d..e3c8991c8ebc3 100644 --- a/tests/python/relay/test_pass_merge_composite.py +++ b/tests/python/relay/test_pass_merge_composite.py @@ -184,7 +184,7 @@ def expected(): add_node = relay.add(in_1, in_2) relu_node = relay.nn.relu(add_node) add_relu = relay.Function([in_1, in_2], relu_node) - add_relu = add_relu.with_attr("Composite", tir.StringImm("add_relu")) + add_relu = add_relu.with_attr("Composite", "add_relu") # merged function r = relay.Call(add_relu, [a, b]) @@ -249,8 +249,7 @@ def expected(): sub_node = relay.subtract(in_1, in_2) mul_node = relay.multiply(add_node, sub_node) add_sub_mul = relay.Function([in_1, in_2], mul_node) - add_sub_mul = add_sub_mul.with_attr("Composite", - tir.StringImm("add_sub_mul")) + add_sub_mul = add_sub_mul.with_attr("Composite", "add_sub_mul") # add_sub_mul1 function in_3 = relay.var('in_3', shape=(10, 10)) @@ -259,8 +258,7 @@ def expected(): sub_node_1 = relay.subtract(in_3, in_4) mul_node_1 = relay.multiply(add_node_1, sub_node_1) add_sub_mul_1 = relay.Function([in_3, in_4], mul_node_1) - add_sub_mul_1 = add_sub_mul_1.with_attr("Composite", - tir.StringImm("add_sub_mul")) + add_sub_mul_1 = add_sub_mul_1.with_attr("Composite", "add_sub_mul") # merged function m_add_sub_mul_1 = relay.Call(add_sub_mul, [a, b]) @@ -319,8 +317,7 @@ def expected(): add_node_1 = relay.add(in_1, add_node) add_node_2 = relay.add(add_node_1, add_node) add_add_add = relay.Function([in_1, in_2], add_node_2) - add_add_add = add_add_add.with_attr("Composite", - tir.StringImm("add_add_add")) + add_add_add = add_add_add.with_attr("Composite", "add_add_add") # merged function sub_node = relay.subtract(a, b) @@ -404,7 +401,7 @@ def expected(): r = relay.nn.relu(bias_node) conv_bias_add_relu = relay.Function([in_1, in_2, in_3], r) conv_bias_add_relu = conv_bias_add_relu.with_attr("Composite", - tir.StringImm("conv2d_bias_relu")) + "conv2d_bias_relu") # add_relu function in_4 = relay.var('in_4', shape=(1, 256, 28, 28)) @@ -412,7 +409,7 @@ def expected(): add_node = relay.add(in_4, in_5) r = relay.nn.relu(add_node) add_relu = relay.Function([in_4, in_5], r) - add_relu = add_relu.with_attr("Composite", tir.StringImm("add_relu")) + add_relu = add_relu.with_attr("Composite", "add_relu") # merged function conv_bias_add_relu_1 = relay.Call(conv_bias_add_relu, [data, kernel, bias]) @@ -481,8 +478,7 @@ def after_A_priority(composite_name): out = relay.abs(out) out = relay.nn.relu(out) merged_func = relay.Function([x, y], out) - merged_func = merged_func.with_attr('Composite', - tir.StringImm(composite_name)) + merged_func = merged_func.with_attr('Composite', composite_name) ret = relay.Call(merged_func, [input_1, input_2]) return relay.Function([input_1, input_2], ret) @@ -547,13 +543,13 @@ def after(): y = relay.var('y') branch_1 = relay.multiply(relay.add(x, y), relay.subtract(x, y)) func_1 = relay.Function([x, y], branch_1) - func_1 = func_1.with_attr('Composite', tir.StringImm("add_sub_mul")) + func_1 = func_1.with_attr('Composite', "add_sub_mul") call_1 = relay.Call(func_1, [input_1, input_2]) x1 = relay.var('x1') y1 = relay.var('y1') branch_2 = relay.multiply(relay.add(x1, y1), relay.subtract(x1, y1)) func_2 = relay.Function([x1, y1], branch_2) - func_2 = func_2.with_attr('Composite', tir.StringImm("add_sub_mul")) + func_2 = func_2.with_attr('Composite', "add_sub_mul") call_2 = relay.Call(func_2, [input_1, input_2]) out = relay.multiply(call_1, call_2) return relay.Function([input_1, input_2], out) @@ -632,14 +628,14 @@ def after_A(): add_relu_1 = relay.add(x, y) add_relu_1 = relay.nn.relu(add_relu_1) add_relu_1 = relay.Function([x, y], add_relu_1) - add_relu_1 = add_relu_1.with_attr('Composite', tir.StringImm('add_relu')) + add_relu_1 = add_relu_1.with_attr('Composite', 'add_relu') add_relu_call_1 = relay.Call(add_relu_1, [inputs[0], inputs[1]]) x1 = relay.var('x1') y1 = relay.var('y1') add_relu_2 = relay.add(x1, y1) add_relu_2 = relay.nn.relu(add_relu_2) add_relu_2 = relay.Function([x1, y1], add_relu_2) - add_relu_2 = add_relu_2.with_attr('Composite', tir.StringImm('add_relu')) + add_relu_2 = add_relu_2.with_attr('Composite', 'add_relu') add_relu_call_2 = relay.Call(add_relu_2, [inputs[2], inputs[3]]) x2 = relay.var('x2') y2 = relay.var('y2') @@ -647,7 +643,7 @@ def after_A(): sub = relay.subtract(x2, y2) add_sub_mul = relay.multiply(add, sub) add_sub_mul = relay.Function([x2, y2], add_sub_mul) - add_sub_mul = add_sub_mul.with_attr('Composite', tir.StringImm('add_sub_mul')) + add_sub_mul = add_sub_mul.with_attr('Composite', 'add_sub_mul') add_sub_mul_call = relay.Call(add_sub_mul, [add_relu_call_1, add_relu_call_2]) return relay.Function(inputs, add_sub_mul_call) @@ -660,7 +656,7 @@ def after_B(): add_relu = relay.add(x, y) add_relu = relay.nn.relu(add_relu) add_relu = relay.Function([x, y], add_relu) - add_relu = add_relu.with_attr('Composite', tir.StringImm('add_relu')) + add_relu = add_relu.with_attr('Composite', 'add_relu') add_relu_call = relay.Call(add_relu, [inputs[i*2], inputs[i*2+1]]) add_relu_calls.append(add_relu_call) @@ -720,7 +716,7 @@ def expected(): tuple_get_item_node = bn_node[0] relu_node = relay.nn.relu(tuple_get_item_node) bn_relu = relay.Function([in_1, in_2, in_3, in_4, in_5], relu_node) - bn_relu = bn_relu.with_attr("Composite", tir.StringImm("bn_relu")) + bn_relu = bn_relu.with_attr("Composite", "bn_relu") # merged function r = relay.Call(bn_relu, [x, gamma, beta, moving_mean, moving_var]) diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index ab9f47e775853..b042b615c9984 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -23,7 +23,6 @@ import tvm.relay.testing from tvm import relay from tvm import runtime -from tvm.runtime import container from tvm.relay import transform from tvm.contrib import util from tvm.relay.op.annotation import compiler_begin, compiler_end @@ -306,8 +305,8 @@ def expected(): func = relay.Function([x0, y0], add) func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1)) - func = func.with_attr("Compiler", tvm.tir.StringImm("ccompiler")) - func = func.with_attr("global_symbol", container.String("ccompiler_0")) + func = func.with_attr("Compiler", "ccompiler") + func = func.with_attr("global_symbol", "ccompiler_0") glb_0 = relay.GlobalVar("ccompiler_0") mod[glb_0] = func add_call = relay.Call(glb_0, [x, y]) @@ -391,8 +390,8 @@ def expected(): func = relay.Function([data0, input0, input1], out) func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1)) - func = func.with_attr("Compiler", tvm.tir.StringImm("dnnl")) - func = func.with_attr("global_symbol", container.String("dnnl_0")) + func = func.with_attr("Compiler", "dnnl") + func = func.with_attr("global_symbol", "dnnl_0") glb_var = relay.GlobalVar("dnnl_0") mod = tvm.IRModule() mod[glb_var] = func @@ -516,10 +515,8 @@ def expected(): bn.astuple()) func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1)) - func0 = func0.with_attr("Compiler", - tvm.tir.StringImm("test_compiler")) - func0 = func0.with_attr("global_symbol", - container.String("test_compiler_0")) + func0 = func0.with_attr("Compiler", "test_compiler") + func0 = func0.with_attr("global_symbol", "test_compiler_0") gv0 = relay.GlobalVar("test_compiler_0") mod[gv0] = func0 @@ -535,10 +532,8 @@ def expected(): func1 = relay.Function([data1, weight1], conv) func1 = func1.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) func1 = func1.with_attr("Inline", tvm.tir.IntImm("int32", 1)) - func1 = func1.with_attr("Compiler", - tvm.tir.StringImm("test_compiler")) - func1 = func1.with_attr("global_symbol", - container.String("test_compiler_1")) + func1 = func1.with_attr("Compiler", "test_compiler") + func1 = func1.with_attr("global_symbol", "test_compiler_1") gv1 = relay.GlobalVar("test_compiler_1") mod[gv1] = func1 @@ -609,10 +604,8 @@ def expected(): bn.astuple()) func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1)) - func0 = func0.with_attr("Compiler", - tvm.tir.StringImm("test_compiler")) - func0 = func0.with_attr("global_symbol", - container.String("test_compiler_0")) + func0 = func0.with_attr("Compiler", "test_compiler") + func0 = func0.with_attr("global_symbol", "test_compiler_0") # main function data = relay.var("data", relay.TensorType((1, 16, 224, 224), "float32")) @@ -646,8 +639,8 @@ def expected(): func = relay.Function([y0], add) func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1)) - func = func.with_attr("Compiler", tvm.tir.StringImm("ccompiler")) - func = func.with_attr("global_symbol", container.String("ccompiler_0")) + func = func.with_attr("Compiler", "ccompiler") + func = func.with_attr("global_symbol", "ccompiler_0") glb_0 = relay.GlobalVar("ccompiler_0") mod[glb_0] = func add_call = relay.Call(glb_0, [y]) @@ -746,10 +739,8 @@ def expected(): bn_mean, bn_var], tuple_o) func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1)) - func0 = func0.with_attr("Compiler", - tvm.tir.StringImm("test_target")) - func0 = func0.with_attr("global_symbol", - container.String("test_target_2")) + func0 = func0.with_attr("Compiler", "test_target") + func0 = func0.with_attr("global_symbol", "test_target_2") gv0 = relay.GlobalVar("test_target_2") mod[gv0] = func0 @@ -814,10 +805,8 @@ def expected(): func1 = func1.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) func1 = func1.with_attr("Inline", tvm.tir.IntImm("int32", 1)) - func1 = func1.with_attr("Compiler", - tvm.tir.StringImm("test_target")) - func1 = func1.with_attr("global_symbol", - container.String("test_target_1")) + func1 = func1.with_attr("Compiler", "test_target") + func1 = func1.with_attr("global_symbol", "test_target_1") gv1 = relay.GlobalVar("test_target_1") mod[gv1] = func1 @@ -829,10 +818,8 @@ def expected(): func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1)) - func0 = func0.with_attr("Compiler", - tvm.tir.StringImm("test_target")) - func0 = func0.with_attr("global_symbol", - container.String("test_target_0")) + func0 = func0.with_attr("Compiler", "test_target") + func0 = func0.with_attr("global_symbol", "test_target_0") gv0 = relay.GlobalVar("test_target_0") mod[gv0] = func0 diff --git a/tests/python/unittest/test_ir_attrs.py b/tests/python/unittest/test_ir_attrs.py index 8f2e9bb8a80dd..48495f48dc5ad 100644 --- a/tests/python/unittest/test_ir_attrs.py +++ b/tests/python/unittest/test_ir_attrs.py @@ -41,7 +41,7 @@ def test_dict_attrs(): dattr = tvm.ir.make_node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0)) assert dattr.x.value == 1 datrr = tvm.ir.load_json(tvm.ir.save_json(dattr)) - assert dattr.name.value == "xyz" + assert dattr.name == "xyz" assert isinstance(dattr, tvm.ir.DictAttrs) assert "name" in dattr assert dattr["x"].value == 1 diff --git a/topi/include/topi/contrib/cublas.h b/topi/include/topi/contrib/cublas.h index 66b8a10baf7d2..ee18deae0781a 100644 --- a/topi/include/topi/contrib/cublas.h +++ b/topi/include/topi/contrib/cublas.h @@ -53,7 +53,7 @@ inline Tensor cublas_matmul(const Tensor& lhs, { { n, m } }, { lhs->dtype }, { lhs, rhs }, [&](Array ins, Array outs) { return call_packed({ - PrimExpr("tvm.contrib.cublas.matmul"), + runtime::String("tvm.contrib.cublas.matmul"), pack_buffer(ins[0]), pack_buffer(ins[1]), pack_buffer(outs[0]), @@ -85,7 +85,7 @@ inline Tensor cublas_batch_matmul(const Tensor& lhs, { { b, n, m } }, { lhs->dtype }, { lhs, rhs }, [&](Array ins, Array outs) { return call_packed({ - PrimExpr("tvm.contrib.cublas.batch_matmul"), + runtime::String("tvm.contrib.cublas.batch_matmul"), pack_buffer(ins[0]), pack_buffer(ins[1]), pack_buffer(outs[0]), diff --git a/topi/include/topi/contrib/rocblas.h b/topi/include/topi/contrib/rocblas.h index 2fcafc7ed251d..9fe1825fe65ee 100644 --- a/topi/include/topi/contrib/rocblas.h +++ b/topi/include/topi/contrib/rocblas.h @@ -52,7 +52,7 @@ inline Tensor rocblas_matmul(const Tensor& lhs, { { n, m } }, { lhs->dtype }, { lhs, rhs }, [&](Array ins, Array outs) { return call_packed({ - PrimExpr("tvm.contrib.rocblas.matmul"), + runtime::String("tvm.contrib.rocblas.matmul"), pack_buffer(ins[0]), pack_buffer(ins[1]), pack_buffer(outs[0]),