Skip to content

Commit

Permalink
[REFACTOR][TIR] Migrate all low-level passes to the Pass Manager. (ap…
Browse files Browse the repository at this point in the history
…ache#5233)

* [REFACTOR][TIR] Migrate all low-level passes to the Pass Manager.

This PR migrates the tvm.lower to return IRModule of PrimFuncs
instead of the LoweredFuncs.

* Remove LoweredFunc.
  • Loading branch information
tqchen authored and zhiics committed Apr 17, 2020
1 parent d391753 commit 92ed53e
Show file tree
Hide file tree
Showing 63 changed files with 608 additions and 766 deletions.
1 change: 0 additions & 1 deletion apps/lldb/tvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def __lldb_init_module(debugger, _):
"tvm::IterVarAttr",
"tvm::IterVarRelation",
"tvm::Layout",
"tir::LoweredFunc",
"tvm::Map",
"tvm::Map",
"tvm::MemoryInfo",
Expand Down
9 changes: 0 additions & 9 deletions docs/dev/codebase_walkthrough.rst
Original file line number Diff line number Diff line change
Expand Up @@ -145,15 +145,6 @@ After lowering is done, ``build()`` function generates target machine code from

Code generation is done by ``build_module()`` function, defined in ``python/tvm/target/codegen.py``. On the C++ side, code generation is implemented in ``src/target/codegen`` subdirectory. ``build_module()`` Python function will reach ``Build()`` function below in ``src/target/codegen/codegen.cc``:

::

runtime::Module Build(const Array<LoweredFunc>& funcs,
const std::string& target) {
std::string build_f_name = "codegen.build_" + target;
const PackedFunc* bf = runtime::Registry::Get(build_f_name);
runtime::Module m = (*bf)(funcs, target);
return m;
}


The ``Build()`` function looks up the code generator for the given target in the ``PackedFunc`` registry, and invokes the function found. For example, ``codegen.build_cuda`` function is registered in ``src/codegen/build_cuda_on.cc``, like this:
Expand Down
25 changes: 12 additions & 13 deletions include/tvm/driver/driver_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
#include <tvm/runtime/packed_func.h>
#include <tvm/target/target.h>
#include <tvm/support/with.h>
#include <tvm/ir/module.h>
#include <tvm/te/schedule_pass.h>
#include <tvm/tir/lowered_func.h>

#include <string>
#include <vector>
Expand All @@ -43,60 +43,59 @@

namespace tvm {
/*!
* \brief Build a LoweredFunc given a schedule, args and binds
* \brief Build an IRModule given a schedule, args and binds
* \param sch The schedule to lower.
* \param args The arguments to the function.
* \param name The name of the lowered function.
* \param binds Buffer assignments.
* \param config The build configuration.
* \return The lowered function.
* \return The result module.
*/
TVM_DLL Array<tir::LoweredFunc> lower(
TVM_DLL IRModule lower(
te::Schedule sch,
const Array<te::Tensor>& args,
const std::string& name,
const std::unordered_map<te::Tensor, tir::Buffer>& binds,
const BuildConfig& config);

/*!
* \brief Build a device and host module for a specific target from an array of lowered functions.
* \brief Build a device and host module for a specific target from an IRModule.
* \param funcs The functions to be built.
* \param target The target device to build for.
* \param target_host The target for building host code. To use the default, pass Target()
* \param config The build configuration.
* \return The built module.
*/
TVM_DLL runtime::Module build(const Array<tir::LoweredFunc>& funcs,
TVM_DLL runtime::Module build(const IRModule& funcs,
const Target& target,
const Target& target_host,
const BuildConfig& config);

/*!
* \brief Build a device and host module for a specific target from a map
* contains target to a list of lowered functions pairs. This function is used
* contains target to IRModule. This function is used
* for heterogeneous build.
* \param input The map contains target to a list of lowered functions pairs.
* \param input The map contains target to an IRModule.
* \param target_host The target for building host code. To use the default,
* pass Target().
* \param config The build configuration.
* \return The built module that contains code for different processors.
*/
TVM_DLL runtime::Module build(const Map<Target, Array<tir::LoweredFunc>>& input,
TVM_DLL runtime::Module build(const Map<Target, IRModule>& input,
const Target& target_host,
const BuildConfig& config);

/*!
* \brief Build a device and host module for a specific target from a map
* contains target to a list of lowered functions pairs. This function is used
* contains target to IRModule. This function is used
* for heterogeneous build.
* \param input The map contains target string to a list of lowered functions
* pairs.
* \param input The map contains target string to an IRModule.
* \param target_host The target for building host code. To use the default,
* pass Target().
* \param config The build configuration.
* \return The built module that contains code for different processors.
*/
TVM_DLL runtime::Module build(const Map<std::string, Array<tir::LoweredFunc>>& input,
TVM_DLL runtime::Module build(const Map<std::string, IRModule>& input,
const Target& target_host,
const BuildConfig& config);
} // namespace tvm
Expand Down
9 changes: 9 additions & 0 deletions include/tvm/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,15 @@ class IRModule : public ObjectRef {
CHECK(ptr != nullptr);
return static_cast<IRModuleNode*>(ptr);
}

/*!
* \brief Construct an empty module.
*
* \returns The constructed module
*/
static IRModule Empty() {
return IRModule(Map<GlobalVar, BaseFunc>());
}
/*!
* \brief Construct a module from a standalone expression.
*
Expand Down
12 changes: 0 additions & 12 deletions include/tvm/target/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
#include <tvm/runtime/packed_func.h>
#include <tvm/ir/module.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/lowered_func.h>
#include <tvm/target/target.h>

#include <string>
Expand All @@ -41,17 +40,6 @@ using runtime::PackedFunc;
using runtime::TVMArgs;
using runtime::TVMRetValue;

/*!
* \brief Temporary backward compatible function to convert a list
* of LoweredFunc to a IRModule of PrimfFuncs
* \param funcs The input lowered function.
* \return The IRModule.
*
* \note This function is only used for code refactor and will be
* removed once the refactor completes.
*/
IRModule ToIRModule(const Array<tir::LoweredFunc>& funcs);

/*!
* \brief Build a module from array of lowered function.
* \param mod The Module to be built
Expand Down
15 changes: 15 additions & 0 deletions include/tvm/tir/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,12 @@
#ifndef TVM_TIR_ANALYSIS_H_
#define TVM_TIR_ANALYSIS_H_

#include <tvm/ir/module.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/function.h>
#include <tvm/tir/stmt.h>


namespace tvm {
namespace tir {

Expand Down Expand Up @@ -59,6 +62,18 @@ struct ExprDeepEqual {
*/
Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& defs);

/*!
* \brief Verify if memory accesses are legal for a specific target device type.
*
* In the case that tgt is cuda, if not all workload is bound with
* threads, CPU code is generated that tries to access GPU memory,
* which is illegal. This pass performs verification for this case.
*
* \param mod The module to be verified.
* \return Success of memory verification.
*/
void VerifyMemory(const IRModule& mod);

} // namespace tir
} // namespace tvm
#endif // TVM_TIR_ANALYSIS_H_
80 changes: 0 additions & 80 deletions include/tvm/tir/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
#include <tvm/tir/expr.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/function.h>
#include <tvm/tir/lowered_func.h>

#include <unordered_map>
#include <unordered_set>
Expand Down Expand Up @@ -366,60 +365,6 @@ Stmt HoistIfThenElse(Stmt stmt);
*/
Stmt NarrowDataType(Stmt stmt, int target_bits);

/*!
* \brief Make an user callable API LoweredFunc.
*
* The main task of this function is to create code to :
* - Map the values in the api_args to Var that is required by body.
* - Insert assertions to check type/value of the passed arguments.
*
* \param body The body of the function.
* \param name The name of the function.
* \param api_args Arguments to the function, can be either Var, or Buffer
* \param num_unpacked_args Number of arguments that
* are processed in plain form instead of packed form.
* \param is_restricted Whether the caller can guarantee that each buffer argument do not overlap.
* It is recommended to set to true for optimized code if such invariant holds.
*
* \return a LoweredFunc with the specified signiture.
*
* \note
* The function signature have two cases
*
* let num_packed_args = len(api_args) - num_unpacked_args;
*
* if num_packed_args is zero:
* f(api_arg_0, api_arg_1, .., api_arg_n) where n == len(api_args)
*
* if num_packed_args is not zero:
* f(TVMArg* packed_args, int* packed_arg_type_ids, int num_packed_args,
* api_arg_k, api_arg_k+1, ... api_arg_n,
* TVMValue* out_ret_val, int* out_ret_tcode)
*
* where n == len(api_args), k == num_packed_args
*
* There is no thread_axis in generated function.
*/
LoweredFunc MakeAPI(Stmt body,
std::string name,
Array<ObjectRef> api_args,
int num_unpacked_args,
bool is_restricted);

/*!
* \brief Remap the thread axis
*
* This can be used to get equivalent program which uses
* threadIdx.y in place of threadIdx.x by passing
* {"threadIdx.x": thread_axis("threadIdx.y")}
*
*
* \param f The device function to be lowered.
* \param axis_map The map from StringImm -> ItrVar
* \return Transformed function.
*/
LoweredFunc RemapThreadAxis(LoweredFunc f, Map<PrimExpr, IterVar> axis_map);

/*!
* \brief Rewrite the pointer content type of arguments,
* as well as Alloc internal to the function to use
Expand All @@ -432,31 +377,6 @@ LoweredFunc RemapThreadAxis(LoweredFunc f, Map<PrimExpr, IterVar> axis_map);
*/
PrimFunc PointerValueTypeRewrite(PrimFunc f);

/*!
* \brief Lower custom datatypes.
*
* See tvm::datatypes::Registry for more information on adding custom datatypes.
*
* \param f The device function to be lowered.
* \param target The target device.
* \return Transformed function.
*/
LoweredFunc LowerCustomDatatypes(LoweredFunc f, const std::string& target);

/*!
* \brief Verify if memory accesses are legal for a specific target device type.
*
* In the case that tgt is cuda, if not all workload is bound with
* threads, CPU code is generated that tries to access GPU memory,
* which is illegal. This pass performs verification for this case.
*
* \param func The function to be verified.
* \param device_type The target device type.
* \return Success of memory verification.
*/
bool VerifyMemory(LoweredFunc func, int device_type);


/*!
* \brief Verify the correctness of a GPU code
* It will check the whether the amount of memory usage or the number of threads
Expand Down
Loading

0 comments on commit 92ed53e

Please sign in to comment.