diff --git a/include/tvm/runtime/registry.h b/include/tvm/runtime/registry.h index 50bb5c5b967d..40e1a520cb67 100644 --- a/include/tvm/runtime/registry.h +++ b/include/tvm/runtime/registry.h @@ -83,6 +83,169 @@ class Registry { Registry& set_body_typed(FLambda f) { return set_body(TypedPackedFunc(f).packed()); } + + /*! + * \brief set the body of the function to the given function pointer. + * Note that this doesn't work with lambdas, you need to + * explicitly give a type for those. + * Note that this will ignore default arg values and always require all arguments to be provided. + * + * \code + * + * int multiply(int x, int y) { + * return x * y; + * } + * + * TVM_REGISTER_API("multiply") + * .set_body_typed(multiply); // will have type int(int, int) + * + * \endcode + * + * \param f The function to forward to. + * \tparam R the return type of the function (inferred). + * \tparam Args the argument types of the function (inferred). + */ + template + Registry& set_body_typed(R (*f)(Args...)) { + return set_body(TypedPackedFunc(f)); + } + + /*! + * \brief set the body of the function to be the passed method pointer. + * Note that this will ignore default arg values and always require all arguments to be provided. + * + * \code + * + * // node subclass: + * struct Example { + * int doThing(int x); + * } + * TVM_REGISTER_API("Example_doThing") + * .set_body_method(&Example::doThing); // will have type int(Example, int) + * + * \endcode + * + * \param f the method pointer to forward to. + * \tparam T the type containing the method (inferred). + * \tparam R the return type of the function (inferred). + * \tparam Args the argument types of the function (inferred). + */ + template + Registry& set_body_method(R (T::*f)(Args...)) { + return set_body_typed([f](T target, Args... params) -> R { + // call method pointer + return (target.*f)(params...); + }); + } + + /*! + * \brief set the body of the function to be the passed method pointer. + * Note that this will ignore default arg values and always require all arguments to be provided. + * + * \code + * + * // node subclass: + * struct Example { + * int doThing(int x); + * } + * TVM_REGISTER_API("Example_doThing") + * .set_body_method(&Example::doThing); // will have type int(Example, int) + * + * \endcode + * + * \param f the method pointer to forward to. + * \tparam T the type containing the method (inferred). + * \tparam R the return type of the function (inferred). + * \tparam Args the argument types of the function (inferred). + */ + template + Registry& set_body_method(R (T::*f)(Args...) const) { + return set_body_typed([f](const T target, Args... params) -> R { + // call method pointer + return (target.*f)(params...); + }); + } + + /*! + * \brief set the body of the function to be the passed method pointer. + * Used when calling a method on a Node subclass through a NodeRef subclass. + * Note that this will ignore default arg values and always require all arguments to be provided. + * + * \code + * + * // node subclass: + * struct ExampleNode: BaseNode { + * int doThing(int x); + * } + * + * // noderef subclass + * struct Example; + * + * TVM_REGISTER_API("Example_doThing") + * .set_body_method(&ExampleNode::doThing); // will have type int(Example, int) + * + * // note that just doing: + * // .set_body_method(&ExampleNode::doThing); + * // wouldn't work, because ExampleNode can't be taken from a TVMArgValue. + * + * \endcode + * + * \param f the method pointer to forward to. + * \tparam TNodeRef the node reference type to call the method on + * \tparam TNode the node type containing the method (inferred). + * \tparam R the return type of the function (inferred). + * \tparam Args the argument types of the function (inferred). + */ + template::value>::type> + Registry& set_body_method(R (TNode::*f)(Args...)) { + return set_body_typed([f](TNodeRef ref, Args... params) { + TNode* target = ref.operator->(); + // call method pointer + return (target->*f)(params...); + }); + } + + /*! + * \brief set the body of the function to be the passed method pointer. + * Used when calling a method on a Node subclass through a NodeRef subclass. + * Note that this will ignore default arg values and always require all arguments to be provided. + * + * \code + * + * // node subclass: + * struct ExampleNode: BaseNode { + * int doThing(int x); + * } + * + * // noderef subclass + * struct Example; + * + * TVM_REGISTER_API("Example_doThing") + * .set_body_method(&ExampleNode::doThing); // will have type int(Example, int) + * + * // note that just doing: + * // .set_body_method(&ExampleNode::doThing); + * // wouldn't work, because ExampleNode can't be taken from a TVMArgValue. + * + * \endcode + * + * \param f the method pointer to forward to. + * \tparam TNodeRef the node reference type to call the method on + * \tparam TNode the node type containing the method (inferred). + * \tparam R the return type of the function (inferred). + * \tparam Args the argument types of the function (inferred). + */ + template::value>::type> + Registry& set_body_method(R (TNode::*f)(Args...) const) { + return set_body_typed([f](TNodeRef ref, Args... params) { + const TNode* target = ref.operator->(); + // call method pointer + return (target->*f)(params...); + }); + } + /*! * \brief Register a function with given name * \param name The name of the function. diff --git a/nnvm/src/compiler/compile_engine.cc b/nnvm/src/compiler/compile_engine.cc index a1422d7a2eee..542455969b8b 100644 --- a/nnvm/src/compiler/compile_engine.cc +++ b/nnvm/src/compiler/compile_engine.cc @@ -360,9 +360,7 @@ TVM_REGISTER_GLOBAL("nnvm.compiler.GraphKeyGetGraph") }); TVM_REGISTER_GLOBAL("nnvm.compiler.MakeGraphKey") -.set_body([](tvm::runtime::TVMArgs args, tvm::runtime::TVMRetValue *rv) { - *rv = GraphKeyNode::make(args[0], args[1], args[2]); - }); +.set_body_typed(GraphKeyNode::make); // This can be used to extract workloads from nnvm compiler TVM_REGISTER_GLOBAL("nnvm.compiler.CacheItem2ScheduleArgs") diff --git a/nnvm/src/compiler/graph_hash.cc b/nnvm/src/compiler/graph_hash.cc index e825ef4efe57..b76f99fa58d3 100644 --- a/nnvm/src/compiler/graph_hash.cc +++ b/nnvm/src/compiler/graph_hash.cc @@ -235,8 +235,6 @@ std::string GraphDeepCompare(const Graph& a, } TVM_REGISTER_GLOBAL("nnvm.graph.DeepCompare") -.set_body([](tvm::runtime::TVMArgs args, tvm::runtime::TVMRetValue *rv) { - *rv = GraphDeepCompare(args[0], args[1], args[2]); - }); +.set_body_typed(GraphDeepCompare); } // namespace compiler } // namespace nnvm diff --git a/src/api/api_arith.cc b/src/api/api_arith.cc index ca0bed18f554..fce73aabf6a7 100644 --- a/src/api/api_arith.cc +++ b/src/api/api_arith.cc @@ -31,73 +31,51 @@ namespace tvm { namespace arith { TVM_REGISTER_API("arith.intset_single_point") -.set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = IntSet::single_point(args[0]); - }); +.set_body_typed(IntSet::single_point); TVM_REGISTER_API("arith.intset_vector") -.set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = IntSet::vector(args[0]); - }); +.set_body_typed(IntSet::vector); TVM_REGISTER_API("arith.intset_interval") -.set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = IntSet::interval(args[0], args[1]); - }); +.set_body_typed(IntSet::interval); TVM_REGISTER_API("arith.DetectLinearEquation") -.set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = DetectLinearEquation(args[0], args[1]); - }); +.set_body_typed(DetectLinearEquation); TVM_REGISTER_API("arith.DetectClipBound") -.set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = DetectClipBound(args[0], args[1]); - }); +.set_body_typed(DetectClipBound); TVM_REGISTER_API("arith.DeduceBound") -.set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = DeduceBound(args[0], args[1], - args[2].operator Map(), - args[3].operator Map()); - }); +.set_body_typed, Map)>([]( + Expr v, Expr cond, + const Map hint_map, + const Map relax_map +) { + return DeduceBound(v, cond, hint_map, relax_map); +}); TVM_REGISTER_API("arith.DomainTouched") -.set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = DomainTouched(args[0], args[1], args[2], args[3]); - }); +.set_body_typed(DomainTouched); TVM_REGISTER_API("_IntervalSetGetMin") -.set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = args[0].operator IntSet().min(); - }); +.set_body_method(&IntSet::min); TVM_REGISTER_API("_IntervalSetGetMax") -.set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = args[0].operator IntSet().max(); - }); +.set_body_method(&IntSet::max); TVM_REGISTER_API("_IntSetIsNothing") -.set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = args[0].operator IntSet().is_nothing(); - }); +.set_body_method(&IntSet::is_nothing); TVM_REGISTER_API("_IntSetIsEverything") -.set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = args[0].operator IntSet().is_everything(); - }); +.set_body_method(&IntSet::is_everything); TVM_REGISTER_API("arith._make_ConstIntBound") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = ConstIntBoundNode::make(args[0], args[1]); - }); +.set_body_typed(ConstIntBoundNode::make); TVM_REGISTER_API("arith._make_ModularSet") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = ModularSetNode::make(args[0], args[1]); - }); +.set_body_typed(ModularSetNode::make); TVM_REGISTER_API("arith._CreateAnalyzer") .set_body([](TVMArgs args, TVMRetValue* ret) { diff --git a/src/api/api_base.cc b/src/api/api_base.cc index 23d1f5c67f7c..28ebb4d65005 100644 --- a/src/api/api_base.cc +++ b/src/api/api_base.cc @@ -50,9 +50,8 @@ TVM_REGISTER_API("_load_json") .set_body_typed(LoadJSON); TVM_REGISTER_API("_TVMSetStream") -.set_body([](TVMArgs args, TVMRetValue *ret) { - TVMSetStream(args[0], args[1], args[2]); - }); +.set_body_typed(TVMSetStream); + TVM_REGISTER_API("_save_param_dict") .set_body([](TVMArgs args, TVMRetValue *rv) { CHECK_EQ(args.size() % 2, 0u); diff --git a/src/api/api_codegen.cc b/src/api/api_codegen.cc index e44ebbec7085..73e26719cf15 100644 --- a/src/api/api_codegen.cc +++ b/src/api/api_codegen.cc @@ -41,8 +41,6 @@ TVM_REGISTER_API("codegen._Build") }); TVM_REGISTER_API("module._PackImportsToC") -.set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = PackImportsToC(args[0], args[1]); - }); +.set_body_typed(PackImportsToC); } // namespace codegen } // namespace tvm diff --git a/src/api/api_ir.cc b/src/api/api_ir.cc index c5680bb3df8d..2525059b47ba 100644 --- a/src/api/api_ir.cc +++ b/src/api/api_ir.cc @@ -31,54 +31,43 @@ namespace tvm { namespace ir { TVM_REGISTER_API("_Var") -.set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = Variable::make(args[1], args[0]); +.set_body_typed([](std::string s, Type t) { + return Variable::make(t, s); }); TVM_REGISTER_API("make.abs") -.set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = tvm::abs(args[0]); - }); +.set_body_typed(tvm::abs); TVM_REGISTER_API("make.floor") -.set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = tvm::floor(args[0]); - }); +.set_body_typed(tvm::floor); TVM_REGISTER_API("make.ceil") -.set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = tvm::ceil(args[0]); - }); +.set_body_typed(tvm::ceil); TVM_REGISTER_API("make.round") -.set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = tvm::round(args[0]); - }); +.set_body_typed(tvm::round); TVM_REGISTER_API("make.trunc") -.set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = tvm::trunc(args[0]); - }); +.set_body_typed(tvm::trunc); TVM_REGISTER_API("make._cast") -.set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = tvm::cast(args[0], args[1]); - }); +.set_body_typed(tvm::cast); TVM_REGISTER_API("make._range_by_min_extent") -.set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = Range::make_by_min_extent(args[0], args[1]); - }); +.set_body_typed(Range::make_by_min_extent); TVM_REGISTER_API("make.For") -.set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = For::make(args[0], - args[1], - args[2], - static_cast(args[3].operator int()), - static_cast(args[4].operator int()), - args[5]); - }); +.set_body_typed([]( + VarExpr loop_var, Expr min, Expr extent, + int for_type, int device_api, Stmt body +) { + return For::make(loop_var, + min, + extent, + static_cast(for_type), + static_cast(device_api), + body); +}); TVM_REGISTER_API("make.Load") .set_body([](TVMArgs args, TVMRetValue *ret) { @@ -101,114 +90,87 @@ TVM_REGISTER_API("make.Store") }); TVM_REGISTER_API("make.Realize") -.set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = Realize::make(args[0], - args[1], - args[2], - args[3], - args[4], - args[5]); - }); - +.set_body_typed(Realize::make); TVM_REGISTER_API("make.Call") -.set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = Call::make(args[0], - args[1], - args[2], - static_cast(args[3].operator int()), - args[4], - args[5]); - }); +.set_body_typed, int, FunctionRef, int)>([]( + Type type, std::string name, + Array args, int call_type, + FunctionRef func, int value_index +) { + return Call::make(type, + name, + args, + static_cast(call_type), + func, + value_index); +}); TVM_REGISTER_API("make.CommReducer") -.set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = CommReducerNode::make(args[0], - args[1], - args[2], - args[3]); - }); +.set_body_typed(CommReducerNode::make); // make from two arguments -#define REGISTER_MAKE1(Node) \ - TVM_REGISTER_API("make."#Node) \ - .set_body([](TVMArgs args, TVMRetValue *ret) { \ - *ret = Node::make(args[0]); \ - }) \ - -#define REGISTER_MAKE2(Node) \ +#define REGISTER_MAKE(Node) \ TVM_REGISTER_API("make."#Node) \ - .set_body([](TVMArgs args, TVMRetValue *ret) { \ - *ret = Node::make(args[0], args[1]); \ - }) \ - -#define REGISTER_MAKE3(Node) \ - TVM_REGISTER_API("make."#Node) \ - .set_body([](TVMArgs args, TVMRetValue *ret) { \ - *ret = Node::make(args[0], args[1], args[2]); \ - }) \ - -#define REGISTER_MAKE4(Node) \ - TVM_REGISTER_API("make."#Node) \ - .set_body([](TVMArgs args, TVMRetValue *ret) { \ - *ret = Node::make(args[0], args[1], args[2], args[3]); \ - }) \ - -#define REGISTER_MAKE5(Node) \ - TVM_REGISTER_API("make."#Node) \ - .set_body([](TVMArgs args, TVMRetValue *ret) { \ - *ret = Node::make(args[0], args[1], args[2], args[3], args[4]); \ - }) \ - - -REGISTER_MAKE5(Reduce); -REGISTER_MAKE4(AttrStmt); - -REGISTER_MAKE2(IntImm); -REGISTER_MAKE2(UIntImm); -REGISTER_MAKE2(FloatImm); -REGISTER_MAKE1(StringImm); - -REGISTER_MAKE2(Add); -REGISTER_MAKE2(Sub); -REGISTER_MAKE2(Mul); -REGISTER_MAKE2(Div); -REGISTER_MAKE2(Mod); -REGISTER_MAKE2(Min); -REGISTER_MAKE2(Max); -REGISTER_MAKE2(EQ); -REGISTER_MAKE2(NE); -REGISTER_MAKE2(LT); -REGISTER_MAKE2(LE); -REGISTER_MAKE2(GT); -REGISTER_MAKE2(GE); -REGISTER_MAKE2(And); -REGISTER_MAKE2(Or); - -REGISTER_MAKE1(Not); -REGISTER_MAKE3(Select); -REGISTER_MAKE3(Ramp); -REGISTER_MAKE2(Cast); -REGISTER_MAKE2(Broadcast); -REGISTER_MAKE2(Shuffle); -REGISTER_MAKE3(Let); -REGISTER_MAKE3(LetStmt); -REGISTER_MAKE3(AssertStmt); -REGISTER_MAKE3(ProducerConsumer); -REGISTER_MAKE5(Allocate); -REGISTER_MAKE4(Provide); -REGISTER_MAKE4(Prefetch); -REGISTER_MAKE1(Free); -REGISTER_MAKE2(Block); -REGISTER_MAKE3(IfThenElse); -REGISTER_MAKE1(Evaluate); + .set_body_typed(Node::make); \ + +REGISTER_MAKE(Reduce); +REGISTER_MAKE(AttrStmt); + +REGISTER_MAKE(IntImm); +REGISTER_MAKE(UIntImm); +REGISTER_MAKE(FloatImm); +REGISTER_MAKE(StringImm); + +REGISTER_MAKE(Add); +REGISTER_MAKE(Sub); +REGISTER_MAKE(Mul); +REGISTER_MAKE(Div); +REGISTER_MAKE(Mod); +REGISTER_MAKE(Min); +REGISTER_MAKE(Max); +REGISTER_MAKE(EQ); +REGISTER_MAKE(NE); +REGISTER_MAKE(LT); +REGISTER_MAKE(LE); +REGISTER_MAKE(GT); +REGISTER_MAKE(GE); +REGISTER_MAKE(And); +REGISTER_MAKE(Or); + +REGISTER_MAKE(Not); +REGISTER_MAKE(Select); +REGISTER_MAKE(Ramp); +REGISTER_MAKE(Cast); +REGISTER_MAKE(Broadcast); +REGISTER_MAKE(Shuffle); +REGISTER_MAKE(Let); +REGISTER_MAKE(LetStmt); +REGISTER_MAKE(AssertStmt); +REGISTER_MAKE(ProducerConsumer); +REGISTER_MAKE(Provide); +REGISTER_MAKE(Prefetch); +REGISTER_MAKE(Free); +REGISTER_MAKE(IfThenElse); +REGISTER_MAKE(Evaluate); + +// overloaded, needs special handling +TVM_REGISTER_API("make.Block") + .set_body_typed(static_cast(Block::make)); + +// has default args +TVM_REGISTER_API("make.Allocate") + .set_body_typed, Expr, Stmt)>([]( + VarExpr buffer_var, Type type, Array extents, Expr condition, Stmt body + ){ + return Allocate::make(buffer_var, type, extents, condition, body); + }); // operator overloading, smarter than make #define REGISTER_MAKE_BINARY_OP(Node, Func) \ TVM_REGISTER_API("make."#Node) \ - .set_body([](TVMArgs args, TVMRetValue *ret) { \ - Expr a = args[0], b = args[1]; \ - *ret = (Func(a, b)); \ + .set_body_typed([](Expr a, Expr b) { \ + return (Func(a, b)); \ }) #define REGISTER_MAKE_BIT_OP(Node, Func) \ diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc index aac73f1878f8..42d60b85e375 100644 --- a/src/api/api_lang.cc +++ b/src/api/api_lang.cc @@ -32,19 +32,14 @@ #include #include + namespace tvm { TVM_REGISTER_API("_min_value") -.set_body([](TVMArgs args, TVMRetValue* ret) { - Type t = args[0].operator Type(); - *ret = t.min(); - }); +.set_body_method(&Type::min); TVM_REGISTER_API("_max_value") -.set_body([](TVMArgs args, TVMRetValue* ret) { - Type t = args[0].operator Type(); - *ret = t.max(); - }); +.set_body_method(&Type::max); TVM_REGISTER_API("_const") .set_body([](TVMArgs args, TVMRetValue* ret) { @@ -58,9 +53,7 @@ TVM_REGISTER_API("_const") }); TVM_REGISTER_API("_str") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = ir::StringImm::make(args[0]); -}); +.set_body_typed(ir::StringImm::make); TVM_REGISTER_API("_Array") @@ -214,373 +207,217 @@ TVM_REGISTER_API("Range") }); TVM_REGISTER_API("_Buffer") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = BufferNode::make(args[0], - args[1], - args[2], - args[3], - args[4], - args[5], - args[6], - args[7], - args[8]); - }); +.set_body_typed(BufferNode::make); TVM_REGISTER_API("_BufferAccessPtr") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = args[0].operator Buffer() - .access_ptr(args[1], args[2], args[3], args[4]); - }); +.set_body_method(&Buffer::access_ptr); TVM_REGISTER_API("_BufferVLoad") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = args[0].operator Buffer() - .vload(args[1], args[2]); - }); +.set_body_method(&Buffer::vload); TVM_REGISTER_API("_BufferVStore") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = args[0].operator Buffer() - .vstore(args[1], args[2]); - }); +.set_body_method(&Buffer::vstore); TVM_REGISTER_API("_Layout") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = LayoutNode::make(args[0]); - }); +.set_body_typed(LayoutNode::make); TVM_REGISTER_API("_LayoutIndexOf") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = args[0].operator Layout() - .IndexOf(LayoutAxis::make(args[1])); +.set_body_typed([](Layout layout, std::string axis) { + return layout.IndexOf(LayoutAxis::make(axis)); }); TVM_REGISTER_API("_LayoutFactorOf") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = args[0].operator Layout() - .FactorOf(LayoutAxis::make(args[1])); +.set_body_typed([](Layout layout, std::string axis) { + return layout.FactorOf(LayoutAxis::make(axis)); }); TVM_REGISTER_API("_LayoutNdim") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = static_cast(args[0].operator Layout().ndim()); +.set_body_typed([](Layout layout) { + return layout.ndim(); }); TVM_REGISTER_API("_LayoutGetItem") -.set_body([](TVMArgs args, TVMRetValue* ret) { - const LayoutAxis& axis = args[0].operator Layout()[args[1]]; - *ret = axis.name(); +.set_body_typed([](Layout layout, int idx) { + const LayoutAxis& axis = layout[idx]; + return axis.name(); }); TVM_REGISTER_API("_BijectiveLayout") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = BijectiveLayoutNode::make(args[0], args[1]); - }); +.set_body_typed(BijectiveLayoutNode::make); TVM_REGISTER_API("_BijectiveLayoutForwardIndex") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = args[0].operator BijectiveLayout() - .ForwardIndex(args[1]); - }); +.set_body_method(&BijectiveLayout::ForwardIndex); TVM_REGISTER_API("_BijectiveLayoutBackwardIndex") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = args[0].operator BijectiveLayout() - .BackwardIndex(args[1]); - }); +.set_body_method(&BijectiveLayout::BackwardIndex); TVM_REGISTER_API("_BijectiveLayoutForwardShape") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = args[0].operator BijectiveLayout() - .ForwardShape(args[1]); - }); +.set_body_method(&BijectiveLayout::ForwardShape); TVM_REGISTER_API("_BijectiveLayoutBackwardShape") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = args[0].operator BijectiveLayout() - .BackwardShape(args[1]); - }); +.set_body_method(&BijectiveLayout::BackwardShape); TVM_REGISTER_API("_Tensor") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = TensorNode::make(args[0], - args[1], - args[2], - args[3]); - }); +.set_body_typed(TensorNode::make); TVM_REGISTER_API("_TensorIntrin") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = TensorIntrinNode::make(args[0], - args[1], - args[2], - args[3], - args[4], - args[5], - args[6]); - }); +.set_body_typed(TensorIntrinNode::make); TVM_REGISTER_API("_TensorIntrinCall") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = TensorIntrinCallNode::make(args[0], - args[1], - args[2], - args[3]); - }); +.set_body_typed(TensorIntrinCallNode::make); TVM_REGISTER_API("_TensorEqual") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = args[0].operator Tensor() == args[1].operator Tensor(); - }); +.set_body_method(&Tensor::operator==); TVM_REGISTER_API("_TensorHash") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = static_cast( - std::hash()(args[0].operator Tensor())); +.set_body_typed([](Tensor tensor) { + return static_cast(std::hash()(tensor)); }); TVM_REGISTER_API("_Placeholder") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = placeholder(args[0], - args[1], - args[2]); - }); +.set_body_typed, Type, std::string)>([]( + Array shape, Type dtype, std::string name +) { + return placeholder(shape, dtype, name); +}); TVM_REGISTER_API("_ComputeOp") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = ComputeOpNode::make(args[0], - args[1], - args[2], - args[3], - args[4]); - }); +.set_body_typed(ComputeOpNode::make); TVM_REGISTER_API("_ScanOp") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = ScanOpNode::make(args[0], - args[1], - args[2], - args[3], - args[4], - args[5], - args[6], - args[7]); - }); +.set_body_typed(ScanOpNode::make); TVM_REGISTER_API("_TensorComputeOp") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = TensorComputeOpNode::make(args[0], - args[1], - args[2], - args[3], - args[4], - args[5], - args[6], - args[7]); - }); +.set_body_typed(TensorComputeOpNode::make); TVM_REGISTER_API("_ExternOp") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = ExternOpNode::make(args[0], - args[1], - args[2], - args[3], - args[4], - args[5], - args[6]); - }); +.set_body_typed(ExternOpNode::make); TVM_REGISTER_API("_HybridOp") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = HybridOpNode::make(args[0], - args[1], - args[2], - args[3], - args[4], - args[5]); - }); +.set_body_typed(HybridOpNode::make); TVM_REGISTER_API("_OpGetOutput") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = args[0].operator Operation().output( - static_cast(args[1].operator int64_t())); - }); +.set_body_typed([](Operation op, int64_t output) { + return op.output(static_cast(output)); +}); TVM_REGISTER_API("_OpNumOutputs") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = args[0].operator Operation()->num_outputs(); - }); +.set_body_method(&OperationNode::num_outputs); TVM_REGISTER_API("_OpInputTensors") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = args[0].operator Operation()->InputTensors(); - }); +.set_body_method(&OperationNode::InputTensors); TVM_REGISTER_API("_IterVar") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = IterVarNode::make( - args[0], args[1], - static_cast(args[2].operator int()), - args[3]); - }); +.set_body_typed([]( + Range dom, Var var, int iter_type, std::string thread_tag +) { + return IterVarNode::make( + dom, var, + static_cast(iter_type), + thread_tag); +}); TVM_REGISTER_API("_CreateSchedule") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = create_schedule(args[0].operator Array()); - }); +.set_body_typed(create_schedule); TVM_REGISTER_API("_StageSetScope") -.set_body([](TVMArgs args, TVMRetValue* ret) { - args[0].operator Stage() - .set_scope(args[1]); - }); +.set_body_method(&Stage::set_scope); TVM_REGISTER_API("_StageBind") -.set_body([](TVMArgs args, TVMRetValue* ret) { - args[0].operator Stage() - .bind(args[1], args[2]); - }); +.set_body_method(&Stage::bind); TVM_REGISTER_API("_StageSplitByFactor") -.set_body([](TVMArgs args, TVMRetValue* ret) { - IterVar outer, inner; - args[0].operator Stage() - .split(args[1], args[2], &outer, &inner); - *ret = Array({outer, inner}); - }); +.set_body_typed(Stage, IterVar, Expr)>([]( + Stage stage, IterVar parent, Expr factor +) { + IterVar outer, inner; + stage.split(parent, factor, &outer, &inner); + return Array({outer, inner}); +}); TVM_REGISTER_API("_StageSplitByNParts") -.set_body([](TVMArgs args, TVMRetValue* ret) { - IterVar outer, inner; - args[0].operator Stage() - .split_by_nparts(args[1], args[2], &outer, &inner); - *ret = Array({outer, inner}); - }); +.set_body_typed(Stage, IterVar, Expr)>([]( + Stage stage, IterVar parent, Expr nparts +) { + IterVar outer, inner; + stage.split_by_nparts(parent, nparts, &outer, &inner); + return Array({outer, inner}); +}); TVM_REGISTER_API("_StageFuse") -.set_body([](TVMArgs args, TVMRetValue* ret) { +.set_body_typed)>([](Stage stage, Array axes) { IterVar fused; - args[0].operator Stage() - .fuse(args[1], &fused); - *ret = fused; + stage.fuse(axes, &fused); + return fused; }); TVM_REGISTER_API("_StageComputeAt") -.set_body([](TVMArgs args, TVMRetValue* ret) { - args[0].operator Stage() - .compute_at(args[1], args[2]); - }); +.set_body_method(&Stage::compute_at); TVM_REGISTER_API("_StageComputeInline") -.set_body([](TVMArgs args, TVMRetValue* ret) { - args[0].operator Stage() - .compute_inline(); - }); +.set_body_method(&Stage::compute_inline); TVM_REGISTER_API("_StageComputeRoot") -.set_body([](TVMArgs args, TVMRetValue* ret) { - args[0].operator Stage() - .compute_root(); - }); +.set_body_method(&Stage::compute_root); TVM_REGISTER_API("_StageReorder") -.set_body([](TVMArgs args, TVMRetValue* ret) { - args[0].operator Stage() - .reorder(args[1]); - }); +.set_body_method(&Stage::reorder); TVM_REGISTER_API("_StageTile") - .set_body([](TVMArgs args, TVMRetValue* ret) { +.set_body_typed(Stage, IterVar, IterVar, Expr, Expr)>([]( + Stage stage, + IterVar x_parent, IterVar y_parent, + Expr x_factor, Expr y_factor +) { IterVar x_outer, y_outer, x_inner, y_inner; - args[0].operator Stage() - .tile(args[1], args[2], - args[3], args[4], - &x_outer, &y_outer, - &x_inner, &y_inner); - *ret = Array({x_outer, y_outer, x_inner, y_inner}); + stage.tile(x_parent, y_parent, + x_factor, y_factor, + &x_outer, &y_outer, + &x_inner, &y_inner); + return Array({x_outer, y_outer, x_inner, y_inner}); }); TVM_REGISTER_API("_StageEnvThreads") - .set_body([](TVMArgs args, TVMRetValue* ret) { - args[0].operator Stage() - .env_threads(args[1]); - }); +.set_body_method(&Stage::env_threads); TVM_REGISTER_API("_StageSetStorePredicate") - .set_body([](TVMArgs args, TVMRetValue* ret) { - args[0].operator Stage() - .set_store_predicate(args[1]); - }); +.set_body_method(&Stage::set_store_predicate); TVM_REGISTER_API("_StageUnroll") - .set_body([](TVMArgs args, TVMRetValue* ret) { - args[0].operator Stage() - .unroll(args[1]); - }); +.set_body_method(&Stage::unroll); TVM_REGISTER_API("_StageVectorize") - .set_body([](TVMArgs args, TVMRetValue* ret) { - args[0].operator Stage() - .vectorize(args[1]); - }); +.set_body_method(&Stage::vectorize); TVM_REGISTER_API("_StageTensorize") - .set_body([](TVMArgs args, TVMRetValue* ret) { - args[0].operator Stage() - .tensorize(args[1], args[2]); - }); +.set_body_method(&Stage::tensorize); TVM_REGISTER_API("_StageParallel") - .set_body([](TVMArgs args, TVMRetValue* ret) { - args[0].operator Stage() - .parallel(args[1]); - }); +.set_body_method(&Stage::parallel); TVM_REGISTER_API("_StagePragma") - .set_body([](TVMArgs args, TVMRetValue* ret) { - args[0].operator Stage() - .pragma(args[1], args[2], args[3]); - }); +.set_body_method(&Stage::pragma); TVM_REGISTER_API("_StagePrefetch") - .set_body([](TVMArgs args, TVMRetValue *ret) { - args[0].operator Stage() - .prefetch(args[1], args[2], args[3]); - }); +.set_body_method(&Stage::prefetch); TVM_REGISTER_API("_StageStorageAlign") - .set_body([](TVMArgs args, TVMRetValue *ret) { - args[0].operator Stage() - .storage_align(args[1], args[2], args[3]); - }); +.set_body_method(&Stage::storage_align); TVM_REGISTER_API("_StageDoubleBuffer") - .set_body([](TVMArgs args, TVMRetValue *ret) { - args[0].operator Stage().double_buffer(); - }); +.set_body_method(&Stage::double_buffer); TVM_REGISTER_API("_StageOpenGL") - .set_body([](TVMArgs args, TVMRetValue *ret) { - args[0].operator Stage().opengl(); - }); +.set_body_method(&Stage::opengl); TVM_REGISTER_API("_ScheduleNormalize") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = args[0].operator Schedule() - .normalize(); - }); +.set_body_method(&Schedule::normalize); TVM_REGISTER_API("_ScheduleCreateGroup") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = args[0].operator Schedule() - .create_group(args[1], args[2], args[3]); - }); +.set_body_method(&Schedule::create_group); TVM_REGISTER_API("_ScheduleCacheRead") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = args[0].operator Schedule() - .cache_read(args[1], args[2], args[3]); - }); +.set_body_method(&Schedule::cache_read); TVM_REGISTER_API("_ScheduleCacheWrite") .set_body([](TVMArgs args, TVMRetValue* ret) { @@ -594,16 +431,9 @@ TVM_REGISTER_API("_ScheduleCacheWrite") }); TVM_REGISTER_API("_ScheduleRFactor") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = args[0].operator Schedule() - .rfactor(args[1], args[2], args[3]); - }); +.set_body_method(&Schedule::rfactor); TVM_REGISTER_API("_CommReducerCombine") -.set_body([](TVMArgs args, TVMRetValue* ret) { - const ir::CommReducerNode* combiner = - args[0].operator ir::CommReducer().as(); - *ret = (*combiner)(args[1], args[2]); - }); +.set_body_method(&ir::CommReducerNode::operator()); } // namespace tvm diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index 2e1ab42e4cbe..6195aac1b93f 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -119,68 +119,43 @@ TVM_REGISTER_API("ir_pass.PostOrderVisit") }); // make from two arguments -#define REGISTER_PASS1(PassName) \ +#define REGISTER_PASS(PassName) \ TVM_REGISTER_API("ir_pass."#PassName) \ - .set_body([](TVMArgs args, TVMRetValue *ret) { \ - *ret = PassName(args[0]); \ - }) \ - -#define REGISTER_PASS2(PassName) \ - TVM_REGISTER_API("ir_pass."#PassName) \ - .set_body([](TVMArgs args, TVMRetValue *ret) { \ - *ret = PassName(args[0], args[1]); \ - }) \ - -#define REGISTER_PASS3(PassName) \ - TVM_REGISTER_API("ir_pass."#PassName) \ - .set_body([](TVMArgs args, TVMRetValue *ret) { \ - *ret = PassName(args[0], args[1], args[2]); \ - }) \ - -#define REGISTER_PASS4(PassName) \ - TVM_REGISTER_API("ir_pass."#PassName) \ - .set_body([](TVMArgs args, TVMRetValue *ret) { \ - *ret = PassName(args[0], args[1], args[2], args[3]); \ - }) \ - -#define REGISTER_PASS5(PassName) \ - TVM_REGISTER_API("ir_pass."#PassName) \ - .set_body([](TVMArgs args, TVMRetValue *ret) { \ - *ret = PassName(args[0], args[1], args[2], args[3], args[4]); \ - }) \ - -REGISTER_PASS1(ConvertSSA); -REGISTER_PASS1(VerifySSA); -REGISTER_PASS1(RewriteUnsafeSelect); -REGISTER_PASS4(Inline); -REGISTER_PASS4(IRTransform); -REGISTER_PASS1(VectorizeLoop); -REGISTER_PASS5(UnrollLoop); -REGISTER_PASS3(InjectCopyIntrin); -REGISTER_PASS2(ThreadSync); -REGISTER_PASS5(MakeAPI); -REGISTER_PASS2(BindDeviceType); -REGISTER_PASS1(SplitHostDevice); -REGISTER_PASS1(StorageRewrite); -REGISTER_PASS1(CoProcSync); -REGISTER_PASS1(LowerStorageAccessInfo); -REGISTER_PASS1(InjectVirtualThread); -REGISTER_PASS1(InjectPrefetch); -REGISTER_PASS2(InjectDoubleBuffer); -REGISTER_PASS2(LoopPartition); -REGISTER_PASS1(RemoveNoOp); -REGISTER_PASS2(SplitPipeline); -REGISTER_PASS2(LiftAttrScope); -REGISTER_PASS1(NarrowChannelAccess); -REGISTER_PASS2(LowerThreadAllreduce); -REGISTER_PASS2(LowerWarpMemory); -REGISTER_PASS2(RemapThreadAxis); -REGISTER_PASS2(LowerIntrin); -REGISTER_PASS1(LowerTVMBuiltin); -REGISTER_PASS1(CombineContextCall); -REGISTER_PASS2(VerifyMemory); -REGISTER_PASS2(VerifyGPUCode); -REGISTER_PASS1(DecorateDeviceScope); -REGISTER_PASS1(InstrumentBoundCheckers); + .set_body_typed(PassName); \ + + +REGISTER_PASS(ConvertSSA); +REGISTER_PASS(VerifySSA); +REGISTER_PASS(RewriteUnsafeSelect); +REGISTER_PASS(Inline); +REGISTER_PASS(IRTransform); +REGISTER_PASS(VectorizeLoop); +REGISTER_PASS(UnrollLoop); +REGISTER_PASS(InjectCopyIntrin); +REGISTER_PASS(ThreadSync); +REGISTER_PASS(MakeAPI); +REGISTER_PASS(BindDeviceType); +REGISTER_PASS(SplitHostDevice); +REGISTER_PASS(StorageRewrite); +REGISTER_PASS(CoProcSync); +REGISTER_PASS(LowerStorageAccessInfo); +REGISTER_PASS(InjectVirtualThread); +REGISTER_PASS(InjectPrefetch); +REGISTER_PASS(InjectDoubleBuffer); +REGISTER_PASS(LoopPartition); +REGISTER_PASS(RemoveNoOp); +REGISTER_PASS(SplitPipeline); +REGISTER_PASS(LiftAttrScope); +REGISTER_PASS(NarrowChannelAccess); +REGISTER_PASS(LowerThreadAllreduce); +REGISTER_PASS(LowerWarpMemory); +REGISTER_PASS(RemapThreadAxis); +REGISTER_PASS(LowerIntrin); +REGISTER_PASS(LowerTVMBuiltin); +REGISTER_PASS(CombineContextCall); +REGISTER_PASS(VerifyMemory); +REGISTER_PASS(VerifyGPUCode); +REGISTER_PASS(DecorateDeviceScope); +REGISTER_PASS(InstrumentBoundCheckers); } // namespace ir } // namespace tvm diff --git a/src/api/api_schedule.cc b/src/api/api_schedule.cc index 45e2eb4c9375..177360bf2ebb 100644 --- a/src/api/api_schedule.cc +++ b/src/api/api_schedule.cc @@ -33,15 +33,11 @@ namespace tvm { namespace schedule { TVM_REGISTER_API("schedule.AutoInlineElemWise") -.set_body([](TVMArgs args, TVMRetValue* ret) { - AutoInlineElemWise(args[0]); - }); +.set_body_typed(AutoInlineElemWise); TVM_REGISTER_API("schedule.AutoInlineInjective") -.set_body([](TVMArgs args, TVMRetValue* ret) { - AutoInlineInjective(args[0]); - }); +.set_body_typed(AutoInlineInjective); TVM_REGISTER_API("schedule.ScheduleOps") .set_body([](TVMArgs args, TVMRetValue* ret) { @@ -51,25 +47,17 @@ TVM_REGISTER_API("schedule.ScheduleOps") *ret = ScheduleOps(args[0], args[1], args[2]); }); -#define REGISTER_SCHEDULE_PASS1(PassName) \ - TVM_REGISTER_API("schedule."#PassName) \ - .set_body([](TVMArgs args, TVMRetValue *ret) { \ - *ret = PassName(args[0]); \ - }) \ - -#define REGISTER_SCHEDULE_PASS2(PassName) \ +#define REGISTER_SCHEDULE_PASS(PassName) \ TVM_REGISTER_API("schedule."#PassName) \ - .set_body([](TVMArgs args, TVMRetValue *ret) { \ - *ret = PassName(args[0], args[1]); \ - }) \ + .set_body_typed(PassName); \ -REGISTER_SCHEDULE_PASS1(InferBound); -REGISTER_SCHEDULE_PASS1(CreateReadGraph); -REGISTER_SCHEDULE_PASS2(PostDFSOrder); -REGISTER_SCHEDULE_PASS1(CreateAttachPath); -REGISTER_SCHEDULE_PASS1(ScanGetBody); -REGISTER_SCHEDULE_PASS1(ScanFixPointAnalysis); +REGISTER_SCHEDULE_PASS(InferBound); +REGISTER_SCHEDULE_PASS(CreateReadGraph); +REGISTER_SCHEDULE_PASS(PostDFSOrder); +REGISTER_SCHEDULE_PASS(CreateAttachPath); +REGISTER_SCHEDULE_PASS(ScanGetBody); +REGISTER_SCHEDULE_PASS(ScanFixPointAnalysis); } // namespace schedule } // namespace tvm diff --git a/src/codegen/codegen_opencl.cc b/src/codegen/codegen_opencl.cc index 96e1b9efe8dd..382124a7ed2d 100644 --- a/src/codegen/codegen_opencl.cc +++ b/src/codegen/codegen_opencl.cc @@ -263,8 +263,6 @@ runtime::Module BuildOpenCL(Array funcs) { } TVM_REGISTER_API("codegen.build_opencl") -.set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = BuildOpenCL(args[0]); - }); +.set_body_typed(BuildOpenCL); } // namespace codegen } // namespace tvm diff --git a/src/codegen/codegen_opengl.cc b/src/codegen/codegen_opengl.cc index 27d910e7211b..797a7d1c406e 100644 --- a/src/codegen/codegen_opengl.cc +++ b/src/codegen/codegen_opengl.cc @@ -302,9 +302,7 @@ runtime::Module BuildOpenGL(Array funcs) { } TVM_REGISTER_API("codegen.build_opengl") -.set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = BuildOpenGL(args[0]); -}); +.set_body_typed(BuildOpenGL); } // namespace codegen } // namespace tvm diff --git a/src/codegen/codegen_vhls.cc b/src/codegen/codegen_vhls.cc index 460647a6e180..a18312fe6af5 100644 --- a/src/codegen/codegen_vhls.cc +++ b/src/codegen/codegen_vhls.cc @@ -164,9 +164,7 @@ runtime::Module BuildSDAccel(Array funcs, std::string target_str) { } TVM_REGISTER_API("codegen.build_sdaccel") -.set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = BuildSDAccel(args[0], args[1]); - }); +.set_body_typed(BuildSDAccel); } // namespace codegen } // namespace tvm diff --git a/src/codegen/llvm/codegen_amdgpu.cc b/src/codegen/llvm/codegen_amdgpu.cc index 22c432cc9e4b..396ae5956556 100644 --- a/src/codegen/llvm/codegen_amdgpu.cc +++ b/src/codegen/llvm/codegen_amdgpu.cc @@ -265,9 +265,7 @@ runtime::Module BuildAMDGPU(Array funcs, std::string target) { } TVM_REGISTER_API("codegen.build_rocm") -.set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = BuildAMDGPU(args[0], args[1]); - }); +.set_body_typed(BuildAMDGPU); } // namespace codegen } // namespace tvm diff --git a/src/codegen/llvm/codegen_nvptx.cc b/src/codegen/llvm/codegen_nvptx.cc index b1b541d4ab74..290727fd9152 100644 --- a/src/codegen/llvm/codegen_nvptx.cc +++ b/src/codegen/llvm/codegen_nvptx.cc @@ -243,9 +243,7 @@ runtime::Module BuildNVPTX(Array funcs, std::string target) { } TVM_REGISTER_API("codegen.build_nvptx") -.set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = BuildNVPTX(args[0], args[1]); - }); +.set_body_typed(BuildNVPTX); } // namespace codegen } // namespace tvm diff --git a/src/codegen/opt/build_cuda_on.cc b/src/codegen/opt/build_cuda_on.cc index 581c33086bee..fda239f0766f 100644 --- a/src/codegen/opt/build_cuda_on.cc +++ b/src/codegen/opt/build_cuda_on.cc @@ -155,8 +155,6 @@ runtime::Module BuildCUDA(Array funcs) { } TVM_REGISTER_API("codegen.build_cuda") -.set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = BuildCUDA(args[0]); - }); +.set_body_typed(BuildCUDA); } // namespace codegen } // namespace tvm diff --git a/src/codegen/source_module.cc b/src/codegen/source_module.cc index 70047a6050db..88be7fed448d 100644 --- a/src/codegen/source_module.cc +++ b/src/codegen/source_module.cc @@ -188,8 +188,6 @@ runtime::Module DeviceSourceModuleCreate( } TVM_REGISTER_GLOBAL("module.source_module_create") -.set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = SourceModuleCreate(args[0], args[1]); - }); +.set_body_typed(SourceModuleCreate); } // namespace codegen } // namespace tvm diff --git a/src/codegen/spirv/build_vulkan.cc b/src/codegen/spirv/build_vulkan.cc index 2b1ef660fbdc..18ffad1a58bc 100644 --- a/src/codegen/spirv/build_vulkan.cc +++ b/src/codegen/spirv/build_vulkan.cc @@ -103,9 +103,7 @@ runtime::Module BuildSPIRV(Array funcs) { } TVM_REGISTER_API("codegen.build_vulkan") -.set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = BuildSPIRV(args[0]); - }); +.set_body_typed(BuildSPIRV); } // namespace codegen } // namespace tvm diff --git a/src/codegen/stackvm/codegen_stackvm.cc b/src/codegen/stackvm/codegen_stackvm.cc index 8c4c258095d3..2d71a20a6232 100644 --- a/src/codegen/stackvm/codegen_stackvm.cc +++ b/src/codegen/stackvm/codegen_stackvm.cc @@ -522,8 +522,6 @@ runtime::Module BuildStackVM(const Array& funcs) { } TVM_REGISTER_API("codegen.build_stackvm") -.set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = BuildStackVM(args[0]); - }); +.set_body_typed(BuildStackVM); } // namespace codegen } // namespace tvm diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 735f1830d049..9af3f822a07d 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -51,9 +51,7 @@ Closure ClosureNode::make(tvm::Map env, Function func) { } TVM_REGISTER_API("relay._make.Closure") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = ClosureNode::make(args[0], args[1]); - }); +.set_body_typed(ClosureNode::make); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const ClosureNode* node, tvm::IRPrinter* p) { @@ -67,9 +65,7 @@ TupleValue TupleValueNode::make(tvm::Array value) { } TVM_REGISTER_API("relay._make.TupleValue") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = TupleValueNode::make(args[0]); - }); +.set_body_typed(TupleValueNode::make); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const TupleValueNode* node, tvm::IRPrinter* p) { @@ -90,10 +86,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) }); TVM_REGISTER_API("relay._make.TensorValue") -.set_body([](TVMArgs args, TVMRetValue* ret) { - runtime::NDArray data = args[0]; - *ret = TensorValueNode::make(data); - }); +.set_body_typed(TensorValueNode::make); RefValue RefValueNode::make(Value value) { NodePtr n = make_node(); @@ -102,9 +95,7 @@ RefValue RefValueNode::make(Value value) { } TVM_REGISTER_API("relay._make.RefValue") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = RefValueNode::make(args[0]); - }); +.set_body_typed(RefValueNode::make); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const RefValueNode* node, @@ -121,9 +112,7 @@ ConstructorValue ConstructorValueNode::make(Constructor constructor, } TVM_REGISTER_API("relay._make.ConstructorValue") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = ConstructorValueNode::make(args[0], args[1]); - }); +.set_body_typed(ConstructorValueNode::make); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const ConstructorValueNode* node, @@ -614,9 +603,7 @@ CreateInterpreter( } TVM_REGISTER_API("relay.backend.CreateInterpreter") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = CreateInterpreter(args[0], args[1], args[2]); - }); +.set_body_typed(CreateInterpreter); TVM_REGISTER_NODE_TYPE(ClosureNode); TVM_REGISTER_NODE_TYPE(TupleValueNode); diff --git a/src/relay/ir/adt.cc b/src/relay/ir/adt.cc index 2e7d854fbd2a..b59281a4f1fd 100644 --- a/src/relay/ir/adt.cc +++ b/src/relay/ir/adt.cc @@ -36,9 +36,7 @@ PatternWildcard PatternWildcardNode::make() { TVM_REGISTER_NODE_TYPE(PatternWildcardNode); TVM_REGISTER_API("relay._make.PatternWildcard") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = PatternWildcardNode::make(); - }); +.set_body_typed(PatternWildcardNode::make); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const PatternWildcardNode* node, @@ -55,9 +53,7 @@ PatternVar PatternVarNode::make(tvm::relay::Var var) { TVM_REGISTER_NODE_TYPE(PatternVarNode); TVM_REGISTER_API("relay._make.PatternVar") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = PatternVarNode::make(args[0]); - }); +.set_body_typed(PatternVarNode::make); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const PatternVarNode* node, @@ -76,9 +72,7 @@ PatternConstructor PatternConstructorNode::make(Constructor constructor, TVM_REGISTER_NODE_TYPE(PatternConstructorNode); TVM_REGISTER_API("relay._make.PatternConstructor") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = PatternConstructorNode::make(args[0], args[1]); - }); +.set_body_typed(PatternConstructorNode::make); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const PatternConstructorNode* node, @@ -100,9 +94,7 @@ Constructor ConstructorNode::make(std::string name_hint, TVM_REGISTER_NODE_TYPE(ConstructorNode); TVM_REGISTER_API("relay._make.Constructor") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = ConstructorNode::make(args[0], args[1], args[2]); - }); +.set_body_typed(ConstructorNode::make); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const ConstructorNode* node, @@ -124,9 +116,7 @@ TypeData TypeDataNode::make(GlobalTypeVar header, TVM_REGISTER_NODE_TYPE(TypeDataNode); TVM_REGISTER_API("relay._make.TypeData") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = TypeDataNode::make(args[0], args[1], args[2]); - }); +.set_body_typed(TypeDataNode::make); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const TypeDataNode* node, @@ -145,9 +135,7 @@ Clause ClauseNode::make(Pattern lhs, Expr rhs) { TVM_REGISTER_NODE_TYPE(ClauseNode); TVM_REGISTER_API("relay._make.Clause") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = ClauseNode::make(args[0], args[1]); - }); +.set_body_typed(ClauseNode::make); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const ClauseNode* node, @@ -166,9 +154,7 @@ Match MatchNode::make(Expr data, tvm::Array clauses) { TVM_REGISTER_NODE_TYPE(MatchNode); TVM_REGISTER_API("relay._make.Match") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = MatchNode::make(args[0], args[1]); - }); +.set_body_typed(MatchNode::make); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const MatchNode* node, diff --git a/src/relay/ir/alpha_equal.cc b/src/relay/ir/alpha_equal.cc index 967034519979..81017d4fddfa 100644 --- a/src/relay/ir/alpha_equal.cc +++ b/src/relay/ir/alpha_equal.cc @@ -505,18 +505,18 @@ bool AlphaEqual(const Expr& lhs, const Expr& rhs) { // TODO(@jroesch): move to correct namespace? TVM_REGISTER_API("relay._make._alpha_equal") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = AlphaEqualHandler(false).Equal(args[0], args[1]); +.set_body_typed([](NodeRef a, NodeRef b) { + return AlphaEqualHandler(false).Equal(a, b); }); TVM_REGISTER_API("relay._make._type_alpha_equal") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = AlphaEqualHandler(false).TypeEqual(args[0], args[1]); +.set_body_typed([](Type a, Type b) { + return AlphaEqualHandler(false).TypeEqual(a, b); }); TVM_REGISTER_API("relay._make._graph_equal") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = AlphaEqualHandler(true).Equal(args[0], args[1]); +.set_body_typed([](NodeRef a, NodeRef b) { + return AlphaEqualHandler(true).Equal(a, b); }); } // namespace relay } // namespace tvm diff --git a/src/relay/ir/base.cc b/src/relay/ir/base.cc index 9c35173bb47a..f60f6594559c 100644 --- a/src/relay/ir/base.cc +++ b/src/relay/ir/base.cc @@ -52,9 +52,7 @@ SourceName SourceName::Get(const std::string& name) { } TVM_REGISTER_API("relay._make.SourceName") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = SourceName::Get(args[0]); - }); +.set_body_typed(SourceName::Get); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const SourceNameNode* node, tvm::IRPrinter* p) { @@ -78,9 +76,7 @@ Span SpanNode::make(SourceName source, int lineno, int col_offset) { TVM_REGISTER_NODE_TYPE(SpanNode); TVM_REGISTER_API("relay._make.Span") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = SpanNode::make(args[0], args[1], args[2]); - }); +.set_body_typed(SpanNode::make); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const SpanNode* node, tvm::IRPrinter* p) { @@ -91,11 +87,9 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_REGISTER_NODE_TYPE(IdNode); TVM_REGISTER_API("relay._base.set_span") -.set_body([](TVMArgs args, TVMRetValue* ret) { - NodeRef node_ref = args[0]; +.set_body_typed([](NodeRef node_ref, Span sp) { auto rn = node_ref.as_derived(); CHECK(rn); - Span sp = args[1]; rn->span = sp; }); diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 3108bc2501fe..63d41c405e33 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -39,9 +39,7 @@ Constant ConstantNode::make(runtime::NDArray data) { TVM_REGISTER_NODE_TYPE(ConstantNode); TVM_REGISTER_API("relay._make.Constant") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = ConstantNode::make(args[0]); - }); +.set_body_typed(ConstantNode::make); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const ConstantNode* node, tvm::IRPrinter* p) { @@ -73,9 +71,7 @@ Tuple TupleNode::make(tvm::Array fields) { TVM_REGISTER_NODE_TYPE(TupleNode); TVM_REGISTER_API("relay._make.Tuple") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = TupleNode::make(args[0]); - }); +.set_body_typed(TupleNode::make); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const TupleNode* node, tvm::IRPrinter* p) { @@ -99,9 +95,7 @@ Var VarNode::make(std::string name_hint, Type type_annotation) { TVM_REGISTER_NODE_TYPE(VarNode); TVM_REGISTER_API("relay._make.Var") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = VarNode::make(args[0].operator std::string(), args[1]); - }); +.set_body_typed(static_cast(VarNode::make)); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const VarNode* node, tvm::IRPrinter* p) { @@ -122,9 +116,7 @@ GlobalVar GlobalVarNode::make(std::string name_hint) { TVM_REGISTER_NODE_TYPE(GlobalVarNode); TVM_REGISTER_API("relay._make.GlobalVar") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = GlobalVarNode::make(args[0]); - }); +.set_body_typed(GlobalVarNode::make); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const GlobalVarNode* node, tvm::IRPrinter* p) { @@ -201,9 +193,7 @@ Function FunctionSetAttr(const Function& func, const std::string& key, const Nod TVM_REGISTER_NODE_TYPE(FunctionNode); TVM_REGISTER_API("relay._make.Function") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = FunctionNode::make(args[0], args[1], args[2], args[3], args[4]); -}); +.set_body_typed(FunctionNode::make); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const FunctionNode* node, @@ -226,9 +216,7 @@ Call CallNode::make(Expr op, Array args, Attrs attrs, TVM_REGISTER_NODE_TYPE(CallNode); TVM_REGISTER_API("relay._make.Call") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = CallNode::make(args[0], args[1], args[2], args[3]); -}); +.set_body_typed(CallNode::make); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const CallNode* node, tvm::IRPrinter* p) { @@ -247,9 +235,7 @@ Let LetNode::make(Var var, Expr value, Expr body) { TVM_REGISTER_NODE_TYPE(LetNode); TVM_REGISTER_API("relay._make.Let") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = LetNode::make(args[0], args[1], args[2]); - }); +.set_body_typed(LetNode::make); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const LetNode* node, tvm::IRPrinter* p) { @@ -267,9 +253,8 @@ If IfNode::make(Expr cond, Expr true_branch, Expr false_branch) { TVM_REGISTER_NODE_TYPE(IfNode); -TVM_REGISTER_API("relay._make.If").set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = IfNode::make(args[0], args[1], args[2]); -}); +TVM_REGISTER_API("relay._make.If") +.set_body_typed(IfNode::make); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const IfNode* node, tvm::IRPrinter* p) { @@ -286,9 +271,8 @@ TupleGetItem TupleGetItemNode::make(Expr tuple, int index) { TVM_REGISTER_NODE_TYPE(TupleGetItemNode); -TVM_REGISTER_API("relay._make.TupleGetItem").set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = TupleGetItemNode::make(args[0], args[1]); -}); +TVM_REGISTER_API("relay._make.TupleGetItem") +.set_body_typed(TupleGetItemNode::make); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const TupleGetItemNode* node, tvm::IRPrinter* p) { @@ -301,9 +285,8 @@ RefCreate RefCreateNode::make(Expr value) { return RefCreate(n); } -TVM_REGISTER_API("relay._make.RefCreate").set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = RefCreateNode::make(args[0]); -}); +TVM_REGISTER_API("relay._make.RefCreate") +.set_body_typed(RefCreateNode::make); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const RefCreateNode* node, tvm::IRPrinter* p) { @@ -317,9 +300,7 @@ RefRead RefReadNode::make(Expr ref) { } TVM_REGISTER_API("relay._make.RefRead") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = RefReadNode::make(args[0]); -}); +.set_body_typed(RefReadNode::make); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const RefReadNode* node, tvm::IRPrinter* p) { @@ -334,9 +315,7 @@ RefWrite RefWriteNode::make(Expr ref, Expr value) { } TVM_REGISTER_API("relay._make.RefWrite") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = RefWriteNode::make(args[0], args[1]); -}); +.set_body_typed(RefWriteNode::make); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const RefWriteNode* node, tvm::IRPrinter* p) { @@ -344,9 +323,8 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) }); TVM_REGISTER_API("relay._expr.TempExprRealize") -.set_body([](TVMArgs args, TVMRetValue* ret) { - TempExpr temp = args[0]; - *ret = temp->Realize(); +.set_body_typed([](TempExpr temp) { + return temp->Realize(); }); } // namespace relay diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index d0cd30adda29..7a6250cd6580 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -346,9 +346,8 @@ void PostOrderVisit(const Expr& e, std::function fvisit) { } TVM_REGISTER_API("relay._ir_pass.post_order_visit") -.set_body([](TVMArgs args, TVMRetValue *ret) { - PackedFunc f = args[1]; - PostOrderVisit(args[0], [f](const Expr& n) { +.set_body_typed([](Expr expr, PackedFunc f) { + PostOrderVisit(expr, [f](const Expr& n) { f(n); }); }); diff --git a/src/relay/ir/hash.cc b/src/relay/ir/hash.cc index cb2be8b2c184..89ad6083fb8e 100644 --- a/src/relay/ir/hash.cc +++ b/src/relay/ir/hash.cc @@ -410,14 +410,14 @@ size_t StructuralHash::operator()(const Expr& expr) const { } TVM_REGISTER_API("relay._ir_pass._expr_hash") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = static_cast(RelayHashHandler().Hash(args[0])); - }); +.set_body_typed([](NodeRef ref) { + return static_cast(RelayHashHandler().Hash(ref)); +}); TVM_REGISTER_API("relay._ir_pass._type_hash") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = static_cast(RelayHashHandler().TypeHash(args[0])); - }); +.set_body_typed([](Type type) { + return static_cast(RelayHashHandler().TypeHash(type)); +}); } // namespace relay } // namespace tvm diff --git a/src/relay/ir/module.cc b/src/relay/ir/module.cc index 38c9756841fc..eabea2ecfeb0 100644 --- a/src/relay/ir/module.cc +++ b/src/relay/ir/module.cc @@ -181,66 +181,43 @@ Module ModuleNode::FromExpr( TVM_REGISTER_NODE_TYPE(ModuleNode); TVM_REGISTER_API("relay._make.Module") -.set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = ModuleNode::make(args[0], args[1]); - }); +.set_body_typed(ModuleNode::make); TVM_REGISTER_API("relay._make.Module_Add") -.set_body([](TVMArgs args, TVMRetValue *ret) { - Module mod = args[0]; - mod->Add(args[1], args[2], args[3]); - }); +.set_body_method(&ModuleNode::Add); TVM_REGISTER_API("relay._module.Module_AddDef") -.set_body([](TVMArgs args, TVMRetValue *ret) { - Module mod = args[0]; - mod->AddDef(args[1], args[2]); - }); +.set_body_method(&ModuleNode::AddDef); TVM_REGISTER_API("relay._module.Module_GetGlobalVar") -.set_body([](TVMArgs args, TVMRetValue *ret) { - Module mod = args[0]; - *ret = mod->GetGlobalVar(args[1]); - }); +.set_body_method(&ModuleNode::GetGlobalVar); TVM_REGISTER_API("relay._module.Module_GetGlobalTypeVar") -.set_body([](TVMArgs args, TVMRetValue *ret) { - Module mod = args[0]; - *ret = mod->GetGlobalTypeVar(args[1]); - }); +.set_body_method(&ModuleNode::GetGlobalTypeVar); TVM_REGISTER_API("relay._module.Module_Lookup") -.set_body([](TVMArgs args, TVMRetValue *ret) { - Module mod = args[0]; - GlobalVar var = args[1]; - *ret = mod->Lookup(var); +.set_body_typed([](Module mod, GlobalVar var) { + return mod->Lookup(var); }); TVM_REGISTER_API("relay._module.Module_Lookup_str") -.set_body([](TVMArgs args, TVMRetValue *ret) { - Module mod = args[0]; - std::string var_name = args[1]; - *ret = mod->Lookup(var_name); +.set_body_typed([](Module mod, std::string var) { + return mod->Lookup(var); }); TVM_REGISTER_API("relay._module.Module_LookupDef") -.set_body([](TVMArgs args, TVMRetValue *ret) { - Module mod = args[0]; - GlobalTypeVar var = args[1]; - *ret = mod->LookupDef(var); +.set_body_typed([](Module mod, GlobalTypeVar var) { + return mod->LookupDef(var); }); TVM_REGISTER_API("relay._module.Module_LookupDef_str") -.set_body([](TVMArgs args, TVMRetValue *ret) { - Module mod = args[0]; - std::string var_name = args[1]; - *ret = mod->LookupDef(var_name); +.set_body_typed([](Module mod, std::string var) { + return mod->LookupDef(var); }); TVM_REGISTER_API("relay._module.Module_Update") -.set_body([](TVMArgs args, TVMRetValue *ret) { - Module mod = args[0]; - mod->Update(args[1]); +.set_body_typed([](Module mod, Module from) { + mod->Update(from); }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) diff --git a/src/relay/ir/type.cc b/src/relay/ir/type.cc index fb0d919b46c3..8f0bdcba2b1b 100644 --- a/src/relay/ir/type.cc +++ b/src/relay/ir/type.cc @@ -56,10 +56,7 @@ IndexExpr TensorTypeNode::Size() const { TVM_REGISTER_NODE_TYPE(TensorTypeNode); TVM_REGISTER_API("relay._make.TensorType") -.set_body([](TVMArgs args, TVMRetValue* ret) { - Array shape = args[0]; - *ret = TensorTypeNode::make(shape, args[1]); -}); +.set_body_typed(TensorTypeNode::make); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const TensorTypeNode* node, @@ -77,10 +74,8 @@ TypeVar TypeVarNode::make(std::string name, Kind kind) { TVM_REGISTER_NODE_TYPE(TypeVarNode); TVM_REGISTER_API("relay._make.TypeVar") -.set_body([](TVMArgs args, TVMRetValue* ret) { - int kind = args[1]; - *ret = - TypeVarNode::make(args[0], static_cast(kind)); +.set_body_typed([](std::string name, int kind) { + return TypeVarNode::make(name, static_cast(kind)); }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) @@ -100,10 +95,9 @@ GlobalTypeVar GlobalTypeVarNode::make(std::string name, Kind kind) { TVM_REGISTER_NODE_TYPE(GlobalTypeVarNode); TVM_REGISTER_API("relay._make.GlobalTypeVar") -.set_body([](TVMArgs args, TVMRetValue* ret) { - int kind = args[1]; - *ret = GlobalTypeVarNode::make(args[0], static_cast(kind)); -}); +.set_body_typed([](std::string name, int kind) { + return GlobalTypeVarNode::make(name, static_cast(kind)); + }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const GlobalTypeVarNode *node, @@ -122,9 +116,7 @@ TypeCall TypeCallNode::make(Type func, tvm::Array args) { TVM_REGISTER_NODE_TYPE(TypeCallNode); TVM_REGISTER_API("relay._make.TypeCall") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = TypeCallNode::make(args[0], args[1]); -}); +.set_body_typed(TypeCallNode::make); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const TypeCallNode* node, @@ -142,9 +134,8 @@ IncompleteType IncompleteTypeNode::make(Kind kind) { TVM_REGISTER_NODE_TYPE(IncompleteTypeNode); TVM_REGISTER_API("relay._make.IncompleteType") -.set_body([](TVMArgs args, TVMRetValue* ret) { - int kind = args[0]; - *ret = IncompleteTypeNode::make(static_cast(kind)); +.set_body_typed([](int kind) { + return IncompleteTypeNode::make(static_cast(kind)); }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) @@ -169,9 +160,7 @@ FuncType FuncTypeNode::make(tvm::Array arg_types, TVM_REGISTER_NODE_TYPE(FuncTypeNode); TVM_REGISTER_API("relay._make.FuncType") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = FuncTypeNode::make(args[0], args[1], args[2], args[3]); -}); +.set_body_typed(FuncTypeNode::make); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const FuncTypeNode* node, @@ -196,9 +185,7 @@ TypeRelation TypeRelationNode::make(TypeRelationFn func, TVM_REGISTER_NODE_TYPE(TypeRelationNode); TVM_REGISTER_API("relay._make.TypeRelation") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = TypeRelationNode::make(args[0], args[1], args[2], args[3]); -}); +.set_body_typed(TypeRelationNode::make); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const TypeRelationNode* node, tvm::IRPrinter* p) { @@ -216,9 +203,7 @@ TupleType TupleTypeNode::make(Array fields) { TVM_REGISTER_NODE_TYPE(TupleTypeNode); TVM_REGISTER_API("relay._make.TupleType") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = TupleTypeNode::make(args[0]); -}); +.set_body_typed(TupleTypeNode::make); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const TupleTypeNode* node, @@ -233,9 +218,7 @@ RefType RefTypeNode::make(Type value) { } TVM_REGISTER_API("relay._make.RefType") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = RefTypeNode::make(args[0]); -}); +.set_body_typed(RefTypeNode::make); TVM_REGISTER_NODE_TYPE(RefTypeNode); diff --git a/src/relay/op/debug.cc b/src/relay/op/debug.cc index 3aea0c03f798..37fb090aa231 100644 --- a/src/relay/op/debug.cc +++ b/src/relay/op/debug.cc @@ -64,9 +64,7 @@ Expr MakeDebug(Expr expr, std::string name) { } TVM_REGISTER_API("relay.op._make.debug") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeDebug, args, rv); - }); +.set_body_typed(MakeDebug); } // namespace relay } // namespace tvm diff --git a/src/relay/op/image/resize.cc b/src/relay/op/image/resize.cc index 7ca762e7394a..ffa489edff76 100644 --- a/src/relay/op/image/resize.cc +++ b/src/relay/op/image/resize.cc @@ -105,9 +105,7 @@ Expr MakeResize(Expr data, TVM_REGISTER_API("relay.op.image._make.resize") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeResize, args, rv); - }); +.set_body_typed(MakeResize); RELAY_REGISTER_OP("image.resize") diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc index f2c0a27600d9..97cba7964000 100644 --- a/src/relay/op/nn/convolution.cc +++ b/src/relay/op/nn/convolution.cc @@ -170,9 +170,7 @@ Expr MakeConv2D(Expr data, TVM_REGISTER_API("relay.op.nn._make.conv2d") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeConv2D, args, rv); - }); +.set_body_typed(MakeConv2D); RELAY_REGISTER_OP("nn.conv2d") @@ -324,9 +322,7 @@ Expr MakeConv2DTranspose(Expr data, TVM_REGISTER_API("relay.op.nn._make.conv2d_transpose") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeConv2DTranspose, args, rv); - }); +.set_body_typed(MakeConv2DTranspose); RELAY_REGISTER_OP("nn.conv2d_transpose") .describe(R"code(Transposed 2D convolution layer (sometimes called Deconvolution). @@ -465,9 +461,7 @@ Expr MakeConv2DWinograd(Expr data, TVM_REGISTER_API("relay.op.nn._make.contrib_conv2d_winograd_without_weight_transform") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeConv2DWinograd, args, rv); - }); +.set_body_typed(MakeConv2DWinograd); RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_without_weight_transform") @@ -530,9 +524,7 @@ Expr MakeConv2DWinogradWeightTransform(Expr weight, TVM_REGISTER_API("relay.op.nn._make.contrib_conv2d_winograd_weight_transform") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeConv2DWinogradWeightTransform, args, rv); -}); +.set_body_typed(MakeConv2DWinogradWeightTransform); RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_weight_transform") @@ -580,9 +572,7 @@ Expr MakeConv2DWinogradNNPACK(Expr data, } TVM_REGISTER_API("relay.op.nn._make.contrib_conv2d_winograd_nnpack_without_weight_transform") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeConv2DWinogradNNPACK, args, rv); -}); +.set_body_typed(MakeConv2DWinogradNNPACK); RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_nnpack_without_weight_transform") .describe(R"code(Compute conv2d with winograd nnpack. Only supports NCHW layout. @@ -649,9 +639,7 @@ Expr MakeConv2DWinogradNNPACKWeightTransform(Expr weight, } TVM_REGISTER_API("relay.op.nn._make.contrib_conv2d_winograd_nnpack_weight_transform") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeConv2DWinogradNNPACKWeightTransform, args, rv); -}); +.set_body_typed(MakeConv2DWinogradNNPACKWeightTransform); RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_nnpack_weight_transform") .describe(R"code(Weight transformation of winograd fast convolution algorithm with NNPACK. @@ -698,9 +686,7 @@ Expr MakeConv2DNCHWc(Expr data, } TVM_REGISTER_API("relay.op.nn._make.contrib_conv2d_NCHWc") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeConv2DNCHWc, args, rv); - }); +.set_body_typed(MakeConv2DNCHWc); RELAY_REGISTER_OP("nn.contrib_conv2d_NCHWc") @@ -750,9 +736,7 @@ Expr MakeDepthwiseConv2DNCHWc(Expr data, } TVM_REGISTER_API("relay.op.nn._make.contrib_depthwise_conv2d_NCHWc") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeDepthwiseConv2DNCHWc, args, rv); - }); +.set_body_typed(MakeDepthwiseConv2DNCHWc); RELAY_REGISTER_OP("nn.contrib_depthwise_conv2d_NCHWc") @@ -910,9 +894,7 @@ Expr MakeDeformableConv2D(Expr data, } TVM_REGISTER_API("relay.op.nn._make.deformable_conv2d") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeDeformableConv2D, args, rv); - }); +.set_body_typed(MakeDeformableConv2D); } // namespace relay diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index d24431347f80..2356634c4ed0 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -78,9 +78,7 @@ Expr MakeBiasAdd(Expr data, TVM_REGISTER_API("relay.op.nn._make.bias_add") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeBiasAdd, args, rv); - }); +.set_body_typed(MakeBiasAdd); RELAY_REGISTER_OP("nn.bias_add") @@ -145,9 +143,7 @@ Expr MakeDense(Expr data, TVM_REGISTER_API("relay.op.nn._make.dense") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeDense, args, rv); - }); +.set_body_typed(MakeDense); RELAY_REGISTER_OP("nn.dense") @@ -179,9 +175,7 @@ Expr MakeLeakyRelu(Expr data, TVM_REGISTER_API("relay.op.nn._make.leaky_relu") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeLeakyRelu, args, rv); - }); +.set_body_typed(MakeLeakyRelu); RELAY_REGISTER_OP("nn.leaky_relu") @@ -244,9 +238,7 @@ Expr MakePRelu(Expr data, TVM_REGISTER_API("relay.op.nn._make.prelu") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakePRelu, args, rv); - }); +.set_body_typed(MakePRelu); RELAY_REGISTER_OP("nn.prelu") @@ -276,17 +268,14 @@ where :math:`*` is an channelwise multiplication for each sample in the batch. TVM_REGISTER_NODE_TYPE(SoftmaxAttrs); TVM_REGISTER_API("relay.op.nn._make.softmax") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - auto make_func = [](Expr data, int axis) { - auto attrs = make_node(); - attrs->axis = axis; - static const Op& op = Op::Get("nn.softmax"); - return CallNode::make(op, {data}, Attrs(attrs), {}); - }; - - runtime::detail::unpack_call(make_func, args, rv); +.set_body_typed([](Expr data, int axis) { + auto attrs = make_node(); + attrs->axis = axis; + static const Op& op = Op::Get("nn.softmax"); + return CallNode::make(op, {data}, Attrs(attrs), {}); }); + RELAY_REGISTER_OP("nn.softmax") .describe(R"code(Softmax layer. @@ -314,15 +303,11 @@ RELAY_REGISTER_OP("nn.softmax") // relay.nn.log_softmax TVM_REGISTER_API("relay.op.nn._make.log_softmax") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - auto make_func = [](Expr data, int axis) { - auto attrs = make_node(); - attrs->axis = axis; - static const Op& op = Op::Get("nn.log_softmax"); - return CallNode::make(op, {data}, Attrs(attrs), {}); - }; - - runtime::detail::unpack_call(make_func, args, rv); +.set_body_typed([](Expr data, int axis) { + auto attrs = make_node(); + attrs->axis = axis; + static const Op& op = Op::Get("nn.log_softmax"); + return CallNode::make(op, {data}, Attrs(attrs), {}); }); RELAY_REGISTER_OP("nn.log_softmax") @@ -382,9 +367,7 @@ Expr MakeBatchFlatten(Expr data) { TVM_REGISTER_API("relay.op.nn._make.batch_flatten") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeBatchFlatten, args, rv); - }); +.set_body_typed(MakeBatchFlatten); RELAY_REGISTER_OP("nn.batch_flatten") @@ -424,7 +407,7 @@ Example:: // relu TVM_REGISTER_API("relay.op.nn._make.relu") -.set_body_typed([](Expr data) { +.set_body_typed([](Expr data) { static const Op& op = Op::Get("nn.relu"); return CallNode::make(op, {data}, Attrs(), {}); }); @@ -469,9 +452,7 @@ Expr MakeLRN(Expr data, } TVM_REGISTER_API("relay.op.nn._make.lrn") - .set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeLRN, args, rv); - }); +.set_body_typed(MakeLRN); RELAY_REGISTER_OP("nn.lrn") .describe(R"code(LRN layer. @@ -509,9 +490,7 @@ Expr MakeL2Normalize(Expr data, } TVM_REGISTER_API("relay.op.nn._make.l2_normalize") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeL2Normalize, args, rv); - }); +.set_body_typed(MakeL2Normalize); RELAY_REGISTER_OP("nn.l2_normalize") .describe(R"code(L2 Normalization layer. @@ -556,9 +535,7 @@ Expr MakeDropout(Expr data, double rate) { } TVM_REGISTER_API("relay.op.nn._make.dropout") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeDropout, args, rv); - }); +.set_body_typed(MakeDropout); RELAY_REGISTER_OP("nn.dropout") .describe(R"code(Applies the dropout operation to the input array. @@ -622,9 +599,7 @@ Expr MakeBatchNorm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr movi } TVM_REGISTER_API("relay.op.nn._make.batch_norm") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeBatchNorm, args, rv); - }); +.set_body_typed(MakeBatchNorm); RELAY_REGISTER_OP("nn.batch_norm") .describe(R"code(Batch normalization layer (Ioffe and Szegedy, 2014). @@ -711,9 +686,7 @@ Expr MakeBatchMatmul(Expr x, TVM_REGISTER_API("relay.op.nn._make.batch_matmul") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeBatchMatmul, args, rv); - }); +.set_body_typed(MakeBatchMatmul); RELAY_REGISTER_OP("nn.batch_matmul") diff --git a/src/relay/op/nn/pad.cc b/src/relay/op/nn/pad.cc index c653e3b9f39d..98b9d671bff9 100644 --- a/src/relay/op/nn/pad.cc +++ b/src/relay/op/nn/pad.cc @@ -115,9 +115,7 @@ Expr MakePad(Expr data, Array > pad_width, double pad_value) { } TVM_REGISTER_API("relay.op.nn._make.pad") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakePad, args, rv); - }); +.set_body_typed(MakePad); RELAY_REGISTER_OP("nn.pad") .describe(R"code(Pad for n-D tensor. diff --git a/src/relay/op/nn/pooling.cc b/src/relay/op/nn/pooling.cc index 0717ee5c577f..df238b38c9cd 100644 --- a/src/relay/op/nn/pooling.cc +++ b/src/relay/op/nn/pooling.cc @@ -186,9 +186,7 @@ Array Pool2DCompute(const Attrs& attrs, } TVM_REGISTER_API("relay.op.nn._make.max_pool2d") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeMaxPool2D, args, rv); - }); +.set_body_typed(MakeMaxPool2D); RELAY_REGISTER_OP("nn.max_pool2d") @@ -242,9 +240,7 @@ Expr MakeAvgPool2D(Expr data, TVM_REGISTER_API("relay.op.nn._make.avg_pool2d") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeAvgPool2D, args, rv); - }); +.set_body_typed(MakeAvgPool2D); RELAY_REGISTER_OP("nn.avg_pool2d") @@ -345,9 +341,7 @@ Expr MakeGlobalAvgPool2D(Expr data, TVM_REGISTER_API("relay.op.nn._make.global_avg_pool2d") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeGlobalAvgPool2D, args, rv); - }); +.set_body_typed(MakeGlobalAvgPool2D); // GlobalAvgPool RELAY_REGISTER_OP("nn.global_avg_pool2d") @@ -378,9 +372,7 @@ Expr MakeGlobalMaxPool2D(Expr data, } TVM_REGISTER_API("relay.op.nn._make.global_max_pool2d") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeGlobalMaxPool2D, args, rv); - }); +.set_body_typed(MakeGlobalMaxPool2D); RELAY_REGISTER_OP("nn.global_max_pool2d") diff --git a/src/relay/op/nn/upsampling.cc b/src/relay/op/nn/upsampling.cc index 98458b9dc258..acefaf3e7e5d 100644 --- a/src/relay/op/nn/upsampling.cc +++ b/src/relay/op/nn/upsampling.cc @@ -110,9 +110,7 @@ Expr MakeUpSampling(Expr data, TVM_REGISTER_API("relay.op.nn._make.upsampling") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeUpSampling, args, rv); - }); +.set_body_typed(MakeUpSampling); RELAY_REGISTER_OP("nn.upsampling") diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc index 7bade46b31d4..b889b6ce51cd 100644 --- a/src/relay/op/tensor/reduce.cc +++ b/src/relay/op/tensor/reduce.cc @@ -265,8 +265,8 @@ bool ReduceRel(const Array& types, #define RELAY_REGISTER_REDUCE_OP(OpName) \ TVM_REGISTER_API("relay.op._make." OpName) \ - .set_body([](const TVMArgs& args, TVMRetValue* rv) { \ - auto make_func = [](Expr data, \ + .set_body_typed, bool, bool)>([]( \ + Expr data, \ Array axis, \ bool keepdims, \ bool exclude) { \ @@ -276,8 +276,6 @@ bool ReduceRel(const Array& types, attrs->exclude = exclude; \ static const Op& op = Op::Get(OpName); \ return CallNode::make(op, {data}, Attrs(attrs), {}); \ - }; \ - runtime::detail::unpack_call(make_func, args, rv); \ }); \ RELAY_REGISTER_OP(OpName) \ .set_num_inputs(1) \ diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index f86156bdbddc..873e75d9660b 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -81,9 +81,7 @@ Expr MakeCast(Expr data, } TVM_REGISTER_API("relay._make.cast") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeCast, args, rv); -}); +.set_body_typed(MakeCast); RELAY_REGISTER_OP("cast") .describe(R"code(Cast the data into a new data type. @@ -161,9 +159,7 @@ Expr MakeExpandDims(Expr data, } TVM_REGISTER_API("relay.op._make.expand_dims") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeExpandDims, args, rv); -}); +.set_body_typed(MakeExpandDims); RELAY_REGISTER_OP("expand_dims") .describe(R"code(Insert `num_newaxis` axises at the position given by `axis` @@ -279,9 +275,7 @@ Expr MakeConcatenate(Expr data, } TVM_REGISTER_API("relay.op._make.concatenate") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeConcatenate, args, rv); -}); +.set_body_typed(MakeConcatenate); RELAY_REGISTER_OP("concatenate") .describe(R"code(Concatenate the input tensors along the given axis. @@ -367,9 +361,7 @@ Expr MakeStack(Expr data, } TVM_REGISTER_API("relay.op._make.stack") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeStack, args, rv); -}); +.set_body_typed(MakeStack); RELAY_REGISTER_OP("stack") .describe(R"code(Stack the input tensors along the given axis. @@ -461,9 +453,7 @@ Expr MakeTranspose(Expr data, } TVM_REGISTER_API("relay.op._make.transpose") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeTranspose, args, rv); -}); +.set_body_typed(MakeTranspose); RELAY_REGISTER_OP("transpose") .describe(R"code(Permutes the dimensions of an array. @@ -598,9 +588,7 @@ Expr MakeReshape(Expr data, } TVM_REGISTER_API("relay.op._make.reshape") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeReshape, args, rv); -}); +.set_body_typed(MakeReshape); RELAY_REGISTER_OP("reshape") .describe(R"code(Reshapes the input array. @@ -698,9 +686,7 @@ Expr MakeReshapeLike(Expr data, TVM_REGISTER_API("relay.op._make.reshape_like") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeReshapeLike, args, rv); -}); +.set_body_typed(MakeReshapeLike); RELAY_REGISTER_OP("reshape_like") @@ -790,9 +776,7 @@ Expr MakeTake(Expr data, } TVM_REGISTER_API("relay.op._make.take") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeTake, args, rv); -}); +.set_body_typed(MakeTake); RELAY_REGISTER_OP("take") .describe(R"code(Take elements from an array along an axis. @@ -873,9 +857,7 @@ Expr MakeFull(Expr fill_value, } TVM_REGISTER_API("relay.op._make.full") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeFull, args, rv); -}); +.set_body_typed(MakeFull); RELAY_REGISTER_OP("full") .describe(R"code(Fill array with scalar value. @@ -910,9 +892,7 @@ Expr MakeZeros(Array shape, } TVM_REGISTER_API("relay.op._make.zeros") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeZeros, args, rv); - }); +.set_body_typed(MakeZeros); RELAY_REGISTER_OP("zeros") .describe(R"code(Fill array with zeros. @@ -933,9 +913,7 @@ Expr MakeOnes(Array shape, } TVM_REGISTER_API("relay.op._make.ones") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeOnes, args, rv); - }); +.set_body_typed(MakeOnes); RELAY_REGISTER_OP("ones") .describe(R"code(Fill array with ones. @@ -982,9 +960,7 @@ Expr MakeFullLike(Expr data, } TVM_REGISTER_API("relay.op._make.full_like") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeFullLike, args, rv); - }); +.set_body_typed(MakeFullLike); RELAY_REGISTER_OP("full_like") .describe(R"code(Return an scalar value array with the same shape @@ -1041,9 +1017,7 @@ Expr MakeArange(tvm::Expr start, } TVM_REGISTER_API("relay.op._make.arange") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeArange, args, rv); -}); +.set_body_typed(MakeArange); RELAY_REGISTER_OP("arange") .describe(R"code(Returns evenly spaced values within a given interval. @@ -1117,9 +1091,7 @@ Expr MakeRepeat(Expr data, } TVM_REGISTER_API("relay.op._make.repeat") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeRepeat, args, rv); -}); +.set_body_typed(MakeRepeat); RELAY_REGISTER_OP("repeat") .describe(R"code(Repeat elements of an array `repeats` times along axis `axis` @@ -1217,9 +1189,7 @@ Expr MakeTile(Expr data, } TVM_REGISTER_API("relay.op._make.tile") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeTile, args, rv); -}); +.set_body_typed(MakeTile); RELAY_REGISTER_OP("tile") .describe(R"code(Repeat the whole array multiple times. @@ -1280,9 +1250,7 @@ Expr MakeReverse(Expr data, } TVM_REGISTER_API("relay.op._make.reverse") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeReverse, args, rv); -}); +.set_body_typed(MakeReverse); RELAY_REGISTER_OP("reverse") .describe(R"code(Reverses the order of elements along given `axis` while preserving array shape. @@ -1345,9 +1313,7 @@ Array WhereCompute(const Attrs& attrs, } TVM_REGISTER_API("relay.op._make.where") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeWhere, args, rv); -}); +.set_body_typed(MakeWhere); RELAY_REGISTER_OP("where") .describe(R"code( @@ -1400,9 +1366,7 @@ Expr MakeSqueeze(Expr data, } TVM_REGISTER_API("relay.op._make.squeeze") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeSqueeze, args, rv); - }); +.set_body_typed(MakeSqueeze); bool SqueezeRel(const Array& types, @@ -1507,9 +1471,7 @@ Array CollapseSumLikeCompute(const Attrs& attrs, } TVM_REGISTER_API("relay.op._make.collapse_sum_like") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeCollapseSumLike, args, rv); - }); +.set_body_typed(MakeCollapseSumLike); RELAY_REGISTER_OP("collapse_sum_like") .describe(R"code(Collapse the first input to match the shape of the second input. @@ -1554,9 +1516,7 @@ Array BroadCastToCompute(const Attrs& attrs, } TVM_REGISTER_API("relay.op._make.broadcast_to") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeBroadCastTo, args, rv); - }); +.set_body_typed(MakeBroadCastTo); RELAY_REGISTER_OP("broadcast_to") .describe(R"code(Broadcast the first input to match the shape argument. @@ -1594,9 +1554,7 @@ Array BroadCastToLikeCompute(const Attrs& attrs, } TVM_REGISTER_API("relay.op._make.broadcast_to_like") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeBroadCastToLike, args, rv); - }); +.set_body_typed(MakeBroadCastToLike); RELAY_REGISTER_OP("broadcast_to_like") .describe(R"code(Broadcast the first input to match the shape of the second input. @@ -1806,9 +1764,7 @@ Array StridedSliceCompute(const Attrs& attrs, TVM_REGISTER_API("relay.op._make.strided_slice") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeStridedSlice, args, rv); - }); +.set_body_typed(MakeStridedSlice); RELAY_REGISTER_OP("strided_slice") @@ -2081,9 +2037,7 @@ Array SliceLikeCompute(const Attrs& attrs, TVM_REGISTER_API("relay.op._make.slice_like") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeSliceLike, args, rv); -}); +.set_body_typed(MakeSliceLike); RELAY_REGISTER_OP("slice_like") @@ -2144,9 +2098,7 @@ Expr MakeLayoutTransform(Expr data, } TVM_REGISTER_API("relay.op._make.layout_transform") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeLayoutTransform, args, rv); -}); +.set_body_typed(MakeLayoutTransform); RELAY_REGISTER_OP("layout_transform") .describe(R"code(Transform the input data layout. @@ -2174,9 +2126,7 @@ Expr MakeReverseReshape(Expr data, } TVM_REGISTER_API("relay.op._make._contrib_reverse_reshape") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeReverseReshape, args, rv); -}); +.set_body_typed(MakeReverseReshape); RELAY_REGISTER_OP("_contrib_reverse_reshape") .describe(R"code(Reshapes the input array where the special values are inferred from @@ -2250,9 +2200,7 @@ Expr MakeGatherND(Expr data, } TVM_REGISTER_API("relay.op._make.gather_nd") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeGatherND, args, rv); -}); +.set_body_typed(MakeGatherND); RELAY_REGISTER_OP("gather_nd") .describe(R"code(Gather elements or slices from data and store to diff --git a/src/relay/op/vision/multibox_op.cc b/src/relay/op/vision/multibox_op.cc index 2c9f76ba2015..56a03ff80bc9 100644 --- a/src/relay/op/vision/multibox_op.cc +++ b/src/relay/op/vision/multibox_op.cc @@ -73,9 +73,7 @@ Expr MakeMultiBoxPrior(Expr data, TVM_REGISTER_API("relay.op.vision._make.multibox_prior") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeMultiBoxPrior, args, rv); -}); +.set_body_typed(MakeMultiBoxPrior); RELAY_REGISTER_OP("vision.multibox_prior") @@ -147,9 +145,7 @@ Expr MakeMultiBoxTransformLoc(Expr cls_prob, } TVM_REGISTER_API("relay.op.vision._make.multibox_transform_loc") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeMultiBoxTransformLoc, args, rv); -}); +.set_body_typed(MakeMultiBoxTransformLoc); RELAY_REGISTER_OP("vision.multibox_transform_loc") .describe(R"doc("Location transformation for multibox detection." diff --git a/src/relay/op/vision/nms.cc b/src/relay/op/vision/nms.cc index 75161bfd1e92..5344bce3d641 100644 --- a/src/relay/op/vision/nms.cc +++ b/src/relay/op/vision/nms.cc @@ -59,9 +59,7 @@ Expr MakeGetValidCounts(Expr data, TVM_REGISTER_API("relay.op.vision._make.get_valid_counts") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeGetValidCounts, args, rv); -}); +.set_body_typed(MakeGetValidCounts); RELAY_REGISTER_OP("vision.get_valid_counts") @@ -125,9 +123,7 @@ Expr MakeNMS(Expr data, TVM_REGISTER_API("relay.op.vision._make.non_max_suppression") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeNMS, args, rv); -}); +.set_body_typed(MakeNMS); RELAY_REGISTER_OP("vision.non_max_suppression") diff --git a/src/relay/op/vision/rcnn_op.cc b/src/relay/op/vision/rcnn_op.cc index 70fe292ed9e5..0522ab845fad 100644 --- a/src/relay/op/vision/rcnn_op.cc +++ b/src/relay/op/vision/rcnn_op.cc @@ -62,9 +62,7 @@ Expr MakeROIAlign(Expr data, Expr rois, Array pooled_size, double spa } TVM_REGISTER_API("relay.op.vision._make.roi_align") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeROIAlign, args, rv); - }); +.set_body_typed(MakeROIAlign); RELAY_REGISTER_OP("vision.roi_align") .describe(R"doc(ROI Align operator. @@ -114,9 +112,7 @@ Expr MakeROIPool(Expr data, Expr rois, Array pooled_size, double spat } TVM_REGISTER_API("relay.op.vision._make.roi_pool") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeROIPool, args, rv); - }); +.set_body_typed(MakeROIPool); RELAY_REGISTER_OP("vision.roi_pool") .describe(R"doc(ROI Pool operator. @@ -182,9 +178,7 @@ Expr MakeProposal(Expr cls_prob, Expr bbox_pred, Expr im_info, Array } TVM_REGISTER_API("relay.op.vision._make.proposal") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeProposal, args, rv); - }); +.set_body_typed(MakeProposal); RELAY_REGISTER_OP("vision.proposal") .describe(R"code(Generate region proposals via RPN. diff --git a/src/relay/op/vision/yolo.cc b/src/relay/op/vision/yolo.cc index 310e30a51890..0a1d9614976e 100644 --- a/src/relay/op/vision/yolo.cc +++ b/src/relay/op/vision/yolo.cc @@ -71,9 +71,7 @@ Expr MakeYoloReorg(Expr data, TVM_REGISTER_API("relay.op.vision._make.yolo_reorg") -.set_body([](const TVMArgs& args, TVMRetValue* rv) { - runtime::detail::unpack_call(MakeYoloReorg, args, rv); -}); +.set_body_typed(MakeYoloReorg); RELAY_REGISTER_OP("vision.yolo_reorg") diff --git a/src/relay/pass/canonicalize_ops.cc b/src/relay/pass/canonicalize_ops.cc index c4350cc0c9db..9a4602750195 100644 --- a/src/relay/pass/canonicalize_ops.cc +++ b/src/relay/pass/canonicalize_ops.cc @@ -61,9 +61,7 @@ Expr CanonicalizeOps(const Expr& e) { } TVM_REGISTER_API("relay._ir_pass.canonicalize_ops") -.set_body([](TVMArgs args, TVMRetValue* ret) { -*ret = CanonicalizeOps(args[0]); -}); +.set_body_typed(CanonicalizeOps); } // namespace relay } // namespace tvm diff --git a/src/relay/pass/combine_parallel_conv2d.cc b/src/relay/pass/combine_parallel_conv2d.cc index cd7a852bcad7..7e76322d5a2a 100644 --- a/src/relay/pass/combine_parallel_conv2d.cc +++ b/src/relay/pass/combine_parallel_conv2d.cc @@ -355,9 +355,7 @@ Expr CombineParallelConv2D(const Expr& expr, uint64_t min_num_branches) { } TVM_REGISTER_API("relay._ir_pass.CombineParallelConv2D") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = CombineParallelConv2D(args[0], args[1]); -}); +.set_body_typed(CombineParallelConv2D); } // namespace relay } // namespace tvm diff --git a/src/relay/pass/dead_code.cc b/src/relay/pass/dead_code.cc index 06cd9091749b..c5c4f333ecfe 100644 --- a/src/relay/pass/dead_code.cc +++ b/src/relay/pass/dead_code.cc @@ -148,9 +148,7 @@ Expr DeadCodeElimination(const Expr& e) { } TVM_REGISTER_API("relay._ir_pass.dead_code_elimination") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = DeadCodeElimination(args[0]); - }); +.set_body_typed(DeadCodeElimination); } // namespace relay } // namespace tvm diff --git a/src/relay/pass/device_annotation.cc b/src/relay/pass/device_annotation.cc index 6f063830cbe9..46f4268cc970 100644 --- a/src/relay/pass/device_annotation.cc +++ b/src/relay/pass/device_annotation.cc @@ -493,19 +493,13 @@ Map CollectDeviceAnnotationOps(const Expr& expr) { } TVM_REGISTER_API("relay._ir_pass.CollectDeviceInfo") -.set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = CollectDeviceInfo(args[0]); -}); +.set_body_typed(CollectDeviceInfo); TVM_REGISTER_API("relay._ir_pass.RewriteDeviceAnnotation") -.set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = RewriteAnnotatedOps(args[0], args[1]); -}); +.set_body_typed(RewriteAnnotatedOps); TVM_REGISTER_API("relay._ir_pass.CollectDeviceAnnotationOps") -.set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = CollectDeviceAnnotationOps(args[0]); -}); +.set_body_typed(CollectDeviceAnnotationOps); } // namespace relay } // namespace tvm diff --git a/src/relay/pass/fold_constant.cc b/src/relay/pass/fold_constant.cc index 9d55a548be10..5bfee6cfe9f6 100644 --- a/src/relay/pass/fold_constant.cc +++ b/src/relay/pass/fold_constant.cc @@ -210,9 +210,7 @@ Expr FoldConstant(const Expr& expr) { } TVM_REGISTER_API("relay._ir_pass.FoldConstant") -.set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = FoldConstant(args[0]); -}); +.set_body_typed(FoldConstant); } // namespace relay } // namespace tvm diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index 4b50c64459a0..6de9c2d65f90 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -912,8 +912,6 @@ Expr FuseOps(const Expr& expr, int fuse_opt_level) { } TVM_REGISTER_API("relay._ir_pass.FuseOps") -.set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = FuseOps(args[0], args[1]); -}); +.set_body_typed(FuseOps); } // namespace relay } // namespace tvm diff --git a/src/relay/pass/gradient.cc b/src/relay/pass/gradient.cc index 8a5d1df53a26..5c5ea01ac2f3 100644 --- a/src/relay/pass/gradient.cc +++ b/src/relay/pass/gradient.cc @@ -247,10 +247,7 @@ Expr FirstOrderGradient(const Expr& re, const Module& mod) { } TVM_REGISTER_API("relay._ir_pass.first_order_gradient") -.set_body([](TVMArgs args, TVMRetValue* ret) { - CHECK_EQ(args.size(), 2); - *ret = FirstOrderGradient(args[0], args[1]); -}); +.set_body_typed(FirstOrderGradient); struct ReverseADType : TypeMutator { Type VisitType_(const TensorTypeNode* ttn) final { @@ -263,7 +260,7 @@ struct ReverseAD : ExprMutator { Var bp; const OpMap rev_map = Op::GetAttr("FPrimalGradient"); - ReverseAD(const Var& bp) : bp(bp) { } + ReverseAD(const Var& bp) : bp(bp) { } /// NOLINT(*) Expr VisitExpr_(const OpNode* op) final { LOG(FATAL) << "op should only be inside call"; @@ -349,10 +346,7 @@ Expr Gradient(const Expr& re, const Module& mod) { } TVM_REGISTER_API("relay._ir_pass.gradient") -.set_body([](TVMArgs args, TVMRetValue* ret) { - CHECK_EQ(args.size(), 2); - *ret = Gradient(args[0], args[1]); -}); +.set_body_typed(Gradient); } // namespace relay } // namespace tvm diff --git a/src/relay/pass/mac_count.cc b/src/relay/pass/mac_count.cc index 702e703cd902..c9ee4eec0337 100644 --- a/src/relay/pass/mac_count.cc +++ b/src/relay/pass/mac_count.cc @@ -147,9 +147,7 @@ int64_t GetTotalMacNumber(const Expr& expr) { } TVM_REGISTER_API("relay._ir_pass.GetTotalMacNumber") -.set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = GetTotalMacNumber(args[0]); -}); +.set_body_typed(GetTotalMacNumber); } // namespace mac_count } // namespace relay diff --git a/src/relay/pass/pass_manager.cc b/src/relay/pass/pass_manager.cc index fad3728d433e..d607247b3bc8 100644 --- a/src/relay/pass/pass_manager.cc +++ b/src/relay/pass/pass_manager.cc @@ -426,12 +426,7 @@ Pass CreateSequentialPass(const tvm::Array& passes, TVM_REGISTER_NODE_TYPE(PassInfoNode); TVM_REGISTER_API("relay._ir_pass.PassInfo") -.set_body([](TVMArgs args, TVMRetValue* ret) { - int opt_level = args[0]; - std::string name = args[1]; - tvm::Array required = args[2]; - *ret = PassInfoNode::make(opt_level, name, required); -}); +.set_body_typed(PassInfoNode::make); TVM_REGISTER_API("relay._ir_pass.Info") .set_body([](TVMArgs args, TVMRetValue* ret) { @@ -456,13 +451,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_REGISTER_NODE_TYPE(ModulePassNode); TVM_REGISTER_API("relay._ir_pass.CreateModulePass") -.set_body([](TVMArgs args, TVMRetValue* ret) { - PackedFunc pass_func = args[0]; - int opt_level = args[1]; - std::string name = args[2]; - tvm::Array required = args[3]; - *ret = CreateModulePass(pass_func, opt_level, name, required); -}); +.set_body_typed(CreateModulePass); TVM_REGISTER_API("relay._ir_pass.RunPass") .set_body([](TVMArgs args, TVMRetValue* ret) { @@ -487,13 +476,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_REGISTER_NODE_TYPE(FunctionPassNode); TVM_REGISTER_API("relay._ir_pass.CreateFunctionPass") -.set_body([](TVMArgs args, TVMRetValue* ret) { - PackedFunc pass_func = args[0]; - int opt_level = args[1]; - std::string name = args[2]; - tvm::Array required = args[3]; - *ret = CreateFunctionPass(pass_func, opt_level, name, required); -}); +.set_body_typed(CreateFunctionPass); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const FunctionPassNode* node, @@ -541,9 +524,7 @@ TVM_REGISTER_API("relay._ir_pass.SetContext") TVM_REGISTER_NODE_TYPE(PassContextNode); TVM_REGISTER_API("relay._ir_pass.PassContext") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = PassContextNode::make(); -}); +.set_body_typed(PassContextNode::make); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const PassContextNode* node, diff --git a/src/relay/pass/quantize.cc b/src/relay/pass/quantize.cc index cb0f9d9c5acb..5fa30535b002 100644 --- a/src/relay/pass/quantize.cc +++ b/src/relay/pass/quantize.cc @@ -571,20 +571,13 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) }); TVM_REGISTER_API("relay._quantize._GetCurrentQConfig") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = QConfig::Current(); - }); +.set_body_typed(QConfig::Current); TVM_REGISTER_API("relay._quantize._EnterQConfigScope") -.set_body([](TVMArgs args, TVMRetValue* ret) { - QConfig target = args[0]; - QConfig::EnterQConfigScope(target); - }); +.set_body_typed(QConfig::EnterQConfigScope); TVM_REGISTER_API("relay._quantize._ExitQConfigScope") -.set_body([](TVMArgs args, TVMRetValue* ret) { - QConfig::ExitQConfigScope(); - }); +.set_body_typed(QConfig::ExitQConfigScope); } // namespace quantize } // namespace relay diff --git a/src/relay/pass/simplify_inference.cc b/src/relay/pass/simplify_inference.cc index 28ebaaa75546..cecebc5c04ed 100644 --- a/src/relay/pass/simplify_inference.cc +++ b/src/relay/pass/simplify_inference.cc @@ -103,9 +103,7 @@ Expr SimplifyInference(const Expr& e) { } TVM_REGISTER_API("relay._ir_pass.simplify_inference") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = SimplifyInference(args[0]); - }); +.set_body_typed(SimplifyInference); } // namespace relay } // namespace tvm diff --git a/src/relay/pass/to_a_normal_form.cc b/src/relay/pass/to_a_normal_form.cc index bac6fd28faf5..5507de471ae5 100644 --- a/src/relay/pass/to_a_normal_form.cc +++ b/src/relay/pass/to_a_normal_form.cc @@ -491,9 +491,7 @@ Expr ToANormalForm(const Expr& e, const Module& m) { } TVM_REGISTER_API("relay._ir_pass.to_a_normal_form") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = ToANormalForm(args[0], args[1]); - }); +.set_body_typed(static_cast(ToANormalForm)); } // namespace relay } // namespace tvm diff --git a/src/relay/pass/to_graph_normal_form.cc b/src/relay/pass/to_graph_normal_form.cc index cc7e1a43068e..490a80f308ce 100644 --- a/src/relay/pass/to_graph_normal_form.cc +++ b/src/relay/pass/to_graph_normal_form.cc @@ -77,9 +77,7 @@ Expr ToGraphNormalForm(const Expr& e) { } TVM_REGISTER_API("relay._ir_pass.to_graph_normal_form") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = ToGraphNormalForm(args[0]); -}); +.set_body_typed(ToGraphNormalForm); } // namespace relay } // namespace tvm diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 5abf0b74ab68..30d4d79f6c86 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -801,8 +801,8 @@ Function InferType(const Function& func, } TVM_REGISTER_API("relay._ir_pass.infer_type") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = InferType(args[0], args[1]); +.set_body_typed([](const Expr& expr, const Module& mod_ref) { + return InferType(expr, mod_ref); }); } // namespace relay } // namespace tvm diff --git a/src/relay/pass/util.cc b/src/relay/pass/util.cc index fa655a785338..8e02cf127bfd 100644 --- a/src/relay/pass/util.cc +++ b/src/relay/pass/util.cc @@ -275,9 +275,7 @@ tvm::Array AllVars(const Expr& expr) { } TVM_REGISTER_API("relay._ir_pass.free_vars") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = FreeVars(args[0]); - }); +.set_body_typed(FreeVars); TVM_REGISTER_API("relay._ir_pass.bound_vars") .set_body([](TVMArgs args, TVMRetValue* ret) { @@ -290,9 +288,7 @@ TVM_REGISTER_API("relay._ir_pass.bound_vars") }); TVM_REGISTER_API("relay._ir_pass.all_vars") - .set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = AllVars(args[0]); - }); +.set_body_typed(AllVars); TVM_REGISTER_API("relay._ir_pass.free_type_vars") .set_body([](TVMArgs args, TVMRetValue* ret) { diff --git a/src/relay/pass/well_formed.cc b/src/relay/pass/well_formed.cc index 86107d66e52f..4eaaa934e78b 100644 --- a/src/relay/pass/well_formed.cc +++ b/src/relay/pass/well_formed.cc @@ -79,10 +79,7 @@ bool WellFormed(const Expr& e) { } TVM_REGISTER_API("relay._ir_pass.well_formed") - .set_body([](TVMArgs args, TVMRetValue *ret) { - Expr e = args[0]; - *ret = WellFormed(e); - }); +.set_body_typed(WellFormed); } // namespace relay } // namespace tvm diff --git a/src/runtime/cuda/cuda_module.cc b/src/runtime/cuda/cuda_module.cc index a46f0ebfdbdc..55d9e648e154 100644 --- a/src/runtime/cuda/cuda_module.cc +++ b/src/runtime/cuda/cuda_module.cc @@ -308,18 +308,12 @@ Module CUDAModuleLoadBinary(void* strm) { } TVM_REGISTER_GLOBAL("module.loadfile_cubin") -.set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = CUDAModuleLoadFile(args[0], args[1]); - }); +.set_body_typed(CUDAModuleLoadFile); TVM_REGISTER_GLOBAL("module.loadfile_ptx") -.set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = CUDAModuleLoadFile(args[0], args[1]); - }); +.set_body_typed(CUDAModuleLoadFile); TVM_REGISTER_GLOBAL("module.loadbinary_cuda") -.set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = CUDAModuleLoadBinary(args[0]); - }); +.set_body_typed(CUDAModuleLoadBinary); } // namespace runtime } // namespace tvm diff --git a/src/runtime/metal/metal_module.mm b/src/runtime/metal/metal_module.mm index e1f0e3fd534b..af809d7619bd 100644 --- a/src/runtime/metal/metal_module.mm +++ b/src/runtime/metal/metal_module.mm @@ -310,13 +310,9 @@ Module MetalModuleLoadBinary(void* strm) { } TVM_REGISTER_GLOBAL("module.loadfile_metal") -.set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = MetalModuleLoadFile(args[0], args[1]); - }); +.set_body_typed(MetalModuleLoadFile); TVM_REGISTER_GLOBAL("module.loadbinary_metal") -.set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = MetalModuleLoadBinary(args[0]); - }); +.set_body_typed(MetalModuleLoadBinary); } // namespace runtime } // namespace tvm diff --git a/src/runtime/opencl/aocl/aocl_module.cc b/src/runtime/opencl/aocl/aocl_module.cc index 38e82edfe296..d9a3aa23c4c5 100644 --- a/src/runtime/opencl/aocl/aocl_module.cc +++ b/src/runtime/opencl/aocl/aocl_module.cc @@ -69,9 +69,7 @@ Module AOCLModuleLoadFile(const std::string& file_name, } TVM_REGISTER_GLOBAL("module.loadfile_aocx") -.set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = AOCLModuleLoadFile(args[0], args[1]); - }); +.set_body_typed(AOCLModuleLoadFile); } // namespace runtime } // namespace tvm diff --git a/src/runtime/opencl/opencl_module.cc b/src/runtime/opencl/opencl_module.cc index 543ffb9825b1..971ae3482014 100644 --- a/src/runtime/opencl/opencl_module.cc +++ b/src/runtime/opencl/opencl_module.cc @@ -281,18 +281,12 @@ Module OpenCLModuleLoadBinary(void* strm) { } TVM_REGISTER_GLOBAL("module.loadfile_cl") -.set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = OpenCLModuleLoadFile(args[0], args[1]); - }); +.set_body_typed(OpenCLModuleLoadFile); TVM_REGISTER_GLOBAL("module.loadfile_clbin") -.set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = OpenCLModuleLoadFile(args[0], args[1]); - }); +.set_body_typed(OpenCLModuleLoadFile); TVM_REGISTER_GLOBAL("module.loadbinary_opencl") -.set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = OpenCLModuleLoadBinary(args[0]); - }); +.set_body_typed(OpenCLModuleLoadBinary); } // namespace runtime } // namespace tvm diff --git a/src/runtime/opencl/sdaccel/sdaccel_module.cc b/src/runtime/opencl/sdaccel/sdaccel_module.cc index 9bfc9d2b2705..900d56433514 100644 --- a/src/runtime/opencl/sdaccel/sdaccel_module.cc +++ b/src/runtime/opencl/sdaccel/sdaccel_module.cc @@ -80,13 +80,9 @@ Module SDAccelModuleLoadBinary(void* strm) { } TVM_REGISTER_GLOBAL("module.loadfile_xclbin") -.set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = SDAccelModuleLoadFile(args[0], args[1]); - }); +.set_body_typed(SDAccelModuleLoadFile); TVM_REGISTER_GLOBAL("module.loadfile_awsxclbin") -.set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = SDAccelModuleLoadFile(args[0], args[1]); - }); +.set_body_typed(SDAccelModuleLoadFile); } // namespace runtime } // namespace tvm diff --git a/src/runtime/rocm/rocm_module.cc b/src/runtime/rocm/rocm_module.cc index b7b93c7c4dfd..6531f97d4b12 100644 --- a/src/runtime/rocm/rocm_module.cc +++ b/src/runtime/rocm/rocm_module.cc @@ -243,14 +243,10 @@ Module ROCMModuleLoadBinary(void* strm) { TVM_REGISTER_GLOBAL("module.loadbinary_hsaco") -.set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = ROCMModuleLoadBinary(args[0]); - }); +.set_body_typed(ROCMModuleLoadBinary); TVM_REGISTER_GLOBAL("module.loadbinary_hip") -.set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = ROCMModuleLoadBinary(args[0]); - }); +.set_body_typed(ROCMModuleLoadBinary); } // namespace runtime } // namespace tvm diff --git a/src/runtime/rpc/rpc_event_impl.cc b/src/runtime/rpc/rpc_event_impl.cc index dfbdb2699d39..7a142f3373db 100644 --- a/src/runtime/rpc/rpc_event_impl.cc +++ b/src/runtime/rpc/rpc_event_impl.cc @@ -64,8 +64,6 @@ PackedFunc CreateEventDrivenServer(PackedFunc fsend, } TVM_REGISTER_GLOBAL("rpc._CreateEventDrivenServer") -.set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = CreateEventDrivenServer(args[0], args[1], args[2]); - }); +.set_body_typed(CreateEventDrivenServer); } // namespace runtime } // namespace tvm diff --git a/src/runtime/rpc/rpc_socket_impl.cc b/src/runtime/rpc/rpc_socket_impl.cc index 33d852f5a575..16528bcc68a1 100644 --- a/src/runtime/rpc/rpc_socket_impl.cc +++ b/src/runtime/rpc/rpc_socket_impl.cc @@ -110,9 +110,7 @@ void RPCServerLoop(int sockfd) { } TVM_REGISTER_GLOBAL("rpc._Connect") -.set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = RPCClientConnect(args[0], args[1], args[2]); - }); +.set_body_typed(RPCClientConnect); TVM_REGISTER_GLOBAL("rpc._ServerLoop") .set_body([](TVMArgs args, TVMRetValue* rv) { diff --git a/src/runtime/stackvm/stackvm_module.cc b/src/runtime/stackvm/stackvm_module.cc index 5e6f96be50df..4e7d42279001 100644 --- a/src/runtime/stackvm/stackvm_module.cc +++ b/src/runtime/stackvm/stackvm_module.cc @@ -142,9 +142,7 @@ Module StackVMModuleCreate(std::unordered_map fmap, } TVM_REGISTER_GLOBAL("module.loadfile_stackvm") -.set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = StackVMModuleNode::LoadFromFile(args[0], args[1]); - }); +.set_body_typed(StackVMModuleNode::LoadFromFile); } // namespace runtime } // namespace tvm diff --git a/src/runtime/vulkan/vulkan_module.cc b/src/runtime/vulkan/vulkan_module.cc index cfa80bef151f..c1db14d35674 100644 --- a/src/runtime/vulkan/vulkan_module.cc +++ b/src/runtime/vulkan/vulkan_module.cc @@ -427,13 +427,9 @@ Module VulkanModuleLoadBinary(void* strm) { } TVM_REGISTER_GLOBAL("module.loadfile_vulkan") -.set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = VulkanModuleLoadFile(args[0], args[1]); - }); +.set_body_typed(VulkanModuleLoadFile); TVM_REGISTER_GLOBAL("module.loadbinary_vulkan") -.set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = VulkanModuleLoadBinary(args[0]); - }); +.set_body_typed(VulkanModuleLoadBinary); } // namespace runtime } // namespace tvm diff --git a/web/web_runtime.cc b/web/web_runtime.cc index 273d43b38f22..12bc53cd3407 100644 --- a/web/web_runtime.cc +++ b/web/web_runtime.cc @@ -60,16 +60,16 @@ struct RPCEnv { }; TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath") -.set_body([](TVMArgs args, TVMRetValue* rv) { +.set_body_typed([](std::string path) { static RPCEnv env; - *rv = env.GetPath(args[0]); + return env.GetPath(path); }); TVM_REGISTER_GLOBAL("tvm.rpc.server.load_module") -.set_body([](TVMArgs args, TVMRetValue *rv) { - std::string file_name = "/rpc/" + args[0].operator std::string(); - *rv = Module::LoadFromFile(file_name, ""); +.set_body_typed([](std::string path) { + std::string file_name = "/rpc/" + path; LOG(INFO) << "Load module from " << file_name << " ..."; + return Module::LoadFromFile(file_name, ""); }); } // namespace contrib } // namespace tvm