Skip to content

Commit

Permalink
[COMPILER] Upgrade to meet latest TVM IR pragma convention (#32)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Jul 12, 2018
1 parent b668e4d commit 8ae7ae8
Showing 1 changed file with 21 additions and 6 deletions.
27 changes: 21 additions & 6 deletions vta/python/vta/ir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,21 @@
from .environment import get_env


def _match_pragma(stmt, key):
"""Internal helper to match stmt to pragma stmt.
Parameters
----------
stmt : Stmt
The AttrStmt
key : str
The pragma key
"""
return ((stmt.attr_key == "pragma_" + key) or
(stmt.attr_key == "pragma_scope" and stmt.value.value == key))


def fold_uop_loop(stmt_in):
"""Detect and fold uop loop.
Expand Down Expand Up @@ -255,7 +270,7 @@ def inject_skip_copy(stmt_in):
Transformed statement
"""
def _do_fold(stmt):
if (stmt.attr_key == "pragma_scope" and stmt.value.value == "skip_dma_copy"):
if _match_pragma(stmt, "skip_dma_copy"):
return tvm.make.Evaluate(0)
return None
return tvm.ir_pass.IRTransform(
Expand All @@ -277,12 +292,12 @@ def inject_coproc_sync(stmt_in):
"""
success = [False]
def _do_fold(stmt):
if stmt.attr_key == "pragma_scope" and stmt.value.value == "coproc_sync":
if _match_pragma(stmt, "coproc_sync"):
success[0] = True
sync = tvm.make.Call(
"int32", "vta.coproc_sync", [], tvm.expr.Call.Intrinsic, None, 0)
return tvm.make.Block(stmt.body, tvm.make.Evaluate(sync))
elif stmt.attr_key == "pragma_scope" and stmt.value.value == "trim_loop":
elif _match_pragma(stmt, "trim_loop"):
op = stmt.body
assert isinstance(op, tvm.stmt.For)
return tvm.make.For(
Expand Down Expand Up @@ -561,15 +576,15 @@ def annotate_alu_coproc_scope(stmt_in):
"""
env = get_env()
def _do_fold(stmt):
if (stmt.attr_key == "pragma_scope" and stmt.value.value == "alu"):
if _match_pragma(stmt, "alu"):
irb = tvm.ir_builder.create()
irb.scope_attr(env.dev.vta_axis, "coproc_scope",
env.dev.get_task_qid(env.dev.QID_COMPUTE))
irb.scope_attr(env.dev.vta_axis, "coproc_uop_scope",
tvm.make.StringImm("VTAPushALUOp"))
irb.emit(stmt)
return irb.get()
elif (stmt.attr_key == "pragma_scope" and stmt.value.value == "skip_alu"):
elif _match_pragma(stmt, "skip_alu"):
return tvm.make.Evaluate(0)
return stmt

Expand Down Expand Up @@ -631,7 +646,7 @@ def _flatten_loop(src_coeff, dst_coeff, extents):

return rev_src_coeff, rev_dst_coeff, rev_extents

if (stmt.attr_key == "pragma_scope" and stmt.value.value == "alu"):
if _match_pragma(stmt, "alu"):
# Get to the innermost loop body
loop_body = stmt.body
nest_size = 0
Expand Down

0 comments on commit 8ae7ae8

Please sign in to comment.