Skip to content

Commit

Permalink
kExternalSymbol -> kGlobalSymbol
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics committed Apr 2, 2020
1 parent 0449966 commit 023409b
Show file tree
Hide file tree
Showing 8 changed files with 58 additions and 34 deletions.
6 changes: 3 additions & 3 deletions include/tvm/runtime/container.h
Original file line number Diff line number Diff line change
Expand Up @@ -512,12 +512,12 @@ class String : public ObjectRef {
#endif
}

TVM_DEFINE_OBJECT_REF_METHODS(String, ObjectRef, StringObj);

private:
/*! \return the internal StringObj pointer */
const StringObj* get() const { return operator->(); }

TVM_DEFINE_OBJECT_REF_METHODS(String, ObjectRef, StringObj);

private:
/*!
* \brief Compare two char sequence
*
Expand Down
18 changes: 18 additions & 0 deletions python/tvm/runtime/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,4 +109,22 @@ def tuple_object(fields=None):
return _Tuple(*fields)


@tvm._ffi.register_object("runtime.String")
class String(Object):
"""The string object.
Parameters
----------
string : Str
The string used to construct a runtime String object
Returns
-------
ret : String
The created object.
"""
def __init__(self, string):
self.__init_handle_by_constructor__(_String, string)


tvm._ffi._init_api("tvm.runtime.container")
9 changes: 5 additions & 4 deletions src/relay/backend/compile_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <tvm/te/operation.h>
#include <tvm/te/schedule_pass.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/container.h>
#include <tvm/relay/attrs/device_copy.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr.h>
Expand Down Expand Up @@ -622,10 +623,10 @@ class CompileEngineImpl : public CompileEngineNode {
if (ext_mods.find(code_gen->value) == ext_mods.end()) {
ext_mods[code_gen->value] = IRModule({}, {});
}
auto symbol_name = src_func->GetAttr<tir::StringImm>(attr::kExternalSymbol);
auto symbol_name = src_func->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
CHECK(symbol_name.defined()) << "No external symbol is set for:\n"
<< AsText(src_func, false);
auto gv = GlobalVar(symbol_name->value);
auto gv = GlobalVar(std::string(symbol_name));
ext_mods[code_gen->value]->Add(gv, src_func);
cached_ext_funcs.push_back(it.first);
}
Expand Down Expand Up @@ -693,10 +694,10 @@ class CompileEngineImpl : public CompileEngineNode {
if (key->source_func->GetAttr<tir::StringImm>(attr::kCompiler).defined()) {
auto cache_node = make_object<CachedFuncNode>();
const auto name_node =
key->source_func->GetAttr<tir::StringImm>(attr::kExternalSymbol);
key->source_func->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
CHECK(name_node.defined())
<< "External function has not been attached a name yet.";
cache_node->func_name = name_node->value;
cache_node->func_name = std::string(name_node);
cache_node->target = tvm::target::ext_dev();
value->cached_func = CachedFunc(cache_node);
return value;
Expand Down
6 changes: 3 additions & 3 deletions src/relay/backend/contrib/codegen_c/codegen_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <tvm/relay/expr.h>
#include <tvm/relay/op.h>
#include <tvm/relay/function.h>
#include <tvm/runtime/container.h>
#include <sstream>
#include <string>
#include <utility>
Expand Down Expand Up @@ -69,10 +70,9 @@ class CSourceModuleCodegenBase {
*/
std::string GetExtSymbol(const Function& func) const {
const auto name_node =
func->GetAttr<tir::StringImm>(attr::kExternalSymbol);
func->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
CHECK(name_node.defined()) << "Fail to retrieve external symbol.";
std::string ext_symbol = name_node->value;
return ext_symbol;
return std::string(name_node);
}
};

Expand Down
5 changes: 3 additions & 2 deletions src/relay/transforms/partition_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
#include <tvm/runtime/container.h>

#include <unordered_map>
#include <unordered_set>
Expand Down Expand Up @@ -239,8 +240,8 @@ class Partitioner : public ExprMutator {
std::string target = call->attrs.as<CompilerAttrs>()->compiler;
std::string name = target + "_" + std::to_string(region->GetID());

global_region_func = WithAttr(std::move(global_region_func), attr::kExternalSymbol,
tir::StringImmNode::make(name));
global_region_func = WithAttr(std::move(global_region_func), tvm::attr::kGlobalSymbol,
runtime::String(name));
global_region_func =
WithAttr(std::move(global_region_func), attr::kPrimitive, tvm::Integer(1));
global_region_func = WithAttr(std::move(global_region_func), attr::kCompiler,
Expand Down
6 changes: 6 additions & 0 deletions src/runtime/container.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,13 @@ TVM_REGISTER_GLOBAL("runtime.container._ADT")
*rv = ADT(tag, fields);
});

TVM_REGISTER_GLOBAL("runtime.container._String")
.set_body_typed([](std::string str) {
return String(std::move(str));
});

TVM_REGISTER_OBJECT_TYPE(ADTObj);
TVM_REGISTER_OBJECT_TYPE(StringObj);
TVM_REGISTER_OBJECT_TYPE(ClosureObj);

} // namespace runtime
Expand Down
3 changes: 2 additions & 1 deletion tests/python/relay/test_external_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ def check_graph_runtime_result():
def set_external_func_attr(func, compiler, ext_symbol):
func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func = func.with_attr("Compiler", tvm.tir.StringImm(compiler))
func = func.with_attr("ExternalSymbol", tvm.tir.StringImm(ext_symbol))
func = func.with_attr("global_symbol",
runtime.container.String(ext_symbol))
return func


