Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REFACTOR][TIR] Migrate LowerTVMBuiltin, InferFragment, LowerThreadAllreduce, ThreadSync to Pass Manager #5213

Merged
merged 1 commit into from
Apr 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 0 additions & 14 deletions include/tvm/driver/driver_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,20 +57,6 @@ TVM_DLL Array<tir::LoweredFunc> lower(
const std::string& name,
const std::unordered_map<te::Tensor, tir::Buffer>& binds,
const BuildConfig& config);
/*!
* \brief Split host/device function and running necessary pass before build
* \param funcs The functions to be built.
* \param target The target device to build for.
* \param target_host The target for building host code. To use the default, pass Target()
* \param config The build configuration.
* \return The Array<Array<LoweredFunc>> with 2 elements. First is host function Array,
second is device function array
*/
TVM_DLL Array<Array<tir::LoweredFunc> > split_dev_host_funcs(
const Array<tir::LoweredFunc>& funcs,
const Target& target,
const Target& target_host,
const BuildConfig& config);

/*!
* \brief Build a device and host module for a specific target from an array of lowered functions.
Expand Down
22 changes: 16 additions & 6 deletions include/tvm/target/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#define TVM_TARGET_CODEGEN_H_

#include <tvm/runtime/packed_func.h>
#include <tvm/ir/module.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/lowered_func.h>
#include <tvm/target/target.h>
Expand All @@ -40,16 +41,25 @@ using runtime::PackedFunc;
using runtime::TVMArgs;
using runtime::TVMRetValue;

/*!
* \brief Temporary backward compatible function to convert a list
* of LoweredFunc to a IRModule of PrimfFuncs
* \param funcs The input lowered function.
* \return The IRModule.
*
* \note This function is only used for code refactor and will be
* removed once the refactor completes.
*/
IRModule ToIRModule(const Array<tir::LoweredFunc>& funcs);

/*!
* \brief Build a module from array of lowered function.
* \param funcs The functions to be built.
* \param mod The Module to be built
* \param target The target to be built.
* \return The builded module.
*
* \note Calls global API function "_codegen_build_" + target
* \return The result runtime::Module.
*/
runtime::Module Build(const Array<tir::LoweredFunc>& funcs,
const std::string& target);
runtime::Module Build(IRModule mod, const Target& target);

/*!
* \brief Pack imported device library to a C file.
* Compile the C file and link with the host library
Expand Down
31 changes: 0 additions & 31 deletions include/tvm/tir/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -477,12 +477,6 @@ LoweredFunc RemapThreadAxis(LoweredFunc f, Map<PrimExpr, IterVar> axis_map);
*/
LoweredFunc LowerTVMBuiltin(LoweredFunc f);

/*!
* \brief Combine context function calls.
* \param f The host function to be lowered.
* \return Transformed function.
*/
LoweredFunc CombineContextCall(LoweredFunc f);

/*!
* \brief Rewrite the pointer content type of arguments,
Expand All @@ -496,7 +490,6 @@ LoweredFunc CombineContextCall(LoweredFunc f);
*/
LoweredFunc PointerValueTypeRewrite(LoweredFunc f);


/*!
* \brief Rewrite the pointer content type of arguments,
* as well as Alloc internal to the function to use
Expand All @@ -509,23 +502,6 @@ LoweredFunc PointerValueTypeRewrite(LoweredFunc f);
*/
PrimFunc PointerValueTypeRewrite(PrimFunc f);

/*!
* \brief Lower attached storage access information on device.
* Do this pass after all storage access analysis finish.
*
* \param func The device function to be lowered.
* \return Transformed function.
*/
LoweredFunc LowerDeviceStorageAccessInfo(LoweredFunc func);

/*!
* \brief Lower intrinsic function calls.
* \param f The device function to be lowered.
* \param target The target device.
* \return Transformed function.
*/
LoweredFunc LowerIntrin(LoweredFunc f, const std::string& target);

/*!
* \brief Lower custom datatypes.
*
Expand All @@ -545,13 +521,6 @@ LoweredFunc LowerCustomDatatypes(LoweredFunc f, const std::string& target);
*/
LoweredFunc InferFragment(LoweredFunc f);

/*!
* \brief skip assert stmt generation
* \param f The function to be transformed.
* \return Transformed function.
*/
LoweredFunc SkipAssert(LoweredFunc f);

/*!
* \brief Verify if memory accesses are legal for a specific target device type.
*
Expand Down
44 changes: 40 additions & 4 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,40 @@ TVM_DLL Pass CreatePrimFuncPass(const runtime::TypedPackedFunc<
const tvm::Array<tvm::PrimExpr>& required);

/*!
* \brief Combine context calls in the host function.
* \brief skip assert stmt.
*
* \return The pass.
*/
TVM_DLL Pass CombineContextCall();
TVM_DLL Pass SkipAssert();

/*!
* \brief Insert sync between parallel read/write of shared buffers.
*
* \param storage_scope The storage scope considered.
* \return The pass.
*/
TVM_DLL Pass ThreadSync(std::string storage_scope);


/*!
* \brief Lower cross thread alleduce.
*
* \return The pass.
*/
TVM_DLL Pass LowerThreadAllreduce();

/*!
* \brief Infer the TensorCore fragment infomation using tensor intrinsics
*
* \return The pass.
*/
TVM_DLL Pass InferFragment();

/*!
* \brief Lower builtin intrinsics.
* \return The pass.
*/
TVM_DLL Pass LowerTVMBuiltin();

