Skip to content

Commit

Permalink
Merge pull request #33 from slyubomirsky/value-objectref-ndarray
Browse files Browse the repository at this point in the history
[api update] Use ObjectRef instead of Value, NDArrays instead of TensorValue
  • Loading branch information
MarisaKirisame authored Jan 15, 2020
2 parents 2d1c9b2 + fd5c456 commit ab5772e
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 35 deletions.
4 changes: 1 addition & 3 deletions aot/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,13 @@ def convert(a, ctx):
elif isinstance(a, np.ndarray):
a = tvm.nd.array(a, ctx)
elif isinstance(a, tvm.ndarray.NDArray):
a = relay.backend.interpreter.TensorValue(a)
return a
elif isinstance(a, relay.Call):
assert isinstance(a.op, relay.Constructor)
a = (a.op, *a.args)
elif isinstance(a, tuple):
assert isinstance(a[0], relay.Constructor)
a = relay.backend.interpreter.ConstructorValue(a[0].tag, [convert(arg, ctx) for arg in a[1:]], a[0])
elif isinstance(a, relay.backend.interpreter.TensorValue):
return a
elif isinstance(a, relay.backend.interpreter.ConstructorValue):
return a
else:
Expand Down
62 changes: 30 additions & 32 deletions aot/to_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def visit_match(self, node):
for v in pattern_var_set:
bind_name = self.fresh_local_name()
self.name_map[v] = bind_name
stmt_str += f"Value {bind_name};\n"
stmt_str += f"ObjectRef {bind_name};\n"

# match data_name to pat, and fill the var accordingly.
# go to fail_label or ok_label base on failure/success.
Expand All @@ -143,7 +143,7 @@ def visit_pattern(pat, data_name, fail_label, ok_label):
for i, input_type in enumerate(pat.constructor.inputs):
bind_name = self.fresh_local_name()
bind_names.append(bind_name)
ok_case += f"Value {bind_name} = {data_name}->fields[{i}];\n"
ok_case += f"ObjectRef {bind_name} = {data_name}->fields[{i}];\n"
for bind_name, p in zip(bind_names, pat.patterns):
next_label = self.fresh_label_name()
ok_case += visit_pattern(p, bind_name, fail_label, next_label)
Expand All @@ -166,8 +166,8 @@ def visit_pattern(pat, data_name, fail_label, ok_label):

