From a07beee7f1aa954ccf8324a16cfadfa82156581c Mon Sep 17 00:00:00 2001 From: kun-zh <32951065+kun-zh@users.noreply.github.com> Date: Tue, 30 Oct 2018 06:39:38 +0800 Subject: [PATCH] [PASS] add a pass for the specific hardware accelarator when it is not binded (#1999) --- include/tvm/ir.h | 5 ++++ include/tvm/ir_pass.h | 9 +++++++ src/api/api_pass.cc | 1 + src/pass/detect_device.cc | 21 +++++++++++++++ src/pass/split_host_device.cc | 3 ++- .../test_pass_decorate_device_scope.py | 26 +++++++++++++++++++ 6 files changed, 64 insertions(+), 1 deletion(-) create mode 100644 src/pass/detect_device.cc create mode 100644 tests/python/unittest/test_pass_decorate_device_scope.py diff --git a/include/tvm/ir.h b/include/tvm/ir.h index 14e60146567f8..212234303c616 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -237,6 +237,11 @@ constexpr const char* pipeline_exec_scope = "pipeline_exec_scope"; */ constexpr const char* opengl_stage_scope = "opengl_stage_scope"; +/*! + * \brief Mark that it is in the device scope. + */ +constexpr const char* device_scope = "device_scope"; + /*! * \brief Check if attr_key is a pragma key extension * \param attr_key The attr key to be compared diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index 9403a2e6151bc..332becb7aa389 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -326,6 +326,15 @@ Stmt RewriteUnsafeSelect(Stmt stmt); */ Stmt LowerStorageAccessInfo(Stmt stmt); +/*! + * \brief Decorate the stmt with a device scope, this is helpful for + * hardware accelerator without thread blocks. + * + * \param stmt The stmt to be trasnformed + * \return Transformed stmt. + */ +Stmt DecorateDeviceScope(Stmt stmt); + /*! * \brief Make an user callable API LoweredFunc. * diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index 1e571ca0dc41c..575535f26e81f 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -154,5 +154,6 @@ REGISTER_PASS1(LowerTVMBuiltin); REGISTER_PASS1(CombineContextCall); REGISTER_PASS2(VerifyMemory); REGISTER_PASS2(VerifyGPUCode); +REGISTER_PASS1(DecorateDeviceScope); } // namespace ir } // namespace tvm diff --git a/src/pass/detect_device.cc b/src/pass/detect_device.cc new file mode 100644 index 0000000000000..c5fb0dd1b8f30 --- /dev/null +++ b/src/pass/detect_device.cc @@ -0,0 +1,21 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file detect_device.cc + */ + +#include +#include +#include "../pass/ir_util.h" + +namespace tvm { +namespace ir { +Stmt DecorateDeviceScope(Stmt stmt) { + Stmt body = AttrStmt::make(make_zero(Int(32)), + ir::attr::device_scope, + 0, + stmt); + return body; +} + +} // namespace ir +} // namespace tvm diff --git a/src/pass/split_host_device.cc b/src/pass/split_host_device.cc index 112c2c173df12..4cfbc7c90d8c7 100644 --- a/src/pass/split_host_device.cc +++ b/src/pass/split_host_device.cc @@ -153,7 +153,8 @@ class HostDeviceSplitter : public IRMutator { Stmt Mutate_(const AttrStmt *op, const Stmt& s) final { if (op->attr_key == attr::thread_extent || - op->attr_key == attr::pipeline_exec_scope) { + op->attr_key == attr::pipeline_exec_scope || + op->attr_key == attr::device_scope) { return SplitDeviceFunc(s); } return IRMutator::Mutate_(op, s); diff --git a/tests/python/unittest/test_pass_decorate_device_scope.py b/tests/python/unittest/test_pass_decorate_device_scope.py new file mode 100644 index 0000000000000..1d9eb899a6427 --- /dev/null +++ b/tests/python/unittest/test_pass_decorate_device_scope.py @@ -0,0 +1,26 @@ +import tvm + +def test_decorate_device(): + m = tvm.var('m') + l = tvm.var('l') + A = tvm.placeholder((m, l), name='A') + + A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1') + A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2') + + s = tvm.create_schedule(A2.op) + xo, xi = s[A2].split(A2.op.axis[0], factor=8) + s[A1].compute_at(s[A2], xo) + s[A1].set_scope("shared") + + bounds = tvm.schedule.InferBound(s) + stmt = tvm.schedule.ScheduleOps(s, bounds) + stmt1 = tvm.ir_pass.Simplify(stmt) + stmt2 = tvm.ir_pass.DecorateDeviceScope(stmt1) + assert isinstance(stmt2, tvm.stmt.AttrStmt) + assert stmt2.attr_key == "device_scope" + assert stmt1 == stmt2.body + +if __name__ == "__main__": + test_decorate_device() +