Expand Down
39 changes: 18 additions & 21 deletions tests/python/relay/test_pass_partition_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import tvm.relay.testing
from tvm import relay
from tvm import runtime
from tvm.runtime import container
from tvm.relay import transform
from tvm.contrib import util
from tvm.relay.op.annotation import compiler_begin, compiler_end
Expand Down Expand Up @@ -305,10 +306,8 @@ def expected():
func = relay.Function([x0, y0], add)
func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func = func.with_attr("Compiler",
tvm.tir.StringImm("ccompiler"))
func = func.with_attr("ExternalSymbol",
tvm.tir.StringImm("ccompiler_0"))
func = func.with_attr("Compiler", tvm.tir.StringImm("ccompiler"))
func = func.with_attr("global_symbol", container.String("ccompiler_0"))
glb_0 = relay.GlobalVar("ccompiler_0")
mod[glb_0] = func
add_call = relay.Call(glb_0, [x, y])
Expand All @@ -319,7 +318,7 @@ def expected():
concat = relay.concatenate([log, exp], axis=0)
fused_func = relay.Function([p0], concat)
fused_func = fused_func.with_attr("Primitive",
tvm.tir.IntImm("int32", 1))
tvm.tir.IntImm("int32", 1))
fused_call = relay.Call(fused_func, [add_call])
main = relay.Function([x, y], fused_call)
mod["main"] = main
Expand Down Expand Up @@ -393,8 +392,7 @@ def expected():
func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func = func.with_attr("Compiler", tvm.tir.StringImm("dnnl"))
func = func.with_attr("ExternalSymbol",
tvm.tir.StringImm("dnnl_0"))
func = func.with_attr("global_symbol", container.String("dnnl_0"))
glb_var = relay.GlobalVar("dnnl_0")
mod = tvm.IRModule()
mod[glb_var] = func
Expand Down Expand Up @@ -520,8 +518,8 @@ def expected():
func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Compiler",
tvm.tir.StringImm("test_compiler"))
func0 = func0.with_attr("ExternalSymbol",
tvm.tir.StringImm("test_compiler_0"))
func0 = func0.with_attr("global_symbol",
container.String("test_compiler_0"))
gv0 = relay.GlobalVar("test_compiler_0")
mod[gv0] = func0

Expand All @@ -539,8 +537,8 @@ def expected():
func1 = func1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func1 = func1.with_attr("Compiler",
tvm.tir.StringImm("test_compiler"))
func1 = func1.with_attr("ExternalSymbol",
tvm.tir.StringImm("test_compiler_1"))
func1 = func1.with_attr("global_symbol",
container.String("test_compiler_1"))
gv1 = relay.GlobalVar("test_compiler_1")
mod[gv1] = func1

Expand Down Expand Up @@ -613,8 +611,8 @@ def expected():
func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Compiler",
tvm.tir.StringImm("test_compiler"))
func0 = func0.with_attr("ExternalSymbol",
tvm.tir.StringImm("test_compiler_0"))
func0 = func0.with_attr("global_symbol",
container.String("test_compiler_0"))

# main function
data = relay.var("data", relay.TensorType((1, 16, 224, 224), "float32"))
Expand Down Expand Up @@ -649,8 +647,7 @@ def expected():
func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func = func.with_attr("Compiler", tvm.tir.StringImm("ccompiler"))
func = func.with_attr("ExternalSymbol",
tvm.tir.StringImm("ccompiler_0"))
func = func.with_attr("global_symbol", container.String("ccompiler_0"))
glb_0 = relay.GlobalVar("ccompiler_0")
mod[glb_0] = func
add_call = relay.Call(glb_0, [y])
Expand Down Expand Up @@ -751,8 +748,8 @@ def expected():
func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Compiler",
tvm.tir.StringImm("test_target"))
func0 = func0.with_attr("ExternalSymbol",
tvm.tir.StringImm("test_target_2"))
func0 = func0.with_attr("global_symbol",
container.String("test_target_2"))
gv0 = relay.GlobalVar("test_target_2")
mod[gv0] = func0

Expand Down Expand Up @@ -819,8 +816,8 @@ def expected():
func1 = func1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func1 = func1.with_attr("Compiler",
tvm.tir.StringImm("test_target"))
func1 = func1.with_attr("ExternalSymbol",
tvm.tir.StringImm("test_target_1"))
func1 = func1.with_attr("global_symbol",
container.String("test_target_1"))
gv1 = relay.GlobalVar("test_target_1")
mod[gv1] = func1

Expand All @@ -834,8 +831,8 @@ def expected():
func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Compiler",
tvm.tir.StringImm("test_target"))
func0 = func0.with_attr("ExternalSymbol",
tvm.tir.StringImm("test_target_0"))
func0 = func0.with_attr("global_symbol",
container.String("test_target_0"))
gv0 = relay.GlobalVar("test_target_0")
mod[gv0] = func0

Expand Down

0 comments on commit 023409b

Please sign in to comment.