From 92ed53ed922c404feb23f61cfe2af8d2211db443 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Sat, 4 Apr 2020 17:36:49 -0700 Subject: [PATCH] [REFACTOR][TIR] Migrate all low-level passes to the Pass Manager. (#5233) * [REFACTOR][TIR] Migrate all low-level passes to the Pass Manager. This PR migrates the tvm.lower to return IRModule of PrimFuncs instead of the LoweredFuncs. * Remove LoweredFunc. --- apps/lldb/tvm.py | 1 - docs/dev/codebase_walkthrough.rst | 9 - include/tvm/driver/driver_api.h | 25 ++- include/tvm/ir/module.h | 9 + include/tvm/target/codegen.h | 12 -- include/tvm/tir/analysis.h | 15 ++ include/tvm/tir/ir_pass.h | 80 --------- include/tvm/tir/lowered_func.h | 149 ----------------- include/tvm/tir/transform.h | 55 +++++++ python/tvm/driver/build_module.py | 99 +++++------ python/tvm/relay/backend/_backend.py | 16 +- .../relay/backend/graph_runtime_codegen.py | 6 +- python/tvm/runtime/__init__.py | 1 + python/tvm/target/build_config.py | 10 +- python/tvm/testing.py | 39 +++++ python/tvm/tir/__init__.py | 2 +- python/tvm/tir/analysis/analysis.py | 11 ++ python/tvm/tir/function.py | 2 + python/tvm/tir/stmt.py | 8 - python/tvm/tir/transform/transform.py | 30 ++++ src/contrib/hybrid/codegen_hybrid.h | 1 - src/driver/driver_api.cc | 69 +++++--- src/relay/backend/build_module.cc | 15 +- src/relay/backend/compile_engine.h | 4 +- src/relay/backend/graph_runtime_codegen.cc | 29 ++-- src/relay/backend/vm/compiler.cc | 39 +++-- src/relay/backend/vm/compiler.h | 2 +- src/relay/transforms/gradient.cc | 1 - src/target/build_common.h | 17 -- src/target/codegen.cc | 47 ------ src/target/llvm/codegen_llvm.cc | 2 +- src/target/llvm/llvm_module.cc | 34 ++-- src/target/source/codegen_c.h | 1 - src/target/spirv/codegen_spirv.h | 1 - src/target/stackvm/codegen_stackvm.h | 1 - src/tir/{pass => analysis}/verify_memory.cc | 33 +++- src/tir/ir/buffer.cc | 2 +- src/tir/ir/lowered_func.cc | 37 ----- src/tir/pass/ffi_api.cc | 11 -- src/tir/pass/storage_rewrite.cc | 23 --- .../lower_custom_datatypes.cc | 27 ++- .../make_packed_api.cc} | 155 +++++++++++------- .../{pass => transforms}/remap_thread_axis.cc | 36 ++-- src/tir/transforms/split_host_device.cc | 1 - tests/cpp/build_module_test.cc | 4 +- tests/python/integration/test_dot.py | 23 --- .../python/unittest/test_runtime_extension.py | 5 +- .../unittest/test_runtime_heterogeneous.py | 7 +- .../unittest/test_runtime_module_load.py | 4 +- .../unittest/test_target_codegen_llvm.py | 11 +- .../test_target_codegen_static_init.py | 16 +- .../unittest/test_target_codegen_vm_basic.py | 20 ++- .../unittest/test_target_custom_datatypes.py | 11 +- ....py => test_tir_analysis_verify_memory.py} | 45 +++-- .../unittest/test_tir_pass_bound_checkers.py | 22 +-- .../test_tir_pass_inject_double_buffer.py | 3 +- .../unittest/test_tir_pass_loop_partition.py | 4 +- .../unittest/test_tir_pass_storage_flatten.py | 4 +- ...test_tir_transform_combine_context_call.py | 7 +- .../test_tir_transform_lower_warp_memory.py | 4 +- ... => test_tir_transform_make_packed_api.py} | 10 +- .../test_tir_transform_thread_sync.py | 5 +- tutorials/dev/low_level_custom_pass.py | 2 +- 63 files changed, 608 insertions(+), 766 deletions(-) delete mode 100644 include/tvm/tir/lowered_func.h rename src/tir/{pass => analysis}/verify_memory.cc (86%) delete mode 100644 src/tir/ir/lowered_func.cc rename src/tir/{pass => transforms}/lower_custom_datatypes.cc (88%) rename src/tir/{pass/make_api.cc => transforms/make_packed_api.cc} (61%) rename src/tir/{pass => transforms}/remap_thread_axis.cc (73%) rename tests/python/unittest/{test_tir_pass_verify_memory.py => test_tir_analysis_verify_memory.py} (70%) rename tests/python/unittest/{test_tir_pass_makeapi.py => test_tir_transform_make_packed_api.py} (84%) diff --git a/apps/lldb/tvm.py b/apps/lldb/tvm.py index 811d32db6c75..135aeff5258a 100644 --- a/apps/lldb/tvm.py +++ b/apps/lldb/tvm.py @@ -46,7 +46,6 @@ def __lldb_init_module(debugger, _): "tvm::IterVarAttr", "tvm::IterVarRelation", "tvm::Layout", - "tir::LoweredFunc", "tvm::Map", "tvm::Map", "tvm::MemoryInfo", diff --git a/docs/dev/codebase_walkthrough.rst b/docs/dev/codebase_walkthrough.rst index b7eb06b32df6..a66328fef7c9 100644 --- a/docs/dev/codebase_walkthrough.rst +++ b/docs/dev/codebase_walkthrough.rst @@ -145,15 +145,6 @@ After lowering is done, ``build()`` function generates target machine code from Code generation is done by ``build_module()`` function, defined in ``python/tvm/target/codegen.py``. On the C++ side, code generation is implemented in ``src/target/codegen`` subdirectory. ``build_module()`` Python function will reach ``Build()`` function below in ``src/target/codegen/codegen.cc``: -:: - - runtime::Module Build(const Array& funcs, - const std::string& target) { - std::string build_f_name = "codegen.build_" + target; - const PackedFunc* bf = runtime::Registry::Get(build_f_name); - runtime::Module m = (*bf)(funcs, target); - return m; - } The ``Build()`` function looks up the code generator for the given target in the ``PackedFunc`` registry, and invokes the function found. For example, ``codegen.build_cuda`` function is registered in ``src/codegen/build_cuda_on.cc``, like this: diff --git a/include/tvm/driver/driver_api.h b/include/tvm/driver/driver_api.h index 64d51736b445..e6d442754446 100644 --- a/include/tvm/driver/driver_api.h +++ b/include/tvm/driver/driver_api.h @@ -32,8 +32,8 @@ #include #include #include +#include #include -#include #include #include @@ -43,15 +43,15 @@ namespace tvm { /*! -* \brief Build a LoweredFunc given a schedule, args and binds +* \brief Build an IRModule given a schedule, args and binds * \param sch The schedule to lower. * \param args The arguments to the function. * \param name The name of the lowered function. * \param binds Buffer assignments. * \param config The build configuration. -* \return The lowered function. +* \return The result module. */ -TVM_DLL Array lower( +TVM_DLL IRModule lower( te::Schedule sch, const Array& args, const std::string& name, @@ -59,44 +59,43 @@ TVM_DLL Array lower( const BuildConfig& config); /*! -* \brief Build a device and host module for a specific target from an array of lowered functions. +* \brief Build a device and host module for a specific target from an IRModule. * \param funcs The functions to be built. * \param target The target device to build for. * \param target_host The target for building host code. To use the default, pass Target() * \param config The build configuration. * \return The built module. */ -TVM_DLL runtime::Module build(const Array& funcs, +TVM_DLL runtime::Module build(const IRModule& funcs, const Target& target, const Target& target_host, const BuildConfig& config); /*! * \brief Build a device and host module for a specific target from a map - * contains target to a list of lowered functions pairs. This function is used + * contains target to IRModule. This function is used * for heterogeneous build. - * \param input The map contains target to a list of lowered functions pairs. + * \param input The map contains target to an IRModule. * \param target_host The target for building host code. To use the default, * pass Target(). * \param config The build configuration. * \return The built module that contains code for different processors. */ -TVM_DLL runtime::Module build(const Map>& input, +TVM_DLL runtime::Module build(const Map& input, const Target& target_host, const BuildConfig& config); /*! * \brief Build a device and host module for a specific target from a map - * contains target to a list of lowered functions pairs. This function is used + * contains target to IRModule. This function is used * for heterogeneous build. - * \param input The map contains target string to a list of lowered functions - * pairs. + * \param input The map contains target string to an IRModule. * \param target_host The target for building host code. To use the default, * pass Target(). * \param config The build configuration. * \return The built module that contains code for different processors. */ -TVM_DLL runtime::Module build(const Map>& input, +TVM_DLL runtime::Module build(const Map& input, const Target& target_host, const BuildConfig& config); } // namespace tvm diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index f63bf96ef2ab..b0776dee661f 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -297,6 +297,15 @@ class IRModule : public ObjectRef { CHECK(ptr != nullptr); return static_cast(ptr); } + + /*! + * \brief Construct an empty module. + * + * \returns The constructed module + */ + static IRModule Empty() { + return IRModule(Map()); + } /*! * \brief Construct a module from a standalone expression. * diff --git a/include/tvm/target/codegen.h b/include/tvm/target/codegen.h index c604eb5c93de..4b7ea56e705d 100644 --- a/include/tvm/target/codegen.h +++ b/include/tvm/target/codegen.h @@ -27,7 +27,6 @@ #include #include #include -#include #include #include @@ -41,17 +40,6 @@ using runtime::PackedFunc; using runtime::TVMArgs; using runtime::TVMRetValue; -/*! - * \brief Temporary backward compatible function to convert a list - * of LoweredFunc to a IRModule of PrimfFuncs - * \param funcs The input lowered function. - * \return The IRModule. - * - * \note This function is only used for code refactor and will be - * removed once the refactor completes. - */ -IRModule ToIRModule(const Array& funcs); - /*! * \brief Build a module from array of lowered function. * \param mod The Module to be built diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h index fe74a96ae118..6af99586d2f9 100644 --- a/include/tvm/tir/analysis.h +++ b/include/tvm/tir/analysis.h @@ -24,9 +24,12 @@ #ifndef TVM_TIR_ANALYSIS_H_ #define TVM_TIR_ANALYSIS_H_ +#include #include +#include #include + namespace tvm { namespace tir { @@ -59,6 +62,18 @@ struct ExprDeepEqual { */ Array UndefinedVars(const Stmt& stmt, const Array& defs); +/*! + * \brief Verify if memory accesses are legal for a specific target device type. + * + * In the case that tgt is cuda, if not all workload is bound with + * threads, CPU code is generated that tries to access GPU memory, + * which is illegal. This pass performs verification for this case. + * + * \param mod The module to be verified. + * \return Success of memory verification. + */ +void VerifyMemory(const IRModule& mod); + } // namespace tir } // namespace tvm #endif // TVM_TIR_ANALYSIS_H_ diff --git a/include/tvm/tir/ir_pass.h b/include/tvm/tir/ir_pass.h index 8ba008bf024d..e228ce32adab 100644 --- a/include/tvm/tir/ir_pass.h +++ b/include/tvm/tir/ir_pass.h @@ -31,7 +31,6 @@ #include #include #include -#include #include #include @@ -366,60 +365,6 @@ Stmt HoistIfThenElse(Stmt stmt); */ Stmt NarrowDataType(Stmt stmt, int target_bits); -/*! - * \brief Make an user callable API LoweredFunc. - * - * The main task of this function is to create code to : - * - Map the values in the api_args to Var that is required by body. - * - Insert assertions to check type/value of the passed arguments. - * - * \param body The body of the function. - * \param name The name of the function. - * \param api_args Arguments to the function, can be either Var, or Buffer - * \param num_unpacked_args Number of arguments that - * are processed in plain form instead of packed form. - * \param is_restricted Whether the caller can guarantee that each buffer argument do not overlap. - * It is recommended to set to true for optimized code if such invariant holds. - * - * \return a LoweredFunc with the specified signiture. - * - * \note - * The function signature have two cases - * - * let num_packed_args = len(api_args) - num_unpacked_args; - * - * if num_packed_args is zero: - * f(api_arg_0, api_arg_1, .., api_arg_n) where n == len(api_args) - * - * if num_packed_args is not zero: - * f(TVMArg* packed_args, int* packed_arg_type_ids, int num_packed_args, - * api_arg_k, api_arg_k+1, ... api_arg_n, - * TVMValue* out_ret_val, int* out_ret_tcode) - * - * where n == len(api_args), k == num_packed_args - * - * There is no thread_axis in generated function. - */ -LoweredFunc MakeAPI(Stmt body, - std::string name, - Array api_args, - int num_unpacked_args, - bool is_restricted); - -/*! - * \brief Remap the thread axis - * - * This can be used to get equivalent program which uses - * threadIdx.y in place of threadIdx.x by passing - * {"threadIdx.x": thread_axis("threadIdx.y")} - * - * - * \param f The device function to be lowered. - * \param axis_map The map from StringImm -> ItrVar - * \return Transformed function. - */ -LoweredFunc RemapThreadAxis(LoweredFunc f, Map axis_map); - /*! * \brief Rewrite the pointer content type of arguments, * as well as Alloc internal to the function to use @@ -432,31 +377,6 @@ LoweredFunc RemapThreadAxis(LoweredFunc f, Map axis_map); */ PrimFunc PointerValueTypeRewrite(PrimFunc f); -/*! - * \brief Lower custom datatypes. - * - * See tvm::datatypes::Registry for more information on adding custom datatypes. - * - * \param f The device function to be lowered. - * \param target The target device. - * \return Transformed function. - */ -LoweredFunc LowerCustomDatatypes(LoweredFunc f, const std::string& target); - -/*! - * \brief Verify if memory accesses are legal for a specific target device type. - * - * In the case that tgt is cuda, if not all workload is bound with - * threads, CPU code is generated that tries to access GPU memory, - * which is illegal. This pass performs verification for this case. - * - * \param func The function to be verified. - * \param device_type The target device type. - * \return Success of memory verification. - */ -bool VerifyMemory(LoweredFunc func, int device_type); - - /*! * \brief Verify the correctness of a GPU code * It will check the whether the amount of memory usage or the number of threads diff --git a/include/tvm/tir/lowered_func.h b/include/tvm/tir/lowered_func.h deleted file mode 100644 index 2d01c8958aef..000000000000 --- a/include/tvm/tir/lowered_func.h +++ /dev/null @@ -1,149 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/tir/lowered_func.h - * \brief Information about a lowered TVM function. - * This data structure is final step toward codegen. - */ -#ifndef TVM_TIR_LOWERED_FUNC_H_ -#define TVM_TIR_LOWERED_FUNC_H_ - -#include -#include -#include -#include - -namespace tvm { -namespace tir { - -// Internal node container of lowered function. -class LoweredFuncNode; - -/*! - * \brief LoweredFunc represents function after lowering. - * This is the final IR representation before codegen. - */ -class LoweredFunc : public FunctionRef { - public: - LoweredFunc() {} - explicit LoweredFunc(ObjectPtr n) : FunctionRef(n) {} - /*! - * \brief access the internal node container - * \return the pointer to the internal node container - */ - inline const LoweredFuncNode* operator->() const; - /*! \brief specify container node */ - using ContainerType = LoweredFuncNode; -}; - -/*! \brief specific type of lowered function */ -enum LoweredFuncType : int { - /*! \brief Function that can mix device and host calls */ - kMixedFunc = 0, - /*! \brief Only contains host code */ - kHostFunc = 1, - /*! \brief Only contains device code */ - kDeviceFunc = 2 -}; - -/*! \brief Node container of LoweredFunc */ -class LoweredFuncNode : public tir::FunctionBaseNode { - public: - /*! \brief The name of the function */ - std::string name; - /*! - * \brief The arguments of the function - * This function can only take pod type(int, float) and void* as arguments. - */ - Array args; - /*! - * \brief The IterVar axis of threads - * Each axis need host function to specify a size. - * \note Calling convention into LoweredFunc - * - * Assume we have a LoweredFunc f, a call into f - * Call(f, arg1, arg2, ..., arg_n, - * size_axis_1, size_axis_2, ... size_axis_m) - * - * Here n = len(args), m = len(thread_axis) - * - * The CodeGen should take this and translate this call - * to corresponding API specific kernel launchs or function calls. - */ - Array thread_axis; - /*! - * \brief The hint data type of Var handles defined in LetStmt - * Can be used as hint when generating type signiture. - * The creation rule is given by - * handle_data_type[var_handle] = make_const(the_type, 0); - * - * \note Expr is used instead Type, because Type cannot be hold by Map. - * constant Expr of given type is used. - */ - Map handle_data_type; - /*! \brief The type of the function */ - LoweredFuncType func_type{kMixedFunc}; - /*! \brief Whether this function is packed function */ - bool is_packed_func{true}; - /*! - * \brief Whether function ensures that argument pointers do not alias. - * This corresponds to restrict keyword in C. - */ - bool is_restricted{true}; - /*! \brief The body statment of the function */ - Stmt body; - /*! \return name of the operation */ - const std::string& func_name() const final { - return name; - } - // there is no return value, but return 1 - // to enable Call into this function. - int num_outputs() const final { - return 1; - } - void VisitAttrs(AttrVisitor* v) { - v->Visit("name", &name); - v->Visit("args", &args); - v->Visit("thread_axis", &thread_axis); - v->Visit("handle_data_type", &handle_data_type); - v->Visit("func_type", &func_type); - v->Visit("is_packed_func", &is_packed_func); - v->Visit("is_restricted", &is_restricted); - v->Visit("body", &body); - } - - static constexpr const char* _type_key = "LoweredFunc"; - TVM_DECLARE_FINAL_OBJECT_INFO(LoweredFuncNode, Object); -}; - -// Implementations of inline functions -inline const LoweredFuncNode* LoweredFunc::operator->() const { - return static_cast(get()); -} -} // namespace tir -} // namespace tvm - -namespace std { -template <> -struct hash<::tvm::tir::LoweredFunc> : public tvm::ObjectHash { -}; -} - -#endif // TVM_TIR_LOWERED_FUNC_H_ diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 211e344fa1d8..860014d774a4 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -58,6 +58,61 @@ TVM_DLL Pass CreatePrimFuncPass(const runtime::TypedPackedFunc< const std::string& name, const tvm::Array& required); +/*! + * \brief Transform the high-level PrimFunc to a low-level version + * that can be used as an API function. + * + * + * The main task of this function is to create code to : + * - Map the values in the api_args to Var that is required by body. + * - Insert assertions to check type/value of the passed arguments. + * + * \param num_unpacked_args Number of arguments that + * are processed in plain form instead of packed form. + * + * \note + * The function signature have two cases + * + * let num_packed_args = len(api_args) - num_unpacked_args; + * + * if num_packed_args is zero: + * f(api_arg_0, api_arg_1, .., api_arg_n) where n == len(api_args) + * + * if num_packed_args is not zero: + * f(TVMArg* packed_args, int* packed_arg_type_ids, int num_packed_args, + * api_arg_k, api_arg_k+1, ... api_arg_n, + * TVMValue* out_ret_val, int* out_ret_tcode) + * + * where n == len(api_args), k == num_packed_args + * + * \return The pass. + */ +TVM_DLL Pass MakePackedAPI(int num_unpacked_args); + + +/*! + * \brief Remap the thread axis + * + * This can be used to get equivalent program which uses + * threadIdx.y in place of threadIdx.x by passing + * {"threadIdx.x": thread_axis("threadIdx.y")} + * + * + * \return The pass. + */ +TVM_DLL Pass RemapThreadAxis(Map axis_map); + + +/*! + * \brief Lower custom datatypes. + * + * See tvm::datatypes::Registry for more information on adding custom datatypes. + * + * \return The pass. + */ +TVM_DLL Pass LowerCustomDatatypes(); + + /*! * \brief Bind the device type ofthe function to be * the device_type specified in the target attribute. diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index e4bd2009841f..0dd01e186034 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -17,9 +17,6 @@ # pylint: disable=invalid-name """The build utils in python. - -This module provides the functions to transform schedule to -LoweredFunc and compiled Module. """ import warnings @@ -30,7 +27,6 @@ from tvm.ir import CallingConv from tvm.target import codegen, BuildConfig from tvm.tir import ir_pass -from tvm.tir.stmt import LoweredFunc from tvm.te import tensor from tvm.te import schedule from tvm import target as _target @@ -136,8 +132,8 @@ def lower(sch, Returns ------- - f : LoweredFunc or Stmt - The result function, if with_api_wrapper=False + m : IRModule or Stmt + The result IRModule, if simple_mode=False Then the Stmt before make api is returned. """ cfg = BuildConfig.current() @@ -199,16 +195,21 @@ def lower(sch, if simple_mode: return stmt - return ir_pass.MakeAPI(stmt, name, arg_list, 0, cfg.restricted_func) + f = tvm.tir.PrimFunc(arg_list, stmt).with_attr( + "global_symbol", tvm.runtime.String(name)) + if cfg.restricted_func: + f = f.with_attr("tir.no_alias", True) + mod = tvm.IRModule({name: f}) + return tvm.tir.transform.MakePackedAPI()(mod) -def _build_for_device(flist, target, target_host): +def _build_for_device(input_mod, target, target_host): """Build the lowered functions for a device with the given compilation target. Parameters ---------- - flist : list of LoweredFunc + input_mod : IRModule The schedule to be built. target : str or :any:`tvm.target.Target` @@ -219,8 +220,8 @@ def _build_for_device(flist, target, target_host): Returns ------- - fhost : list of LoweredFunc - A list of lowered functions for the host. + fhost : IRModule + The host IRModule. mdev : tvm.module A module that contains device code. @@ -229,14 +230,13 @@ def _build_for_device(flist, target, target_host): target_host = _target.create(target_host) device_type = ndarray.context(target.target_name, 0).device_type - for func in flist: - if not ir_pass.VerifyMemory(func, device_type): - raise ValueError( - "Direct host side access to device memory is detected in %s. " - "Did you forget to bind?" % func.name) + mod_mixed = input_mod + mod_mixed = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod_mixed) + tvm.tir.analysis.verify_memory(mod_mixed) - mod_mixed = tvm.testing.LoweredFuncsToIRModule(flist) - opt_mixed = [tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))] + opt_mixed = [] + if len(mod_mixed.functions) == 1: + opt_mixed += [tvm.tir.transform.Apply(lambda f: f.with_attr("tir.is_entry_func", True))] if BuildConfig.current().detect_global_barrier: opt_mixed += [tvm.tir.transform.ThreadSync("global")] opt_mixed += [tvm.tir.transform.ThreadSync("shared"), @@ -292,7 +292,7 @@ def build(inputs, Parameters ---------- - inputs : tvm.te.Schedule, LoweredFunc, or dict of target to LoweredFunc list + inputs : tvm.te.Schedule, IRModule, or dict of target to IRModule The schedule to be built args : list of Buffer or Tensor or Var, optional @@ -326,7 +326,7 @@ def build(inputs, ________ There are two typical example uses of this function depending on the type of the argument `inputs`: - 1. it is a list of lowered functions: + 1. it is an IRModule. .. code-block:: python @@ -335,10 +335,10 @@ def build(inputs, B = te.placeholder((n,), name='B') C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name='C') s = tvm.te.create_schedule(C.op) - f = tvm.lower(s, [A, B, C], name="test_add") - m = tvm.build(f, target="llvm") + m = tvm.lower(s, [A, B, C], name="test_add") + rt_mod = tvm.build(m, target="llvm") - 2. it is a dict of compilation target to list of lowered functions: + 2. it is a dict of compilation target to IRModule. .. code-block:: python @@ -349,9 +349,9 @@ def build(inputs, s1 = tvm.te.create_schedule(C.op) with tvm.target.cuda() as cuda_tgt: s2 = topi.cuda.schedule_injective(cuda_tgt, [C]) - f1 = tvm.lower(s1, [A, B, C], name="test_add1") - f2 = tvm.lower(s2, [A, B, C], name="test_add2") - m = tvm.build({"llvm": [f1], "cuda": [f2]}, target_host="llvm") + m1 = tvm.lower(s1, [A, B, C], name="test_add1") + m2 = tvm.lower(s2, [A, B, C], name="test_add2") + rt_mod = tvm.build({"llvm": m1, "cuda": m2}, target_host="llvm") Note ---- @@ -360,45 +360,36 @@ def build(inputs, if isinstance(inputs, schedule.Schedule): if args is None: raise ValueError("args must be given for build from schedule") - flist = lower(inputs, args, - name=name, - binds=binds) - if isinstance(flist, LoweredFunc): - flist = [flist] - elif isinstance(inputs, LoweredFunc): - if args: - raise ValueError("args must be done when build from LoweredFunc.") - flist = [inputs] + input_mod = lower(inputs, args, + name=name, + binds=binds) elif isinstance(inputs, (list, tuple, container.Array)): - flist = inputs + merged_mod = tvm.IRModule({}) + for x in inputs: + merged_mod.update(x) + input_mod = merged_mod + elif isinstance(inputs, tvm.IRModule): + input_mod = inputs elif not isinstance(inputs, (dict, container.Map)): - raise ValueError("inputs must be Schedule, LoweredFunc, list of " - "LoweredFunc, or dict of target to list of " - "LoweredFunc.") + raise ValueError("inputs must be Schedule, IRModule or dict of target to IRModule") if not isinstance(inputs, (dict, container.Map)): target = _target.Target.current() if target is None else target target = target if target else "llvm" - target_flist = {target: flist} + target_input_mod = {target: input_mod} else: - target_flist = inputs + target_input_mod = inputs - for tar, flist in target_flist.items(): + for tar, mod in target_input_mod.items(): if not isinstance(tar, (str, _target.Target)): raise ValueError("The key of inputs must be str or " "_target.Target when inputs is dict.") - fname_set = set() - for x in flist: - if not isinstance(x, LoweredFunc): - raise ValueError("inputs must be Schedule, LoweredFunc, list " - "of LoweredFunc, or dict of str to list of " - "LoweredFunc.") - if x.name in fname_set: - raise ValueError("Duplicate function name %s" % x.name) - fname_set.add(x.name) + if not isinstance(mod, tvm.IRModule): + raise ValueError("inputs must be Schedule, IRModule," + "or dict of str to IRModule.") if not target_host: - for tar, _ in target_flist.items(): + for tar, _ in target_input_mod.items(): tar = _target.create(tar) device_type = ndarray.context(tar.target_name, 0).device_type if device_type == ndarray.cpu(0).device_type: @@ -410,8 +401,8 @@ def build(inputs, mod_host_all = tvm.IRModule({}) device_modules = [] - for tar, flist in target_flist.items(): - mod_host, mdev = _build_for_device(flist, tar, target_host) + for tar, input_mod in target_input_mod.items(): + mod_host, mdev = _build_for_device(input_mod, tar, target_host) mod_host_all.update(mod_host) device_modules.append(mdev) diff --git a/python/tvm/relay/backend/_backend.py b/python/tvm/relay/backend/_backend.py index df0347bd2bae..641ff04adab0 100644 --- a/python/tvm/relay/backend/_backend.py +++ b/python/tvm/relay/backend/_backend.py @@ -17,7 +17,6 @@ """The interface of expr function exposed from C++.""" import tvm._ffi import tvm.driver -from tvm.ir import container as _container @tvm._ffi.register_func("relay.backend.lower") @@ -40,7 +39,7 @@ def lower(sch, inputs, func_name, source_func): Returns ------- - lowered_funcs : List[tvm.LoweredFunc] + mod : tvm.IRModule The result of lowering. """ # pylint: disable=broad-except, import-outside-toplevel @@ -56,20 +55,17 @@ def lower(sch, inputs, func_name, source_func): msg += "-----------------------------\n" msg += source_func.astext() raise RuntimeError(msg) - return f if isinstance( - f, (_container.Array, tuple, list)) else [f] + return f @tvm._ffi.register_func("relay.backend.build") -def build(funcs, target, target_host=None): +def build(mod, target, target_host=None): """Backend build function. Parameters ---------- - funcs : List[tvm.LoweredFunc] or Dict[str, List[tvm.LoweredFunc]] - A list of lowered functions or dictionary mapping from targets to - lowered functions. - + mod : tvm.IRModule or Dict[str, tvm.IRModule] + Input module target : tvm.Target The target to run the code on. @@ -84,7 +80,7 @@ def build(funcs, target, target_host=None): """ if target_host == "": target_host = None - return tvm.driver.build(funcs, target=target, target_host=target_host) + return tvm.driver.build(mod, target=target, target_host=target_host) @tvm._ffi.register_func("relay._tensor_value_repr") diff --git a/python/tvm/relay/backend/graph_runtime_codegen.py b/python/tvm/relay/backend/graph_runtime_codegen.py index 762210dbe428..3e5f0157b32f 100644 --- a/python/tvm/relay/backend/graph_runtime_codegen.py +++ b/python/tvm/relay/backend/graph_runtime_codegen.py @@ -48,7 +48,7 @@ def __init__(self, mod, target): self._get_graph_json = self._mod["get_graph_json"] self._list_params_name = self._mod["list_params_name"] self._get_param_by_name = self._mod["get_param_by_name"] - self._get_lowered_funcs = self._mod["get_lowered_funcs"] + self._get_irmodule = self._mod["get_irmodule"] self._setup(mod, target) def _setup(self, mod, target): @@ -74,14 +74,14 @@ def codegen(self, func): ------- graph_json : str The graph json that can be consumed by runtime. - lowered_funcs : List[tvm.LoweredFunc] or Dict[str, List[tvm.LoweredFunc]] + mod : IRModule or Dict[str, IRModule] The lowered functions. params : Dict[str, tvm.nd.NDArray] Additional constant parameters. """ self._codegen(func) graph_json = self._get_graph_json() - lowered_func = self._get_lowered_funcs() + lowered_func = self._get_irmodule() param_names = self._list_params_name() params = {} for name in param_names: diff --git a/python/tvm/runtime/__init__.py b/python/tvm/runtime/__init__.py index 24db0e86f22b..235ef0cf219e 100644 --- a/python/tvm/runtime/__init__.py +++ b/python/tvm/runtime/__init__.py @@ -28,3 +28,4 @@ from .ndarray import context, cpu, gpu, opencl, cl, vulkan, metal, mtl from .ndarray import vpi, rocm, opengl, ext_dev, micro_dev from .module import load_module, enabled, system_lib +from .container import String diff --git a/python/tvm/target/build_config.py b/python/tvm/target/build_config.py index c105175d3e26..6a0dcf743a0d 100644 --- a/python/tvm/target/build_config.py +++ b/python/tvm/target/build_config.py @@ -20,9 +20,7 @@ import tvm.ir from tvm.runtime import Object -from tvm.ir import container from tvm.tir import Stmt -from tvm.tir.stmt import LoweredFunc from . import _ffi_api @@ -48,17 +46,13 @@ def decorate(self, func): def dump(*args, **kwargs): """dump function""" retv = func(*args, **kwargs) - if not isinstance(retv, (Stmt, LoweredFunc, container.Array)): + if not isinstance(retv, (Stmt,)): return retv fname = func.func_name if hasattr(func, 'func_name') else func.__name__ pname = str(self._pass_id) + "_" + fname + "_ir.cc" with open(pname, "a") as f: - out = retv.body if isinstance(retv, LoweredFunc) else retv + out = retv f.write(str(out)) - if isinstance(retv, container.Array): - for x in retv: - out = x.body if isinstance(x, LoweredFunc) else x - f.write("---------%s\n%s\n-----------\n"%(x.name, str(out))) self._pass_id += 1 return retv return dump diff --git a/python/tvm/testing.py b/python/tvm/testing.py index 077ac35f69a0..9c429302f25d 100644 --- a/python/tvm/testing.py +++ b/python/tvm/testing.py @@ -14,9 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + +# pylint: disable=invalid-name """ TVM testing utilities """ import logging import numpy as np +import tvm import tvm._ffi @@ -165,4 +168,40 @@ def compare_derivative(j, n_der, grad): x_name, grad.shape, dist, max_diff, avg_diff) +def MakeAPILegacy(stmt, name, args, num_unpacked_args, noalias): + """Legacy adapter to build a Module from statement. + + Used for migrating existing test cases only. + + Parameters + ---------- + stmt: Stmt + The input statement. + + name: str + The name of the funciton. + + args: list of Buffer or Vars + The function arguments + + num_unpacked_args: int + Number of unpacked arguments. + + nolias: bool + Whether allow noalias. + + Returns + ------- + mod : IRModule + The created IRModule. + """ + f = tvm.tir.PrimFunc(args, stmt).with_attr( + "global_symbol", tvm.runtime.String(name)) + f = f.with_attr("tir.is_entry_func", True) + if noalias: + f = f.with_attr("tir.no_alias", True) + mod = tvm.IRModule({name: f}) + return tvm.tir.transform.MakePackedAPI(num_unpacked_args)(mod) + + tvm._ffi._init_api("testing", __name__) diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index bd8e33fe4c3b..b5d9fb147722 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -29,7 +29,7 @@ from .stmt import Stmt, LetStmt, AssertStmt, ProducerConsumer, For from .stmt import BufferStore, Store, Provide, Allocate, AttrStmt, Free, Realize, SeqStmt -from .stmt import IfThenElse, Evaluate, Prefetch, LoweredFunc, stmt_seq, stmt_list +from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list from .function import PrimFunc diff --git a/python/tvm/tir/analysis/analysis.py b/python/tvm/tir/analysis/analysis.py index 84eeaac370c2..448d0e6c5f8e 100644 --- a/python/tvm/tir/analysis/analysis.py +++ b/python/tvm/tir/analysis/analysis.py @@ -55,3 +55,14 @@ def expr_deep_equal(lhs, rhs): tvm.ir.structural_equal """ return _ffi_api.expr_deep_equal(lhs, rhs) + + +def verify_memory(mod): + """Verify if module contains illegal host side direct memory access. + + Parameters + ---------- + mod: tvm.IRModule + The module to be verified. + """ + _ffi_api.verify_memory(mod) diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index 37946f66b1bb..0ed1762a889c 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -18,6 +18,7 @@ import tvm._ffi import tvm.runtime +from tvm.runtime import Object from tvm.ir import BaseFunc from .buffer import Buffer from .expr import Var @@ -54,6 +55,7 @@ def __init__(self, param_list = [] buffer_map = {} if buffer_map is None else buffer_map for x in params: + x = tvm.runtime.convert(x) if not isinstance(x, Object) else x if isinstance(x, Buffer): var = Var(x.name, dtype="handle") param_list.append(var) diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index 0badad3c092f..4531cdfc35ac 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -385,14 +385,6 @@ def __init__(self, func, value_index, dtype, bounds): _ffi_api.Prefetch, func, value_index, dtype, bounds) -@tvm._ffi.register_object -class LoweredFunc(Object): - """Represent a LoweredFunc in TVM.""" - MixedFunc = 0 - HostFunc = 1 - DeviceFunc = 2 - - def stmt_seq(*args): """Make sequence of statements diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index c823c1af5baa..64c31a5d9444 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -60,6 +60,36 @@ def _transform(func, mod, ctx): return _fpass.prim_func_pass(_transform, opt_level=0) +def LowerCustomDatatypes(): + """Lower custom datatypes. + + See tvm::datatypes::Registry for more information on adding custom datatypes. + + Returns + ------- + fpass : tvm.ir.transform.Pass + The result pass + """ + return _ffi_api.LowerCustomDatatypes() + + +def MakePackedAPI(num_unpacked_params=0): + """Transform the PrimFuncs in the module to a packed func API. + + Parameters + ---------- + num_unpacked_params : int + Number of parameters that we hope to directly pass via normal arguments + following the PackedFunc input signature. + + Returns + ------- + fpass : tvm.ir.transform.Pass + The result pass + """ + return _ffi_api.MakePackedAPI(num_unpacked_params) + + def BindDeviceType(): """Bind the device type of the function to be the device_type specified in the target attribute. diff --git a/src/contrib/hybrid/codegen_hybrid.h b/src/contrib/hybrid/codegen_hybrid.h index 6491491ec2b3..9784defcba88 100644 --- a/src/contrib/hybrid/codegen_hybrid.h +++ b/src/contrib/hybrid/codegen_hybrid.h @@ -27,7 +27,6 @@ #include #include #include -#include #include #include #include diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index d54d6f8773ce..ae1d5393ad34 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -26,8 +26,10 @@ #include #include +#include #include #include +#include #include #include @@ -39,7 +41,6 @@ namespace tvm { using runtime::TVMArgs; using runtime::TVMRetValue; using runtime::PackedFunc; -using tir::LoweredFunc; bool LLVMEnabled() { const runtime::PackedFunc* pf = runtime::Registry::Get("target.build.llvm"); @@ -166,17 +167,6 @@ tir::Stmt BuildStmt(te::Schedule sch, return stmt; } -Array lower(te::Schedule sch, - const Array& args, - const std::string& name, - const std::unordered_map& binds, - const BuildConfig& config) { - Array out_arg_list; - auto stmt = BuildStmt(sch, args, binds, true, &out_arg_list, config); - return Array({ tir::MakeAPI(stmt, name, out_arg_list, 0, config->restricted_func) }); -} - - transform::Pass BindTarget(Target target) { auto fpass = [target](tir::PrimFunc f, IRModule m, transform::PassContext ctx) { return WithAttr(std::move(f), tvm::attr::kTarget, target); @@ -198,18 +188,46 @@ transform::Pass FilterBy(FCond fcond) { } +IRModule lower(te::Schedule sch, + const Array& args, + const std::string& name, + const std::unordered_map& binds, + const BuildConfig& config) { + Array out_arg_list; + auto stmt = BuildStmt(sch, args, binds, true, &out_arg_list, config); + + Array params; + Map buffer_map; + + for (auto var : out_arg_list) { + if (auto* n = var.as()) { + params.push_back(GetRef(n)); + } else { + tir::Buffer buffer = Downcast(var); + tir::Var bptr(buffer->name, DataType::Handle()); + params.push_back(bptr); + buffer_map.Set(bptr, buffer); + } + } + + auto f = tir::PrimFunc(params, stmt, VoidType(), buffer_map); + f = WithAttr(std::move(f), "global_symbol", runtime::String(name)); + + if (config->restricted_func) { + f = WithAttr(std::move(f), "tir.no_alias", Integer(1)); + } + auto mod = IRModule(Map({{GlobalVar(name), f}})); + return tir::transform::MakePackedAPI(0)(mod); +} + + std::pair -split_dev_host_funcs(const Array& funcs, +split_dev_host_funcs(IRModule mod_mixed, const Target& target, const Target& target_host, const BuildConfig& config) { - for (const auto& x : funcs) { - CHECK(tir::VerifyMemory(x, target->device_type)) - << "Direct host side access to device memory is detected in " - << x->func_name() << ". Did you forget to bind?"; - } - - IRModule mod_mixed = codegen::ToIRModule(funcs); + mod_mixed = BindTarget(target)(std::move(mod_mixed)); + tir::VerifyMemory(mod_mixed); Array mixed_pass_list = {BindTarget(target)}; if (config->detect_global_barrier) { @@ -274,10 +292,9 @@ split_dev_host_funcs(const Array& funcs, // Build for heterogeneous execution. -runtime::Module build(const Map>& inputs, +runtime::Module build(const Map& inputs, const Target& target_host, const BuildConfig& config) { - Array fhost_all; std::vector device_modules; Target target_host_val = target_host; @@ -319,10 +336,10 @@ runtime::Module build(const Map>& inputs, } // Build for heterogeneous execution when target is a string. -runtime::Module build(const Map>& inputs, +runtime::Module build(const Map& inputs, const Target& target_host, const BuildConfig& config) { - Map> updated_input; + Map updated_input; for (const auto& it : inputs) { auto target = Target::Create(it.first); if (target->device_name == "vta") { @@ -334,11 +351,11 @@ runtime::Module build(const Map>& inputs, } // Build for homogeneous execution. -runtime::Module build(const Array& funcs, +runtime::Module build(const IRModule& funcs, const Target& target, const Target& target_host, const BuildConfig& config) { - Map> inputs = {{target, funcs}}; + Map inputs = {{target, funcs}}; return build(inputs, target_host, config); } diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 4073271073ea..eaf78bc1b0f7 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -38,7 +38,6 @@ namespace tvm { namespace relay { namespace backend { -using tir::LoweredFunc; using TargetsMap = Map; using namespace tvm::relay::transform; @@ -78,16 +77,16 @@ struct GraphCodegen { } Array GetExternalModules() { - return CallFunc >("get_external_modules", nullptr); + return CallFunc>("get_external_modules", nullptr); } - Map > GetLoweredFunc() { - return CallFunc > >("get_lowered_funcs", nullptr); + Map GetIRModule() { + return CallFunc>("get_irmodule", nullptr); } std::unordered_map GetParams() { std::unordered_map ret; - auto names = CallFunc >("list_params_name", nullptr); + auto names = CallFunc>("list_params_name", nullptr); for (auto expr : names) { auto key = expr.as()->value; ret[key] = CallFunc("get_param_by_name", key); @@ -152,9 +151,9 @@ class RelayBuildModule : public runtime::ModuleNode { this->SetParam(kv.first, kv.second->data); } }); - } else if (name == "get_lowered_funcs") { + } else if (name == "get_irmodule") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - *rv = this->graph_codegen_->GetLoweredFunc(); + *rv = this->graph_codegen_->GetIRModule(); }); } else if (name == "get_external_modules") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { @@ -452,7 +451,7 @@ class RelayBuildModule : public runtime::ModuleNode { ret_.graph_json = graph_codegen_->GetJSON(); ret_.params = graph_codegen_->GetParams(); - auto lowered_funcs = graph_codegen_->GetLoweredFunc(); + auto lowered_funcs = graph_codegen_->GetIRModule(); // When there is no lowered_funcs due to reasons such as optimization. if (lowered_funcs.size() == 0) { diff --git a/src/relay/backend/compile_engine.h b/src/relay/backend/compile_engine.h index 9bd6a4ef31b6..4a3a04d02dcd 100644 --- a/src/relay/backend/compile_engine.h +++ b/src/relay/backend/compile_engine.h @@ -27,7 +27,6 @@ #include #include -#include #include #include #include @@ -82,7 +81,8 @@ struct CachedFuncNode : public Object { /*! \brief The schedule to the function */ te::Schedule schedule; /*! \brief The lowered functions to support the function. */ - tvm::Array funcs; + IRModule funcs = IRModule::Empty(); + /*! \brief Parameter usage states in the shape function. */ tvm::Array shape_func_param_states; diff --git a/src/relay/backend/graph_runtime_codegen.cc b/src/relay/backend/graph_runtime_codegen.cc index 0587cd216e12..c7f1be82c371 100644 --- a/src/relay/backend/graph_runtime_codegen.cc +++ b/src/relay/backend/graph_runtime_codegen.cc @@ -55,7 +55,7 @@ using TargetsMap = std::unordered_map; /*! \brief Lowered outputs */ struct LoweredOutput { std::string graph_json; - Map > lowered_funcs; + Map lowered_funcs; Array external_mods; std::unordered_map params; }; @@ -214,19 +214,14 @@ class GraphRuntimeCodegen LoweredOutput ret; ret.graph_json = os.str(); ret.params = params_; + for (auto& kv : lowered_funcs_) { if (ret.lowered_funcs.count(kv.first) == 0) { - ret.lowered_funcs.Set(kv.first, Array()); - } - auto& vec = ret.lowered_funcs[kv.first]; - Array tmp; - for (auto f : kv.second) { - tmp.push_back(f); - } - for (auto f : vec) { - tmp.push_back(f); + ret.lowered_funcs.Set(kv.first, IRModule::Empty()); } - ret.lowered_funcs.Set(kv.first, tmp); + auto& mod = ret.lowered_funcs[kv.first]; + mod->Update(kv.second); + ret.lowered_funcs.Set(kv.first, mod); } ret.external_mods = compile_engine_->LowerExternalFunctions(); return ret; @@ -457,12 +452,9 @@ class GraphRuntimeCodegen CCacheKey key = (*pf0)(func, target); CachedFunc lowered_func = (*pf1)(compile_engine_, key); if (!lowered_funcs_.count(target->str())) { - lowered_funcs_[target->str()] = {}; + lowered_funcs_[target->str()] = IRModule::Empty(); } - for (auto f : lowered_func->funcs) { - lowered_funcs_[target->str()].insert(f); - } - + lowered_funcs_[target->str()]->Update(lowered_func->funcs); return GraphAddCallNode(op, _GetUniqueName(lowered_func->func_name), lowered_func->func_name); @@ -602,8 +594,7 @@ class GraphRuntimeCodegen /*! \brief plan memory of device result */ Map> storage_device_map_; /*! \brief lowered funcs */ - std::unordered_map> - lowered_funcs_; + std::unordered_map lowered_funcs_; /*! \brief name map */ std::unordered_map name_map_; /*! \brief compile engine */ @@ -655,7 +646,7 @@ class GraphRuntimeCodegenModule : public runtime::ModuleNode { CHECK_GT(this->output_.params.count(key), 0); *rv = this->output_.params[key]; }); - } else if (name == "get_lowered_funcs") { + } else if (name == "get_irmodule") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->output_.lowered_funcs; }); diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 4d15c76ccb75..78ebb0fc5383 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -226,6 +226,7 @@ std::vector ToAllocTensorShape32(NDArray shape) { return raw_shape; } + class VMFunctionCompiler : ExprFunctor { public: VMFunctionCompiler(VMCompilerContext* context, TargetsMap targets, Target target_host) @@ -407,12 +408,15 @@ class VMFunctionCompiler : ExprFunctor { CCacheKey key(func, target_host_); auto cfunc = engine_->LowerShapeFunc(key); int op_index = -1; - if (context_->seen_funcs.count(cfunc->funcs[0]) == 0) { + // pick the only function inside the context + CHECK_EQ(cfunc->funcs->functions.size(), 1); + auto pfunc = Downcast((*cfunc->funcs->functions.begin()).second); + if (context_->seen_funcs.count(pfunc) == 0) { op_index = context_->cached_funcs.size(); context_->cached_funcs.push_back(cfunc); - context_->seen_funcs[cfunc->funcs[0]] = op_index; + context_->seen_funcs[pfunc] = op_index; } else { - op_index = context_->seen_funcs[cfunc->funcs[0]]; + op_index = context_->seen_funcs[pfunc]; } // Prepare input and output registers @@ -494,13 +498,14 @@ class VMFunctionCompiler : ExprFunctor { context_->cached_funcs.push_back(cfunc); } else { // TODO(jroesch): support lowered funcs for multiple targets - CHECK_EQ(cfunc->funcs.size(), 1); - if (context_->seen_funcs.find(cfunc->funcs[0]) == context_->seen_funcs.end()) { + CHECK_EQ(cfunc->funcs->functions.size(), 1); + auto pfunc = Downcast((*cfunc->funcs->functions.begin()).second); + if (context_->seen_funcs.find(pfunc) == context_->seen_funcs.end()) { op_index = context_->cached_funcs.size(); context_->cached_funcs.push_back(cfunc); - context_->seen_funcs[cfunc->funcs[0]] = op_index; + context_->seen_funcs[pfunc] = op_index; } else { - op_index = context_->seen_funcs[cfunc->funcs[0]]; + op_index = context_->seen_funcs[pfunc]; } } @@ -862,11 +867,7 @@ void VMCompiler::Lower(IRModule mod, // update primitive function map size_t primitive_index = 0; for (const auto& cfunc : context_.cached_funcs) { - if (cfunc->target->str() == "ext_dev") { - exec_->primitive_map.insert({cfunc->func_name, primitive_index++}); - } else { - exec_->primitive_map.insert({cfunc->funcs[0]->name, primitive_index++}); - } + exec_->primitive_map.insert({cfunc->func_name, primitive_index++}); } } @@ -961,8 +962,6 @@ void VMCompiler::PopulateGlobalMap() { } void VMCompiler::Codegen() { - using tir::LoweredFunc; - if (!context_.module.defined()) { LOG(WARNING) << "Did you forget to call VMCompiler::Lower?"; return; @@ -971,15 +970,21 @@ void VMCompiler::Codegen() { if (cached_funcs.size() == 0) { return; } - std::unordered_map> funcs; + std::unordered_map funcs; + for (auto& cfunc : cached_funcs) { std::string target_str = cfunc->target->str(); + // NOTE: because module, is mutable, we need to make an + // explicit copy of the IRModule. + IRModule mod = cfunc->funcs; + mod.CopyOnWrite(); + if (target_str == "ext_dev") { continue; } else if (funcs.count(target_str) == 0) { - funcs.emplace(target_str, Array{cfunc->funcs[0]}); + funcs.emplace(target_str, mod); } else { - funcs[target_str].push_back(cfunc->funcs[0]); + funcs[target_str]->Update(mod); } } diff --git a/src/relay/backend/vm/compiler.h b/src/relay/backend/vm/compiler.h index f18e2c037459..c1040f1ed18e 100644 --- a/src/relay/backend/vm/compiler.h +++ b/src/relay/backend/vm/compiler.h @@ -76,7 +76,7 @@ struct VMCompilerContext { // List of cached functions std::vector cached_funcs; // The functions that have been lowered. - std::unordered_map seen_funcs; + std::unordered_map seen_funcs; }; diff --git a/src/relay/transforms/gradient.cc b/src/relay/transforms/gradient.cc index a3728e905922..d0ff169445fb 100644 --- a/src/relay/transforms/gradient.cc +++ b/src/relay/transforms/gradient.cc @@ -22,7 +22,6 @@ * \brief API for Automatic Differentiation for the Relay IR. */ #include -#include #include #include #include diff --git a/src/target/build_common.h b/src/target/build_common.h index 47ec8f032c40..fc45cef3a874 100644 --- a/src/target/build_common.h +++ b/src/target/build_common.h @@ -31,29 +31,12 @@ #include #include #include -#include #include #include #include "../runtime/meta_data.h" namespace tvm { namespace codegen { -// Extract function information from device function. -inline std::unordered_map -ExtractFuncInfo(const Array& funcs) { - std::unordered_map fmap; - for (tir::LoweredFunc f : funcs) { - runtime::FunctionInfo info; - for (size_t i = 0; i < f->args.size(); ++i) { - info.arg_types.push_back(f->args[i].dtype()); - } - for (size_t i = 0; i < f->thread_axis.size(); ++i) { - info.thread_axis_tags.push_back(f->thread_axis[i]->thread_tag); - } - fmap[f->name] = info; - } - return fmap; -} inline std::unordered_map ExtractFuncInfo(const IRModule& mod) { diff --git a/src/target/codegen.cc b/src/target/codegen.cc index 703328f8761f..0eceea81da17 100644 --- a/src/target/codegen.cc +++ b/src/target/codegen.cc @@ -43,50 +43,6 @@ namespace tvm { namespace codegen { -// convert legacy LoweredFunc to PrimFunc. -tir::PrimFunc ToPrimFunc(tir::LoweredFunc from) { - // remap args to attach type annotations. - Array args; - Map remap_vars; - - for (auto var : from->args) { - auto it = from->handle_data_type.find(var); - if (it != from->handle_data_type.end()) { - tir::Var new_var(var->name_hint, - PointerType(PrimType((*it).second->dtype))); - args.push_back(new_var); - remap_vars.Set(var, new_var); - } else { - args.push_back(var); - } - } - tir::PrimFunc func(args, Substitute(from->body, remap_vars)); - - func = WithAttr(std::move(func), attr::kGlobalSymbol, runtime::String(from->name)); - func = WithAttr(std::move(func), tir::attr::kDeviceThreadAxis, from->thread_axis); - if (from->func_type == tir::LoweredFuncType::kDeviceFunc) { - func = WithAttr(std::move(func), - attr::kCallingConv, Integer(CallingConv::kDeviceKernelLaunch)); - } - if (from->is_restricted) { - func = WithAttr(std::move(func), tir::attr::kNoAlias, Integer(1)); - } - return func; -} - -IRModule ToIRModule(const Array& funcs) { - Map functions; - for (size_t i = 0; i < funcs.size(); ++i) { - auto f = funcs[i]; - tir::PrimFunc pf = ToPrimFunc(f); - if (i == 0) { - pf = WithAttr(std::move(pf), tir::attr::kIsEntryFunc, Integer(1)); - } - functions.Set(GlobalVar(f->name), pf); - } - return IRModule(functions); -} - runtime::Module Build(IRModule mod, const Target& target) { if (BuildConfig::Current()->disable_assert) { mod = tir::transform::SkipAssert()(mod); @@ -284,9 +240,6 @@ runtime::Module PackImportsToLLVM(const runtime::Module& mod, TVM_REGISTER_GLOBAL("target.Build") .set_body_typed(Build); -TVM_REGISTER_GLOBAL("testing.LoweredFuncsToIRModule") -.set_body_typed(ToIRModule); - // Export two auxiliary function to the runtime namespace. TVM_REGISTER_GLOBAL("runtime.ModulePackImportsToC") .set_body_typed(PackImportsToC); diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 31465cd56bcb..450ebbcd02b8 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -448,7 +448,7 @@ CodeGenLLVM::CreateDebugInfo(llvm::Module* module) { auto debug_info = llvm::make_unique(); debug_info->di_builder_ = llvm::make_unique(*module); #endif - // TODO(tulloch): pass this information through relay::Span classes to the LoweredFunc instance? + // TODO(tulloch): pass this information through relay::Span classes to the IRModule instance? debug_info->file_ = debug_info->di_builder_->createFile("model.tvm", "/tmp/"); debug_info->compilation_unit_ = debug_info->di_builder_->createCompileUnit( llvm::dwarf::DW_LANG_C, debug_info->file_, "TVM", 0, "", 0, "", diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index 3f508a530c74..9ea77ac2d79f 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -67,20 +67,23 @@ class LLVMModuleNode final : public runtime::ModuleNode { } else if (name == "_get_target_triple") { std::string target_triple = tm_->getTargetTriple().str(); return PackedFunc([target_triple](TVMArgs args, TVMRetValue *rv) { - * rv = target_triple; + *rv = target_triple; }); } if (ee_ == nullptr) LazyInitJIT(); - // This LLVMModule is empty and no function can be retrieved. - if (entry_func_.empty()) return nullptr; - std::lock_guard lock(mutex_); - const std::string& fname = (name == runtime::symbol::tvm_module_main ? - entry_func_ : name); - TVMBackendPackedCFunc faddr = - reinterpret_cast(GetFunctionAddr(fname)); + TVMBackendPackedCFunc faddr; + if (name == runtime::symbol::tvm_module_main) { + const char* entry_name = reinterpret_cast( + GetGlobalAddr(runtime::symbol::tvm_module_main)); + CHECK(entry_name != nullptr) + << "Symbol " << runtime::symbol::tvm_module_main << " is not presented"; + faddr = reinterpret_cast(GetFunctionAddr(entry_name)); + } else { + faddr = reinterpret_cast(GetFunctionAddr(name)); + } if (faddr == nullptr) return PackedFunc(); return WrapPackedFunc(faddr, sptr_to_self); } @@ -205,6 +208,7 @@ class LLVMModuleNode final : public runtime::ModuleNode { std::unique_ptr cg = CodeGenLLVM::Create(tm_.get()); std::vector funcs; + std::string entry_func; for (auto kv : mod->functions) { CHECK(kv.second->IsInstance()) << "Can only lower IR Module with PrimFuncs"; @@ -212,7 +216,7 @@ class LLVMModuleNode final : public runtime::ModuleNode { if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); CHECK(global_symbol.defined()); - entry_func_ = global_symbol; + entry_func = global_symbol; } funcs.push_back(f); } @@ -225,8 +229,8 @@ class LLVMModuleNode final : public runtime::ModuleNode { cg->AddFunction(f); } - if (entry_func_.length() != 0) { - cg->AddMainFunction(entry_func_); + if (entry_func.length() != 0) { + cg->AddMainFunction(entry_func); } module_ = cg->Finish(); @@ -321,13 +325,7 @@ class LLVMModuleNode final : public runtime::ModuleNode { CHECK(ee_ != nullptr) << "Failed to initialize jit engine for " << mptr_->getTargetTriple(); ee_->runStaticConstructorsDestructors(false); - // setup context address. - // we will skip context setup if this LLVMModule is empty. - if (GetGlobalAddr(runtime::symbol::tvm_module_main) == 0) - return; - entry_func_ = - reinterpret_cast(GetGlobalAddr(runtime::symbol::tvm_module_main)); if (void** ctx_addr = reinterpret_cast( GetGlobalAddr(runtime::symbol::tvm_module_ctx))) { *ctx_addr = this; @@ -356,8 +354,6 @@ class LLVMModuleNode final : public runtime::ModuleNode { // The target configuration string std::string target_; - // Name of entry function. - std::string entry_func_; // JIT lock std::mutex mutex_; // execution engine diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h index c1894a379ddb..30ad890c923d 100644 --- a/src/target/source/codegen_c.h +++ b/src/target/source/codegen_c.h @@ -29,7 +29,6 @@ #include #include #include -#include #include #include #include diff --git a/src/target/spirv/codegen_spirv.h b/src/target/spirv/codegen_spirv.h index a5ccd549633d..edcee20f173f 100644 --- a/src/target/spirv/codegen_spirv.h +++ b/src/target/spirv/codegen_spirv.h @@ -27,7 +27,6 @@ #include #include #include -#include #include #include diff --git a/src/target/stackvm/codegen_stackvm.h b/src/target/stackvm/codegen_stackvm.h index 041c7a7225cf..fd370d285ea8 100644 --- a/src/target/stackvm/codegen_stackvm.h +++ b/src/target/stackvm/codegen_stackvm.h @@ -26,7 +26,6 @@ #include #include -#include #include #include #include diff --git a/src/tir/pass/verify_memory.cc b/src/tir/analysis/verify_memory.cc similarity index 86% rename from src/tir/pass/verify_memory.cc rename to src/tir/analysis/verify_memory.cc index 5e805f8f9560..d6a521f98487 100644 --- a/src/tir/pass/verify_memory.cc +++ b/src/tir/analysis/verify_memory.cc @@ -22,8 +22,10 @@ * \brief Pass to check if memory accesses are legal. */ #include -#include +#include #include +#include +#include namespace tvm { @@ -44,7 +46,7 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { public: /// Special member functions //@{ - explicit MemoryAccessVerifier(LoweredFunc f, int device_type) + explicit MemoryAccessVerifier(PrimFunc f, int device_type) : func_(f), dev_type_(device_type) {} virtual ~MemoryAccessVerifier() = default; MemoryAccessVerifier(const MemoryAccessVerifier &) = delete; @@ -116,7 +118,7 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { CHECK(V) << "Invalid Variable\n"; // Variable is from function args. Return true. - if (V == func_->args[0].get()) return true; + if (V == func_->params[0].get()) return true; // The value is expected to come from a tvm_struct_get Call. // Get the first argument of tvm_struct_get, and continue. @@ -179,18 +181,33 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { const ProducerConsumerNode *pc_{nullptr}; bool failure_{false}; ///< If the verification fails (i.e. has illegal access) //@} - LoweredFunc func_{nullptr}; ///< Function to be verified. + tir::PrimFunc func_{nullptr}; ///< Function to be verified. int dev_type_{kDLCPU}; ///< Device type std::unordered_map defs_; ///< Variable definitions }; } // namespace /// Interface of VerifyMemory pass -bool VerifyMemory(LoweredFunc func, int device_type) { - MemoryAccessVerifier v(func, device_type); - v.Run(); - return !v.Failed(); +void VerifyMemory(const IRModule& mod) { + for (auto kv : mod->functions) { + if (auto* n = kv.second.as()) { + PrimFunc func = GetRef(n); + auto target = func->GetAttr(tvm::attr::kTarget); + CHECK(target.defined()) + << "LowerWarpMemory: Require the target attribute"; + MemoryAccessVerifier v(func, target->device_type); + v.Run(); + if (v.Failed()) { + LOG(FATAL) + << "ValueError: Direct host side access to device memory is detected." + << " Did you forget to bind?\n" + << func; + } + } + } } +TVM_REGISTER_GLOBAL("tir.analysis.verify_memory") +.set_body_typed(VerifyMemory); } // namespace tir } // namespace tvm diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index eec7c108b8c8..6bbf6451b7ac 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -48,7 +48,7 @@ Buffer decl_buffer(Array shape, DataType dtype, std::string name) { return BufferNode::make( - Var(name, DataType::Handle()), + Var(name, PointerType(PrimType(dtype))), dtype, shape, Array(), diff --git a/src/tir/ir/lowered_func.cc b/src/tir/ir/lowered_func.cc deleted file mode 100644 index 8790f2b12e39..000000000000 --- a/src/tir/ir/lowered_func.cc +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file lowered_func.cc - */ -#include - -namespace tvm { -namespace tir { -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << "LoweredFunc(" << op->name << ", " << op << ")"; -}); - -TVM_REGISTER_NODE_TYPE(LoweredFuncNode); - - -} // namespace tir -} // namespace tvm diff --git a/src/tir/pass/ffi_api.cc b/src/tir/pass/ffi_api.cc index 83db1a900fc6..3083b6879635 100644 --- a/src/tir/pass/ffi_api.cc +++ b/src/tir/pass/ffi_api.cc @@ -105,13 +105,6 @@ TVM_REGISTER_GLOBAL("ir_pass.PostOrderVisit") }); }); -TVM_REGISTER_GLOBAL("ir_pass.LowerStorageAccess") -.set_body([](TVMArgs args, TVMRetValue *ret) { - LoweredFunc f = args[0]; - auto n = make_object(*f.operator->()); - n->body = LowerStorageAccessInfo(f->body); - *ret = LoweredFunc(n); -}); // make from two arguments #define REGISTER_PASS(PassName) \ @@ -128,7 +121,6 @@ REGISTER_PASS(VectorizeLoop); REGISTER_PASS(SkipVectorize); REGISTER_PASS(UnrollLoop); REGISTER_PASS(InjectCopyIntrin); -REGISTER_PASS(MakeAPI); REGISTER_PASS(StorageRewrite); REGISTER_PASS(CoProcSync); REGISTER_PASS(LowerStorageAccessInfo); @@ -138,9 +130,6 @@ REGISTER_PASS(InjectDoubleBuffer); REGISTER_PASS(LoopPartition); REGISTER_PASS(RemoveNoOp); REGISTER_PASS(LiftAttrScope); -REGISTER_PASS(RemapThreadAxis); -REGISTER_PASS(LowerCustomDatatypes); -REGISTER_PASS(VerifyMemory); REGISTER_PASS(VerifyGPUCode); REGISTER_PASS(DecorateDeviceScope); REGISTER_PASS(InstrumentBoundCheckers); diff --git a/src/tir/pass/storage_rewrite.cc b/src/tir/pass/storage_rewrite.cc index b4e6061a35d0..f3604b640349 100644 --- a/src/tir/pass/storage_rewrite.cc +++ b/src/tir/pass/storage_rewrite.cc @@ -994,29 +994,6 @@ class VectorAllocRewriter : public StmtExprMutator { }; -LoweredFunc PointerValueTypeRewrite(LoweredFunc f) { - auto n = make_object(*f.operator->()); - VectorAllocRewriter rewriter; - n->body = rewriter(n->body); - for (Var arg : f->args) { - if (arg.dtype().is_handle()) { - const auto& tvec = rewriter.acc_map_[arg.get()]; - if (tvec.size() == 1) { - PrimExpr dtype = make_const(tvec[0], 0); - n->handle_data_type.Set(arg, dtype); - } else { - // always set data type to be non vectorized so - // load/store can still work via scalarization - if (tvec.size() != 0 && !n->handle_data_type.count(arg)) { - PrimExpr dtype = make_const(tvec[0].with_lanes(1), 0); - n->handle_data_type.Set(arg, dtype); - } - } - } - } - return LoweredFunc(n); -} - PrimFunc PointerValueTypeRewrite(PrimFunc f) { auto* n = f.CopyOnWrite(); VectorAllocRewriter rewriter; diff --git a/src/tir/pass/lower_custom_datatypes.cc b/src/tir/transforms/lower_custom_datatypes.cc similarity index 88% rename from src/tir/pass/lower_custom_datatypes.cc rename to src/tir/transforms/lower_custom_datatypes.cc index b24fdf158f4a..6026f8c67567 100644 --- a/src/tir/pass/lower_custom_datatypes.cc +++ b/src/tir/transforms/lower_custom_datatypes.cc @@ -22,7 +22,9 @@ */ #include -#include +#include +#include +#include #include "../../target/datatype/registry.h" namespace tvm { @@ -129,11 +131,26 @@ class CustomDatatypesLowerer : public StmtExprMutator { std::string target_; }; -LoweredFunc LowerCustomDatatypes(LoweredFunc f, const std::string& target) { - auto n = make_object(*f.operator->()); - n->body = CustomDatatypesLowerer(target)(n->body); - return LoweredFunc(n); + +namespace transform { + +Pass LowerCustomDatatypes() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + auto target = f->GetAttr(tvm::attr::kTarget); + CHECK(target.defined()) + << "LowerCustomDatatypes: Require the target attribute"; + + n->body = CustomDatatypesLowerer(target->target_name)(std::move(n->body)); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.LowerCustomDatatypes", {}); } +TVM_REGISTER_GLOBAL("tir.transform.LowerCustomDatatypes") +.set_body_typed(LowerCustomDatatypes); + +} // namespace transform + } // namespace tir } // namespace tvm diff --git a/src/tir/pass/make_api.cc b/src/tir/transforms/make_packed_api.cc similarity index 61% rename from src/tir/pass/make_api.cc rename to src/tir/transforms/make_packed_api.cc index 861cd43e5376..c49b04442b2f 100644 --- a/src/tir/pass/make_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -18,20 +18,24 @@ */ /*! - * \file make_api.cc Build API function. + * \file make_packed_api.cc Lower PrimFunc to use the packed function API. */ #include #include #include +#include #include #include #include +#include +#include + #include #include #include -#include "ir_util.h" -#include "arg_binder.h" +#include "../pass/ir_util.h" +#include "../pass/arg_binder.h" namespace tvm { namespace tir { @@ -40,14 +44,18 @@ inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) { return AssertStmtNode::make(lhs == rhs, msg, EvaluateNode::make(0)); } -LoweredFunc MakeAPI(Stmt body, - std::string name, - Array api_args, - int num_unpacked_args, - bool is_restricted) { +PrimFunc MakePackedAPI(PrimFunc&& func, + int num_unpacked_args) { + auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); + CHECK(global_symbol.defined()) + << "MakePackedAPI: Expect PrimFunc to have the global_symbol attribute"; + std::string name_hint = global_symbol; + + auto* func_ptr = func.CopyOnWrite(); const Stmt nop = EvaluateNode::make(0); - int num_args = static_cast(api_args.size()); + int num_args = static_cast(func_ptr->params.size()); CHECK_LE(num_unpacked_args, num_args); + int num_packed_args = num_args - num_unpacked_args; // Data field definitions // The packed fields @@ -69,9 +77,10 @@ LoweredFunc MakeAPI(Stmt body, // local function definitions // load i-th argument as type t auto f_arg_value = [&](DataType t, int i) { - Array call_args{v_packed_args, - IntImm(DataType::Int(32), i), - IntImm(DataType::Int(32), intrinsic::kTVMValueContent)}; + Array call_args{ + v_packed_args, + IntImm(DataType::Int(32), i), + IntImm(DataType::Int(32), intrinsic::kTVMValueContent)}; // load 64 bit version DataType api_type = APIType(t); PrimExpr res = CallNode::make( @@ -83,13 +92,7 @@ LoweredFunc MakeAPI(Stmt body, } return res; }; - // get declaration of argument i - auto f_arg_decl = [&](int i) { - std::ostringstream os; - os << "arg" << i; - const VarNode* v = api_args[i].as(); - return Var(os.str(), v ? v->dtype: DataType::Handle()); - }; + // --------------------------- // start of logics // add signiture for packed arguments. @@ -99,16 +102,25 @@ LoweredFunc MakeAPI(Stmt body, args.push_back(v_num_packed_args); std::ostringstream os; - os << name << ": num_args should be " << num_packed_args; + os << name_hint << ": num_args should be " << num_packed_args; seq_init.emplace_back( MakeAssertEQ(v_num_packed_args, num_packed_args, os.str())); } - // Save the input variables and buffers that will be bound later. - std::vector > var_defs; - std::vector > buf_defs; - for (int i = 0; i < static_cast(api_args.size()); ++i) { - Var v_arg = f_arg_decl(i); + // Need to re-declare vars, in case some arguments also appears in the buffer. + std::vector > var_def; + std::vector > buffer_def; + + for (int i = 0; i < static_cast(func_ptr->params.size()); ++i) { + Var param = func_ptr->params[i]; + Var v_arg = Var("arg" + std::to_string(i), param->dtype); + + auto it = func_ptr->buffer_map.find(param); + if (it != func_ptr->buffer_map.end()) { + buffer_def.emplace_back(v_arg, (*it).second); + } else { + var_def.emplace_back(v_arg, param); + } if (i < num_packed_args) { // Value loads seq_init.emplace_back(LetStmtNode::make( @@ -123,35 +135,26 @@ LoweredFunc MakeAPI(Stmt body, DataType t = v_arg.dtype(); if (t.is_handle()) { std::ostringstream msg; - msg << name << ": Expect arg[" << i << "] to be pointer"; + msg << name_hint << ": Expect arg[" << i << "] to be pointer"; seq_check.emplace_back( AssertStmtNode::make(tcode == kTVMOpaqueHandle || - tcode == kTVMNDArrayHandle || - tcode == kTVMDLTensorHandle || - tcode == kTVMNullptr, msg.str(), nop)); + tcode == kTVMNDArrayHandle || + tcode == kTVMDLTensorHandle || + tcode == kTVMNullptr, msg.str(), nop)); } else if (t.is_int() || t.is_uint()) { std::ostringstream msg; - msg << name << ": Expect arg[" << i << "] to be int"; + msg << name_hint << ": Expect arg[" << i << "] to be int"; seq_check.emplace_back(AssertStmtNode::make(tcode == kDLInt, msg.str(), nop)); } else { CHECK(t.is_float()); std::ostringstream msg; - msg << name << ": Expect arg[" << i << "] to be float"; + msg << name_hint << ": Expect arg[" << i << "] to be float"; seq_check.emplace_back( AssertStmtNode::make(tcode == kDLFloat, msg.str(), nop)); } } else { args.push_back(v_arg); } - // add checks for functions. - if (api_args[i].as()) { - var_defs.emplace_back(std::make_pair(Downcast(api_args[i]), v_arg)); - } else { - // Buffer checks - CHECK(api_args[i].as()) - << "api_args can only be Buffer or Var"; - buf_defs.emplace_back(std::make_pair(Downcast(api_args[i]), v_arg)); - } } // allow return value if the function is packed. @@ -170,24 +173,22 @@ LoweredFunc MakeAPI(Stmt body, // either 0 or the original stride will be correctly used. Checks here have // to use the args that may have no let bining yet. Therefore, hoisting let // binding for args before buffer declaration is needed. - for (const auto& arg : var_defs) { - binder.Bind(arg.first, arg.second, arg.second->name_hint, true); + for (const auto& kv : var_def) { + binder.Bind(kv.second, kv.first, kv.first->name_hint, true); + } + + for (const auto& kv : buffer_def) { + binder.BindDLTensor(kv.second, device_type, device_id, + kv.first, kv.first->name_hint); } - for (const auto& buf_arg : buf_defs) { - binder.BindDLTensor(buf_arg.first, device_type, device_id, - buf_arg.second, buf_arg.second->name_hint); + if (num_unpacked_args == 0) { + func = WithAttr(std::move(func), tvm::attr::kCallingConv, Integer(CallingConv::kCPackedFunc)); } - ObjectPtr n = make_object(); - n->name = name; - n->args = args; - n->handle_data_type = binder.def_handle_dtype(); - n->is_packed_func = num_unpacked_args == 0; - n->is_restricted = is_restricted; - body = AttrStmtNode::make( + auto body = AttrStmtNode::make( make_zero(DataType::Int(32)), attr::compute_scope, - StringImmNode::make(name + "_compute_"), body); + StringImmNode::make(name_hint + "_compute_"), func_ptr->body); // Set device context if (vmap.count(device_id.get())) { PrimExpr node = StringImmNode::make("default"); @@ -203,21 +204,59 @@ LoweredFunc MakeAPI(Stmt body, device_type, device_id}, CallNode::Intrinsic))); body = SeqStmt({set_device, body}); } - n->body = MergeNest( + func_ptr->body = MergeNest( {seq_init, binder.init_nest(), seq_check, binder.asserts()}, body); - LoweredFunc f(n); - Array undefined = UndefinedVars(f->body, f->args); + func_ptr->params = args; + + Array undefined = UndefinedVars(func_ptr->body, func_ptr->params); if (undefined.size() != 0) { std::ostringstream os; for (Var v : undefined) { os << " \'" << v->name_hint << "\' "; } - os << " does not appear in api_args"; + os << " is not bound to any variables"; LOG(FATAL) << "Not all Vars are passed in api_args: " << os.str(); } - return f; + + + func_ptr->buffer_map = Map(); + func_ptr->checked_type_ = func_ptr->func_type_annotation(); + func_ptr->ret_type = PrimType(DataType::Int(32)); + + // return the function. + return std::move(func); } +namespace transform { + +Pass MakePackedAPI(int num_unpacked_args) { + auto pass_func = [num_unpacked_args](IRModule m, PassContext ctx) { + IRModuleNode* mptr = m.CopyOnWrite(); + std::vector > updates; + + for (const auto& kv : mptr->functions) { + if (auto* n = kv.second.as()) { + PrimFunc func = GetRef(n); + if (func->GetAttr(tvm::attr::kCallingConv, 0)->value + == static_cast(CallingConv::kDefault)) { + auto updated_func = MakePackedAPI(std::move(func), num_unpacked_args); + updates.push_back({kv.first, updated_func}); + } + } + } + + for (const auto& pair : updates) { + mptr->AddUnchecked(pair.first, pair.second); + } + return m; + }; + + return tvm::transform::CreateModulePass( + pass_func, 0, "tir.MakePackedAPI", {}); +} +TVM_REGISTER_GLOBAL("tir.transform.MakePackedAPI") +.set_body_typed(MakePackedAPI); +} // namespace transform } // namespace tir } // namespace tvm diff --git a/src/tir/pass/remap_thread_axis.cc b/src/tir/transforms/remap_thread_axis.cc similarity index 73% rename from src/tir/pass/remap_thread_axis.cc rename to src/tir/transforms/remap_thread_axis.cc index 4fa5dd3cbe9b..f695b3c777aa 100644 --- a/src/tir/pass/remap_thread_axis.cc +++ b/src/tir/transforms/remap_thread_axis.cc @@ -22,7 +22,8 @@ */ #include #include -#include +#include +#include #include @@ -74,8 +75,8 @@ class ThreadAxisRewriter : private StmtExprMutator { std::unordered_map vmap_; }; -LoweredFunc -RemapThreadAxis(LoweredFunc f, Map thread_map) { + +PrimFunc RemapThreadAxis(PrimFunc&& f, Map thread_map) { std::unordered_map tmap; for (const auto& kv : thread_map) { const StringImmNode* str = kv.first.as(); @@ -83,18 +84,33 @@ RemapThreadAxis(LoweredFunc f, Map thread_map) { tmap[str->value] = kv.second; } - CHECK_EQ(f->func_type, kDeviceFunc); - auto n = make_object(*f.operator->()); + auto thread_axis = f->GetAttr >(tir::attr::kDeviceThreadAxis); + auto* n = f.CopyOnWrite(); + // replace the thread axis - for (size_t i = 0; i < n->thread_axis.size(); ++i) { - auto it = tmap.find(n->thread_axis[i]->thread_tag); + for (size_t i = 0; i < thread_axis.size(); ++i) { + auto it = tmap.find(thread_axis[i]->thread_tag); if (it != tmap.end()) { - n->thread_axis.Set(i, it->second); + thread_axis.Set(i, it->second); } } - n->body = ThreadAxisRewriter(tmap).Rewrite(n->body); - return LoweredFunc(n); + n->body = ThreadAxisRewriter(tmap).Rewrite(std::move(n->body)); + return WithAttr(std::move(f), tir::attr::kDeviceThreadAxis, thread_axis); } + +namespace transform { + +Pass RemapThreadAxis(Map thread_map) { + auto pass_func = [thread_map](PrimFunc f, IRModule m, PassContext ctx) { + return RemapThreadAxis(std::move(f), thread_map); + }; + return CreatePrimFuncPass(pass_func, 0, "tir.RemapThreadAxis", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.RemapThreadAxis") +.set_body_typed(RemapThreadAxis); + +} // namespace transform } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 838ad82d974f..ae32bdcbadea 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -264,7 +264,6 @@ class HostDeviceSplitter : public StmtMutator { std::string name_prefix_; // Number of device functions. int device_func_counter_{0}; - std::vector device_funcs_; std::unordered_map handle_data_type_; }; diff --git a/tests/cpp/build_module_test.cc b/tests/cpp/build_module_test.cc index c2c808fa7895..9333a3470715 100644 --- a/tests/cpp/build_module_test.cc +++ b/tests/cpp/build_module_test.cc @@ -117,8 +117,8 @@ TEST(BuildModule, Heterogeneous) { std::unordered_map binds; auto lowered_s1 = lower(s1, args1, "elemwise_add", binds, config); auto lowered_s2 = lower(s2, args2, "elemwise_sub", binds, config); - Map> inputs = {{target_cuda, lowered_s1}, - {target_llvm, lowered_s2}}; + Map inputs = {{target_cuda, lowered_s1}, + {target_llvm, lowered_s2}}; auto module = build(inputs, Target(), config); // Assertion for build. diff --git a/tests/python/integration/test_dot.py b/tests/python/integration/test_dot.py index 4f2b6aa99fcd..27f3788fef5c 100644 --- a/tests/python/integration/test_dot.py +++ b/tests/python/integration/test_dot.py @@ -18,29 +18,6 @@ from tvm import te import numpy as np -def lower(s, args, name="mydot"): - binds = {} - arg_list = [] - - for x in args: - assert isinstance(x, te.tensor.Tensor) - buf = tvm.tir.decl_buffer(x.shape, dtype=x.dtype, name=x.op.name) - binds[x] = buf - arg_list.append(buf) - s = s.normalize() - bounds = tvm.te.schedule.InferBound(s) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - stmt = tvm.tir.ir_pass.StorageFlatten(stmt, binds, 16) - stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt) - stmt = tvm.tir.ir_pass.Simplify(stmt) - fapi = tvm.tir.ir_pass.MakeAPI(stmt, name, arg_list, 0, True) - fapi = tvm.tir.ir_pass.LowerTVMBuiltin(fapi) - return fapi - - -def mybuild(fapi, target="llvm"): - return - def test_dot(): nn = 12 diff --git a/tests/python/unittest/test_runtime_extension.py b/tests/python/unittest/test_runtime_extension.py index 13de67efa8f4..d9088b64168d 100644 --- a/tests/python/unittest/test_runtime_extension.py +++ b/tests/python/unittest/test_runtime_extension.py @@ -38,8 +38,9 @@ def test_dltensor_compatible(): with ib.for_range(0, n - 1, "i") as i: A[i + 1] = A[i] + 1 stmt = ib.get() - fapi = tvm.tir.ir_pass.MakeAPI(stmt, "arange", [Ab], 0, True) - mod = tvm.testing.LoweredFuncsToIRModule([fapi]) + + + mod = tvm.testing.MakeAPILegacy(stmt, "arange", [Ab], 0, True) mod = tvm.tir.transform.LowerTVMBuiltin()(mod) f = tvm.target.codegen.build_module(mod, "stackvm") a = tvm.nd.array(np.zeros(10, dtype=dtype)) diff --git a/tests/python/unittest/test_runtime_heterogeneous.py b/tests/python/unittest/test_runtime_heterogeneous.py index 8ca61c1920ba..343b86717028 100644 --- a/tests/python/unittest/test_runtime_heterogeneous.py +++ b/tests/python/unittest/test_runtime_heterogeneous.py @@ -156,7 +156,7 @@ def check_device(device, target_device): elemwise_sub], name="elemwise_sub") - target_flist = {target_device: [lower_add], target_host: [lower_sub]} + target_flist = {target_device: lower_add, target_host: lower_sub} mhost = tvm.build(target_flist, target_host=target_host) ctx = [host_ctx, device_ctx] mod = graph_runtime.create(graph, mhost, ctx) @@ -354,8 +354,9 @@ def check_device(device, target_device): elemwise_sub], name="elemwise_sub") - target_flist = {target_device: [lower_add0, lower_add1], target_host: - [lower_sub]} + lower_add0.update(lower_add1) + target_flist = {target_device: lower_add0, target_host: + lower_sub} mhost = tvm.build(target_flist, target_host=target_host) ctx = [host_ctx, device_ctx] params = {} diff --git a/tests/python/unittest/test_runtime_module_load.py b/tests/python/unittest/test_runtime_module_load.py index 37ccb5e47830..f6abebd2fbc4 100644 --- a/tests/python/unittest/test_runtime_module_load.py +++ b/tests/python/unittest/test_runtime_module_load.py @@ -57,8 +57,8 @@ def save_object(names): tvm.tir.Store(Ab.data, tvm.tir.Load(dtype, Ab.data, i) + 1, i + 1)) - fapi = tvm.tir.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True) - m = tvm.driver.build(fapi, target="llvm") + m = tvm.testing.MakeAPILegacy(stmt, "ramp", [Ab], 0, True) + m = tvm.driver.build(m, target="llvm") for name in names: m.save(name) diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py index 45554c5475a3..34135c6ef7ee 100644 --- a/tests/python/unittest/test_target_codegen_llvm.py +++ b/tests/python/unittest/test_target_codegen_llvm.py @@ -22,6 +22,7 @@ import ctypes import math + def test_llvm_intrin(): ib = tvm.tir.ir_builder.create() n = tvm.runtime.convert(4) @@ -34,7 +35,8 @@ def test_llvm_intrin(): tvm.tir.Call( "int32", "prefetch", args, tvm.tir.Call.Intrinsic, None, 0))) body = ib.get() - func = tvm.tir.ir_pass.MakeAPI(body, "prefetch", [A], 0, True) + + func = tvm.testing.MakeAPILegacy(body, "prefetch", [A], 0, True) fcode = tvm.build(func, None, "llvm") @@ -85,7 +87,7 @@ def test_llvm_lookup_intrin(): x = tvm.tir.call_llvm_intrin("uint8x8", "llvm.ctpop.i8", tvm.tir.const(1, 'uint32'), A) ib.emit(x) body = ib.get() - func = tvm.tir.ir_pass.MakeAPI(body, "ctpop", [A], 1, True) + func = tvm.testing.MakeAPILegacy(body, "ctpop", [A], 1, True) fcode = tvm.build(func, None, "llvm") @@ -307,8 +309,9 @@ def check_llvm(): f2 = tvm.lower(s, [A, B, C], name="fadd1") f1 = tvm.lower(s, [A, B, C], name="fadd2") m = tvm.build([f1, f2], "llvm") - fadd1 = m['fadd1'] fadd2 = m['fadd2'] + fadd1 = m['fadd1'] + ctx = tvm.cpu(0) # launch the kernel. n = nn @@ -665,6 +668,7 @@ def vectorizer(op): tvm.testing.assert_allclose(c_.asnumpy(), (a_.asnumpy() * 2).astype('int32')) if __name__ == "__main__": + test_multiple_func() test_llvm_large_uintimm() test_llvm_import() test_alignment() @@ -676,7 +680,6 @@ def vectorizer(op): test_llvm_vadd_pipeline() test_llvm_add_pipeline() test_llvm_intrin() - test_multiple_func() test_llvm_flip_pipeline() test_llvm_madd_pipeline() test_llvm_temp_space() diff --git a/tests/python/unittest/test_target_codegen_static_init.py b/tests/python/unittest/test_target_codegen_static_init.py index a9fa35f1a533..bd4d0d8cd52d 100644 --- a/tests/python/unittest/test_target_codegen_static_init.py +++ b/tests/python/unittest/test_target_codegen_static_init.py @@ -19,6 +19,18 @@ import ctypes import numpy as np + +def MakeAPILegacy(stmt, name, args, num_unpacked_args, noalias): + """Legacy adapter to create a API""" + f = tvm.tir.PrimFunc(args, stmt).with_attr( + "global_symbol", tvm.runtime.String(name)) + f = f.with_attr("tir.is_entry_func", True) + if noalias: + f = f.with_attr("tir.no_alias", True) + mod = tvm.IRModule.from_expr(f) + return tvm.tir.transform.MakePackedAPI()(mod) + + def test_static_callback(): dtype = 'int64' n = te.size_var('n') @@ -32,7 +44,7 @@ def test_static_callback(): with ib.for_range(0, n, "i", for_type="parallel") as i: A[i] = A[i] + 1 stmt = ib.get() - fapi = tvm.tir.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True) + fapi = tvm.testing.MakeAPILegacy(stmt, "ramp", [Ab], 0, True) f = tvm.driver.build(fapi, target="llvm") a = tvm.nd.array(np.zeros(10, dtype=dtype)) f(a) @@ -55,7 +67,7 @@ def test_cb(sh, A): return sh stmt = ib.get() - fapi = tvm.tir.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True) + fapi = tvm.testing.MakeAPILegacy(stmt, "ramp", [Ab], 0, True) f = tvm.driver.build(fapi, target="llvm") a = tvm.nd.array(np.zeros(10, dtype=dtype)) f(a) diff --git a/tests/python/unittest/test_target_codegen_vm_basic.py b/tests/python/unittest/test_target_codegen_vm_basic.py index 26464ceedfd7..ee0d89bf298b 100644 --- a/tests/python/unittest/test_target_codegen_vm_basic.py +++ b/tests/python/unittest/test_target_codegen_vm_basic.py @@ -26,6 +26,18 @@ def run_jit(fapi, check): s = f.get_source() check(f) + +def MakeAPILegacy(stmt, name, args, num_unpacked_args, noalias): + """Legacy adapter to create a API""" + f = tvm.tir.PrimFunc(args, stmt).with_attr( + "global_symbol", tvm.runtime.String(name)) + f = f.with_attr("tir.is_entry_func", True) + if noalias: + f = f.with_attr("tir.no_alias", True) + mod = tvm.IRModule.from_expr(f) + return tvm.tir.transform.MakePackedAPI()(mod) + + def test_stack_vm_basic(): a = tvm.nd.array(np.zeros(10, dtype='float32')) @tvm.register_func @@ -36,7 +48,7 @@ def tvm_call_back_get_shape(shape0): n = te.size_var('n') Ab = tvm.tir.decl_buffer((n, ), "float32") stmt = tvm.tir.Evaluate(tvm.tir.call_packed("tvm_call_back_get_shape", Ab.shape[0])) - fapi = tvm.tir.ir_pass.MakeAPI(stmt, "print_shape", [Ab], 0, True) + fapi = tvm.testing.MakeAPILegacy(stmt, "print_shape", [Ab], 0, True) run_jit(fapi, lambda f: f(a)) @@ -57,7 +69,7 @@ def test_stack_vm_loop(): ib.emit(tvm.tir.call_packed("tvm_stack_vm_print", i)) stmt = ib.get() - fapi = tvm.tir.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True) + fapi = tvm.testing.MakeAPILegacy(stmt, "ramp", [Ab], 0, True) a = tvm.nd.array(np.zeros(10, dtype=dtype)) def check(f): f(a) @@ -79,7 +91,7 @@ def test_stack_vm_cond(): A[i + 1] = A[i] + 2 stmt = ib.get() - fapi = tvm.tir.ir_pass.MakeAPI(stmt, "test", [Ab], 0, True) + fapi = tvm.testing.MakeAPILegacy(stmt, "test", [Ab], 0, True) def check(f): a = tvm.nd.array(np.zeros(10, dtype=dtype)) f(a) @@ -98,7 +110,7 @@ def test_vm_parallel(): with ib.for_range(0, n, "i", for_type="parallel") as i: A[i] = A[i] + 1 stmt = ib.get() - fapi = tvm.tir.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True) + fapi = tvm.testing.MakeAPILegacy(stmt, "ramp", [Ab], 0, True) def check(f): a = tvm.nd.array(np.zeros(10, dtype=dtype)) f(a) diff --git a/tests/python/unittest/test_target_custom_datatypes.py b/tests/python/unittest/test_target_custom_datatypes.py index 32f6e1865b63..f6723e2b1ee1 100644 --- a/tests/python/unittest/test_target_custom_datatypes.py +++ b/tests/python/unittest/test_target_custom_datatypes.py @@ -19,7 +19,6 @@ from tvm import te from ctypes import * import topi -import tvm.tir.ir_pass as ir_pass import numpy as np tgt = "llvm" @@ -51,10 +50,12 @@ def lower_datatypes_and_build(schedule, args): Once datatype lowering is integrated directly into TVM's lower/build process, we won't need to do this manually. TODO(gus) integrate datatype lowering into build process; change this test""" - flist = tvm.lower(schedule, args) - flist = [flist] - flist = [ir_pass.LowerCustomDatatypes(func, tgt) for func in flist] - return tvm.build(flist[0], target=tgt) + mod = tvm.lower(schedule, args) + target = tvm.target.create(tgt) + mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod) + mod = tvm.tir.transform.LowerCustomDatatypes()(mod) + return tvm.build(mod, target=tgt) + def test_bfloat_add_and_cast_1(): X = te.placeholder((3, ), name="X") diff --git a/tests/python/unittest/test_tir_pass_verify_memory.py b/tests/python/unittest/test_tir_analysis_verify_memory.py similarity index 70% rename from tests/python/unittest/test_tir_pass_verify_memory.py rename to tests/python/unittest/test_tir_analysis_verify_memory.py index 3747caed1586..f993c915aa9c 100644 --- a/tests/python/unittest/test_tir_pass_verify_memory.py +++ b/tests/python/unittest/test_tir_analysis_verify_memory.py @@ -15,12 +15,13 @@ # specific language governing permissions and limitations # under the License. import tvm +import pytest from tvm import te # The following DLDeviceType/TVMDeviceExtType values # are originally defined in dlpack.h and c_runtime_api.h. -gpu_devices = [2, 4, 7, 8, 10, 11] -other_devices = [1, 3, 9, 12] +gpu_devices = ["cuda", "opencl", "metal", "vulkan"] +other_devices = ["llvm", "ext_dev"] def lower(sch, args): @@ -39,8 +40,11 @@ def lower(sch, args): stmt = tvm.te.schedule.ScheduleOps(sch, bounds) stmt = tvm.tir.ir_pass.LoopPartition(stmt, False) stmt = tvm.tir.ir_pass.StorageFlatten(stmt, binds, 64) - func = tvm.tir.ir_pass.MakeAPI(stmt, "myadd", arg_list, 0, True) - return func + + f = tvm.tir.PrimFunc(arg_list, stmt).with_attr( + "global_symbol", tvm.runtime.String("test")) + mod = tvm.IRModule({"test": f}) + return tvm.tir.transform.MakePackedAPI()(mod) # All computations are bound. @@ -57,10 +61,13 @@ def test_verify_memory_all_bind(): s[B].bind(bx, te.thread_axis("blockIdx.x")) s[B].bind(tx, te.thread_axis("threadIdx.x")) - func = lower(s, [A, B]) + mod = lower(s, [A, B]) for dev_type in gpu_devices + other_devices: - assert tvm.tir.ir_pass.VerifyMemory(func, dev_type) + binded_mod = tvm.tir.transform.Apply( + lambda f: f.with_attr("target", tvm.target.create(dev_type)))(mod) + tvm.tir.analysis.verify_memory(binded_mod) + # Computations are not bound. @@ -74,12 +81,18 @@ def test_verify_memory_not_bind(): # B is not bound to threads. s = te.create_schedule(B.op) - func = lower(s, [A, B]) + mod = lower(s, [A, B]) for dev_type in gpu_devices: - assert not tvm.tir.ir_pass.VerifyMemory(func, dev_type) + binded_mod = tvm.tir.transform.Apply( + lambda f: f.with_attr("target", tvm.target.create(dev_type)))(mod) + with pytest.raises(ValueError): + tvm.tir.analysis.verify_memory(binded_mod) + for dev_type in other_devices: - assert tvm.tir.ir_pass.VerifyMemory(func, dev_type) + binded_mod = tvm.tir.transform.Apply( + lambda f: f.with_attr("target", tvm.target.create(dev_type)))(mod) + tvm.tir.analysis.verify_memory(binded_mod) # Computations are partially bound. @@ -98,16 +111,22 @@ def test_verify_memory_partially_bind(): s[C].bind(bx, te.thread_axis("blockIdx.x")) s[C].bind(tx, te.thread_axis("threadIdx.x")) - func = lower(s, [A, B, C, D]) + mod = lower(s, [A, B, C, D]) for dev_type in gpu_devices: - assert not tvm.tir.ir_pass.VerifyMemory(func, dev_type) + binded_mod = tvm.tir.transform.Apply( + lambda f: f.with_attr("target", tvm.target.create(dev_type)))(mod) + with pytest.raises(ValueError): + tvm.tir.analysis.verify_memory(binded_mod) + for dev_type in other_devices: - assert tvm.tir.ir_pass.VerifyMemory(func, dev_type) + binded_mod = tvm.tir.transform.Apply( + lambda f: f.with_attr("target", tvm.target.create(dev_type)))(mod) + tvm.tir.analysis.verify_memory(binded_mod) + if __name__ == "__main__": test_verify_memory_all_bind() test_verify_memory_not_bind() test_verify_memory_partially_bind() - diff --git a/tests/python/unittest/test_tir_pass_bound_checkers.py b/tests/python/unittest/test_tir_pass_bound_checkers.py index b3390972ab00..d6c89b2ab878 100644 --- a/tests/python/unittest/test_tir_pass_bound_checkers.py +++ b/tests/python/unittest/test_tir_pass_bound_checkers.py @@ -118,7 +118,6 @@ def test_in_bounds_vectorize_llvm(): s[B].vectorize(xi) # build and invoke the kernel. lowered_func = tvm.lower (s, [A, C], "llvm", simple_mode=False) - print (lowered_func.body) f = tvm.build(s, [A, C], "llvm") ctx = tvm.cpu(0) # launch the kernel. @@ -137,7 +136,6 @@ def test_in_bounds_loop_partition_basic_llvm(): s = te.create_schedule(T.op) xo, xi = s[T].split(T.op.axis[0], factor=4) lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False) - print (lowered_func.body) ctx = tvm.cpu(0) f = tvm.build(s, [A, B, T], "llvm") @@ -156,7 +154,6 @@ def test_out_of_bounds_loop_partition_basic_llvm(index_a, index_b): s = te.create_schedule(T.op) xo, xi = s[T].split(T.op.axis[0], factor=4) lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False) - print (lowered_func.body) ctx = tvm.cpu(0) f = tvm.build(s, [A, B, T], "llvm") @@ -205,12 +202,11 @@ def collect_branch_stmt (x): # after instrumentation assert_bound_instrumentation(stmt, check_attr_stmt, 2 * 3) assert_bound_instrumentation(stmt, check_branch_stmt, 2) - print (stmt) + branch_collector = list() collect_visit(stmt, collect_branch_stmt) assert(len(branch_collector) == 2) - print (branch_collector[0].condition) - print (branch_collector[1].condition) + def test_in_bounds_const_loop_partition_llvm(): with tvm.target.build_config(instrument_bound_checkers=True, partition_const_loop=True): @@ -222,7 +218,6 @@ def test_in_bounds_const_loop_partition_llvm(): s = te.create_schedule(T.op) xo, xi = s[T].split(T.op.axis[0], factor=4) lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False) - print (lowered_func.body) ctx = tvm.cpu(0) f = tvm.build(s, [A, B, T], "llvm") @@ -242,7 +237,6 @@ def test_out_of_bounds_const_loop_partition_llvm(index_a, index_b): s = te.create_schedule(T.op) xo, xi = s[T].split(T.op.axis[0], factor=4) lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False) - print (lowered_func.body) ctx = tvm.cpu(0) f = tvm.build(s, [A, B, T], "llvm") @@ -276,7 +270,6 @@ def test_in_bounds_conv_llvm(loop_tiling=False): if loop_tiling: oho, owo, ohi, owi = s[conv].tile(oh, ow, 16, 16) lowered_func = tvm.lower(s, [data, kernel, conv], simple_mode=True) - print (lowered_func.body) ctx = tvm.cpu (0) f = tvm.build(s, [data, kernel, conv], "llvm") @@ -320,7 +313,6 @@ def test_out_of_bounds_conv_llvm(data_offsets, kernel_offsets, loop_tiling=False if loop_tiling: oho, owo, ohi, owi = s[conv].tile(oh, ow, 16, 16) lowered_func = tvm.lower(s, [data, kernel, conv], simple_mode=True) - print (lowered_func.body) ctx = tvm.cpu (0) f = tvm.build(s, [data, kernel, conv], "llvm") @@ -341,7 +333,6 @@ def test_in_bounds_tensors_with_same_shapes1D_llvm(): T = te.compute((m, ), lambda i: A[i]*B[i]) s = te.create_schedule(T.op) lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False) - print (lowered_func.body) ctx = tvm.cpu(0) f = tvm.build(s, [A, B, T], "llvm") @@ -361,7 +352,6 @@ def test_out_of_bounds_tensors_with_diff_shapes1D_llvm(a_shape, b_shape, c_shape T = te.compute((m, ), lambda i: A[i]*B[i]) s = te.create_schedule(T.op) lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False) - print (lowered_func.body) ctx = tvm.cpu(0) f = tvm.build(s, [A, B, T], "llvm") @@ -380,7 +370,6 @@ def test_in_bounds_tensors_with_same_shapes2D_llvm(): T = te.compute((m, m), lambda i, j: A[i][j]*B[i][j]) s = te.create_schedule(T.op) lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False) - print (lowered_func.body) ctx = tvm.cpu(0) f = tvm.build(s, [A, B, T], "llvm") @@ -400,7 +389,6 @@ def test_out_of_bounds_tensors_with_diff_shapes2D_llvm(a_shape, b_shape, c_shape T = te.compute((m, m), lambda i, j: A[i][j]*B[i][j]) s = te.create_schedule(T.op) lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False) - print (lowered_func.body) ctx = tvm.cpu(0) f = tvm.build(s, [A, B, T], "llvm") @@ -419,7 +407,7 @@ def test_in_bounds_tensors_with_same_shapes3D_llvm(): T = te.compute((m, m, m), lambda i, j, p: A[i][j][p]*B[i][j][p]) s = te.create_schedule(T.op) lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False) - print (lowered_func.body) + ctx = tvm.cpu(0) f = tvm.build(s, [A, B, T], "llvm") @@ -439,7 +427,7 @@ def test_out_of_bounds_tensors_with_diff_shapes3D_llvm(a_shape, b_shape, c_shape T = te.compute((m, m, m), lambda i, j, p: A[i][j][p]*B[i][j][p]) s = te.create_schedule(T.op) lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False) - print (lowered_func.body) + ctx = tvm.cpu(0) f = tvm.build(s, [A, B, T], "llvm") @@ -460,7 +448,7 @@ def test_out_of_bounds_tensors_with_zero_shape_op_with_not_zero_shape_llvm(): D = te.compute((), lambda : C + 1) s = te.create_schedule(D.op) stmt = tvm.lower (s, [A, scale, D], simple_mode=True) - print (stmt) + # build and invoke the kernel. f = tvm.build(s, [A, scale, D], "llvm") ctx = tvm.cpu(0) diff --git a/tests/python/unittest/test_tir_pass_inject_double_buffer.py b/tests/python/unittest/test_tir_pass_inject_double_buffer.py index 94e29c68d930..95a10547463c 100644 --- a/tests/python/unittest/test_tir_pass_inject_double_buffer.py +++ b/tests/python/unittest/test_tir_pass_inject_double_buffer.py @@ -40,8 +40,7 @@ def test_double_buffer(): stmt = tvm.tir.ir_pass.Simplify(stmt) assert isinstance(stmt.body.body, tvm.tir.Allocate) assert stmt.body.body.extents[0].value == 2 - f = tvm.tir.ir_pass.MakeAPI(stmt, "db", [A.asobject(), C.asobject()], 2, True) - mod = tvm.testing.LoweredFuncsToIRModule([f]) + mod = tvm.testing.MakeAPILegacy(stmt, "db", [A.asobject(), C.asobject()], 2, True) f = tvm.tir.transform.ThreadSync("shared")(mod)["db"] count = [0] diff --git a/tests/python/unittest/test_tir_pass_loop_partition.py b/tests/python/unittest/test_tir_pass_loop_partition.py index 7e383ddf7810..0818c0ed0fe2 100644 --- a/tests/python/unittest/test_tir_pass_loop_partition.py +++ b/tests/python/unittest/test_tir_pass_loop_partition.py @@ -381,7 +381,7 @@ def test_multilevel_splitting_with_indivisble_factors(): ## But this does the right thing. with tvm.target.build_config(partition_const_loop=True): - lowered_body = tvm.lower(s, [A, B]).body + lowered_body = tvm.lower(s, [A, B], name="x")["x"].body def visit_stmt(op): return(isinstance(op, tvm.tir.Max)) num_max = collect_visit(lowered_body, visit_stmt) @@ -407,7 +407,7 @@ def test_double_splitting_with_indivisible_factors(): # Find the beginning of the Halide IR corresponding to kernel code # and make sure it doesn't have an if statements left - top_produce = find_top_produce(f.body) + top_produce = find_top_produce(f["fadd1"].body) assert(not any(collect_visit(top_produce, lambda x: isinstance(x, tvm.tir.IfThenElse)))) # check functional correctness of generated code diff --git a/tests/python/unittest/test_tir_pass_storage_flatten.py b/tests/python/unittest/test_tir_pass_storage_flatten.py index dbfcd20f0843..da9253f1dfca 100644 --- a/tests/python/unittest/test_tir_pass_storage_flatten.py +++ b/tests/python/unittest/test_tir_pass_storage_flatten.py @@ -92,9 +92,7 @@ def test_flatten_double_buffer(): stmt = tvm.tir.ir_pass.Simplify(stmt) assert isinstance(stmt.body.body, tvm.tir.Allocate) assert stmt.body.body.extents[0].value == 2 - f = tvm.tir.ir_pass.MakeAPI(stmt, "db", [A.asobject(), C.asobject()], 2, True) - f = tvm.tir.ir_pass.MakeAPI(stmt, "db", [A.asobject(), C.asobject()], 2, True) - mod = tvm.testing.LoweredFuncsToIRModule([f]) + mod = tvm.testing.MakeAPILegacy(stmt, "db", [A.asobject(), C.asobject()], 2, True) f = tvm.tir.transform.ThreadSync("shared")(mod)["db"] count = [0] diff --git a/tests/python/unittest/test_tir_transform_combine_context_call.py b/tests/python/unittest/test_tir_transform_combine_context_call.py index 8140ddb97cba..6f2bc65450be 100644 --- a/tests/python/unittest/test_tir_transform_combine_context_call.py +++ b/tests/python/unittest/test_tir_transform_combine_context_call.py @@ -36,12 +36,7 @@ def device_context(dev_id): ib.emit(tvm.tir.call_extern ("int32", "fadd", device_context(0), A)) body = ib.get() - f = tvm.tir.ir_pass.MakeAPI(body, "func", [dev_type, n], 2, True) - - # temp adapter to convert loweredFunc to IRModule - # to test passes in the new style.x - mod = tvm.testing.LoweredFuncsToIRModule([f]) - + mod = tvm.testing.MakeAPILegacy(body, "func", [dev_type, n], 2, True) mod = tvm.tir.transform.CombineContextCall()(mod) assert mod["func"].body.value.dtype == "handle" diff --git a/tests/python/unittest/test_tir_transform_lower_warp_memory.py b/tests/python/unittest/test_tir_transform_lower_warp_memory.py index 167899a46838..cf6ef721fcc5 100644 --- a/tests/python/unittest/test_tir_transform_lower_warp_memory.py +++ b/tests/python/unittest/test_tir_transform_lower_warp_memory.py @@ -35,10 +35,8 @@ def test_lower_warp_mem(): cuda_target = tvm.target.create("cuda") assert cuda_target.thread_warp_size == 32 - f = tvm.lower(s, [A, B], name="f") + mod = tvm.lower(s, [A, B], name="f") - - mod = tvm.testing.LoweredFuncsToIRModule([f]) mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", cuda_target))(mod) fdevice = tvm.tir.transform.SplitHostDevice()(mod)["f_kernel0"] mod = tvm.IRModule.from_expr(fdevice) diff --git a/tests/python/unittest/test_tir_pass_makeapi.py b/tests/python/unittest/test_tir_transform_make_packed_api.py similarity index 84% rename from tests/python/unittest/test_tir_pass_makeapi.py rename to tests/python/unittest/test_tir_transform_make_packed_api.py index 6b28ef6cee18..b44d3c4ba7f7 100644 --- a/tests/python/unittest/test_tir_pass_makeapi.py +++ b/tests/python/unittest/test_tir_transform_make_packed_api.py @@ -35,11 +35,11 @@ def test_makeapi(): stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb}, 64) num_unpacked_args = 2 - f = tvm.tir.ir_pass.MakeAPI( - stmt, "myadd", [n, Ab, Bb, Cb], num_unpacked_args, True) - assert(f.handle_data_type[Ab.data].dtype == Ab.dtype) - assert(len(f.args) == 7) - output_ssa = False + f = tvm.tir.PrimFunc([n, Ab, Bb, Cb], stmt).with_attr( + "tir.no_alias", True).with_attr("global_symbol", tvm.runtime.String("myadd")) + mod = tvm.IRModule.from_expr(f) + f = tvm.tir.transform.MakePackedAPI(num_unpacked_args)(mod)["main"] + assert(len(f.params) == 7) if __name__ == "__main__": diff --git a/tests/python/unittest/test_tir_transform_thread_sync.py b/tests/python/unittest/test_tir_transform_thread_sync.py index 6c9e7f9b76b7..64b454fe7645 100644 --- a/tests/python/unittest/test_tir_transform_thread_sync.py +++ b/tests/python/unittest/test_tir_transform_thread_sync.py @@ -37,10 +37,9 @@ def test_thread_storage_sync(): Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A') A2b = tvm.tir.decl_buffer(A2.shape, A2.dtype, name='A2') stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}, 64) - f = tvm.tir.ir_pass.MakeAPI(stmt, "test", [Ab, A2b], 0, True) - cuda_target = tvm.target.create("cuda") - mod = tvm.testing.LoweredFuncsToIRModule([f]) + cuda_target = tvm.target.create("cuda") + mod = tvm.testing.MakeAPILegacy(stmt, "test", [Ab, A2b], 0, True) mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", cuda_target))(mod) fdevice = tvm.tir.transform.SplitHostDevice()(mod)["test_kernel0"] mod = tvm.IRModule.from_expr(fdevice) diff --git a/tutorials/dev/low_level_custom_pass.py b/tutorials/dev/low_level_custom_pass.py index 298b24f6d046..25ca279bf339 100644 --- a/tutorials/dev/low_level_custom_pass.py +++ b/tutorials/dev/low_level_custom_pass.py @@ -36,7 +36,7 @@ - Visitor design pattern. Otherwise, check the `Python AST module `_ to see how an AST visitor is implemented. -- How a HalideIR/Schedule is lowered to either a LoweredFunc class or a LLVM module. Otherwise, +- How a Schedule is lowered to either an IRModule class or a LLVM module. Otherwise, take a look at ``python/tvm/build_module.py`` to get some basics. """