diff --git a/aot/convert.py b/aot/convert.py index 420a0c0..65fa225 100644 --- a/aot/convert.py +++ b/aot/convert.py @@ -10,15 +10,13 @@ def convert(a, ctx): elif isinstance(a, np.ndarray): a = tvm.nd.array(a, ctx) elif isinstance(a, tvm.ndarray.NDArray): - a = relay.backend.interpreter.TensorValue(a) + return a elif isinstance(a, relay.Call): assert isinstance(a.op, relay.Constructor) a = (a.op, *a.args) elif isinstance(a, tuple): assert isinstance(a[0], relay.Constructor) a = relay.backend.interpreter.ConstructorValue(a[0].tag, [convert(arg, ctx) for arg in a[1:]], a[0]) - elif isinstance(a, relay.backend.interpreter.TensorValue): - return a elif isinstance(a, relay.backend.interpreter.ConstructorValue): return a else: diff --git a/aot/to_source.py b/aot/to_source.py index 9f673e7..84b88b1 100644 --- a/aot/to_source.py +++ b/aot/to_source.py @@ -130,7 +130,7 @@ def visit_match(self, node): for v in pattern_var_set: bind_name = self.fresh_local_name() self.name_map[v] = bind_name - stmt_str += f"Value {bind_name};\n" + stmt_str += f"ObjectRef {bind_name};\n" # match data_name to pat, and fill the var accordingly. # go to fail_label or ok_label base on failure/success. @@ -143,7 +143,7 @@ def visit_pattern(pat, data_name, fail_label, ok_label): for i, input_type in enumerate(pat.constructor.inputs): bind_name = self.fresh_local_name() bind_names.append(bind_name) - ok_case += f"Value {bind_name} = {data_name}->fields[{i}];\n" + ok_case += f"ObjectRef {bind_name} = {data_name}->fields[{i}];\n" for bind_name, p in zip(bind_names, pat.patterns): next_label = self.fresh_label_name() ok_case += visit_pattern(p, bind_name, fail_label, next_label) @@ -166,8 +166,8 @@ def visit_pattern(pat, data_name, fail_label, ok_label): in_name = self.fresh_local_name() out_name = self.fresh_local_name() - stmt_str += f"Value {in_name} = {vd.expr};\n" - stmt_str += f"Value {out_name};\n" + stmt_str += f"ObjectRef {in_name} = {vd.expr};\n" + stmt_str += f"ObjectRef {out_name};\n" match_finish_label = self.fresh_label_name() for c in node.clause: vc = self.visit(c[1]) @@ -203,10 +203,10 @@ def visit_if(self, node): vt = self.visit(node.true_branch) vf = self.visit(node.false_branch) ret_name = self.fresh_local_name() - stmt = f"Value {ret_name};" + stmt = f"ObjectRef {ret_name};" stmt += f""" {vc.stmt} - if (NDToBool(ValueToND({vc.expr}))) {{ + if (NDToBool(ObjectRefToND({vc.expr}))) {{ {vt.stmt} {ret_name} = {vt.expr}; }} else {{ @@ -220,7 +220,7 @@ def visit_constant(self, const): if const not in self.declare_map: name = self.fresh_global_name() self.declare_map[const] = name - self.declare += f"Value {name};\n" + self.declare += f"ObjectRef {name};\n" self.input_const.append((name, const.data.asnumpy())) return ExprWithStmt(self.declare_map[const]) @@ -247,7 +247,7 @@ def visit_args(self, args): def visit_invoke(self, invoke): args_str, stmt_str = self.visit_args(invoke.args) func = self.visit(invoke.call) - return ExprWithStmt(f"Apply({func.expr}, std::vector({{{args_str}}}))", stmt_str + func.stmt) + return ExprWithStmt(f"Apply({func.expr}, std::vector({{{args_str}}}))", stmt_str + func.stmt) def visit_decl(self, decl): source = "" @@ -256,7 +256,7 @@ def visit_decl(self, decl): self.name_map[var] = local_name vv = self.visit(value, name=local_name) source += vv.stmt - source += f"""Value {local_name} = {vv.expr};""" + source += f"""ObjectRef {local_name} = {vv.expr};""" vb = self.visit(decl.body) source += vb.stmt return ExprWithStmt(vb.expr, source) @@ -286,7 +286,7 @@ def visit_packed_call(self, call): args_str = [] def convert_input(ty, arg): if isinstance(ty, relay.ty.TensorType): - args_str.append(f"ValueToND({arg})") + args_str.append(f"{arg}") else: assert isinstance(ty, relay.ty.TupleType) tuple_name = self.fresh_local_name() @@ -302,8 +302,8 @@ def convert_output(ty): if isinstance(ty, relay.ty.TensorType): tensor_name = self.fresh_local_name() nonlocal decl_str - decl_str += f"TensorValue {tensor_name} = TensorValueNode::make(NDArray::Empty({self.nd_shape(ty)}, {self.nd_dtype(ty)}, context));\n" - args_str.append(f"{tensor_name}->data") + decl_str += f"NDArray {tensor_name} = NDArray::Empty({self.nd_shape(ty)}, {self.nd_dtype(ty)}, context);\n" + args_str.append(f"{tensor_name}") return tensor_name else: assert isinstance(ty, relay.ty.TupleType) @@ -324,12 +324,12 @@ def visit_cpp_function(self, func, local, name): for i, param in enumerate(func.params): pname = self.fresh_local_name(param) self.name_map[param] = pname - body += f"Value {pname} = {vec}.at({i});\n" + body += f"ObjectRef {pname} = {vec}.at({i});\n" - body += f"Value {name} = self;\n" + body += f"ObjectRef {name} = self;\n" vb = self.visit(func.body) body = body + vb.stmt + f"""return {vb.expr};""" - expr = f"""FunctionValueNode::make([=](const std::vector& {vec}, const Value& self) {{ + expr = f"""FunctionValueNode::make([=](const std::vector& {vec}, const ObjectRef& self) {{ {body} }}); """ @@ -340,11 +340,11 @@ def visit_cpp_function(self, func, local, name): if name is None: name = self.fresh_global_name() self.declare += f""" - static Value {name}_func() {{ - static Value ret = {expr}; + static ObjectRef {name}_func() {{ + static ObjectRef ret = {expr}; return ret; }} - Value {name} = {name}_func(); + ObjectRef {name} = {name}_func(); """ return ExprWithStmt(f"{name}") @@ -369,8 +369,8 @@ def mk_register_api(self, name: str, func) -> str: TVM_REGISTER_GLOBAL("{name}") .set_body([](TVMArgs args, TVMRetValue* ret) {{ {init} - std::initializer_list ilist = {{{args}}}; - *ret = Apply({vf.expr}, std::vector(ilist)); + std::initializer_list ilist = {{{args}}}; + *ret = Apply({vf.expr}, std::vector(ilist)); }}); """ return source @@ -387,7 +387,7 @@ def mk_file(body, ctx): return f""" #include #include - #include + #include #include #include @@ -411,13 +411,11 @@ def mk_file(body, ctx): return reinterpret_cast(cpu_array->data)[0]; }} - static NDArray ValueToND(const Value& v) {{ - const TensorValueNode* tv = v.as(); - CHECK(tv); - return tv->data; + static NDArray ObjectRefToND(const ObjectRef& v) {{ + return Downcast(v); }} - static ConstructorValue TagToCV(size_t tag, const tvm::Array& fields) {{ + static ConstructorValue TagToCV(size_t tag, const tvm::Array& fields) {{ ObjectPtr n = make_object(); ObjectPtr con = make_object(); con->tag = tag; @@ -430,8 +428,8 @@ def mk_file(body, ctx): /*! \\brief A Function value. */ class FunctionValue; - using function_value_t = std::function&, const Value&)>; - struct FunctionValueNode : ValueNode {{ + using function_value_t = std::function&, const ObjectRef&)>; + struct FunctionValueNode : Object {{ function_value_t f; FunctionValueNode() {{ }} @@ -441,12 +439,12 @@ class FunctionValue; TVM_DLL static FunctionValue make(const function_value_t& f); static constexpr const char* _type_key = "relay.FunctionValue"; - TVM_DECLARE_FINAL_OBJECT_INFO(FunctionValueNode, ValueNode); + TVM_DECLARE_FINAL_OBJECT_INFO(FunctionValueNode, Object); }}; - class FunctionValue : public Value {{ + class FunctionValue : public ObjectRef {{ public: - TVM_DEFINE_OBJECT_REF_METHODS(FunctionValue, Value, FunctionValueNode); + TVM_DEFINE_OBJECT_REF_METHODS(FunctionValue, ObjectRef, FunctionValueNode); }}; FunctionValue FunctionValueNode::make(const function_value_t& f) {{ @@ -455,7 +453,7 @@ class FunctionValue : public Value {{ return FunctionValue(n); }} - Value Apply(const Value& op, const std::vector& args) {{ + ObjectRef Apply(const ObjectRef& op, const std::vector& args) {{ return Downcast(op)->f(args, op); }}