Skip to content

Commit

Permalink
[PASS] add a pass for the specific hardware accelarator when it is no…
Browse files Browse the repository at this point in the history
…t binded (apache#1999)
  • Loading branch information
kun-zh authored and Wei Chen committed Feb 19, 2019
1 parent 74165ce commit a07beee
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 1 deletion.
5 changes: 5 additions & 0 deletions include/tvm/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,11 @@ constexpr const char* pipeline_exec_scope = "pipeline_exec_scope";
*/
constexpr const char* opengl_stage_scope = "opengl_stage_scope";

/*!
* \brief Mark that it is in the device scope.
*/
constexpr const char* device_scope = "device_scope";

/*!
* \brief Check if attr_key is a pragma key extension
* \param attr_key The attr key to be compared
Expand Down
9 changes: 9 additions & 0 deletions include/tvm/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,15 @@ Stmt RewriteUnsafeSelect(Stmt stmt);
*/
Stmt LowerStorageAccessInfo(Stmt stmt);

/*!
* \brief Decorate the stmt with a device scope, this is helpful for
* hardware accelerator without thread blocks.
*
* \param stmt The stmt to be trasnformed
* \return Transformed stmt.
*/
Stmt DecorateDeviceScope(Stmt stmt);

/*!
* \brief Make an user callable API LoweredFunc.
*
Expand Down
1 change: 1 addition & 0 deletions src/api/api_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -154,5 +154,6 @@ REGISTER_PASS1(LowerTVMBuiltin);
REGISTER_PASS1(CombineContextCall);
REGISTER_PASS2(VerifyMemory);
REGISTER_PASS2(VerifyGPUCode);
REGISTER_PASS1(DecorateDeviceScope);
} // namespace ir
} // namespace tvm
21 changes: 21 additions & 0 deletions src/pass/detect_device.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/*!
* Copyright (c) 2018 by Contributors
* \file detect_device.cc
*/

#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
#include "../pass/ir_util.h"

namespace tvm {
namespace ir {
Stmt DecorateDeviceScope(Stmt stmt) {
Stmt body = AttrStmt::make(make_zero(Int(32)),
ir::attr::device_scope,
0,
stmt);
return body;
}

} // namespace ir
} // namespace tvm
3 changes: 2 additions & 1 deletion src/pass/split_host_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@ class HostDeviceSplitter : public IRMutator {

Stmt Mutate_(const AttrStmt *op, const Stmt& s) final {
if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::pipeline_exec_scope) {
op->attr_key == attr::pipeline_exec_scope ||
op->attr_key == attr::device_scope) {
return SplitDeviceFunc(s);
}
return IRMutator::Mutate_(op, s);
Expand Down
26 changes: 26 additions & 0 deletions tests/python/unittest/test_pass_decorate_device_scope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import tvm

def test_decorate_device():
m = tvm.var('m')
l = tvm.var('l')
A = tvm.placeholder((m, l), name='A')

A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')

s = tvm.create_schedule(A2.op)
xo, xi = s[A2].split(A2.op.axis[0], factor=8)
s[A1].compute_at(s[A2], xo)
s[A1].set_scope("shared")

bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
stmt1 = tvm.ir_pass.Simplify(stmt)
stmt2 = tvm.ir_pass.DecorateDeviceScope(stmt1)
assert isinstance(stmt2, tvm.stmt.AttrStmt)
assert stmt2.attr_key == "device_scope"
assert stmt1 == stmt2.body

if __name__ == "__main__":
test_decorate_device()

0 comments on commit a07beee

Please sign in to comment.