diff --git a/include/tvm/node/container.h b/include/tvm/node/container.h index 461fa11b4f309..cf2ac260d6748 100644 --- a/include/tvm/node/container.h +++ b/include/tvm/node/container.h @@ -36,6 +36,8 @@ namespace tvm { +using runtime::String; +using runtime::StringObj; using runtime::Object; using runtime::ObjectPtr; using runtime::ObjectRef; diff --git a/src/node/serialization.cc b/src/node/serialization.cc index c661b34dd5e20..97436bae242c9 100644 --- a/src/node/serialization.cc +++ b/src/node/serialization.cc @@ -102,7 +102,7 @@ class NodeIndexer : public AttrVisitor { for (const auto& kv : n->data) { MakeIndex(const_cast(kv.second.get())); } - } else if (!node->IsInstance()) { + } else if (!node->IsInstance()) { reflection_->VisitAttrs(node, this); } } @@ -242,7 +242,7 @@ class JSONAttrGetter : public AttrVisitor { node_->data.push_back( node_index_->at(const_cast(kv.second.get()))); } - } else if (node->IsInstance()) { + } else if (node->IsInstance()) { node_->data.push_back(node_index_->at(node)); } else { // recursively index normal object. @@ -337,7 +337,11 @@ class JSONAttrSetter : public AttrVisitor { n->data[node_->keys[i]] = ObjectRef(node_list_->at(node_->data[i])); } - } else if (!node->IsInstance()) { + } else if (node->IsInstance()) { + StringObj* n = static_cast(node); + auto saved = node_list_->at(node_->data[0]); + saved = runtime::GetObjectPtr(n); + } else { reflection_->VisitAttrs(node, this); } } diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 6a06f83e5a24e..36e0914b5b792 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -617,14 +617,14 @@ class CompileEngineImpl : public CompileEngineNode { for (const auto& it : cache_) { auto src_func = it.first->source_func; CHECK(src_func.defined()); - if (src_func->GetAttr(attr::kCompiler).defined()) { - auto code_gen = src_func->GetAttr(attr::kCompiler); + if (src_func->GetAttr(attr::kCompiler).defined()) { + auto code_gen = src_func->GetAttr(attr::kCompiler); CHECK(code_gen.defined()) << "No external codegen is set"; std::string code_gen_name = code_gen.operator std::string(); if (ext_mods.find(code_gen_name) == ext_mods.end()) { ext_mods[code_gen_name] = IRModule({}, {}); } - auto symbol_name = src_func->GetAttr(tvm::attr::kGlobalSymbol); + auto symbol_name = src_func->GetAttr(tvm::attr::kGlobalSymbol); CHECK(symbol_name.defined()) << "No external symbol is set for:\n" << AsText(src_func, false); auto gv = GlobalVar(std::string(symbol_name)); @@ -692,10 +692,10 @@ class CompileEngineImpl : public CompileEngineNode { } // No need to lower external functions for now. We will invoke the external // codegen tool once and lower all functions together. - if (key->source_func->GetAttr(attr::kCompiler).defined()) { + if (key->source_func->GetAttr(attr::kCompiler).defined()) { auto cache_node = make_object(); const auto name_node = - key->source_func->GetAttr(tvm::attr::kGlobalSymbol); + key->source_func->GetAttr(tvm::attr::kGlobalSymbol); CHECK(name_node.defined()) << "External function has not been attached a name yet."; cache_node->func_name = std::string(name_node); diff --git a/src/relay/backend/contrib/codegen_c/codegen_c.h b/src/relay/backend/contrib/codegen_c/codegen_c.h index 79d4d3fd3946c..1db3f20ef05b8 100644 --- a/src/relay/backend/contrib/codegen_c/codegen_c.h +++ b/src/relay/backend/contrib/codegen_c/codegen_c.h @@ -70,7 +70,7 @@ class CSourceModuleCodegenBase { */ std::string GetExtSymbol(const Function& func) const { const auto name_node = - func->GetAttr(tvm::attr::kGlobalSymbol); + func->GetAttr(tvm::attr::kGlobalSymbol); CHECK(name_node.defined()) << "Fail to retrieve external symbol."; return std::string(name_node); } diff --git a/src/relay/backend/graph_runtime_codegen.cc b/src/relay/backend/graph_runtime_codegen.cc index c126017f982f4..e4d632317966b 100644 --- a/src/relay/backend/graph_runtime_codegen.cc +++ b/src/relay/backend/graph_runtime_codegen.cc @@ -419,7 +419,7 @@ class GraphRuntimeCodegen auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower"); Target target; // Handle external function - if (func->GetAttr(attr::kCompiler).defined()) { + if (func->GetAttr(attr::kCompiler).defined()) { target = tvm::target::ext_dev(); CCacheKey key = (*pf0)(func, target); CachedFunc ext_func = (*pf1)(compile_engine_, key); @@ -482,7 +482,7 @@ class GraphRuntimeCodegen return {}; } std::vector VisitExpr_(const FunctionNode* op) override { - CHECK(op->GetAttr(attr::kCompiler).defined()) + CHECK(op->GetAttr(attr::kCompiler).defined()) << "Only functions supported by custom codegen"; return {}; } diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 1a55b504f29e9..13bfb05eded86 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -475,7 +475,7 @@ class VMFunctionCompiler : ExprFunctor { Target target; - if (func->GetAttr(attr::kCompiler).defined()) { + if (func->GetAttr(attr::kCompiler).defined()) { target = tvm::target::ext_dev(); } else { // Next generate the invoke instruction. @@ -493,7 +493,7 @@ class VMFunctionCompiler : ExprFunctor { auto cfunc = engine_->Lower(key); auto op_index = -1; - if (func->GetAttr(attr::kCompiler).defined()) { + if (func->GetAttr(attr::kCompiler).defined()) { op_index = context_->cached_funcs.size(); context_->cached_funcs.push_back(cfunc); } else { diff --git a/src/relay/backend/vm/inline_primitives.cc b/src/relay/backend/vm/inline_primitives.cc index 9ecea40f24fd2..12113b0683f2b 100644 --- a/src/relay/backend/vm/inline_primitives.cc +++ b/src/relay/backend/vm/inline_primitives.cc @@ -122,7 +122,7 @@ struct PrimitiveInliner : ExprMutator { auto global = pair.first; auto base_func = pair.second; if (auto* n = base_func.as()) { - if (n->GetAttr(attr::kCompiler).defined()) continue; + if (n->GetAttr(attr::kCompiler).defined()) continue; auto func = GetRef(n); DLOG(INFO) << "Before inlining primitives: " << global diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc index 0cd17a97fb03a..59c549cabfee8 100644 --- a/src/relay/backend/vm/lambda_lift.cc +++ b/src/relay/backend/vm/lambda_lift.cc @@ -190,7 +190,7 @@ class LambdaLifter : public ExprMutator { auto glob_funcs = module_->functions; for (auto pair : glob_funcs) { if (auto* n = pair.second.as()) { - if (n->GetAttr(attr::kCompiler).defined()) continue; + if (n->GetAttr(attr::kCompiler).defined()) continue; auto func = GetRef(n); func = Function(func->params, VisitExpr(func->body), diff --git a/src/relay/ir/transform.cc b/src/relay/ir/transform.cc index d5f0fbac29c8d..fa709ebca061d 100644 --- a/src/relay/ir/transform.cc +++ b/src/relay/ir/transform.cc @@ -145,7 +145,7 @@ IRModule FunctionPassNode::operator()(const IRModule& mod, bool FunctionPassNode::SkipFunction(const Function& func) const { return func->GetAttr(attr::kSkipOptimization, 0)->value != 0 || - (func->GetAttr(attr::kCompiler).defined()); + (func->GetAttr(attr::kCompiler).defined()); } Pass CreateFunctionPass( diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index 0863cd33183c4..65e17fcefddb8 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -59,7 +59,7 @@ class AnnotateTargetWrapper : public ExprMutator { // handle composite functions Function func = Downcast(call->op); CHECK(func.defined()); - auto comp_name = func->GetAttr(attr::kComposite); + auto comp_name = func->GetAttr(attr::kComposite); if (comp_name.defined()) { std::string comp_name_str = comp_name; size_t i = comp_name_str.find('.'); @@ -148,7 +148,7 @@ class AnnotateTargetWrapper : public ExprMutator { Function func; Expr new_body; // don't step into composite functions - if (fn->GetAttr(attr::kComposite).defined()) { + if (fn->GetAttr(attr::kComposite).defined()) { func = GetRef(fn); new_body = func->body; } else { diff --git a/src/relay/transforms/inline.cc b/src/relay/transforms/inline.cc index 8400b3fb3088e..ba0f5688ea9d8 100644 --- a/src/relay/transforms/inline.cc +++ b/src/relay/transforms/inline.cc @@ -131,7 +131,7 @@ class Inliner : ExprMutator { fn->attrs); // Inline the function body to the caller if this function uses default // compiler, i.e. no external codegen is needed. - if (!func->GetAttr(attr::kCompiler).defined()) { + if (!func->GetAttr(attr::kCompiler).defined()) { CHECK_EQ(func->params.size(), args.size()) << "Mismatch found in the number of parameters and call args"; // Bind the parameters with call args. diff --git a/src/relay/transforms/merge_composite.cc b/src/relay/transforms/merge_composite.cc index 1fb1dea93d83f..75d95f0378f1c 100644 --- a/src/relay/transforms/merge_composite.cc +++ b/src/relay/transforms/merge_composite.cc @@ -159,7 +159,7 @@ class MergeCompositeWrapper : public ExprMutator { if (call->op->IsInstance()) { Function func = Downcast(call->op); CHECK(func.defined()); - auto name_node = func->GetAttr(attr::kComposite); + auto name_node = func->GetAttr(attr::kComposite); // don't step into existing composite functions if (name_node.defined() && name_node != "") { tvm::Array new_args; diff --git a/src/relay/transforms/to_a_normal_form.cc b/src/relay/transforms/to_a_normal_form.cc index 2f18be851e286..21c516201dd73 100644 --- a/src/relay/transforms/to_a_normal_form.cc +++ b/src/relay/transforms/to_a_normal_form.cc @@ -299,7 +299,7 @@ IRModule ToANormalForm(const IRModule& m) { for (const auto& it : funcs) { CHECK_EQ(FreeVars(it.second).size(), 0); if (const auto* n = it.second.as()) { - if (n->GetAttr(attr::kCompiler).defined()) continue; + if (n->GetAttr(attr::kCompiler).defined()) continue; } Expr ret = TransformF([&](const Expr& e) { diff --git a/src/target/build_common.h b/src/target/build_common.h index fc45cef3a874e..5ba51da4ce672 100644 --- a/src/target/build_common.h +++ b/src/target/build_common.h @@ -57,7 +57,7 @@ ExtractFuncInfo(const IRModule& mod) { info.thread_axis_tags.push_back(thread_axis[i]->thread_tag); } } - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); fmap[static_cast(global_symbol)] = info; } return fmap; diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index f0b0a4b69cf69..a863056e82267 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -126,7 +126,7 @@ void CodeGenCPU::Init(const std::string& module_name, void CodeGenCPU::AddFunction(const PrimFunc& f) { CodeGenLLVM::AddFunction(f); if (f_tvm_register_system_symbol_ != nullptr) { - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); CHECK(global_symbol.defined()) << "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute"; export_system_symbols_.emplace_back( diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 28f4efd74c906..bb0b7e46baf8a 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -128,7 +128,7 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) { llvm::FunctionType* ftype = llvm::FunctionType::get( ret_void ? t_void_ : t_int_, param_types, false); - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); CHECK(global_symbol.defined()) << "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute"; CHECK(module_->getFunction(static_cast(global_symbol)) == nullptr) diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index 9ea77ac2d79f5..52dccbaf5eb61 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -214,7 +214,7 @@ class LLVMModuleNode final : public runtime::ModuleNode { << "Can only lower IR Module with PrimFuncs"; auto f = Downcast(kv.second); if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); CHECK(global_symbol.defined()); entry_func = global_symbol; } diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 0cb47427ff1c1..a0e18a6120554 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -78,7 +78,7 @@ void CodeGenC::AddFunction(const PrimFunc& f) { // reserve keywords ReserveKeywordsAsUnique(); - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); CHECK(global_symbol.defined()) << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias); diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index 2f31a3e3adf10..715c0ae92ddca 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -56,7 +56,7 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) { GetUniqueName("_"); // add to alloc buffer type. - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); CHECK(global_symbol.defined()) << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; diff --git a/src/target/source/codegen_opengl.cc b/src/target/source/codegen_opengl.cc index 474859977fcbb..13d87d282e6cb 100644 --- a/src/target/source/codegen_opengl.cc +++ b/src/target/source/codegen_opengl.cc @@ -156,7 +156,7 @@ void CodeGenOpenGL::AddFunction(const PrimFunc& f) { arg_kinds.push_back(kind); } - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); CHECK(global_symbol.defined()) << "CodeGenOpenGL: Expect PrimFunc to have the global_symbol attribute"; diff --git a/src/target/source/codegen_vhls.cc b/src/target/source/codegen_vhls.cc index 482d5a25a1a3e..d1ce1a7c5bab5 100644 --- a/src/target/source/codegen_vhls.cc +++ b/src/target/source/codegen_vhls.cc @@ -161,7 +161,7 @@ runtime::Module BuildSDAccel(IRModule mod, std::string target_str) { code = (*f)(code).operator std::string(); } - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); CHECK(global_symbol.defined()) << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; runtime::String func_name(global_symbol); diff --git a/src/target/spirv/build_vulkan.cc b/src/target/spirv/build_vulkan.cc index b6f9b86fbdb35..58721414a6651 100644 --- a/src/target/spirv/build_vulkan.cc +++ b/src/target/spirv/build_vulkan.cc @@ -90,7 +90,7 @@ runtime::Module BuildSPIRV(IRModule mod) { CHECK(calling_conv.defined() && calling_conv->value == static_cast(CallingConv::kDeviceKernelLaunch)) << "CodeGenSPIRV: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); CHECK(global_symbol.defined()) << "CodeGenSPIRV: Expect PrimFunc to have the global_symbol attribute"; diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index 0241e2218d71e..db2a2f359aa48 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -78,7 +78,7 @@ std::vector CodeGenSPIRV::BuildFunction(const PrimFunc& f) { builder_->MakeInst(spv::OpReturn); builder_->MakeInst(spv::OpFunctionEnd); - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); CHECK(global_symbol.defined()) << "CodeGenSPIRV: Expect PrimFunc to have the global_symbol attribute"; diff --git a/src/target/stackvm/codegen_stackvm.cc b/src/target/stackvm/codegen_stackvm.cc index af8b34142ec90..da75a70e91232 100644 --- a/src/target/stackvm/codegen_stackvm.cc +++ b/src/target/stackvm/codegen_stackvm.cc @@ -536,7 +536,7 @@ runtime::Module BuildStackVM(const IRModule& mod) { CHECK(kv.second->IsInstance()) << "CodeGenStackVM: Can only take PrimFunc"; auto f = Downcast(kv.second); - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); CHECK(global_symbol.defined()) << "CodeGenStackVM: Expect PrimFunc to have the global_symbol attribute"; std::string f_name = global_symbol; diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index 4b75c46452bd9..b1dd235bce03f 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -47,7 +47,7 @@ inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) { PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { - auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); CHECK(global_symbol.defined()) << "MakePackedAPI: Expect PrimFunc to have the global_symbol attribute"; std::string name_hint = global_symbol; diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index ae32bdcbadeac..5149d2882fb77 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -272,7 +272,7 @@ PrimFunc SplitHostDevice(PrimFunc&& func, IRModuleNode* device_mod) { auto target = func->GetAttr(tvm::attr::kTarget); CHECK(target.defined()) << "SplitHostDevice: Require the target attribute"; - auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); CHECK(global_symbol.defined()) << "SplitHostDevice: Expect PrimFunc to have the global_symbol attribute"; diff --git a/tests/python/relay/test_ir_nodes.py b/tests/python/relay/test_ir_nodes.py index dc73e3516e2da..f182cbaae6315 100644 --- a/tests/python/relay/test_ir_nodes.py +++ b/tests/python/relay/test_ir_nodes.py @@ -25,7 +25,9 @@ def check_json_roundtrip(node): json_str = tvm.ir.save_json(node) + print(node) back = tvm.ir.load_json(json_str) + print(back) assert tvm.ir.structural_equal(back, node, map_free_vars=True)