in_name = self.fresh_local_name()
out_name = self.fresh_local_name()
stmt_str += f"Value {in_name} = {vd.expr};\n"
stmt_str += f"Value {out_name};\n"
stmt_str += f"ObjectRef {in_name} = {vd.expr};\n"
stmt_str += f"ObjectRef {out_name};\n"
match_finish_label = self.fresh_label_name()
for c in node.clause:
vc = self.visit(c[1])
Expand Down Expand Up @@ -203,10 +203,10 @@ def visit_if(self, node):
vt = self.visit(node.true_branch)
vf = self.visit(node.false_branch)
ret_name = self.fresh_local_name()
stmt = f"Value {ret_name};"
stmt = f"ObjectRef {ret_name};"
stmt += f"""
{vc.stmt}
if (NDToBool(ValueToND({vc.expr}))) {{
if (NDToBool(ObjectRefToND({vc.expr}))) {{
{vt.stmt}
{ret_name} = {vt.expr};
}} else {{
Expand All @@ -220,7 +220,7 @@ def visit_constant(self, const):
if const not in self.declare_map:
name = self.fresh_global_name()
self.declare_map[const] = name
self.declare += f"Value {name};\n"
self.declare += f"ObjectRef {name};\n"
self.input_const.append((name, const.data.asnumpy()))
return ExprWithStmt(self.declare_map[const])

Expand All @@ -247,7 +247,7 @@ def visit_args(self, args):
def visit_invoke(self, invoke):
args_str, stmt_str = self.visit_args(invoke.args)
func = self.visit(invoke.call)
return ExprWithStmt(f"Apply({func.expr}, std::vector<Value>({{{args_str}}}))", stmt_str + func.stmt)
return ExprWithStmt(f"Apply({func.expr}, std::vector<ObjectRef>({{{args_str}}}))", stmt_str + func.stmt)

def visit_decl(self, decl):
source = ""
Expand All @@ -256,7 +256,7 @@ def visit_decl(self, decl):
self.name_map[var] = local_name
vv = self.visit(value, name=local_name)
source += vv.stmt
source += f"""Value {local_name} = {vv.expr};"""
source += f"""ObjectRef {local_name} = {vv.expr};"""
vb = self.visit(decl.body)
source += vb.stmt
return ExprWithStmt(vb.expr, source)
Expand Down Expand Up @@ -286,7 +286,7 @@ def visit_packed_call(self, call):
args_str = []
def convert_input(ty, arg):
if isinstance(ty, relay.ty.TensorType):
args_str.append(f"ValueToND({arg})")
args_str.append(f"{arg}")
else:
assert isinstance(ty, relay.ty.TupleType)
tuple_name = self.fresh_local_name()
Expand All @@ -302,8 +302,8 @@ def convert_output(ty):
if isinstance(ty, relay.ty.TensorType):
tensor_name = self.fresh_local_name()
nonlocal decl_str
decl_str += f"TensorValue {tensor_name} = TensorValueNode::make(NDArray::Empty({self.nd_shape(ty)}, {self.nd_dtype(ty)}, context));\n"
args_str.append(f"{tensor_name}->data")
decl_str += f"NDArray {tensor_name} = NDArray::Empty({self.nd_shape(ty)}, {self.nd_dtype(ty)}, context);\n"
args_str.append(f"{tensor_name}")
return tensor_name
else:
assert isinstance(ty, relay.ty.TupleType)
Expand All @@ -324,12 +324,12 @@ def visit_cpp_function(self, func, local, name):
for i, param in enumerate(func.params):
pname = self.fresh_local_name(param)
self.name_map[param] = pname
body += f"Value {pname} = {vec}.at({i});\n"
body += f"ObjectRef {pname} = {vec}.at({i});\n"

body += f"Value {name} = self;\n"
body += f"ObjectRef {name} = self;\n"
vb = self.visit(func.body)
body = body + vb.stmt + f"""return {vb.expr};"""
expr = f"""FunctionValueNode::make([=](const std::vector<Value>& {vec}, const Value& self) {{
expr = f"""FunctionValueNode::make([=](const std::vector<ObjectRef>& {vec}, const ObjectRef& self) {{
{body}
}});
"""
Expand All @@ -340,11 +340,11 @@ def visit_cpp_function(self, func, local, name):
if name is None:
name = self.fresh_global_name()
self.declare += f"""
static Value {name}_func() {{
static Value ret = {expr};
static ObjectRef {name}_func() {{
static ObjectRef ret = {expr};
return ret;
}}
Value {name} = {name}_func();
ObjectRef {name} = {name}_func();
"""
return ExprWithStmt(f"{name}")

Expand All @@ -369,8 +369,8 @@ def mk_register_api(self, name: str, func) -> str:
TVM_REGISTER_GLOBAL("{name}")
.set_body([](TVMArgs args, TVMRetValue* ret) {{
{init}
std::initializer_list<Value> ilist = {{{args}}};
*ret = Apply({vf.expr}, std::vector<Value>(ilist));
std::initializer_list<ObjectRef> ilist = {{{args}}};
*ret = Apply({vf.expr}, std::vector<ObjectRef>(ilist));
}});
"""
return source
Expand All @@ -387,7 +387,7 @@ def mk_file(body, ctx):
return f"""
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include <tvm/node/env_func.h>
#include <tvm/ir/env_func.h>
#include <tvm/relay/interpreter.h>
#include <iostream>
Expand All @@ -411,13 +411,11 @@ def mk_file(body, ctx):
return reinterpret_cast<uint8_t*>(cpu_array->data)[0];
}}
static NDArray ValueToND(const Value& v) {{
const TensorValueNode* tv = v.as<TensorValueNode>();
CHECK(tv);
return tv->data;
static NDArray ObjectRefToND(const ObjectRef& v) {{
return Downcast<runtime::NDArray>(v);
}}
static ConstructorValue TagToCV(size_t tag, const tvm::Array<Value>& fields) {{
static ConstructorValue TagToCV(size_t tag, const tvm::Array<ObjectRef>& fields) {{
ObjectPtr<ConstructorValueNode> n = make_object<ConstructorValueNode>();
ObjectPtr<ConstructorNode> con = make_object<ConstructorNode>();
con->tag = tag;
Expand All @@ -430,8 +428,8 @@ def mk_file(body, ctx):
/*! \\brief A Function value. */
class FunctionValue;
using function_value_t = std::function<Value(const std::vector<Value>&, const Value&)>;
struct FunctionValueNode : ValueNode {{
using function_value_t = std::function<ObjectRef(const std::vector<ObjectRef>&, const ObjectRef&)>;
struct FunctionValueNode : Object {{
function_value_t f;
FunctionValueNode() {{ }}
Expand All @@ -441,12 +439,12 @@ class FunctionValue;
TVM_DLL static FunctionValue make(const function_value_t& f);
static constexpr const char* _type_key = "relay.FunctionValue";
TVM_DECLARE_FINAL_OBJECT_INFO(FunctionValueNode, ValueNode);
TVM_DECLARE_FINAL_OBJECT_INFO(FunctionValueNode, Object);
}};
class FunctionValue : public Value {{
class FunctionValue : public ObjectRef {{
public:
TVM_DEFINE_OBJECT_REF_METHODS(FunctionValue, Value, FunctionValueNode);
TVM_DEFINE_OBJECT_REF_METHODS(FunctionValue, ObjectRef, FunctionValueNode);
}};
FunctionValue FunctionValueNode::make(const function_value_t& f) {{
Expand All @@ -455,7 +453,7 @@ class FunctionValue : public Value {{
return FunctionValue(n);
}}
Value Apply(const Value& op, const std::vector<Value>& args) {{
ObjectRef Apply(const ObjectRef& op, const std::vector<ObjectRef>& args) {{
return Downcast<FunctionValue>(op)->f(args, op);
}}
Expand Down

0 comments on commit ab5772e

Please sign in to comment.