From dcc47686e36038023fc1de5b294c76b15f3b4447 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Thu, 2 Apr 2020 13:22:23 -0700 Subject: [PATCH] [REFACTOR][TIR] Migrate low-level pass functions to Pass Manager, (#5213) - Migrate LowerTVMBultin - Migrate inferFragment, LowerThreadAllreduce - Migrate ThreadSync - Refactor target::Build to directly take IRModule. - Remove un-used legacy functions. --- include/tvm/driver/driver_api.h | 14 --- include/tvm/target/codegen.h | 22 +++- include/tvm/tir/ir_pass.h | 31 ------ include/tvm/tir/transform.h | 44 +++++++- python/tvm/driver/build_module.py | 56 ++++++---- python/tvm/target/codegen.py | 12 +- python/tvm/tir/transform/transform.py | 73 +++++++++++-- src/driver/driver_api.cc | 103 +++++++++--------- src/target/codegen.cc | 45 ++------ src/tir/pass/ffi_api.cc | 5 - src/tir/transforms/combine_context_call.cc | 7 +- .../lower_device_storage_access_info.cc | 5 - src/tir/transforms/lower_intrin.cc | 10 -- .../lower_thread_allreduce.cc | 30 ++++- .../{pass => transforms}/lower_tvm_builtin.cc | 27 +++-- src/tir/transforms/lower_warp_memory.cc | 10 +- src/tir/{pass => transforms}/skip_assert.cc | 20 +++- .../tensorcore_infer_fragment.cc} | 23 +++- .../thread_storage_sync.cc} | 26 ++++- tests/python/integration/test_dot.py | 3 +- .../python/unittest/test_runtime_extension.py | 5 +- .../unittest/test_runtime_module_load.py | 3 +- .../unittest/test_target_codegen_c_host.py | 5 +- .../unittest/test_target_codegen_device.py | 60 +--------- .../test_target_codegen_static_init.py | 6 +- .../unittest/test_target_codegen_vm_basic.py | 7 +- ...e_sync.py => test_tir_pass_coproc_sync.py} | 29 ----- .../test_tir_transform_thread_sync.py | 55 ++++++++++ 28 files changed, 407 insertions(+), 329 deletions(-) rename src/tir/{pass => transforms}/lower_thread_allreduce.cc (92%) rename src/tir/{pass => transforms}/lower_tvm_builtin.cc (95%) rename src/tir/{pass => transforms}/skip_assert.cc (74%) rename src/tir/{pass/infer_fragment.cc => transforms/tensorcore_infer_fragment.cc} (93%) rename src/tir/{pass/storage_sync.cc => transforms/thread_storage_sync.cc} (95%) rename tests/python/unittest/{test_tir_pass_storage_sync.py => test_tir_pass_coproc_sync.py} (78%) create mode 100644 tests/python/unittest/test_tir_transform_thread_sync.py diff --git a/include/tvm/driver/driver_api.h b/include/tvm/driver/driver_api.h index 58f7d6d641a4..64d51736b445 100644 --- a/include/tvm/driver/driver_api.h +++ b/include/tvm/driver/driver_api.h @@ -57,20 +57,6 @@ TVM_DLL Array lower( const std::string& name, const std::unordered_map& binds, const BuildConfig& config); -/*! -* \brief Split host/device function and running necessary pass before build -* \param funcs The functions to be built. -* \param target The target device to build for. -* \param target_host The target for building host code. To use the default, pass Target() -* \param config The build configuration. -* \return The Array> with 2 elements. First is host function Array, - second is device function array -*/ -TVM_DLL Array > split_dev_host_funcs( - const Array& funcs, - const Target& target, - const Target& target_host, - const BuildConfig& config); /*! * \brief Build a device and host module for a specific target from an array of lowered functions. diff --git a/include/tvm/target/codegen.h b/include/tvm/target/codegen.h index 3851c46433b7..c604eb5c93de 100644 --- a/include/tvm/target/codegen.h +++ b/include/tvm/target/codegen.h @@ -25,6 +25,7 @@ #define TVM_TARGET_CODEGEN_H_ #include +#include #include #include #include @@ -40,16 +41,25 @@ using runtime::PackedFunc; using runtime::TVMArgs; using runtime::TVMRetValue; +/*! + * \brief Temporary backward compatible function to convert a list + * of LoweredFunc to a IRModule of PrimfFuncs + * \param funcs The input lowered function. + * \return The IRModule. + * + * \note This function is only used for code refactor and will be + * removed once the refactor completes. + */ +IRModule ToIRModule(const Array& funcs); + /*! * \brief Build a module from array of lowered function. - * \param funcs The functions to be built. + * \param mod The Module to be built * \param target The target to be built. - * \return The builded module. - * - * \note Calls global API function "_codegen_build_" + target + * \return The result runtime::Module. */ -runtime::Module Build(const Array& funcs, - const std::string& target); +runtime::Module Build(IRModule mod, const Target& target); + /*! * \brief Pack imported device library to a C file. * Compile the C file and link with the host library diff --git a/include/tvm/tir/ir_pass.h b/include/tvm/tir/ir_pass.h index f056d9f97845..6a1a1788c312 100644 --- a/include/tvm/tir/ir_pass.h +++ b/include/tvm/tir/ir_pass.h @@ -477,12 +477,6 @@ LoweredFunc RemapThreadAxis(LoweredFunc f, Map axis_map); */ LoweredFunc LowerTVMBuiltin(LoweredFunc f); -/*! - * \brief Combine context function calls. - * \param f The host function to be lowered. - * \return Transformed function. - */ -LoweredFunc CombineContextCall(LoweredFunc f); /*! * \brief Rewrite the pointer content type of arguments, @@ -496,7 +490,6 @@ LoweredFunc CombineContextCall(LoweredFunc f); */ LoweredFunc PointerValueTypeRewrite(LoweredFunc f); - /*! * \brief Rewrite the pointer content type of arguments, * as well as Alloc internal to the function to use @@ -509,23 +502,6 @@ LoweredFunc PointerValueTypeRewrite(LoweredFunc f); */ PrimFunc PointerValueTypeRewrite(PrimFunc f); -/*! - * \brief Lower attached storage access information on device. - * Do this pass after all storage access analysis finish. - * - * \param func The device function to be lowered. - * \return Transformed function. - */ -LoweredFunc LowerDeviceStorageAccessInfo(LoweredFunc func); - -/*! - * \brief Lower intrinsic function calls. - * \param f The device function to be lowered. - * \param target The target device. - * \return Transformed function. - */ -LoweredFunc LowerIntrin(LoweredFunc f, const std::string& target); - /*! * \brief Lower custom datatypes. * @@ -545,13 +521,6 @@ LoweredFunc LowerCustomDatatypes(LoweredFunc f, const std::string& target); */ LoweredFunc InferFragment(LoweredFunc f); -/*! - * \brief skip assert stmt generation - * \param f The function to be transformed. - * \return Transformed function. - */ -LoweredFunc SkipAssert(LoweredFunc f); - /*! * \brief Verify if memory accesses are legal for a specific target device type. * diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index a414bccfafa1..d809e07ad6db 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -59,11 +59,40 @@ TVM_DLL Pass CreatePrimFuncPass(const runtime::TypedPackedFunc< const tvm::Array& required); /*! - * \brief Combine context calls in the host function. + * \brief skip assert stmt. * * \return The pass. */ -TVM_DLL Pass CombineContextCall(); +TVM_DLL Pass SkipAssert(); + +/*! + * \brief Insert sync between parallel read/write of shared buffers. + * + * \param storage_scope The storage scope considered. + * \return The pass. + */ +TVM_DLL Pass ThreadSync(std::string storage_scope); + + +/*! + * \brief Lower cross thread alleduce. + * + * \return The pass. + */ +TVM_DLL Pass LowerThreadAllreduce(); + +/*! + * \brief Infer the TensorCore fragment infomation using tensor intrinsics + * + * \return The pass. + */ +TVM_DLL Pass InferFragment(); + +/*! + * \brief Lower builtin intrinsics. + * \return The pass. + */ +TVM_DLL Pass LowerTVMBuiltin(); /*! * \brief Lower the target specific function intrinsics in each of the function. @@ -72,6 +101,12 @@ TVM_DLL Pass CombineContextCall(); */ TVM_DLL Pass LowerIntrin(); +/*! + * \brief Lower warp memory access to low-level device related function calls. + * \return The pass. + */ +TVM_DLL Pass LowerWarpMemory(); + /*! * \brief Lower attached storage access information on device. * @@ -82,10 +117,11 @@ TVM_DLL Pass LowerIntrin(); TVM_DLL Pass LowerDeviceStorageAccessInfo(); /*! - * \brief Lower warp memory access to low-level device related function calls. + * \brief Combine context calls in the host function. + * * \return The pass. */ -TVM_DLL Pass LowerWarpMemory(); +TVM_DLL Pass CombineContextCall(); /*! diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 88231aaba56c..7eda40de7215 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -222,6 +222,15 @@ def _build_for_device(flist, target, target_host): mdev : tvm.module A module that contains device code. """ + @tvm.tir.transform.prim_func_pass(opt_level=0) + class BindTarget: + def __init__(self, target): + self.target = target + + # pylint: disable=unused-argument + def transform_function(self, func, mod, ctx): + return func.with_attr("target", self.target) + target = _target.create(target) device_type = ndarray.context(target.target_name, 0).device_type fhost = [] @@ -250,30 +259,39 @@ def _build_for_device(flist, target, target_host): else: raise ValueError("unknown function type %d" % func.func_type) - for i, func in enumerate(fdevice): - warp_size = target.thread_warp_size - fdevice[i] = ir_pass.LowerWarpMemory(func, warp_size) - if "gpu" in target.keys and not fdevice: warnings.warn( "Specified target %s, but cannot find device code, did you do " "bind?" % target) fhost = [ir_pass.BindDeviceType(x, device_type) for x in fhost] - fhost = [ir_pass.LowerTVMBuiltin(x) for x in fhost] if device_type == ndarray.cpu(0).device_type and target_host == target: assert not fdevice target_host = _target.create(target_host) - fdevice = [ir_pass.LowerDeviceStorageAccessInfo(x) for x in fdevice] - fhost = [ir_pass.LowerDeviceStorageAccessInfo(x) for x in fhost] - fdevice = [ir_pass.LowerIntrin(x, target.target_name) for x in fdevice] - fhost = [ir_pass.LowerIntrin(x, target_host.target_name) for x in fhost] - fhost = [ir_pass.CombineContextCall(x) for x in fhost] - mdev = codegen.build_module(fdevice, str(target)) if fdevice else None - return fhost, mdev + # device optimizations + mod_dev = tvm.testing.LoweredFuncsToIRModule(fdevice) + opt_device = tvm.ir.transform.Sequential( + [BindTarget(target), + tvm.tir.transform.LowerWarpMemory(), + tvm.tir.transform.LowerDeviceStorageAccessInfo(), + tvm.tir.transform.LowerIntrin()]) + mod_dev = opt_device(mod_dev) + + # host optimizations + mod_host = tvm.testing.LoweredFuncsToIRModule(fhost) + opt_host = tvm.ir.transform.Sequential( + [BindTarget(target_host), + tvm.tir.transform.LowerTVMBuiltin(), + tvm.tir.transform.LowerDeviceStorageAccessInfo(), + tvm.tir.transform.LowerIntrin(), + tvm.tir.transform.CombineContextCall()]) + mod_host = opt_host(mod_host) + + rt_mod_dev = codegen.build_module(mod_dev, target) if fdevice else None + return mod_host, rt_mod_dev def build(inputs, @@ -402,19 +420,19 @@ def build(inputs, if not target_host: target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm" - fhost_all = [] + mod_host_all = tvm.IRModule({}) + device_modules = [] for tar, flist in target_flist.items(): - fhost, mdev = _build_for_device(flist, tar, target_host) - # Save the current lowered functions of the host and the device module. - fhost_all += fhost + mod_host, mdev = _build_for_device(flist, tar, target_host) + mod_host_all.update(mod_host) device_modules.append(mdev) # Generate a unified host module. - mhost = codegen.build_module(fhost_all, str(target_host)) + rt_mod_host = codegen.build_module(mod_host_all, target_host) # Import all modules. for mdev in device_modules: if mdev: - mhost.import_module(mdev) - return mhost + rt_mod_host.import_module(mdev) + return rt_mod_host diff --git a/python/tvm/target/codegen.py b/python/tvm/target/codegen.py index e7bedaa1bbad..dc65c5b72d4d 100644 --- a/python/tvm/target/codegen.py +++ b/python/tvm/target/codegen.py @@ -17,15 +17,16 @@ # under the License. """Code generation related functions.""" from . import _ffi_api +from . import target as _tgt -def build_module(lowered_func, target): - """Build lowered_func into Module. +def build_module(mod, target): + """Build IRModule into Module. Parameters ---------- - lowered_func : LoweredFunc - The lowered function + mod : tvm.IRModule + The ir module. target : str The target module type. @@ -35,7 +36,8 @@ def build_module(lowered_func, target): module : runtime.Module The corressponding module. """ - return _ffi_api.Build(lowered_func, target) + target = _tgt.create(target) if isinstance(target, str) else target + return _ffi_api.Build(mod, target) def llvm_lookup_intrinsic_id(name): diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 7c2b3c812714..6be4a38fec03 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -16,19 +16,78 @@ # under the License. """Wrapping existing transformations.""" # pylint: disable=invalid-name - from . import _ffi_api -def CombineContextCall(): - """Combine context calls in the host function. +def SkipAssert(): + """Skip assert stmt. Returns ------- fpass : tvm.ir.transform.Pass The result pass """ - return _ffi_api.CombineContextCall() + return _ffi_api.SkipAssert() + + +def ThreadSync(storage_scope): + """ Insert sync between parallel read/write of shared buffers. + + Parameters + ---------- + storage_scope: str + The target storage scope. + + Returns + ------- + fpass : tvm.ir.transform.Pass + The result pass + """ + return _ffi_api.ThreadSync(storage_scope) + + +def LowerThreadAllreduce(): + """Lower cross thread alleduce. + + Returns + ------- + fpass : tvm.ir.transform.Pass + The result pass + """ + return _ffi_api.LowerThreadAllreduce() + + +def InferFragment(): + """ Infer the TensorCore fragment infomation using tensor intrinsics. + + Returns + ------- + fpass : tvm.ir.transform.Pass + The result pass + """ + return _ffi_api.InferFragment() + + +def LowerWarpMemory(): + """Lower warp memory access to low-level device related function calls. + + Returns + ------- + fpass : tvm.ir.transform.Pass + The result pass + """ + return _ffi_api.LowerWarpMemory() + + +def LowerTVMBuiltin(): + """Lower tvm builtin intrinsics. + + Returns + ------- + fpass : tvm.ir.transform.Pass + The result pass + """ + return _ffi_api.LowerTVMBuiltin() def LowerIntrin(): @@ -57,15 +116,15 @@ def LowerDeviceStorageAccessInfo(): return _ffi_api.LowerDeviceStorageAccessInfo() -def LowerWarpMemory(): - """Lower warp memory access to low-level device related function calls. +def CombineContextCall(): + """Combine context calls in the host function. Returns ------- fpass : tvm.ir.transform.Pass The result pass """ - return _ffi_api.LowerWarpMemory() + return _ffi_api.CombineContextCall() def NarrowDataType(): diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 0f56f9d654ae..f59e7646a2ac 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -24,6 +24,8 @@ #include #include #include + +#include #include #include #include @@ -174,10 +176,20 @@ Array lower(te::Schedule sch, return Array({ tir::MakeAPI(stmt, name, out_arg_list, 0, config->restricted_func) }); } -Array > split_dev_host_funcs(const Array& funcs, - const Target& target, - const Target& target_host, - const BuildConfig& config) { + +transform::Pass BindTarget(Target target) { + auto fpass = [target](tir::PrimFunc f, IRModule m, transform::PassContext ctx) { + return WithAttr(std::move(f), tvm::attr::kTarget, target); + }; + return tir::transform::CreatePrimFuncPass(fpass, 0, "BindTarget", {}); +} + + +std::pair +split_dev_host_funcs(const Array& funcs, + const Target& target, + const Target& target_host, + const BuildConfig& config) { std::unordered_set all_names; for (const auto& x : funcs) { CHECK(all_names.count(x->name) == 0) @@ -217,13 +229,6 @@ Array > split_dev_host_funcs(const Array& funcs, } } - for (size_t i = 0; i < fdevice.size(); i++) { - auto warp_size = target->thread_warp_size; - auto func = fdevice[i]; - func = tir::LowerWarpMemory(fdevice[i], warp_size); - fdevice.Set(i, func); - } - auto keys = target->keys(); bool target_is_gpu = std::find(keys.begin(), keys.end(), "gpu") != keys.end(); if (target_is_gpu && fdevice.size() == 0) { @@ -232,53 +237,46 @@ Array > split_dev_host_funcs(const Array& funcs, << " but cannot find device code. Did you forget to bind?"; } - for (size_t i = 0; i < fdevice.size(); ++i) { - auto func = fdevice[i]; - func = tir::LowerIntrin(func, target->target_name); - fdevice.Set(i, func); - } if (target->device_type == target::llvm()->device_type && - target_host == target) { + target_host == target) { CHECK(fdevice.empty()) << "No device code should be generated when target " << "and host_target are both llvm target." << "\n"; } - for (size_t i = 0; i < fdevice.size(); ++i) { - auto func = fdevice[i]; - func = tir::LowerDeviceStorageAccessInfo(func); - fdevice.Set(i, func); - } - for (size_t i = 0; i < fhost.size(); ++i) { auto func = fhost[i]; func = tir::BindDeviceType(func, target->device_type); - func = tir::LowerDeviceStorageAccessInfo(func); - func = tir::LowerTVMBuiltin(func); fhost.Set(i, func); } - for (size_t i = 0; i < fhost.size(); ++i) { - auto func = fhost[i]; - func = tir::LowerIntrin(func, target_host->target_name); - func = tir::LowerDeviceStorageAccessInfo(func); - func = tir::CombineContextCall(func); - fhost.Set(i, func); - } - return {fhost, fdevice}; + // host pipeline + auto mhost = codegen::ToIRModule(fhost); + auto host_pass_list = { + BindTarget(target_host), + tir::transform::LowerTVMBuiltin(), + tir::transform::LowerIntrin(), + tir::transform::LowerDeviceStorageAccessInfo(), + tir::transform::CombineContextCall(), + }; + auto opt_host = transform::Sequential(host_pass_list); + mhost = opt_host(mhost); + + // device pipeline + auto mdevice = codegen::ToIRModule(fdevice); + auto device_pass_list = { + BindTarget(target), + tir::transform::LowerWarpMemory(), + tir::transform::LowerIntrin(), + tir::transform::LowerDeviceStorageAccessInfo(), + }; + auto opt_device = transform::Sequential(device_pass_list); + mdevice = opt_device(mdevice); + + return {mhost, mdevice}; } -// Create a module for a specific device (target). The lowered functions -// associated with the host is returned as well. -runtime::Module DeviceBuild(const Array& fdevice, - const Target& target) { - if (!fdevice.empty()) { - return codegen::Build(fdevice, target->str()); - } else { - return runtime::Module(nullptr); - } -} // Build for heterogeneous execution. runtime::Module build(const Map>& inputs, @@ -301,20 +299,21 @@ runtime::Module build(const Map>& inputs, target_host_val = DefaultTargetHost(target_host_val); } + IRModule mhost_all = IRModule(Map()); + for (const auto& it : inputs) { - auto host_dev_funcs = + auto pair = split_dev_host_funcs(it.second, it.first, target_host_val, config); - auto& fhost = host_dev_funcs[0]; - auto& fdevice = host_dev_funcs[1]; - // Get the module for a certain target. - runtime::Module mdev = DeviceBuild(fdevice, it.first); - for (const auto& it : fhost) { - fhost_all.push_back(it); + auto& mhost = pair.first; + auto& mdevice = pair.second; + + mhost_all->Update(mhost); + if (mdevice->functions.size() != 0) { + device_modules.push_back(codegen::Build(mdevice, it.first)); } - device_modules.push_back(mdev); } - runtime::Module mhost = codegen::Build(fhost_all, target_host_val->str()); + runtime::Module mhost = codegen::Build(mhost_all, target_host_val); // Import all modules for (const auto& it : device_modules) { if (it.operator->()) { diff --git a/src/target/codegen.cc b/src/target/codegen.cc index 1981c21b4488..a977d35b2198 100644 --- a/src/target/codegen.cc +++ b/src/target/codegen.cc @@ -26,6 +26,7 @@ #include #include +#include #include #include @@ -42,18 +43,6 @@ namespace tvm { namespace codegen { -// The new build function. -// adapt the old function to the new one -runtime::Module BuildForIRModule(const IRModule& module, - const Target& target) { - std::string build_f_name = "target.build." + target->target_name; - // the build function. - const PackedFunc* bf = runtime::Registry::Get(build_f_name); - CHECK(bf != nullptr) - << "target.build." << target << " is not enabled"; - return (*bf)(module, target->str()); -} - // convert legacy LoweredFunc to PrimFunc. tir::PrimFunc ToPrimFunc(tir::LoweredFunc from) { // remap args to attach type annotations. @@ -97,24 +86,16 @@ IRModule ToIRModule(const Array& funcs) { return IRModule(functions); } -runtime::Module Build(const Array& funcs, - const std::string& target) { - std::string mode = target; - size_t pos = mode.find(' '); - if (pos != std::string::npos) { - mode = mode.substr(0, pos); - } - Array transformed_funcs; +runtime::Module Build(IRModule mod, const Target& target) { if (BuildConfig::Current()->disable_assert) { - for (const auto& x : funcs) { - auto func = tir::SkipAssert(x); - transformed_funcs.push_back(func); - } + mod = tir::transform::SkipAssert()(mod); } - - return BuildForIRModule( - transformed_funcs.size() != 0 ? ToIRModule(transformed_funcs) : ToIRModule(funcs), - Target::Create(target)); + std::string build_f_name = "target.build." + target->target_name; + // the build function. + const PackedFunc* bf = runtime::Registry::Get(build_f_name); + CHECK(bf != nullptr) + << "target.build." << target << " is not enabled"; + return (*bf)(mod, target->str()); } /*! \brief Helper class to serialize module */ @@ -300,13 +281,7 @@ runtime::Module PackImportsToLLVM(const runtime::Module& mod, } TVM_REGISTER_GLOBAL("target.Build") -.set_body([](TVMArgs args, TVMRetValue *ret) { - if (args[0].IsObjectRef()) { - *ret = Build({args[0]}, args[1]); - } else { - *ret = Build(args[0], args[1]); - } - }); +.set_body_typed(Build); TVM_REGISTER_GLOBAL("testing.LoweredFuncsToIRModule") .set_body_typed(ToIRModule); diff --git a/src/tir/pass/ffi_api.cc b/src/tir/pass/ffi_api.cc index d13461ea926e..ff821fe48517 100644 --- a/src/tir/pass/ffi_api.cc +++ b/src/tir/pass/ffi_api.cc @@ -135,7 +135,6 @@ REGISTER_PASS(SplitHostDevice); REGISTER_PASS(StorageRewrite); REGISTER_PASS(CoProcSync); REGISTER_PASS(LowerStorageAccessInfo); -REGISTER_PASS(LowerDeviceStorageAccessInfo) REGISTER_PASS(InjectVirtualThread); REGISTER_PASS(InjectPrefetch); REGISTER_PASS(InjectDoubleBuffer); @@ -143,12 +142,8 @@ REGISTER_PASS(LoopPartition); REGISTER_PASS(RemoveNoOp); REGISTER_PASS(LiftAttrScope); REGISTER_PASS(LowerThreadAllreduce); -REGISTER_PASS(LowerWarpMemory); REGISTER_PASS(RemapThreadAxis); -REGISTER_PASS(LowerIntrin); REGISTER_PASS(LowerCustomDatatypes); -REGISTER_PASS(LowerTVMBuiltin); -REGISTER_PASS(CombineContextCall); REGISTER_PASS(VerifyMemory); REGISTER_PASS(VerifyGPUCode); REGISTER_PASS(DecorateDeviceScope); diff --git a/src/tir/transforms/combine_context_call.cc b/src/tir/transforms/combine_context_call.cc index 324c1704aa63..f8e14a2a8fb3 100644 --- a/src/tir/transforms/combine_context_call.cc +++ b/src/tir/transforms/combine_context_call.cc @@ -109,18 +109,13 @@ class ContextCallCombiner final : public StmtExprMutator { std::unordered_map ctx_map_; }; -LoweredFunc CombineContextCall(LoweredFunc f) { - auto n = make_object(*f.operator->()); - n->body = ContextCallCombiner().Combine(n->body); - return LoweredFunc(n); -} namespace transform { Pass CombineContextCall() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); - n->body = ContextCallCombiner().Combine(n->body); + n->body = ContextCallCombiner().Combine(std::move(n->body)); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.CombineContextCall", {}); diff --git a/src/tir/transforms/lower_device_storage_access_info.cc b/src/tir/transforms/lower_device_storage_access_info.cc index 5797665ffdc0..2438da958a70 100644 --- a/src/tir/transforms/lower_device_storage_access_info.cc +++ b/src/tir/transforms/lower_device_storage_access_info.cc @@ -142,11 +142,6 @@ Stmt LowerStorageAccessInfo(Stmt stmt) { return StorageAccessInfoLower()(std::move(stmt)); } -LoweredFunc LowerDeviceStorageAccessInfo(LoweredFunc f) { - auto n = make_object(*f.operator->()); - n->body = LowerStorageAccessInfo(f->body); - return LoweredFunc(n); -} namespace transform { diff --git a/src/tir/transforms/lower_intrin.cc b/src/tir/transforms/lower_intrin.cc index 6d4863d6a3ed..41a94937d4ce 100644 --- a/src/tir/transforms/lower_intrin.cc +++ b/src/tir/transforms/lower_intrin.cc @@ -283,16 +283,6 @@ Stmt LowerIntrinStmt(Stmt stmt, const std::string target_name) { return IntrinInjecter(&analyzer, target_name)(std::move(stmt)); } -LoweredFunc -LowerIntrin(LoweredFunc f, const std::string& target) { - auto n = make_object(*f.operator->()); - std::istringstream is(target); - std::string target_name; - is >> target_name; - n->body = LowerIntrinStmt(n->body, target_name); - return LoweredFunc(n); -} - namespace transform { Pass LowerIntrin() { diff --git a/src/tir/pass/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc similarity index 92% rename from src/tir/pass/lower_thread_allreduce.cc rename to src/tir/transforms/lower_thread_allreduce.cc index 259a3a62d24b..e7e89f899d4f 100644 --- a/src/tir/pass/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -23,9 +23,14 @@ */ #include #include +#include #include +#include +#include + #include -#include "ir_util.h" + +#include "../pass/ir_util.h" #include "../../arith/compute_expr.h" #include "../../runtime/thread_storage_scope.h" @@ -342,5 +347,28 @@ LowerThreadAllreduce(LoweredFunc f, int warp_size) { n->body = ThreadAllreduceBuilder(warp_size)(n->body); return LoweredFunc(n); } + +namespace transform { + +Pass LowerThreadAllreduce() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + auto target = f->GetAttr(tvm::attr::kTarget); + CHECK(target.defined()) + << "LowerThreadAllreduce: Require the target attribute"; + auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); + CHECK(calling_conv.defined() && + calling_conv->value == static_cast(CallingConv::kDeviceKernelLaunch)) + << "LowerThreadAllreeduce: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; + n->body = ThreadAllreduceBuilder(target->thread_warp_size)(n->body); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.LowerThreadAllreduce", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.LowerThreadAllreduce") +.set_body_typed(LowerThreadAllreduce); + +} // namespace transform } // namespace tir } // namespace tvm diff --git a/src/tir/pass/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc similarity index 95% rename from src/tir/pass/lower_tvm_builtin.cc rename to src/tir/transforms/lower_tvm_builtin.cc index 106a604abce0..58c966b21711 100644 --- a/src/tir/pass/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -18,14 +18,18 @@ */ /*! - * Lower TVM related buildin intrinsics such as packed call. - * \file lower_tvm_buildin.cc + * Lower TVM related builtin intrinsics such as packed call. + * \file tir/transforms/lower_tvm_buildin.cc */ #include #include +#include #include +#include + #include -#include "ir_util.h" + +#include "../pass/ir_util.h" #include "../../arith/compute_expr.h" namespace tvm { @@ -368,11 +372,20 @@ class BuiltinLower : public StmtExprMutator { uint64_t max_arg_stack_{0}; }; -LoweredFunc LowerTVMBuiltin(LoweredFunc f) { - auto n = make_object(*f.operator->()); - n->body = BuiltinLower().Build(n->body); - return LoweredFunc(n); +namespace transform { + +Pass LowerTVMBuiltin() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + n->body = BuiltinLower().Build(n->body); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.LowerTVMBuiltin", {}); } +TVM_REGISTER_GLOBAL("tir.transform.LowerTVMBuiltin") +.set_body_typed(LowerTVMBuiltin); + +} // namespace transform } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index 808b0816a971..0361100f1f57 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -385,14 +385,6 @@ class WarpMemoryRewriter : private StmtMutator { std::unordered_map var_dom_; }; -LoweredFunc -LowerWarpMemory(LoweredFunc f, int warp_size) { - CHECK_EQ(f->func_type, kDeviceFunc); - auto n = make_object(*f.operator->()); - n->body = WarpMemoryRewriter(warp_size).Rewrite(n->body); - return LoweredFunc(n); -} - namespace transform { Pass LowerWarpMemory() { @@ -401,7 +393,7 @@ Pass LowerWarpMemory() { auto target = f->GetAttr(tvm::attr::kTarget); CHECK(target.defined()) << "LowerWarpMemory: Require the target attribute"; - n->body = WarpMemoryRewriter(target->thread_warp_size).Rewrite(n->body); + n->body = WarpMemoryRewriter(target->thread_warp_size).Rewrite(std::move(n->body)); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerWarpMemory", {}); diff --git a/src/tir/pass/skip_assert.cc b/src/tir/transforms/skip_assert.cc similarity index 74% rename from src/tir/pass/skip_assert.cc rename to src/tir/transforms/skip_assert.cc index 14f59f090cac..2857639f2e78 100644 --- a/src/tir/pass/skip_assert.cc +++ b/src/tir/transforms/skip_assert.cc @@ -19,7 +19,9 @@ #include #include +#include #include +#include namespace tvm { namespace tir { @@ -37,11 +39,21 @@ Stmt SkipAssert(Stmt stmt) { return AssertSkipper()(std::move(stmt)); } -LoweredFunc SkipAssert(LoweredFunc f) { - auto n = make_object(*f.operator->()); - n->body = SkipAssert(f->body); - return LoweredFunc(n); +namespace transform { + +Pass SkipAssert() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + n->body = AssertSkipper()(std::move(n->body)); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.SkipAssert", {}); } +TVM_REGISTER_GLOBAL("tir.transform.SkipAssert") +.set_body_typed(SkipAssert); + +} // namespace transform + } // namespace tir } // namespace tvm diff --git a/src/tir/pass/infer_fragment.cc b/src/tir/transforms/tensorcore_infer_fragment.cc similarity index 93% rename from src/tir/pass/infer_fragment.cc rename to src/tir/transforms/tensorcore_infer_fragment.cc index 608945a7a68a..fad423392937 100644 --- a/src/tir/pass/infer_fragment.cc +++ b/src/tir/transforms/tensorcore_infer_fragment.cc @@ -23,11 +23,15 @@ */ #include #include +#include #include +#include + #include #include -#include "ir_util.h" -#include "storage_access.h" + +#include "../pass/storage_access.h" +#include "../pass/ir_util.h" #include "../../runtime/thread_storage_scope.h" namespace tvm { @@ -221,5 +225,20 @@ LoweredFunc InferFragment(LoweredFunc f) { return LoweredFunc(n); } +namespace transform { + +Pass InferFragement() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + n->body = InferFragment(std::move(n->body)); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.InferFragement", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.InferFragement") +.set_body_typed(InferFragement); + +} // namespace transform } // namespace tir } // namespace tvm diff --git a/src/tir/pass/storage_sync.cc b/src/tir/transforms/thread_storage_sync.cc similarity index 95% rename from src/tir/pass/storage_sync.cc rename to src/tir/transforms/thread_storage_sync.cc index 7e81ba613cda..b631a6200d47 100644 --- a/src/tir/pass/storage_sync.cc +++ b/src/tir/transforms/thread_storage_sync.cc @@ -18,16 +18,21 @@ */ /*! - * \file storage_sync.cc + * \file thread_storage_sync.cc */ #include #include #include +#include #include +#include +#include + #include #include -#include "ir_util.h" -#include "storage_access.h" + +#include "../pass/ir_util.h" +#include "../pass/storage_access.h" #include "../../runtime/thread_storage_scope.h" namespace tvm { @@ -376,5 +381,20 @@ LoweredFunc ThreadSync(LoweredFunc f, std::string storage_scope) { return LoweredFunc(n); } +namespace transform { + +Pass ThreadSync(std::string storage_scope) { + auto pass_func = [storage_scope](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + n->body = ThreadSync(std::move(n->body), storage_scope); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.ThreadSync", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.ThreadSync") +.set_body_typed(ThreadSync); + +} // namespace transform } // namespace tir } // namespace tvm diff --git a/tests/python/integration/test_dot.py b/tests/python/integration/test_dot.py index c66e596ef50c..4f2b6aa99fcd 100644 --- a/tests/python/integration/test_dot.py +++ b/tests/python/integration/test_dot.py @@ -50,13 +50,12 @@ def test_dot(): k = te.reduce_axis((0, n), 'k') C = te.compute((1,), lambda _: te.sum(A[k] * B[k], axis=k), name='C') s = te.create_schedule(C.op) - fapi = lower(s, [A, B, C]) def verify(target): if not tvm.runtime.enabled(target): print("Target %s is not enabled" % target) return - f = tvm.target.codegen.build_module(fapi, target) + f = tvm.driver.build(s, [A, B, C], target) # verify ctx = tvm.cpu(0) a = tvm.nd.array(np.random.uniform(size=(nn,)).astype(A.dtype), ctx) diff --git a/tests/python/unittest/test_runtime_extension.py b/tests/python/unittest/test_runtime_extension.py index 375b99b0ad31..13de67efa8f4 100644 --- a/tests/python/unittest/test_runtime_extension.py +++ b/tests/python/unittest/test_runtime_extension.py @@ -39,8 +39,9 @@ def test_dltensor_compatible(): A[i + 1] = A[i] + 1 stmt = ib.get() fapi = tvm.tir.ir_pass.MakeAPI(stmt, "arange", [Ab], 0, True) - fapi = tvm.tir.ir_pass.LowerTVMBuiltin(fapi) - f = tvm.target.codegen.build_module(fapi, "stackvm") + mod = tvm.testing.LoweredFuncsToIRModule([fapi]) + mod = tvm.tir.transform.LowerTVMBuiltin()(mod) + f = tvm.target.codegen.build_module(mod, "stackvm") a = tvm.nd.array(np.zeros(10, dtype=dtype)) aview = MyTensorView(a) f(aview) diff --git a/tests/python/unittest/test_runtime_module_load.py b/tests/python/unittest/test_runtime_module_load.py index e7771e3c6721..37ccb5e47830 100644 --- a/tests/python/unittest/test_runtime_module_load.py +++ b/tests/python/unittest/test_runtime_module_load.py @@ -58,8 +58,7 @@ def save_object(names): tvm.tir.Load(dtype, Ab.data, i) + 1, i + 1)) fapi = tvm.tir.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True) - fapi = tvm.tir.ir_pass.LowerTVMBuiltin(fapi) - m = tvm.target.codegen.build_module(fapi, "llvm") + m = tvm.driver.build(fapi, target="llvm") for name in names: m.save(name) diff --git a/tests/python/unittest/test_target_codegen_c_host.py b/tests/python/unittest/test_target_codegen_c_host.py index 1604ffb2293b..c96531e4710e 100644 --- a/tests/python/unittest/test_target_codegen_c_host.py +++ b/tests/python/unittest/test_target_codegen_c_host.py @@ -74,9 +74,8 @@ def check_c(): binds = {A : Ab} # BUILD and invoke the kernel. f1 = tvm.lower(s, [A,B,C], name="fadd_pipeline") - fsplits = [x for x in tvm.tir.ir_pass.SplitHostDevice(f1)] - fsplits[0] = tvm.tir.ir_pass.LowerTVMBuiltin(fsplits[0]) - mhost = tvm.target.codegen.build_module(fsplits[0], "c") + mhost = tvm.build(f1, target="c") + temp = util.tempdir() path_dso = temp.relpath("temp.so") mhost.export_library(path_dso) diff --git a/tests/python/unittest/test_target_codegen_device.py b/tests/python/unittest/test_target_codegen_device.py index 88abca8d2820..ddb35f31fe1d 100644 --- a/tests/python/unittest/test_target_codegen_device.py +++ b/tests/python/unittest/test_target_codegen_device.py @@ -63,79 +63,27 @@ def test_add_pipeline(): s[D].bind(xi, te.thread_axis("threadIdx.x")) s[D].bind(xo, te.thread_axis("blockIdx.x")) - # compile to IR - s = s.normalize() - bounds = tvm.te.schedule.InferBound(s) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A') - Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B') - Db = tvm.tir.decl_buffer(D.shape, D.dtype, name='D') - stmt = tvm.tir.ir_pass.LoopPartition(stmt, False) - stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, D:Db}, 64) - stmt = tvm.tir.ir_pass.Simplify(stmt) - fapi = tvm.tir.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Db], 0, True) - fsplits = [x for x in tvm.tir.ir_pass.SplitHostDevice(fapi)] - # lower the floordiv(use stackvm rules so it works for all targets) - fsplits = [tvm.tir.ir_pass.LowerIntrin(x, "stackvm") for x in fsplits] - fsplits[0] = tvm.tir.ir_pass.LowerTVMBuiltin(fsplits[0]) - def check_target(device, host="stackvm"): ctx = tvm.context(device, 0) if not ctx.exist: return if not tvm.runtime.enabled(host): return - mhost = tvm.target.codegen.build_module(fsplits[0], host) - mdev = tvm.target.codegen.build_module(fsplits[1:], device) - mhost.import_module(mdev) - code = mdev.get_source() - f = mhost.entry_func - # launch the kernel. - n = 1027 - a = tvm.nd.array(np.random.uniform(size=n).astype(Ab.dtype), ctx) - b = tvm.nd.array(np.random.uniform(size=()).astype(Bb.dtype), ctx) - d = tvm.nd.array(np.zeros(n, dtype=Db.dtype), ctx) - f(a, b, d) - tvm.testing.assert_allclose( - d.asnumpy(), a.asnumpy() + b.asnumpy() + 1) - - def check_module_save(device, host="stackvm"): - ctx = tvm.context(device, 0) - if not ctx.exist: - return - if not tvm.runtime.enabled(host): - return - if device == "cuda": - fmt = "ptx" - elif device == "rocm": - fmt = "hsaco" - else: - fmt = device - mhost = tvm.target.codegen.build_module(fsplits[0], host) - mdev = tvm.target.codegen.build_module(fsplits[1:], device) - temp = util.tempdir() - mpath = temp.relpath("test.%s" % fmt) - mdev.save(mpath) - mdev2 = tvm.runtime.load_module(mpath) - mhost.import_module(mdev2) + mhost = tvm.driver.build(s, [A, B, D], target=device, target_host=host) f = mhost.entry_func # launch the kernel. n = 1027 - a = tvm.nd.array(np.random.uniform(size=n).astype(Ab.dtype), ctx) - b = tvm.nd.array(np.random.uniform(size=()).astype(Bb.dtype), ctx) - d = tvm.nd.array(np.zeros(n, dtype=Db.dtype), ctx) + a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx) + b = tvm.nd.array(np.random.uniform(size=()).astype(B.dtype), ctx) + d = tvm.nd.array(np.zeros(n, dtype=D.dtype), ctx) f(a, b, d) tvm.testing.assert_allclose( d.asnumpy(), a.asnumpy() + b.asnumpy() + 1) - check_target("cuda", host="stackvm") check_target("cuda", host="llvm") - check_module_save("cuda", host="stackvm") check_target("nvptx", host="llvm") check_target("vulkan", host="llvm") - check_module_save("vulkan", host="stackvm") check_target("rocm", host="llvm") - check_module_save("rocm", host="llvm") if __name__ == "__main__": diff --git a/tests/python/unittest/test_target_codegen_static_init.py b/tests/python/unittest/test_target_codegen_static_init.py index 3b5f17a4243a..a9fa35f1a533 100644 --- a/tests/python/unittest/test_target_codegen_static_init.py +++ b/tests/python/unittest/test_target_codegen_static_init.py @@ -33,8 +33,7 @@ def test_static_callback(): A[i] = A[i] + 1 stmt = ib.get() fapi = tvm.tir.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True) - fapi = tvm.tir.ir_pass.LowerTVMBuiltin(fapi) - f = tvm.target.codegen.build_module(fapi, "llvm") + f = tvm.driver.build(fapi, target="llvm") a = tvm.nd.array(np.zeros(10, dtype=dtype)) f(a) f(a) @@ -57,8 +56,7 @@ def test_cb(sh, A): stmt = ib.get() fapi = tvm.tir.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True) - fapi = tvm.tir.ir_pass.LowerTVMBuiltin(fapi) - f = tvm.target.codegen.build_module(fapi, "llvm") + f = tvm.driver.build(fapi, target="llvm") a = tvm.nd.array(np.zeros(10, dtype=dtype)) f(a) diff --git a/tests/python/unittest/test_target_codegen_vm_basic.py b/tests/python/unittest/test_target_codegen_vm_basic.py index e2ff4875e6fd..26464ceedfd7 100644 --- a/tests/python/unittest/test_target_codegen_vm_basic.py +++ b/tests/python/unittest/test_target_codegen_vm_basic.py @@ -22,7 +22,7 @@ def run_jit(fapi, check): for target in ["llvm", "stackvm"]: if not tvm.runtime.enabled(target): continue - f = tvm.target.codegen.build_module(fapi, target) + f = tvm.driver.build(fapi, target=target) s = f.get_source() check(f) @@ -37,8 +37,6 @@ def tvm_call_back_get_shape(shape0): Ab = tvm.tir.decl_buffer((n, ), "float32") stmt = tvm.tir.Evaluate(tvm.tir.call_packed("tvm_call_back_get_shape", Ab.shape[0])) fapi = tvm.tir.ir_pass.MakeAPI(stmt, "print_shape", [Ab], 0, True) - fapi = tvm.tir.ir_pass.LowerTVMBuiltin(fapi) - fapi = tvm.tir.ir_pass.LowerIntrin(fapi, "stackvm") run_jit(fapi, lambda f: f(a)) @@ -60,7 +58,6 @@ def test_stack_vm_loop(): stmt = ib.get() fapi = tvm.tir.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True) - fapi = tvm.tir.ir_pass.LowerTVMBuiltin(fapi) a = tvm.nd.array(np.zeros(10, dtype=dtype)) def check(f): f(a) @@ -83,7 +80,6 @@ def test_stack_vm_cond(): stmt = ib.get() fapi = tvm.tir.ir_pass.MakeAPI(stmt, "test", [Ab], 0, True) - fapi = tvm.tir.ir_pass.LowerTVMBuiltin(fapi) def check(f): a = tvm.nd.array(np.zeros(10, dtype=dtype)) f(a) @@ -103,7 +99,6 @@ def test_vm_parallel(): A[i] = A[i] + 1 stmt = ib.get() fapi = tvm.tir.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True) - fapi = tvm.tir.ir_pass.LowerTVMBuiltin(fapi) def check(f): a = tvm.nd.array(np.zeros(10, dtype=dtype)) f(a) diff --git a/tests/python/unittest/test_tir_pass_storage_sync.py b/tests/python/unittest/test_tir_pass_coproc_sync.py similarity index 78% rename from tests/python/unittest/test_tir_pass_storage_sync.py rename to tests/python/unittest/test_tir_pass_coproc_sync.py index 9edfa9575435..b0e2050e2ee9 100644 --- a/tests/python/unittest/test_tir_pass_storage_sync.py +++ b/tests/python/unittest/test_tir_pass_coproc_sync.py @@ -17,34 +17,6 @@ import tvm from tvm import te -def test_storage_sync(): - m = te.size_var('m') - l = te.size_var('l') - A = te.placeholder((m, l), name='A') - - A1 = te.compute((m, l), lambda i, j: A[i, j], name='A1') - A2 = te.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2') - - s = te.create_schedule(A2.op) - xo, xi = s[A2].split(A2.op.axis[0], factor=8) - s[A2].bind(xo, te.thread_axis("blockIdx.x")) - s[A1].compute_at(s[A2], xo) - s[A1].set_scope("shared") - - bounds = tvm.te.schedule.InferBound(s) - assert isinstance(bounds, tvm.container.Map) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A') - A2b = tvm.tir.decl_buffer(A2.shape, A2.dtype, name='A2') - stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}, 64) - f = tvm.tir.ir_pass.MakeAPI(stmt, "test", [Ab, A2b], 0, True) - flist = tvm.tir.ir_pass.SplitHostDevice(f) - f = flist[1] - f = tvm.tir.ir_pass.ThreadSync(f, "shared") - body_list = tvm.tir.stmt_list(f.body.body.body.body) - assert(body_list[1].value.name == "tvm_storage_sync") - - def test_coproc_sync(): @tvm.register_func("tvm.info.mem.global.cache") def meminfo_cache(): @@ -133,6 +105,5 @@ def __check_list(tvm_array, py_list): if __name__ == "__main__": test_coproc_sync() - test_storage_sync() test_coproc_sync2() test_coproc_sync3() diff --git a/tests/python/unittest/test_tir_transform_thread_sync.py b/tests/python/unittest/test_tir_transform_thread_sync.py new file mode 100644 index 000000000000..e692e23b0878 --- /dev/null +++ b/tests/python/unittest/test_tir_transform_thread_sync.py @@ -0,0 +1,55 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import te + +def test_thread_storage_sync(): + m = te.size_var('m') + l = te.size_var('l') + A = te.placeholder((m, l), name='A') + + A1 = te.compute((m, l), lambda i, j: A[i, j], name='A1') + A2 = te.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2') + + s = te.create_schedule(A2.op) + xo, xi = s[A2].split(A2.op.axis[0], factor=8) + s[A2].bind(xo, te.thread_axis("blockIdx.x")) + s[A1].compute_at(s[A2], xo) + s[A1].set_scope("shared") + + bounds = tvm.te.schedule.InferBound(s) + assert isinstance(bounds, tvm.container.Map) + stmt = tvm.te.schedule.ScheduleOps(s, bounds) + Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A') + A2b = tvm.tir.decl_buffer(A2.shape, A2.dtype, name='A2') + stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}, 64) + f = tvm.tir.ir_pass.MakeAPI(stmt, "test", [Ab, A2b], 0, True) + flist = tvm.tir.ir_pass.SplitHostDevice(f) + f = flist[1] + fname = f.name + mod = tvm.testing.LoweredFuncsToIRModule([f]) + + cuda_target = tvm.target.create("cuda") + mod = tvm.IRModule.from_expr(mod[fname].with_attr("target", cuda_target)) + f = tvm.tir.transform.ThreadSync("shared")(mod)["main"] + body_list = tvm.tir.stmt_list(f.body.body.body.body) + assert(body_list[1].value.name == "tvm_storage_sync") + + + +if __name__ == "__main__": + test_thread_storage_sync()