/*!
* \brief Lower the target specific function intrinsics in each of the function.
Expand All @@ -72,6 +101,12 @@ TVM_DLL Pass CombineContextCall();
*/
TVM_DLL Pass LowerIntrin();

/*!
* \brief Lower warp memory access to low-level device related function calls.
* \return The pass.
*/
TVM_DLL Pass LowerWarpMemory();

/*!
* \brief Lower attached storage access information on device.
*
Expand All @@ -82,10 +117,11 @@ TVM_DLL Pass LowerIntrin();
TVM_DLL Pass LowerDeviceStorageAccessInfo();

/*!
* \brief Lower warp memory access to low-level device related function calls.
* \brief Combine context calls in the host function.
*
* \return The pass.
*/
TVM_DLL Pass LowerWarpMemory();
TVM_DLL Pass CombineContextCall();


/*!
Expand Down
56 changes: 37 additions & 19 deletions python/tvm/driver/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,15 @@ def _build_for_device(flist, target, target_host):
mdev : tvm.module
A module that contains device code.
"""
@tvm.tir.transform.prim_func_pass(opt_level=0)
class BindTarget:
def __init__(self, target):
self.target = target

# pylint: disable=unused-argument
def transform_function(self, func, mod, ctx):
return func.with_attr("target", self.target)

target = _target.create(target)
device_type = ndarray.context(target.target_name, 0).device_type
fhost = []
Expand Down Expand Up @@ -250,30 +259,39 @@ def _build_for_device(flist, target, target_host):
else:
raise ValueError("unknown function type %d" % func.func_type)

for i, func in enumerate(fdevice):
warp_size = target.thread_warp_size
fdevice[i] = ir_pass.LowerWarpMemory(func, warp_size)

if "gpu" in target.keys and not fdevice:
warnings.warn(
"Specified target %s, but cannot find device code, did you do "
"bind?" % target)

fhost = [ir_pass.BindDeviceType(x, device_type) for x in fhost]
zhiics marked this conversation as resolved.
Show resolved Hide resolved
fhost = [ir_pass.LowerTVMBuiltin(x) for x in fhost]

if device_type == ndarray.cpu(0).device_type and target_host == target:
assert not fdevice

target_host = _target.create(target_host)
fdevice = [ir_pass.LowerDeviceStorageAccessInfo(x) for x in fdevice]
fhost = [ir_pass.LowerDeviceStorageAccessInfo(x) for x in fhost]
fdevice = [ir_pass.LowerIntrin(x, target.target_name) for x in fdevice]
fhost = [ir_pass.LowerIntrin(x, target_host.target_name) for x in fhost]
fhost = [ir_pass.CombineContextCall(x) for x in fhost]
mdev = codegen.build_module(fdevice, str(target)) if fdevice else None

return fhost, mdev
# device optimizations
mod_dev = tvm.testing.LoweredFuncsToIRModule(fdevice)
opt_device = tvm.ir.transform.Sequential(
[BindTarget(target),
tvm.tir.transform.LowerWarpMemory(),
tvm.tir.transform.LowerDeviceStorageAccessInfo(),
tvm.tir.transform.LowerIntrin()])
mod_dev = opt_device(mod_dev)

# host optimizations
mod_host = tvm.testing.LoweredFuncsToIRModule(fhost)
opt_host = tvm.ir.transform.Sequential(
[BindTarget(target_host),
tvm.tir.transform.LowerTVMBuiltin(),
tvm.tir.transform.LowerDeviceStorageAccessInfo(),
tvm.tir.transform.LowerIntrin(),
tvm.tir.transform.CombineContextCall()])
mod_host = opt_host(mod_host)

rt_mod_dev = codegen.build_module(mod_dev, target) if fdevice else None
return mod_host, rt_mod_dev


def build(inputs,
Expand Down Expand Up @@ -402,19 +420,19 @@ def build(inputs,
if not target_host:
target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm"

fhost_all = []
mod_host_all = tvm.IRModule({})

device_modules = []
for tar, flist in target_flist.items():
fhost, mdev = _build_for_device(flist, tar, target_host)
# Save the current lowered functions of the host and the device module.
fhost_all += fhost
mod_host, mdev = _build_for_device(flist, tar, target_host)
mod_host_all.update(mod_host)
device_modules.append(mdev)

# Generate a unified host module.
mhost = codegen.build_module(fhost_all, str(target_host))
rt_mod_host = codegen.build_module(mod_host_all, target_host)

# Import all modules.
for mdev in device_modules:
if mdev:
mhost.import_module(mdev)
return mhost
rt_mod_host.import_module(mdev)
return rt_mod_host
12 changes: 7 additions & 5 deletions python/tvm/target/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,16 @@
# under the License.
"""Code generation related functions."""
from . import _ffi_api
from . import target as _tgt


def build_module(lowered_func, target):
"""Build lowered_func into Module.
def build_module(mod, target):
"""Build IRModule into Module.

Parameters
----------
lowered_func : LoweredFunc
The lowered function
mod : tvm.IRModule
The ir module.

target : str
The target module type.
Expand All @@ -35,7 +36,8 @@ def build_module(lowered_func, target):
module : runtime.Module
The corressponding module.
"""
return _ffi_api.Build(lowered_func, target)
target = _tgt.create(target) if isinstance(target, str) else target
return _ffi_api.Build(mod, target)


def llvm_lookup_intrinsic_id(name):
Expand Down
Loading