diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 410a6df504c5..f75da0731242 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -622,10 +623,10 @@ class CompileEngineImpl : public CompileEngineNode { if (ext_mods.find(code_gen->value) == ext_mods.end()) { ext_mods[code_gen->value] = IRModule({}, {}); } - auto symbol_name = src_func->GetAttr(attr::kExternalSymbol); + auto symbol_name = src_func->GetAttr(tvm::attr::kGlobalSymbol); CHECK(symbol_name.defined()) << "No external symbol is set for:\n" << AsText(src_func, false); - auto gv = GlobalVar(symbol_name->value); + auto gv = GlobalVar(std::string(symbol_name)); ext_mods[code_gen->value]->Add(gv, src_func); cached_ext_funcs.push_back(it.first); } @@ -693,10 +694,10 @@ class CompileEngineImpl : public CompileEngineNode { if (key->source_func->GetAttr(attr::kCompiler).defined()) { auto cache_node = make_object(); const auto name_node = - key->source_func->GetAttr(attr::kExternalSymbol); + key->source_func->GetAttr(tvm::attr::kGlobalSymbol); CHECK(name_node.defined()) << "External function has not been attached a name yet."; - cache_node->func_name = name_node->value; + cache_node->func_name = std::string(name_node); cache_node->target = tvm::target::ext_dev(); value->cached_func = CachedFunc(cache_node); return value; diff --git a/src/relay/backend/contrib/codegen_c/codegen_c.h b/src/relay/backend/contrib/codegen_c/codegen_c.h index 60cecef0ce3c..79d4d3fd3946 100644 --- a/src/relay/backend/contrib/codegen_c/codegen_c.h +++ b/src/relay/backend/contrib/codegen_c/codegen_c.h @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -69,10 +70,9 @@ class CSourceModuleCodegenBase { */ std::string GetExtSymbol(const Function& func) const { const auto name_node = - func->GetAttr(attr::kExternalSymbol); + func->GetAttr(tvm::attr::kGlobalSymbol); CHECK(name_node.defined()) << "Fail to retrieve external symbol."; - std::string ext_symbol = name_node->value; - return ext_symbol; + return std::string(name_node); } }; diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc index d8e93ed232a5..a4e38634bf9d 100644 --- a/src/relay/transforms/partition_graph.cc +++ b/src/relay/transforms/partition_graph.cc @@ -35,6 +35,7 @@ #include #include #include +#include #include #include @@ -239,8 +240,8 @@ class Partitioner : public ExprMutator { std::string target = call->attrs.as()->compiler; std::string name = target + "_" + std::to_string(region->GetID()); - global_region_func = WithAttr(std::move(global_region_func), attr::kExternalSymbol, - tir::StringImmNode::make(name)); + global_region_func = WithAttr(std::move(global_region_func), tvm::attr::kGlobalSymbol, + runtime::String(name)); global_region_func = WithAttr(std::move(global_region_func), attr::kPrimitive, tvm::Integer(1)); global_region_func = WithAttr(std::move(global_region_func), attr::kCompiler, diff --git a/tests/python/relay/test_external_codegen.py b/tests/python/relay/test_external_codegen.py index bda590f563f0..724e81d65af7 100644 --- a/tests/python/relay/test_external_codegen.py +++ b/tests/python/relay/test_external_codegen.py @@ -80,7 +80,8 @@ def check_graph_runtime_result(): def set_external_func_attr(func, compiler, ext_symbol): func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) func = func.with_attr("Compiler", tvm.tir.StringImm(compiler)) - func = func.with_attr("ExternalSymbol", tvm.tir.StringImm(ext_symbol)) + func = func.with_attr("global_symbol", + runtime.container.String(ext_symbol)) return func diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 9d4d71179fd7..ab9f47e77585 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -23,6 +23,7 @@ import tvm.relay.testing from tvm import relay from tvm import runtime +from tvm.runtime import container from tvm.relay import transform from tvm.contrib import util from tvm.relay.op.annotation import compiler_begin, compiler_end @@ -305,10 +306,8 @@ def expected(): func = relay.Function([x0, y0], add) func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1)) - func = func.with_attr("Compiler", - tvm.tir.StringImm("ccompiler")) - func = func.with_attr("ExternalSymbol", - tvm.tir.StringImm("ccompiler_0")) + func = func.with_attr("Compiler", tvm.tir.StringImm("ccompiler")) + func = func.with_attr("global_symbol", container.String("ccompiler_0")) glb_0 = relay.GlobalVar("ccompiler_0") mod[glb_0] = func add_call = relay.Call(glb_0, [x, y]) @@ -319,7 +318,7 @@ def expected(): concat = relay.concatenate([log, exp], axis=0) fused_func = relay.Function([p0], concat) fused_func = fused_func.with_attr("Primitive", - tvm.tir.IntImm("int32", 1)) + tvm.tir.IntImm("int32", 1)) fused_call = relay.Call(fused_func, [add_call]) main = relay.Function([x, y], fused_call) mod["main"] = main @@ -393,8 +392,7 @@ def expected(): func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1)) func = func.with_attr("Compiler", tvm.tir.StringImm("dnnl")) - func = func.with_attr("ExternalSymbol", - tvm.tir.StringImm("dnnl_0")) + func = func.with_attr("global_symbol", container.String("dnnl_0")) glb_var = relay.GlobalVar("dnnl_0") mod = tvm.IRModule() mod[glb_var] = func @@ -520,8 +518,8 @@ def expected(): func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1)) func0 = func0.with_attr("Compiler", tvm.tir.StringImm("test_compiler")) - func0 = func0.with_attr("ExternalSymbol", - tvm.tir.StringImm("test_compiler_0")) + func0 = func0.with_attr("global_symbol", + container.String("test_compiler_0")) gv0 = relay.GlobalVar("test_compiler_0") mod[gv0] = func0 @@ -539,8 +537,8 @@ def expected(): func1 = func1.with_attr("Inline", tvm.tir.IntImm("int32", 1)) func1 = func1.with_attr("Compiler", tvm.tir.StringImm("test_compiler")) - func1 = func1.with_attr("ExternalSymbol", - tvm.tir.StringImm("test_compiler_1")) + func1 = func1.with_attr("global_symbol", + container.String("test_compiler_1")) gv1 = relay.GlobalVar("test_compiler_1") mod[gv1] = func1 @@ -613,8 +611,8 @@ def expected(): func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1)) func0 = func0.with_attr("Compiler", tvm.tir.StringImm("test_compiler")) - func0 = func0.with_attr("ExternalSymbol", - tvm.tir.StringImm("test_compiler_0")) + func0 = func0.with_attr("global_symbol", + container.String("test_compiler_0")) # main function data = relay.var("data", relay.TensorType((1, 16, 224, 224), "float32")) @@ -649,8 +647,7 @@ def expected(): func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1)) func = func.with_attr("Compiler", tvm.tir.StringImm("ccompiler")) - func = func.with_attr("ExternalSymbol", - tvm.tir.StringImm("ccompiler_0")) + func = func.with_attr("global_symbol", container.String("ccompiler_0")) glb_0 = relay.GlobalVar("ccompiler_0") mod[glb_0] = func add_call = relay.Call(glb_0, [y]) @@ -751,8 +748,8 @@ def expected(): func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1)) func0 = func0.with_attr("Compiler", tvm.tir.StringImm("test_target")) - func0 = func0.with_attr("ExternalSymbol", - tvm.tir.StringImm("test_target_2")) + func0 = func0.with_attr("global_symbol", + container.String("test_target_2")) gv0 = relay.GlobalVar("test_target_2") mod[gv0] = func0 @@ -819,8 +816,8 @@ def expected(): func1 = func1.with_attr("Inline", tvm.tir.IntImm("int32", 1)) func1 = func1.with_attr("Compiler", tvm.tir.StringImm("test_target")) - func1 = func1.with_attr("ExternalSymbol", - tvm.tir.StringImm("test_target_1")) + func1 = func1.with_attr("global_symbol", + container.String("test_target_1")) gv1 = relay.GlobalVar("test_target_1") mod[gv1] = func1 @@ -834,8 +831,8 @@ def expected(): func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1)) func0 = func0.with_attr("Compiler", tvm.tir.StringImm("test_target")) - func0 = func0.with_attr("ExternalSymbol", - tvm.tir.StringImm("test_target_0")) + func0 = func0.with_attr("global_symbol", + container.String("test_target_0")) gv0 = relay.GlobalVar("test_target_0") mod[gv0] = func0