diff --git a/CMakeLists.txt b/CMakeLists.txt index 58b75238d62c..017c278fa402 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -134,7 +134,7 @@ file(GLOB_RECURSE COMPILER_SRCS src/tir/*.cc src/driver/*.cc src/printer/*.cc - src/api/*.cc + src/support/*.cc ) file(GLOB CODEGEN_SRCS diff --git a/docs/dev/codebase_walkthrough.rst b/docs/dev/codebase_walkthrough.rst index 0732c26f0c58..8513ce5bd89d 100644 --- a/docs/dev/codebase_walkthrough.rst +++ b/docs/dev/codebase_walkthrough.rst @@ -55,7 +55,7 @@ We use a simple example that uses the low level TVM API directly. The example is B = tvm.placeholder((n,), name='B') C = tvm.compute(A.shape, lambda i: A[i] + B[i], name="C") -Here, types of ``A``, ``B``, ``C`` are ``tvm.tensor.Tensor``, defined in ``python/tvm/tensor.py``. The Python ``Tensor`` is backed by C++ ``Tensor``, implemented in ``include/tvm/tensor.h`` and ``src/lang/tensor.cc``. All Python types in TVM can be thought of as a handle to the underlying C++ type with the same name. If you look at the definition of Python ``Tensor`` type below, you can see it is a subclass of ``Object``. +Here, types of ``A``, ``B``, ``C`` are ``tvm.tensor.Tensor``, defined in ``python/tvm/te/tensor.py``. The Python ``Tensor`` is backed by C++ ``Tensor``, implemented in ``include/tvm/te/tensor.h`` and ``src/te/tensor.cc``. All Python types in TVM can be thought of as a handle to the underlying C++ type with the same name. If you look at the definition of Python ``Tensor`` type below, you can see it is a subclass of ``Object``. :: @@ -68,24 +68,12 @@ Here, types of ``A``, ``B``, ``C`` are ``tvm.tensor.Tensor``, defined in ``pytho The object protocol is the basis of exposing C++ types to frontend languages, including Python. The way TVM implements Python wrapping is not straightforward. It is briefly covered in `this document `_, and details are in ``python/tvm/_ffi/`` if you are interested. -``Tensor`` is created by functions in ``python/tvm/api.py``, which in turn calls into C++ functions exposed in ``src/api/api_lang.cc``. All C++ functions that are callable from Python are exposed in the ``src/api`` subdirectory. For example, the ``tvm.compute()`` function above calls into ``_ComputeOp`` API exposed in ``src/api/api_lang.cc``: - -:: - - TVM_REGISTER_GLOBAL("_ComputeOp") - .set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = ComputeOpNode::make(args[0], - args[1], - args[2], - args[3], - args[4]); - }); - We use the ``TVM_REGISTER_*`` macro to expose C++ functions to frontend languages, in the form of a `PackedFunc `_. A ``PackedFunc`` is another mechanism by which TVM implements interoperability between C++ and Python. In particular, this is what makes calling Python functions from the C++ codebase very easy. +You can also checkout `FFI Navigator `_ which allows you to navigate between python and c++ FFI calls. -A ``Tensor`` object has an ``Operation`` object associated with it, defined in ``python/tvm/tensor.py``, ``include/tvm/operation.h``, and ``src/tvm/op`` subdirectory. A ``Tensor`` is an output of its ``Operation`` object. Each ``Operation`` object has in turn ``input_tensors()`` method, which returns a list of input ``Tensor`` to it. This way we can keep track of dependencies between ``Operation``. +A ``Tensor`` object has an ``Operation`` object associated with it, defined in ``python/tvm/te/tensor.py``, ``include/tvm/te/operation.h``, and ``src/tvm/te/operation`` subdirectory. A ``Tensor`` is an output of its ``Operation`` object. Each ``Operation`` object has in turn ``input_tensors()`` method, which returns a list of input ``Tensor`` to it. This way we can keep track of dependencies between ``Operation``. -We pass the operation corresponding to the output tensor ``C`` to ``tvm.create_schedule()`` function in ``python/tvm/schedule.py``. +We pass the operation corresponding to the output tensor ``C`` to ``tvm.create_schedule()`` function in ``python/tvm/te/schedule.py``. :: @@ -103,7 +91,7 @@ This function is mapped to the C++ function in ``include/tvm/schedule.h``. ``Stage`` corresponds to one ``Operation``. In the vector add example above, there are two placeholder ops and one compute op, so the schedule ``s`` contains three stages. Each ``Stage`` holds information about a loop nest structure, types of each loop (``Parallel``, ``Vectorized``, ``Unrolled``), and where to execute its computation in the loop nest of the next ``Stage``, if any. -``Schedule`` and ``Stage`` are defined in ``tvm/python/schedule.py``, ``include/tvm/schedule.h``, and ``src/schedule/schedule_ops.cc``. +``Schedule`` and ``Stage`` are defined in ``tvm/python/te/schedule.py``, ``include/tvm/te/schedule.h``, and ``src/te/schedule/schedule_ops.cc``. To keep it simple, we call ``tvm.build(...)`` on the default schedule created by ``create_schedule()`` function above. @@ -112,7 +100,7 @@ To keep it simple, we call ``tvm.build(...)`` on the default schedule created by target = "cuda" fadd = tvm.build(s, [A, B, C], target) -``tvm.build()``, defined in ``python/tvm/build_module.py``, takes a schedule, input and output ``Tensor``, and a target, and returns a ``tvm.Module`` object, defined in ``python/tvm/module.py``. A ``Module`` object contains a compiled function which can be invoked with function call syntax. +``tvm.build()``, defined in ``python/tvm/driver/build_module.py``, takes a schedule, input and output ``Tensor``, and a target, and returns a :py:class:`tvm.runtime.Module` object. A :py:class:`tvm.runtime.Module` object contains a compiled function which can be invoked with function call syntax. The process of ``tvm.build()`` can be divided into two steps: @@ -133,14 +121,14 @@ Lowering is done by ``tvm.lower()`` function, defined in ``python/tvm/build_modu stmt = schedule.ScheduleOps(sch, bounds) ... -Bound inference is the process where all loop bounds and sizes of intermediate buffers are inferred. If you target the CUDA backend and you use shared memory, its required minimum size is automatically determined here. Bound inference is implemented in ``src/schedule/bound.cc``, ``src/schedule/graph.cc`` and ``src/schedule/message_passing.cc``. For more information on how bound inference works, see `InferBound Pass`_. +Bound inference is the process where all loop bounds and sizes of intermediate buffers are inferred. If you target the CUDA backend and you use shared memory, its required minimum size is automatically determined here. Bound inference is implemented in ``src/te/schedule/bound.cc``, ``src/te/schedule/graph.cc`` and ``src/te/schedule/message_passing.cc``. For more information on how bound inference works, see `InferBound Pass`_. .. _InferBound Pass: http://docs.tvm.ai/dev/inferbound.html -``stmt``, which is the output of ``ScheduleOps()``, represents an initial loop nest structure. If you have applied ``reorder`` or ``split`` primitives to your schedule, then the initial loop nest already reflects those changes. ``ScheduleOps()`` is defined in ``src/schedule/schedule_ops.cc``. +``stmt``, which is the output of ``ScheduleOps()``, represents an initial loop nest structure. If you have applied ``reorder`` or ``split`` primitives to your schedule, then the initial loop nest already reflects those changes. ``ScheduleOps()`` is defined in ``src/te/schedule/schedule_ops.cc``. -Next, we apply a number of lowering passes to ``stmt``. These passes are implemented in ``src/pass`` subdirectory. For example, if you have applied ``vectorize`` or ``unroll`` primitives to your schedule, they are applied in loop vectorization and unrolling passes below. +Next, we apply a number of lowering passes to ``stmt``. These passes are implemented in ``src/tir/pass`` subdirectory. For example, if you have applied ``vectorize`` or ``unroll`` primitives to your schedule, they are applied in loop vectorization and unrolling passes below. :: @@ -157,7 +145,7 @@ Next, we apply a number of lowering passes to ``stmt``. These passes are impleme After lowering is done, ``build()`` function generates target machine code from the lowered function. This code can contain SSE or AVX instructions if you target x86, or PTX instructions for CUDA target. In addition to target specific machine code, TVM also generates host side code that is responsible for memory management, kernel launch etc. -Code generation is done by ``build_module()`` function, defined in ``python/tvm/codegen.py``. On the C++ side, code generation is implemented in ``src/codegen`` subdirectory. ``build_module()`` Python function will reach ``Build()`` function below in ``src/codegen/codegen.cc``: +Code generation is done by ``build_module()`` function, defined in ``python/tvm/target/codegen.py``. On the C++ side, code generation is implemented in ``src/target/codegen`` subdirectory. ``build_module()`` Python function will reach ``Build()`` function below in ``src/target/codegen/codegen.cc``: :: diff --git a/python/tvm/_api_internal.py b/python/tvm/_api_internal.py deleted file mode 100644 index 571523757cac..000000000000 --- a/python/tvm/_api_internal.py +++ /dev/null @@ -1,25 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Namespace of internal API - -The functions in this namespace are automatically exported from C++ side via PackedFunc -that is registered by "TVM_REGISTER_*" macro. This way makes calling Python functions from C++ -side very easily. - -Each string starts with "_" in the "TVM_REGISTER_*" macro is an internal API. You can find -all the functions in "api_lang.cc", "api_base.cc", "api_arith.cc" and "api_ir.cc" under "src/api". -""" diff --git a/python/tvm/_ffi/registry.py b/python/tvm/_ffi/registry.py index be1578550a3b..e4b8b18b4805 100644 --- a/python/tvm/_ffi/registry.py +++ b/python/tvm/_ffi/registry.py @@ -19,7 +19,6 @@ """FFI registry to register function and objects.""" import sys import ctypes -from .. import _api_internal from .base import _LIB, check_call, py_str, c_str, string_types, _FFI_MODE, _RUNTIME_ONLY @@ -288,17 +287,11 @@ def _init_api_prefix(module_name, prefix): module = sys.modules[module_name] for name in list_global_func_names(): - if prefix == "api": - fname = name - if name.startswith("_"): - target_module = sys.modules["tvm._api_internal"] - else: - target_module = module - else: - if not name.startswith(prefix): - continue - fname = name[len(prefix)+1:] - target_module = module + if not name.startswith(prefix): + continue + + fname = name[len(prefix)+1:] + target_module = module if fname.find(".") != -1: continue diff --git a/src/api/api_arith.cc b/src/api/api_arith.cc deleted file mode 100644 index 3942f6ef0f20..000000000000 --- a/src/api/api_arith.cc +++ /dev/null @@ -1,153 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Implementation of API functions related to arith - * \file api_arith.cc - */ -#include -#include -#include -#include - -#include -#include -#include - -#include - -namespace tvm { -namespace arith { - -TVM_REGISTER_GLOBAL("arith.intset_single_point") -.set_body_typed(IntSet::single_point); - -TVM_REGISTER_GLOBAL("arith.intset_vector") -.set_body_typed(IntSet::vector); - -TVM_REGISTER_GLOBAL("arith.intset_interval") -.set_body_typed(IntSet::interval); - - -TVM_REGISTER_GLOBAL("arith.DetectLinearEquation") -.set_body_typed(DetectLinearEquation); - -TVM_REGISTER_GLOBAL("arith.DetectClipBound") -.set_body_typed(DetectClipBound); - -TVM_REGISTER_GLOBAL("arith.DeduceBound") -.set_body_typed([]( - PrimExpr v, PrimExpr cond, - const Map hint_map, - const Map relax_map -) { - return DeduceBound(v, cond, hint_map, relax_map); -}); - - -TVM_REGISTER_GLOBAL("arith.DomainTouched") -.set_body_typed(DomainTouched); - -TVM_REGISTER_GLOBAL("arith.IntervalSetGetMin") -.set_body_method(&IntSet::min); - -TVM_REGISTER_GLOBAL("arith.IntervalSetGetMax") -.set_body_method(&IntSet::max); - -TVM_REGISTER_GLOBAL("arith.IntSetIsNothing") -.set_body_method(&IntSet::is_nothing); - -TVM_REGISTER_GLOBAL("arith.IntSetIsEverything") -.set_body_method(&IntSet::is_everything); - -ConstIntBound MakeConstIntBound(int64_t min_value, int64_t max_value) { - return ConstIntBound(min_value, max_value); -} - -TVM_REGISTER_GLOBAL("arith.ConstIntBound") -.set_body_typed(MakeConstIntBound); - -ModularSet MakeModularSet(int64_t coeff, int64_t base) { - return ModularSet(coeff, base); -} - -TVM_REGISTER_GLOBAL("arith.ModularSet") -.set_body_typed(MakeModularSet); - -TVM_REGISTER_GLOBAL("arith.CreateAnalyzer") -.set_body([](TVMArgs args, TVMRetValue* ret) { - using runtime::PackedFunc; - using runtime::TypedPackedFunc; - auto self = std::make_shared(); - auto f = [self](std::string name) -> PackedFunc { - if (name == "const_int_bound") { - return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { - *ret = self->const_int_bound(args[0]); - }); - } else if (name == "modular_set") { - return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { - *ret = self->modular_set(args[0]); - }); - } else if (name == "const_int_bound_update") { - return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { - self->const_int_bound.Update(args[0], args[1], args[2]); - }); - } else if (name == "Simplify") { - return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { - *ret = self->Simplify(args[0]); - }); - } else if (name == "rewrite_simplify") { - return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { - *ret = self->rewrite_simplify(args[0]); - }); - } else if (name == "canonical_simplify") { - return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { - *ret = self->canonical_simplify(args[0]); - }); - } else if (name == "int_set") { - return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { - *ret = self->int_set(args[0], args[1]); - }); - } else if (name == "bind") { - return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { - if (args[1].IsObjectRef()) { - self->Bind(args[0], args[1].operator Range()); - } else { - self->Bind(args[0], args[1].operator PrimExpr()); - } - }); - } else if (name == "enter_constraint_context") { - return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { - // can't use make_shared due to noexcept(false) decl in destructor, - // see https://stackoverflow.com/a/43907314 - auto ctx = std::shared_ptr >( - new With(self.get(), args[0])); - auto fexit = [ctx](TVMArgs, TVMRetValue*) mutable { - ctx.reset(); - }; - *ret = PackedFunc(fexit); - }); - } - return PackedFunc(); - }; - *ret = TypedPackedFunc(f); -}); - -} // namespace arith -} // namespace tvm diff --git a/src/api/api_ir.cc b/src/api/api_ir.cc deleted file mode 100644 index 1e71baf305d4..000000000000 --- a/src/api/api_ir.cc +++ /dev/null @@ -1,237 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Implementation of API functions related to IR build - * \file api_ir.cc - */ -#include -#include -#include - -#include - -namespace tvm { -namespace tir { - -TVM_REGISTER_GLOBAL("tir.Var") -.set_body_typed([](std::string s, DataType t) { - return Var(s, t); - }); - -TVM_REGISTER_GLOBAL("tir.SizeVar") -.set_body_typed([](std::string s, DataType t) { - return SizeVar(s, t); - }); - -TVM_REGISTER_GLOBAL("tir.abs") -.set_body_typed(tvm::abs); - -TVM_REGISTER_GLOBAL("tir.isnan") -.set_body_typed(tvm::isnan); - -TVM_REGISTER_GLOBAL("tir.floor") -.set_body_typed(tvm::floor); - -TVM_REGISTER_GLOBAL("tir.ceil") -.set_body_typed(tvm::ceil); - -TVM_REGISTER_GLOBAL("tir.round") -.set_body_typed(tvm::round); - -TVM_REGISTER_GLOBAL("tir.nearbyint") -.set_body_typed(tvm::nearbyint); - -TVM_REGISTER_GLOBAL("tir.trunc") -.set_body_typed(tvm::trunc); - -TVM_REGISTER_GLOBAL("tir._cast") -.set_body_typed(tvm::cast); - -TVM_REGISTER_GLOBAL("ir.range_by_min_extent") -.set_body_typed(Range::make_by_min_extent); - - -TVM_REGISTER_GLOBAL("tir.SeqStmt") -.set_body_typed([](Array seq) { - return SeqStmt(std::move(seq)); -}); - -TVM_REGISTER_GLOBAL("tir.For") -.set_body_typed([]( - Var loop_var, PrimExpr min, PrimExpr extent, - int for_type, int device_api, Stmt body) { - return ForNode::make(loop_var, - min, - extent, - static_cast(for_type), - static_cast(device_api), - body); -}); - -TVM_REGISTER_GLOBAL("tir.Load") -.set_body([](TVMArgs args, TVMRetValue *ret) { - DataType t = args[0]; - if (args.size() == 3) { - *ret = LoadNode::make(t, args[1], args[2], const_true(t.lanes())); - } else { - *ret = LoadNode::make(t, args[1], args[2], args[3]); - } - }); - -TVM_REGISTER_GLOBAL("tir.Store") -.set_body([](TVMArgs args, TVMRetValue *ret) { - PrimExpr value = args[1]; - if (args.size() == 3) { - *ret = StoreNode::make(args[0], value, args[2], const_true(value.dtype().lanes())); - } else { - *ret = StoreNode::make(args[0], value, args[2], args[3]); - } - }); - -TVM_REGISTER_GLOBAL("tir.Realize") -.set_body_typed(RealizeNode::make); - -TVM_REGISTER_GLOBAL("tir.Call") -.set_body_typed([]( - DataType type, std::string name, - Array args, int call_type, - FunctionRef func, int value_index -) { - return CallNode::make(type, - name, - args, - static_cast(call_type), - func, - value_index); -}); - -TVM_REGISTER_GLOBAL("tir.CommReducer") -.set_body_typed(CommReducerNode::make); - -// make from two arguments -#define REGISTER_MAKE(NodeName) \ - TVM_REGISTER_GLOBAL("tir."#NodeName) \ - .set_body_typed(NodeName ## Node::make); \ - - -REGISTER_MAKE(Reduce); -REGISTER_MAKE(AttrStmt); - -REGISTER_MAKE(StringImm); - -REGISTER_MAKE(Add); -REGISTER_MAKE(Sub); -REGISTER_MAKE(Mul); -REGISTER_MAKE(Div); -REGISTER_MAKE(Mod); -REGISTER_MAKE(FloorDiv); -REGISTER_MAKE(FloorMod); -REGISTER_MAKE(Min); -REGISTER_MAKE(Max); -REGISTER_MAKE(EQ); -REGISTER_MAKE(NE); -REGISTER_MAKE(LT); -REGISTER_MAKE(LE); -REGISTER_MAKE(GT); -REGISTER_MAKE(GE); -REGISTER_MAKE(And); -REGISTER_MAKE(Or); - -REGISTER_MAKE(Not); -REGISTER_MAKE(Select); -REGISTER_MAKE(Ramp); -REGISTER_MAKE(Cast); -REGISTER_MAKE(Broadcast); -REGISTER_MAKE(Shuffle); -REGISTER_MAKE(Let); -REGISTER_MAKE(LetStmt); -REGISTER_MAKE(AssertStmt); -REGISTER_MAKE(ProducerConsumer); -REGISTER_MAKE(Provide); -REGISTER_MAKE(Prefetch); -REGISTER_MAKE(Free); -REGISTER_MAKE(IfThenElse); -REGISTER_MAKE(Evaluate); - -// overloaded, needs special handling -// has default args -TVM_REGISTER_GLOBAL("tir.Allocate") - .set_body_typed([]( - Var buffer_var, DataType type, Array extents, PrimExpr condition, Stmt body - ){ - return AllocateNode::make(buffer_var, type, extents, condition, body); - }); - -// operator overloading, smarter than make -#define REGISTER_MAKE_BINARY_OP(Node, Func) \ - TVM_REGISTER_GLOBAL("tir."#Node) \ - .set_body_typed([](PrimExpr a, PrimExpr b) { \ - return (Func(a, b)); \ - }) - -#define REGISTER_MAKE_BIT_OP(Node, Func) \ - TVM_REGISTER_GLOBAL("tir."#Node) \ - .set_body([](TVMArgs args, TVMRetValue *ret) { \ - bool lhs_is_int = args[0].type_code() == kDLInt; \ - bool rhs_is_int = args[1].type_code() == kDLInt; \ - if (lhs_is_int) { \ - *ret = (Func(args[0].operator int(), args[1].operator PrimExpr())); \ - } else if (rhs_is_int) { \ - *ret = (Func(args[0].operator PrimExpr(), args[1].operator int())); \ - } else { \ - *ret = (Func(args[0].operator PrimExpr(), args[1].operator PrimExpr())); \ - } \ - }) - - -REGISTER_MAKE_BINARY_OP(_OpAdd, operator+); -REGISTER_MAKE_BINARY_OP(_OpSub, operator-); -REGISTER_MAKE_BINARY_OP(_OpMul, operator*); -REGISTER_MAKE_BINARY_OP(_OpDiv, div); -REGISTER_MAKE_BINARY_OP(_OpMod, truncmod); -REGISTER_MAKE_BINARY_OP(_OpIndexDiv, indexdiv); -REGISTER_MAKE_BINARY_OP(_OpIndexMod, indexmod); -REGISTER_MAKE_BINARY_OP(_OpFloorDiv, floordiv); -REGISTER_MAKE_BINARY_OP(_OpFloorMod, floormod); -REGISTER_MAKE_BINARY_OP(_OpTruncDiv, truncdiv); -REGISTER_MAKE_BINARY_OP(_OpTruncMod, truncmod); -REGISTER_MAKE_BINARY_OP(_OpPow, pow); -REGISTER_MAKE_BINARY_OP(_OpMin, min); -REGISTER_MAKE_BINARY_OP(_OpMax, max); -REGISTER_MAKE_BINARY_OP(_OpEQ, operator==); -REGISTER_MAKE_BINARY_OP(_OpNE, operator!=); -REGISTER_MAKE_BINARY_OP(_OpLT, operator<); // NOLINT(*) -REGISTER_MAKE_BINARY_OP(_OpLE, operator<=); // NOLINT(*) -REGISTER_MAKE_BINARY_OP(_OpGT, operator>); // NOLINT(*) -REGISTER_MAKE_BINARY_OP(_OpGE, operator>=); -REGISTER_MAKE_BINARY_OP(_OpAnd, operator&&); -REGISTER_MAKE_BINARY_OP(_OpOr, operator||); -REGISTER_MAKE_BIT_OP(bitwise_and, operator&); -REGISTER_MAKE_BIT_OP(bitwise_or, operator|); -REGISTER_MAKE_BIT_OP(bitwise_xor, operator^); -REGISTER_MAKE_BIT_OP(left_shift, operator<<); // NOLINT(*) -REGISTER_MAKE_BIT_OP(right_shift, operator>>); -TVM_REGISTER_GLOBAL("tir._OpIfThenElse") -.set_body_typed([] (PrimExpr cond, PrimExpr true_value, PrimExpr false_value) { - return if_then_else(cond, true_value, false_value); -}); - -} // namespace tir -} // namespace tvm diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc deleted file mode 100644 index 613b82311aed..000000000000 --- a/src/api/api_lang.cc +++ /dev/null @@ -1,223 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Implementation of API functions related to Higher DSL build. - * \file api_lang.cc - */ -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -namespace tvm { - -TVM_REGISTER_GLOBAL("tir.min_value") -.set_body_typed(min_value); - -TVM_REGISTER_GLOBAL("tir.max_value") -.set_body_typed(max_value); - -TVM_REGISTER_GLOBAL("ir.Range") -.set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = Range(args[0], args[1]); - }); - -namespace tir { -TVM_REGISTER_GLOBAL("tir.IterVar") -.set_body_typed([](Range dom, Var var, int iter_type, std::string thread_tag) { - return IterVarNode::make( - dom, var, - static_cast(iter_type), - thread_tag); -}); -} - -namespace te { -TVM_REGISTER_GLOBAL("te.Tensor") -.set_body_typed(TensorNode::make); - -TVM_REGISTER_GLOBAL("te.TensorIntrin") -.set_body_typed(TensorIntrinNode::make); - -TVM_REGISTER_GLOBAL("te.TensorIntrinCall") -.set_body_typed(TensorIntrinCallNode::make); - -TVM_REGISTER_GLOBAL("te.TensorEqual") -.set_body_method(&Tensor::operator==); - -TVM_REGISTER_GLOBAL("te.TensorHash") -.set_body_typed([](Tensor tensor) -> int64_t { - return static_cast(std::hash()(tensor)); - }); - -TVM_REGISTER_GLOBAL("te.Placeholder") -.set_body_typed([](Array shape, DataType dtype, std::string name) { - return placeholder(shape, dtype, name); -}); - -TVM_REGISTER_GLOBAL("te.ComputeOp") -.set_body_typed(ComputeOpNode::make); - -TVM_REGISTER_GLOBAL("te.ScanOp") -.set_body_typed(ScanOpNode::make); - -TVM_REGISTER_GLOBAL("te.TensorComputeOp") -.set_body_typed(TensorComputeOpNode::make); - -TVM_REGISTER_GLOBAL("te.ExternOp") -.set_body_typed(ExternOpNode::make); - -TVM_REGISTER_GLOBAL("te.HybridOp") -.set_body_typed(HybridOpNode::make); - -TVM_REGISTER_GLOBAL("te.OpGetOutput") -.set_body_typed([](Operation op, int64_t output) { - return op.output(static_cast(output)); -}); - -TVM_REGISTER_GLOBAL("te.OpNumOutputs") -.set_body_method(&OperationNode::num_outputs); - -TVM_REGISTER_GLOBAL("te.OpInputTensors") -.set_body_method(&OperationNode::InputTensors); - -TVM_REGISTER_GLOBAL("te.CreateSchedule") -.set_body_typed(create_schedule); - -TVM_REGISTER_GLOBAL("te.StageSetScope") -.set_body_method(&Stage::set_scope); - -TVM_REGISTER_GLOBAL("te.StageBind") -.set_body_method(&Stage::bind); - -TVM_REGISTER_GLOBAL("te.StageSplitByFactor") -.set_body_typed([](Stage stage, IterVar parent, PrimExpr factor) { - IterVar outer, inner; - stage.split(parent, factor, &outer, &inner); - return Array({outer, inner}); -}); - -TVM_REGISTER_GLOBAL("te.StageSplitByNParts") -.set_body_typed([](Stage stage, IterVar parent, PrimExpr nparts) { - IterVar outer, inner; - stage.split_by_nparts(parent, nparts, &outer, &inner); - return Array({outer, inner}); -}); - -TVM_REGISTER_GLOBAL("te.StageFuse") -.set_body_typed([](Stage stage, Array axes) { - IterVar fused; - stage.fuse(axes, &fused); - return fused; - }); - -TVM_REGISTER_GLOBAL("te.StageComputeAt") -.set_body_method(&Stage::compute_at); - -TVM_REGISTER_GLOBAL("te.StageComputeInline") -.set_body_method(&Stage::compute_inline); - -TVM_REGISTER_GLOBAL("te.StageComputeRoot") -.set_body_method(&Stage::compute_root); - -TVM_REGISTER_GLOBAL("te.StageReorder") -.set_body_method(&Stage::reorder); - -TVM_REGISTER_GLOBAL("te.StageTile") -.set_body_typed([]( - Stage stage, - IterVar x_parent, IterVar y_parent, - PrimExpr x_factor, PrimExpr y_factor -) { - IterVar x_outer, y_outer, x_inner, y_inner; - stage.tile(x_parent, y_parent, - x_factor, y_factor, - &x_outer, &y_outer, - &x_inner, &y_inner); - return Array({x_outer, y_outer, x_inner, y_inner}); - }); - -TVM_REGISTER_GLOBAL("te.StageEnvThreads") -.set_body_method(&Stage::env_threads); - -TVM_REGISTER_GLOBAL("te.StageSetStorePredicate") -.set_body_method(&Stage::set_store_predicate); - -TVM_REGISTER_GLOBAL("te.StageUnroll") -.set_body_method(&Stage::unroll); - -TVM_REGISTER_GLOBAL("te.StageVectorize") -.set_body_method(&Stage::vectorize); - -TVM_REGISTER_GLOBAL("te.StageTensorize") -.set_body_method(&Stage::tensorize); - -TVM_REGISTER_GLOBAL("te.StageParallel") -.set_body_method(&Stage::parallel); - -TVM_REGISTER_GLOBAL("te.StagePragma") -.set_body_method(&Stage::pragma); - -TVM_REGISTER_GLOBAL("te.StagePrefetch") -.set_body_method(&Stage::prefetch); - -TVM_REGISTER_GLOBAL("te.StageStorageAlign") -.set_body_method(&Stage::storage_align); - -TVM_REGISTER_GLOBAL("te.StageDoubleBuffer") -.set_body_method(&Stage::double_buffer); - -TVM_REGISTER_GLOBAL("te.StageOpenGL") -.set_body_method(&Stage::opengl); - -TVM_REGISTER_GLOBAL("te.ScheduleNormalize") -.set_body_method(&Schedule::normalize); - -TVM_REGISTER_GLOBAL("te.ScheduleCreateGroup") -.set_body_method(&Schedule::create_group); - -TVM_REGISTER_GLOBAL("te.ScheduleCacheRead") -.set_body_method(&Schedule::cache_read); - -TVM_REGISTER_GLOBAL("te.ScheduleCacheWrite") -.set_body([](TVMArgs args, TVMRetValue* ret) { - if (args[1].IsObjectRef()) { - *ret = args[0].operator Schedule() - .cache_write(args[1].operator Tensor(), args[2]); - } else { - *ret = args[0].operator Schedule() - .cache_write(args[1].operator Array(), args[2]); - } - }); - -TVM_REGISTER_GLOBAL("te.ScheduleRFactor") -.set_body_method(&Schedule::rfactor); -} // namespace te - -TVM_REGISTER_GLOBAL("te.CommReducerCombine") -.set_body_method(&tir::CommReducerNode::operator()); - -} // namespace tvm diff --git a/src/api/api_schedule.cc b/src/api/api_schedule.cc deleted file mode 100644 index a53c6e99a999..000000000000 --- a/src/api/api_schedule.cc +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Implementation of API functions related to schedule pass. - * \file api_schedule.cc - */ -#include -#include -#include -#include -#include - -#include "../te/schedule/graph.h" - -namespace tvm { -namespace te { - -TVM_REGISTER_GLOBAL("schedule.AutoInlineElemWise") -.set_body_typed(AutoInlineElemWise); - - -TVM_REGISTER_GLOBAL("schedule.AutoInlineInjective") -.set_body_typed(AutoInlineInjective); - -TVM_REGISTER_GLOBAL("schedule.ScheduleOps") -.set_body([](TVMArgs args, TVMRetValue* ret) { - if (args.size() == 2) - *ret = ScheduleOps(args[0], args[1], false); - else - *ret = ScheduleOps(args[0], args[1], args[2]); -}); - -#define REGISTER_SCHEDULE_PASS(PassName) \ - TVM_REGISTER_GLOBAL("schedule."#PassName) \ - .set_body_typed(PassName); \ - - -REGISTER_SCHEDULE_PASS(InferBound); -REGISTER_SCHEDULE_PASS(CreateReadGraph); -REGISTER_SCHEDULE_PASS(PostDFSOrder); -REGISTER_SCHEDULE_PASS(CreateAttachPath); -REGISTER_SCHEDULE_PASS(ScanGetBody); -REGISTER_SCHEDULE_PASS(ScanFixPointAnalysis); - -} // namespace te -} // namespace tvm diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index b12e5f51f4fb..9df5aa2d246d 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -20,6 +20,7 @@ /*! * \file tvm/arith/analyzer.cc */ +#include #include #include #include @@ -109,5 +110,64 @@ PrimExpr Analyzer::Simplify(const PrimExpr& expr) { return res; } +TVM_REGISTER_GLOBAL("arith.CreateAnalyzer") +.set_body([](TVMArgs args, TVMRetValue* ret) { + using runtime::PackedFunc; + using runtime::TypedPackedFunc; + auto self = std::make_shared(); + auto f = [self](std::string name) -> PackedFunc { + if (name == "const_int_bound") { + return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { + *ret = self->const_int_bound(args[0]); + }); + } else if (name == "modular_set") { + return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { + *ret = self->modular_set(args[0]); + }); + } else if (name == "const_int_bound_update") { + return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { + self->const_int_bound.Update(args[0], args[1], args[2]); + }); + } else if (name == "Simplify") { + return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { + *ret = self->Simplify(args[0]); + }); + } else if (name == "rewrite_simplify") { + return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { + *ret = self->rewrite_simplify(args[0]); + }); + } else if (name == "canonical_simplify") { + return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { + *ret = self->canonical_simplify(args[0]); + }); + } else if (name == "int_set") { + return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { + *ret = self->int_set(args[0], args[1]); + }); + } else if (name == "bind") { + return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { + if (args[1].IsObjectRef()) { + self->Bind(args[0], args[1].operator Range()); + } else { + self->Bind(args[0], args[1].operator PrimExpr()); + } + }); + } else if (name == "enter_constraint_context") { + return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { + // can't use make_shared due to noexcept(false) decl in destructor, + // see https://stackoverflow.com/a/43907314 + auto ctx = std::shared_ptr >( + new With(self.get(), args[0])); + auto fexit = [ctx](TVMArgs, TVMRetValue*) mutable { + ctx.reset(); + }; + *ret = PackedFunc(fexit); + }); + } + return PackedFunc(); + }; + *ret = TypedPackedFunc(f); +}); + } // namespace arith } // namespace tvm diff --git a/src/arith/bound_deducer.cc b/src/arith/bound_deducer.cc index df8f40230e04..26be5d51115f 100644 --- a/src/arith/bound_deducer.cc +++ b/src/arith/bound_deducer.cc @@ -21,11 +21,11 @@ * \file bound_deducer.cc * \brief Utility to deduce bound of expression */ +#include #include #include #include #include -#include #include #include @@ -362,5 +362,16 @@ IntSet DeduceBound(PrimExpr v, PrimExpr e, return DeduceBound(v, e, hmap, rmap); } + +TVM_REGISTER_GLOBAL("arith.DeduceBound") +.set_body_typed([]( + PrimExpr v, PrimExpr cond, + const Map hint_map, + const Map relax_map +) { + return DeduceBound(v, cond, hint_map, relax_map); +}); + + } // namespace arith } // namespace tvm diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index 7fb90a5e87c1..9ef5723e153e 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -20,6 +20,7 @@ /*! * \file tvm/arith/const_int_bound.cc */ +#include #include #include #include @@ -41,6 +42,13 @@ ConstIntBound::ConstIntBound( data_ = std::move(node); } +ConstIntBound MakeConstIntBound(int64_t min_value, int64_t max_value) { + return ConstIntBound(min_value, max_value); +} + +TVM_REGISTER_GLOBAL("arith.ConstIntBound") +.set_body_typed(MakeConstIntBound); + inline void PrintBoundValue(std::ostream& os, int64_t val) { if (val == ConstIntBound::kPosInf) { os << "pos_inf"; diff --git a/src/arith/detect_linear_equation.cc b/src/arith/detect_linear_equation.cc index 53adf35eb6ee..cc9c745a24b8 100644 --- a/src/arith/detect_linear_equation.cc +++ b/src/arith/detect_linear_equation.cc @@ -21,6 +21,7 @@ * \file detect_linear_equation.cc * \brief Utility to detect patterns in the expression. */ +#include #include #include #include @@ -268,6 +269,12 @@ Array DetectClipBound(const PrimExpr& e, const Array& vars) { return ret; } +TVM_REGISTER_GLOBAL("arith.DetectLinearEquation") +.set_body_typed(DetectLinearEquation); +TVM_REGISTER_GLOBAL("arith.DetectClipBound") +.set_body_typed([](const PrimExpr& e, const Array& vars) { + return DetectClipBound(e, vars); +}); } // namespace arith } // namespace tvm diff --git a/src/arith/domain_touched.cc b/src/arith/domain_touched.cc index aa1ba4eb67be..4eecabdb6d8c 100644 --- a/src/arith/domain_touched.cc +++ b/src/arith/domain_touched.cc @@ -119,5 +119,8 @@ Domain DomainTouched(Stmt stmt, return FuncTouchedDomain(tensor, consider_calls, consider_provides).Find(stmt); } +TVM_REGISTER_GLOBAL("arith.DomainTouched") +.set_body_typed(DomainTouched); + } // namespace arith } // namespace tvm diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index adb38799fdf2..8c5afb1be8b5 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -820,5 +820,28 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << "[" << op->min_value << ", " << op->max_value << ']'; }); + + +TVM_REGISTER_GLOBAL("arith.intset_single_point") +.set_body_typed(IntSet::single_point); + +TVM_REGISTER_GLOBAL("arith.intset_vector") +.set_body_typed(IntSet::vector); + +TVM_REGISTER_GLOBAL("arith.intset_interval") +.set_body_typed(IntSet::interval); + +TVM_REGISTER_GLOBAL("arith.IntervalSetGetMin") +.set_body_method(&IntSet::min); + +TVM_REGISTER_GLOBAL("arith.IntervalSetGetMax") +.set_body_method(&IntSet::max); + +TVM_REGISTER_GLOBAL("arith.IntSetIsNothing") +.set_body_method(&IntSet::is_nothing); + +TVM_REGISTER_GLOBAL("arith.IntSetIsEverything") +.set_body_method(&IntSet::is_everything); + } // namespace arith } // namespace tvm diff --git a/src/arith/modular_set.cc b/src/arith/modular_set.cc index c3031ca0edfc..40cd7f8793ee 100644 --- a/src/arith/modular_set.cc +++ b/src/arith/modular_set.cc @@ -21,6 +21,7 @@ * \file modular_set.cc * \brief Modular set analysis */ +#include #include #include #include @@ -52,6 +53,12 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << op->base << ')'; }); +ModularSet MakeModularSet(int64_t coeff, int64_t base) { + return ModularSet(coeff, base); +} + +TVM_REGISTER_GLOBAL("arith.ModularSet") +.set_body_typed(MakeModularSet); // internal entry for const int bound struct ModularSetAnalyzer::Entry { diff --git a/src/ir/expr.cc b/src/ir/expr.cc index 4feabeb8e505..6244c7645acc 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -134,6 +134,14 @@ Range Range::make_by_min_extent(PrimExpr min, PrimExpr extent) { return Range(make_object(min, extent)); } +TVM_REGISTER_GLOBAL("ir.range_by_min_extent") +.set_body_typed(Range::make_by_min_extent); + +TVM_REGISTER_GLOBAL("ir.Range") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = Range(args[0], args[1]); + }); + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); diff --git a/src/api/api_test.cc b/src/support/ffi_testing.cc similarity index 97% rename from src/api/api_test.cc rename to src/support/ffi_testing.cc index 2a1e60539bdf..9053f6298999 100644 --- a/src/api/api_test.cc +++ b/src/support/ffi_testing.cc @@ -18,13 +18,13 @@ */ /*! - * Code mainly used for test purposes. - * \file api_test.cc + * FFI registration code used for frontend testing purposes. + * \file ffi_testing.cc */ +#include #include #include #include -#include #include namespace tvm { diff --git a/src/te/operation/compute_op.cc b/src/te/operation/compute_op.cc index 1886d976555b..6123c613d0bd 100644 --- a/src/te/operation/compute_op.cc +++ b/src/te/operation/compute_op.cc @@ -21,6 +21,7 @@ * \brief Compute Op. * \file compute_op.cc */ +#include #include #include #include @@ -156,6 +157,10 @@ Operation ComputeOpNode::make(std::string name, return Operation(n); } +TVM_REGISTER_GLOBAL("te.ComputeOp") +.set_body_typed(ComputeOpNode::make); + + // The schedule related logics Array ComputeOpNode::InputTensors() const { Array ret; diff --git a/src/te/operation/extern_op.cc b/src/te/operation/extern_op.cc index c1e55046102b..62c8dfd30d49 100644 --- a/src/te/operation/extern_op.cc +++ b/src/te/operation/extern_op.cc @@ -21,6 +21,7 @@ * \brief External computation rule. * \file extern_op.cc */ +#include #include #include #include @@ -86,6 +87,10 @@ Operation ExternOpNode::make(std::string name, return Operation(n); } +TVM_REGISTER_GLOBAL("te.ExternOp") +.set_body_typed(ExternOpNode::make); + + Array ExternOpNode::InputTensors() const { return inputs; } diff --git a/src/te/operation/hybrid_op.cc b/src/te/operation/hybrid_op.cc index bb883ae47004..70abf34523b9 100644 --- a/src/te/operation/hybrid_op.cc +++ b/src/te/operation/hybrid_op.cc @@ -21,6 +21,7 @@ * \brief Hybrid computation rule. * \file hybrid_op.cc */ +#include #include #include #include @@ -83,6 +84,10 @@ Operation HybridOpNode::make(std::string name, return res; } +TVM_REGISTER_GLOBAL("te.HybridOp") +.set_body_typed(HybridOpNode::make); + + Array HybridOpNode::InputTensors() const { // Because input tensors could be potentially inlined into hybrid scripts, // we need to check if all input tensors are used in the body. diff --git a/src/te/operation/placeholder_op.cc b/src/te/operation/placeholder_op.cc index 866ef949cf49..d48be4c53668 100644 --- a/src/te/operation/placeholder_op.cc +++ b/src/te/operation/placeholder_op.cc @@ -21,6 +21,7 @@ * \brief Placeholder op. * \file placeholder_op.cc */ +#include #include namespace tvm { @@ -67,6 +68,11 @@ Tensor placeholder(Array shape, DataType dtype, std::string name) { return PlaceholderOpNode::make(name, shape, dtype).output(0); } +TVM_REGISTER_GLOBAL("te.Placeholder") +.set_body_typed([](Array shape, DataType dtype, std::string name) { + return placeholder(shape, dtype, name); +}); + Array PlaceholderOpNode::InputTensors() const { return {}; } diff --git a/src/te/operation/scan_op.cc b/src/te/operation/scan_op.cc index cacfd8c4a4f1..956a297f5b3c 100644 --- a/src/te/operation/scan_op.cc +++ b/src/te/operation/scan_op.cc @@ -21,6 +21,7 @@ * \brief Scan Operator. * \file scan_op.cc */ +#include #include #include #include @@ -120,6 +121,10 @@ Operation ScanOpNode::make(std::string name, return Operation(n); } +TVM_REGISTER_GLOBAL("te.ScanOp") +.set_body_typed(ScanOpNode::make); + + Array scan(Array init, Array update, Array state_placeholder, diff --git a/src/te/operation/tensor_compute_op.cc b/src/te/operation/tensor_compute_op.cc index 8ce621ccc55b..4cdc9e1f8d32 100644 --- a/src/te/operation/tensor_compute_op.cc +++ b/src/te/operation/tensor_compute_op.cc @@ -21,6 +21,7 @@ * \brief Tensor Compute Op. * \file tensor_compute_op.cc */ +#include #include #include #include @@ -72,6 +73,10 @@ Operation TensorComputeOpNode::make(std::string name, return Operation(n); } +TVM_REGISTER_GLOBAL("te.TensorComputeOp") +.set_body_typed(TensorComputeOpNode::make); + + Array TensorComputeOpNode::InputTensors() const { return inputs; } diff --git a/src/te/schedule/auto_inline_elem_wise.cc b/src/te/schedule/auto_inline_elem_wise.cc index 3a2226780f20..6d79f4a8d1d6 100644 --- a/src/te/schedule/auto_inline_elem_wise.cc +++ b/src/te/schedule/auto_inline_elem_wise.cc @@ -20,6 +20,7 @@ /*! * \file auto_inline_elem_wise.cc */ +#include #include #include #include @@ -111,5 +112,12 @@ void AutoInlineInjective(Schedule sch) { } } +TVM_REGISTER_GLOBAL("schedule.AutoInlineElemWise") +.set_body_typed(AutoInlineElemWise); + + +TVM_REGISTER_GLOBAL("schedule.AutoInlineInjective") +.set_body_typed(AutoInlineInjective); + } // namespace te } // namespace tvm diff --git a/src/te/schedule/bound.cc b/src/te/schedule/bound.cc index 27896e6738a8..50cbafd2b654 100644 --- a/src/te/schedule/bound.cc +++ b/src/te/schedule/bound.cc @@ -21,6 +21,7 @@ * \file bound.cc * \brief The bound inference logic. */ +#include #include #include #include @@ -259,5 +260,8 @@ Map InferBound(const Schedule& sch) { return Map(ret.begin(), ret.end()); } +TVM_REGISTER_GLOBAL("schedule.InferBound") +.set_body_typed(InferBound); + } // namespace te } // namespace tvm diff --git a/src/te/schedule/graph.cc b/src/te/schedule/graph.cc index eff0a25c569a..9dce36f220ef 100644 --- a/src/te/schedule/graph.cc +++ b/src/te/schedule/graph.cc @@ -21,6 +21,7 @@ * \file graph.cc * \brief Utilities to get information about schedule graph. */ +#include #include #include #include @@ -429,5 +430,24 @@ Map ScanFixPointAnalysis(const Operation& scan_op) { return ret; } + +TVM_REGISTER_GLOBAL("schedule.CreateReadGraph") +.set_body_typed(CreateReadGraph); + +TVM_REGISTER_GLOBAL("schedule.PostDFSOrder") +.set_body_typed([](const Array& roots, + const ReadGraph& g) { + return PostDFSOrder(roots, g); +}); + +TVM_REGISTER_GLOBAL("schedule.CreateAttachPath") +.set_body_typed(CreateAttachPath); + +TVM_REGISTER_GLOBAL("schedule.ScanGetBody") +.set_body_typed(ScanGetBody); + +TVM_REGISTER_GLOBAL("schedule.ScanFixPointAnalysis") +.set_body_typed(ScanFixPointAnalysis); + } // namespace te } // namespace tvm diff --git a/src/te/schedule/schedule_lang.cc b/src/te/schedule/schedule_lang.cc index 1763bd64c15f..d3b448d37790 100644 --- a/src/te/schedule/schedule_lang.cc +++ b/src/te/schedule/schedule_lang.cc @@ -20,6 +20,7 @@ /*! * \file schedule_lang.cc */ +#include #include #include #include @@ -848,5 +849,118 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) auto* op = static_cast(node.get()); p->stream << "schedule(" << op << ")"; }); + + +TVM_REGISTER_GLOBAL("te.CreateSchedule") +.set_body_typed(create_schedule); + +TVM_REGISTER_GLOBAL("te.StageSetScope") +.set_body_method(&Stage::set_scope); + +TVM_REGISTER_GLOBAL("te.StageBind") +.set_body_method(&Stage::bind); + +TVM_REGISTER_GLOBAL("te.StageSplitByFactor") +.set_body_typed([](Stage stage, IterVar parent, PrimExpr factor) { + IterVar outer, inner; + stage.split(parent, factor, &outer, &inner); + return Array({outer, inner}); +}); + +TVM_REGISTER_GLOBAL("te.StageSplitByNParts") +.set_body_typed([](Stage stage, IterVar parent, PrimExpr nparts) { + IterVar outer, inner; + stage.split_by_nparts(parent, nparts, &outer, &inner); + return Array({outer, inner}); +}); + +TVM_REGISTER_GLOBAL("te.StageFuse") +.set_body_typed([](Stage stage, Array axes) { + IterVar fused; + stage.fuse(axes, &fused); + return fused; + }); + +TVM_REGISTER_GLOBAL("te.StageComputeAt") +.set_body_method(&Stage::compute_at); + +TVM_REGISTER_GLOBAL("te.StageComputeInline") +.set_body_method(&Stage::compute_inline); + +TVM_REGISTER_GLOBAL("te.StageComputeRoot") +.set_body_method(&Stage::compute_root); + +TVM_REGISTER_GLOBAL("te.StageReorder") +.set_body_method(&Stage::reorder); + +TVM_REGISTER_GLOBAL("te.StageTile") +.set_body_typed([]( + Stage stage, + IterVar x_parent, IterVar y_parent, + PrimExpr x_factor, PrimExpr y_factor +) { + IterVar x_outer, y_outer, x_inner, y_inner; + stage.tile(x_parent, y_parent, + x_factor, y_factor, + &x_outer, &y_outer, + &x_inner, &y_inner); + return Array({x_outer, y_outer, x_inner, y_inner}); + }); + +TVM_REGISTER_GLOBAL("te.StageEnvThreads") +.set_body_method(&Stage::env_threads); + +TVM_REGISTER_GLOBAL("te.StageSetStorePredicate") +.set_body_method(&Stage::set_store_predicate); + +TVM_REGISTER_GLOBAL("te.StageUnroll") +.set_body_method(&Stage::unroll); + +TVM_REGISTER_GLOBAL("te.StageVectorize") +.set_body_method(&Stage::vectorize); + +TVM_REGISTER_GLOBAL("te.StageTensorize") +.set_body_method(&Stage::tensorize); + +TVM_REGISTER_GLOBAL("te.StageParallel") +.set_body_method(&Stage::parallel); + +TVM_REGISTER_GLOBAL("te.StagePragma") +.set_body_method(&Stage::pragma); + +TVM_REGISTER_GLOBAL("te.StagePrefetch") +.set_body_method(&Stage::prefetch); + +TVM_REGISTER_GLOBAL("te.StageStorageAlign") +.set_body_method(&Stage::storage_align); + +TVM_REGISTER_GLOBAL("te.StageDoubleBuffer") +.set_body_method(&Stage::double_buffer); + +TVM_REGISTER_GLOBAL("te.StageOpenGL") +.set_body_method(&Stage::opengl); + +TVM_REGISTER_GLOBAL("te.ScheduleNormalize") +.set_body_method(&Schedule::normalize); + +TVM_REGISTER_GLOBAL("te.ScheduleCreateGroup") +.set_body_method(&Schedule::create_group); + +TVM_REGISTER_GLOBAL("te.ScheduleCacheRead") +.set_body_method(&Schedule::cache_read); + +TVM_REGISTER_GLOBAL("te.ScheduleCacheWrite") +.set_body([](TVMArgs args, TVMRetValue* ret) { + if (args[1].IsObjectRef()) { + *ret = args[0].operator Schedule() + .cache_write(args[1].operator Tensor(), args[2]); + } else { + *ret = args[0].operator Schedule() + .cache_write(args[1].operator Array(), args[2]); + } + }); + +TVM_REGISTER_GLOBAL("te.ScheduleRFactor") +.set_body_method(&Schedule::rfactor); } // namespace te } // namespace tvm diff --git a/src/te/schedule/schedule_ops.cc b/src/te/schedule/schedule_ops.cc index 0930f26372c4..a110bc458fe9 100644 --- a/src/te/schedule/schedule_ops.cc +++ b/src/te/schedule/schedule_ops.cc @@ -20,6 +20,7 @@ /*! * \file schedule_ops.cc */ +#include #include #include #include @@ -423,5 +424,13 @@ Stmt ScheduleOps( return post_proc(std::move(body)); } +TVM_REGISTER_GLOBAL("schedule.ScheduleOps") +.set_body([](TVMArgs args, TVMRetValue* ret) { + if (args.size() == 2) + *ret = ScheduleOps(args[0], args[1], false); + else + *ret = ScheduleOps(args[0], args[1], args[2]); +}); + } // namespace te } // namespace tvm diff --git a/src/te/tensor.cc b/src/te/tensor.cc index f200514468cb..cb14f6a35270 100644 --- a/src/te/tensor.cc +++ b/src/te/tensor.cc @@ -20,6 +20,7 @@ /*! * \file tensor.cc */ +#include #include #include #include @@ -147,5 +148,33 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_NODE_TYPE(TensorIntrinCallNode); +TVM_REGISTER_GLOBAL("te.Tensor") +.set_body_typed(TensorNode::make); + +TVM_REGISTER_GLOBAL("te.TensorIntrin") +.set_body_typed(TensorIntrinNode::make); + +TVM_REGISTER_GLOBAL("te.TensorIntrinCall") +.set_body_typed(TensorIntrinCallNode::make); + +TVM_REGISTER_GLOBAL("te.TensorEqual") +.set_body_method(&Tensor::operator==); + +TVM_REGISTER_GLOBAL("te.TensorHash") +.set_body_typed([](Tensor tensor) -> int64_t { + return static_cast(std::hash()(tensor)); + }); + +TVM_REGISTER_GLOBAL("te.OpGetOutput") +.set_body_typed([](Operation op, int64_t output) { + return op.output(static_cast(output)); +}); + +TVM_REGISTER_GLOBAL("te.OpNumOutputs") +.set_body_method(&OperationNode::num_outputs); + +TVM_REGISTER_GLOBAL("te.OpInputTensors") +.set_body_method(&OperationNode::InputTensors); + } // namespace te } // namespace tvm diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index d06c33f79dcc..22844745982f 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -20,6 +20,7 @@ /*! * \file expr.cc */ +#include #include #include #include @@ -45,6 +46,17 @@ SizeVar::SizeVar(std::string name_hint, DataType t) SizeVarNode::SizeVarNode(DataType t, std::string name_hint) : VarNode(t, std::move(name_hint)) {} + +TVM_REGISTER_GLOBAL("tir.Var") +.set_body_typed([](std::string s, DataType t) { + return Var(s, t); + }); + +TVM_REGISTER_GLOBAL("tir.SizeVar") +.set_body_typed([](std::string s, DataType t) { + return SizeVar(s, t); + }); + IterVar IterVarNode::make(Range dom, Var var, IterVarType t, @@ -57,6 +69,14 @@ IterVar IterVarNode::make(Range dom, return IterVar(n); } +TVM_REGISTER_GLOBAL("tir.IterVar") +.set_body_typed([](Range dom, Var var, int iter_type, std::string thread_tag) { + return IterVarNode::make( + dom, var, + static_cast(iter_type), + thread_tag); +}); + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); @@ -83,6 +103,9 @@ PrimExpr StringImmNode::make(std::string value) { return PrimExpr(node); } +TVM_REGISTER_GLOBAL("tir.StringImm") +.set_body_typed(StringImmNode::make); + PrimExpr CastNode::make(DataType t, PrimExpr value) { CHECK(value.defined()); CHECK_EQ(t.lanes(), value.dtype().lanes()); @@ -311,6 +334,13 @@ Array CommReducerNode::operator()(Array a, Array b }); } +TVM_REGISTER_GLOBAL("tir.CommReducer") +.set_body_typed(CommReducerNode::make); + +TVM_REGISTER_GLOBAL("tir.CommReducerCombine") +.set_body_method(&tir::CommReducerNode::operator()); + + PrimExpr ReduceNode::make(CommReducer combiner, Array source, Array axis, PrimExpr condition, int value_index) { for (size_t i = 0; i < axis.size(); ++i) { @@ -334,6 +364,11 @@ PrimExpr ReduceNode::make(CommReducer combiner, Array source, return PrimExpr(n); } + +TVM_REGISTER_GLOBAL("tir.Reduce") +.set_body_typed(ReduceNode::make); + + PrimExpr AnyNode::make() { auto n = make_object(); return PrimExpr(n); @@ -659,5 +694,104 @@ TVM_REGISTER_NODE_TYPE(CommReducerNode); TVM_REGISTER_NODE_TYPE(ReduceNode); TVM_REGISTER_NODE_TYPE(AnyNode); + +TVM_REGISTER_GLOBAL("tir.Add") +.set_body_typed(AddNode::make); + +TVM_REGISTER_GLOBAL("tir.Sub") +.set_body_typed(SubNode::make); + +TVM_REGISTER_GLOBAL("tir.Mul") +.set_body_typed(MulNode::make); + +TVM_REGISTER_GLOBAL("tir.Div") +.set_body_typed(DivNode::make); + +TVM_REGISTER_GLOBAL("tir.Mod") +.set_body_typed(ModNode::make); + +TVM_REGISTER_GLOBAL("tir.FloorDiv") +.set_body_typed(FloorDivNode::make); + +TVM_REGISTER_GLOBAL("tir.FloorMod") +.set_body_typed(FloorModNode::make); + +TVM_REGISTER_GLOBAL("tir.Min") +.set_body_typed(MinNode::make); + +TVM_REGISTER_GLOBAL("tir.Max") +.set_body_typed(MaxNode::make); + +TVM_REGISTER_GLOBAL("tir.EQ") +.set_body_typed(EQNode::make); + +TVM_REGISTER_GLOBAL("tir.NE") +.set_body_typed(NENode::make); + +TVM_REGISTER_GLOBAL("tir.LT") +.set_body_typed(LTNode::make); + +TVM_REGISTER_GLOBAL("tir.LE") +.set_body_typed(LENode::make); + +TVM_REGISTER_GLOBAL("tir.GT") +.set_body_typed(GTNode::make); + +TVM_REGISTER_GLOBAL("tir.GE") +.set_body_typed(GENode::make); + +TVM_REGISTER_GLOBAL("tir.And") +.set_body_typed(AndNode::make); + +TVM_REGISTER_GLOBAL("tir.Or") +.set_body_typed(OrNode::make); + +TVM_REGISTER_GLOBAL("tir.Not") +.set_body_typed(NotNode::make); + +TVM_REGISTER_GLOBAL("tir.Select") +.set_body_typed(SelectNode::make); + +TVM_REGISTER_GLOBAL("tir.Ramp") +.set_body_typed(RampNode::make); + +TVM_REGISTER_GLOBAL("tir.Cast") +.set_body_typed(CastNode::make); + +TVM_REGISTER_GLOBAL("tir.Broadcast") +.set_body_typed(BroadcastNode::make); + +TVM_REGISTER_GLOBAL("tir.Shuffle") +.set_body_typed(ShuffleNode::make); + +TVM_REGISTER_GLOBAL("tir.Let") +.set_body_typed(LetNode::make); + +TVM_REGISTER_GLOBAL("tir.Load") +.set_body([](TVMArgs args, TVMRetValue *ret) { + DataType t = args[0]; + if (args.size() == 3) { + *ret = LoadNode::make(t, args[1], args[2], const_true(t.lanes())); + } else { + *ret = LoadNode::make(t, args[1], args[2], args[3]); + } + }); + + + +TVM_REGISTER_GLOBAL("tir.Call") +.set_body_typed([]( + DataType type, std::string name, + Array args, int call_type, + FunctionRef func, int value_index +) { + return CallNode::make(type, + name, + args, + static_cast(call_type), + func, + value_index); +}); + } // namespace tir } // namespace tvm diff --git a/src/tir/ir/op.cc b/src/tir/ir/op.cc index 58f8b6b76da8..452c3bbc68a2 100644 --- a/src/tir/ir/op.cc +++ b/src/tir/ir/op.cc @@ -662,4 +662,90 @@ TVM_REGISTER_GLOBAL("node.LargeUIntImm") TVM_REGISTER_GLOBAL("node.String") .set_body_typed(tir::StringImmNode::make); +TVM_REGISTER_GLOBAL("tir.min_value") +.set_body_typed(min_value); + +TVM_REGISTER_GLOBAL("tir.max_value") +.set_body_typed(max_value); + +TVM_REGISTER_GLOBAL("tir.abs") +.set_body_typed(tvm::abs); + +TVM_REGISTER_GLOBAL("tir.isnan") +.set_body_typed(tvm::isnan); + +TVM_REGISTER_GLOBAL("tir.floor") +.set_body_typed(tvm::floor); + +TVM_REGISTER_GLOBAL("tir.ceil") +.set_body_typed(tvm::ceil); + +TVM_REGISTER_GLOBAL("tir.round") +.set_body_typed(tvm::round); + +TVM_REGISTER_GLOBAL("tir.nearbyint") +.set_body_typed(tvm::nearbyint); + +TVM_REGISTER_GLOBAL("tir.trunc") +.set_body_typed(tvm::trunc); + +TVM_REGISTER_GLOBAL("tir._cast") +.set_body_typed(tvm::cast); + + + +// operator overloading, smarter than make +#define REGISTER_MAKE_BINARY_OP(Node, Func) \ + TVM_REGISTER_GLOBAL("tir."#Node) \ + .set_body_typed([](PrimExpr a, PrimExpr b) { \ + return (Func(a, b)); \ + }) + +#define REGISTER_MAKE_BIT_OP(Node, Func) \ + TVM_REGISTER_GLOBAL("tir."#Node) \ + .set_body([](TVMArgs args, TVMRetValue *ret) { \ + bool lhs_is_int = args[0].type_code() == kDLInt; \ + bool rhs_is_int = args[1].type_code() == kDLInt; \ + if (lhs_is_int) { \ + *ret = (Func(args[0].operator int(), args[1].operator PrimExpr())); \ + } else if (rhs_is_int) { \ + *ret = (Func(args[0].operator PrimExpr(), args[1].operator int())); \ + } else { \ + *ret = (Func(args[0].operator PrimExpr(), args[1].operator PrimExpr())); \ + } \ + }) + + +REGISTER_MAKE_BINARY_OP(_OpAdd, operator+); +REGISTER_MAKE_BINARY_OP(_OpSub, operator-); +REGISTER_MAKE_BINARY_OP(_OpMul, operator*); +REGISTER_MAKE_BINARY_OP(_OpDiv, div); +REGISTER_MAKE_BINARY_OP(_OpMod, truncmod); +REGISTER_MAKE_BINARY_OP(_OpIndexDiv, indexdiv); +REGISTER_MAKE_BINARY_OP(_OpIndexMod, indexmod); +REGISTER_MAKE_BINARY_OP(_OpFloorDiv, floordiv); +REGISTER_MAKE_BINARY_OP(_OpFloorMod, floormod); +REGISTER_MAKE_BINARY_OP(_OpTruncDiv, truncdiv); +REGISTER_MAKE_BINARY_OP(_OpTruncMod, truncmod); +REGISTER_MAKE_BINARY_OP(_OpPow, pow); +REGISTER_MAKE_BINARY_OP(_OpMin, min); +REGISTER_MAKE_BINARY_OP(_OpMax, max); +REGISTER_MAKE_BINARY_OP(_OpEQ, operator==); +REGISTER_MAKE_BINARY_OP(_OpNE, operator!=); +REGISTER_MAKE_BINARY_OP(_OpLT, operator<); // NOLINT(*) +REGISTER_MAKE_BINARY_OP(_OpLE, operator<=); // NOLINT(*) +REGISTER_MAKE_BINARY_OP(_OpGT, operator>); // NOLINT(*) +REGISTER_MAKE_BINARY_OP(_OpGE, operator>=); +REGISTER_MAKE_BINARY_OP(_OpAnd, operator&&); +REGISTER_MAKE_BINARY_OP(_OpOr, operator||); +REGISTER_MAKE_BIT_OP(bitwise_and, operator&); +REGISTER_MAKE_BIT_OP(bitwise_or, operator|); +REGISTER_MAKE_BIT_OP(bitwise_xor, operator^); +REGISTER_MAKE_BIT_OP(left_shift, operator<<); // NOLINT(*) +REGISTER_MAKE_BIT_OP(right_shift, operator>>); + +TVM_REGISTER_GLOBAL("tir._OpIfThenElse") +.set_body_typed([] (PrimExpr cond, PrimExpr true_value, PrimExpr false_value) { + return if_then_else(cond, true_value, false_value); +}); } // namespace tvm diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 0cd2aba319ee..a8fe9cd2bad3 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -20,7 +20,7 @@ /*! * \file tvm/tir/stmt.cc */ - +#include #include #include #include "../pass/ir_util.h" @@ -40,6 +40,9 @@ Stmt LetStmtNode::make(Var var, PrimExpr value, Stmt body) { return Stmt(node); } +TVM_REGISTER_GLOBAL("tir.LetStmt") +.set_body_typed(LetStmtNode::make); + Stmt AttrStmtNode::make(ObjectRef node, std::string attr_key, PrimExpr value, @@ -52,6 +55,10 @@ Stmt AttrStmtNode::make(ObjectRef node, return Stmt(n); } +TVM_REGISTER_GLOBAL("tir.AttrStmt") +.set_body_typed(AttrStmtNode::make); + + Stmt AssertStmtNode::make(PrimExpr condition, PrimExpr message, Stmt body) { CHECK(condition.defined()); CHECK(message.dtype() == DataType::Int(32) || @@ -66,6 +73,10 @@ Stmt AssertStmtNode::make(PrimExpr condition, PrimExpr message, Stmt body) { return Stmt(node); } +TVM_REGISTER_GLOBAL("tir.AssertStmt") +.set_body_typed(AssertStmtNode::make); + + Stmt ProducerConsumerNode::make(FunctionRef func, bool is_producer, Stmt body) { CHECK(body.defined()); @@ -76,6 +87,10 @@ Stmt ProducerConsumerNode::make(FunctionRef func, bool is_producer, Stmt body) { return Stmt(node); } +TVM_REGISTER_GLOBAL("tir.ProducerConsumer") +.set_body_typed(ProducerConsumerNode::make); + + Stmt ForNode::make(Var loop_var, PrimExpr min, PrimExpr extent, @@ -99,6 +114,19 @@ Stmt ForNode::make(Var loop_var, return Stmt(node); } +TVM_REGISTER_GLOBAL("tir.For") +.set_body_typed([]( + Var loop_var, PrimExpr min, PrimExpr extent, + int for_type, int device_api, Stmt body) { + return ForNode::make(loop_var, + min, + extent, + static_cast(for_type), + static_cast(device_api), + body); +}); + + Stmt StoreNode::make(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr predicate) { CHECK(value.defined()); CHECK(index.defined()); @@ -114,6 +142,18 @@ Stmt StoreNode::make(Var buffer_var, PrimExpr value, PrimExpr index, PrimExpr pr return Stmt(node); } + +TVM_REGISTER_GLOBAL("tir.Store") +.set_body([](TVMArgs args, TVMRetValue *ret) { + PrimExpr value = args[1]; + if (args.size() == 3) { + *ret = StoreNode::make(args[0], value, args[2], const_true(value.dtype().lanes())); + } else { + *ret = StoreNode::make(args[0], value, args[2], args[3]); + } + }); + + Stmt ProvideNode::make(FunctionRef func, int value_index, PrimExpr value, Array args) { CHECK(value_index >=0 && value_index < func->num_outputs()) << "value index output function return value bound"; @@ -131,6 +171,10 @@ Stmt ProvideNode::make(FunctionRef func, int value_index, PrimExpr value, Array< return Stmt(node); } +TVM_REGISTER_GLOBAL("tir.Provide") +.set_body_typed(ProvideNode::make); + + Stmt AllocateNode::make(Var buffer_var, DataType dtype, Array extents, @@ -157,6 +201,15 @@ Stmt AllocateNode::make(Var buffer_var, return Stmt(node); } +// overloaded, needs special handling +// has default args +TVM_REGISTER_GLOBAL("tir.Allocate") +.set_body_typed([]( + Var buffer_var, DataType type, Array extents, PrimExpr condition, Stmt body + ){ + return AllocateNode::make(buffer_var, type, extents, condition, body); +}); + int32_t AllocateNode::constant_allocation_size(const Array& extents) { int64_t result = 1; for (size_t i = 0; i < extents.size(); ++i) { @@ -178,12 +231,16 @@ Stmt FreeNode::make(Var buffer_var) { return Stmt(node); } +TVM_REGISTER_GLOBAL("tir.Free") +.set_body_typed(FreeNode::make); + + Stmt RealizeNode::make(FunctionRef func, - int value_index, - DataType dtype, - Region bounds, - PrimExpr condition, - Stmt body) { + int value_index, + DataType dtype, + Region bounds, + PrimExpr condition, + Stmt body) { for (size_t i = 0; i < bounds.size(); ++i) { CHECK(bounds[i]->min.defined()); CHECK(bounds[i]->extent.defined()); @@ -204,6 +261,11 @@ Stmt RealizeNode::make(FunctionRef func, return Stmt(node); } + +TVM_REGISTER_GLOBAL("tir.Realize") +.set_body_typed(RealizeNode::make); + + Stmt PrefetchNode::make(FunctionRef func, int value_index, DataType dtype, Region bounds) { for (size_t i = 0; i < bounds.size(); ++i) { CHECK(bounds[i]->min.defined()); @@ -220,12 +282,21 @@ Stmt PrefetchNode::make(FunctionRef func, int value_index, DataType dtype, Regio return Stmt(node); } +TVM_REGISTER_GLOBAL("tir.Prefetch") +.set_body_typed(PrefetchNode::make); + + SeqStmt::SeqStmt(Array seq) { auto node = make_object(); node->seq = std::move(seq); data_ = std::move(node); } +TVM_REGISTER_GLOBAL("tir.SeqStmt") +.set_body_typed([](Array seq) { + return SeqStmt(std::move(seq)); +}); + Stmt IfThenElseNode::make(PrimExpr condition, Stmt then_case, Stmt else_case) { CHECK(condition.defined()); CHECK(then_case.defined()); @@ -238,6 +309,10 @@ Stmt IfThenElseNode::make(PrimExpr condition, Stmt then_case, Stmt else_case) { return Stmt(node); } +TVM_REGISTER_GLOBAL("tir.IfThenElse") +.set_body_typed(IfThenElseNode::make); + + Stmt EvaluateNode::make(PrimExpr value) { CHECK(value.defined()); @@ -246,6 +321,9 @@ Stmt EvaluateNode::make(PrimExpr value) { return Stmt(node); } +TVM_REGISTER_GLOBAL("tir.Evaluate") +.set_body_typed(EvaluateNode::make); + // Printers TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) diff --git a/src/api/api_pass.cc b/src/tir/pass/ffi_api.cc similarity index 99% rename from src/api/api_pass.cc rename to src/tir/pass/ffi_api.cc index 75d5439b7f1b..233bfa51d614 100644 --- a/src/api/api_pass.cc +++ b/src/tir/pass/ffi_api.cc @@ -19,7 +19,7 @@ /*! * Exposure of pass functions. - * \file api_pass.cc + * \file ffi_api.cc */ #include #include @@ -136,8 +136,8 @@ TVM_REGISTER_GLOBAL("ir_pass.LowerStorageAccess") // make from two arguments #define REGISTER_PASS(PassName) \ - TVM_REGISTER_GLOBAL("ir_pass."#PassName) \ - .set_body_typed(PassName); \ + TVM_REGISTER_GLOBAL("ir_pass."#PassName) \ + .set_body_typed(PassName); \ REGISTER_PASS(ConvertSSA); diff --git a/tests/python/unittest/test_runtime_error.py b/tests/python/unittest/test_runtime_error.py index d1a2d983ff25..ac019a0aab40 100644 --- a/tests/python/unittest/test_runtime_error.py +++ b/tests/python/unittest/test_runtime_error.py @@ -27,7 +27,7 @@ def test_op_translation(): except tvm.error.OpNotImplemented as e: msg = str(e) assert isinstance(e, NotImplementedError) - assert msg.find("api_test.cc") != -1 + assert msg.find("ffi_testing.cc") != -1 fchk_eq = tvm.testing.test_check_eq_callback( "InternalError: myop") @@ -36,14 +36,14 @@ def test_op_translation(): assert False except tvm.error.InternalError as e: msg = str(e) - assert msg.find("api_test.cc") != -1 + assert msg.find("ffi_testing.cc") != -1 try: tvm.testing.ErrorTest(0, 1) assert False except ValueError as e: msg = str(e) - assert msg.find("api_test.cc") != -1 + assert msg.find("ffi_testing.cc") != -1 def test_deep_callback():