Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR][Hybrid] fold attr&allocate/realize #110

Merged
merged 4 commits into from
Sep 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 25 additions & 32 deletions python/tvm/hybrid/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __init__(self, src, base_lienno):
self.meta = None

self.functions = {}
self.assign_target = None
self.target = None

def init_function_parsing_env(self):
"""Initialize function parsing environment"""
Expand Down Expand Up @@ -304,29 +304,38 @@ def visit_Assign(self, node):
1.3 Var = tir.var()
2. (BufferStore) Buffer[PrimExpr, PrimExpr, ..., PrimExpr] = PrimExpr
3. (Store) Var[PrimExpr] = PrimExpr
4. with scope handlers with concise scoping and var def
4.1 var = tir.alloc_with_scope()
"""

if not len(node.targets) == 1:
self.report_error("Only one-valued assignment is supported now")
target = node.targets[0]

if isinstance(target, ast.Name):
# scenario 1
self.assign_target = target.id
rhs = self.visit(node.value)
# scenario 1&4
self.target = [target.id]
if not isinstance(node.value, ast.Call):
self.report_error("Unsupported assign stmt")
self.scope_emitter.update_symbol(target.id, rhs)
func = self.visit(node.value.func)
if Registry.is_with_scope(func):
# scenario 4
return self.visit(node.value)
else:
# scenario 1
rhs = self.visit(node.value)
self.scope_emitter.update_symbol(target.id, rhs)
elif isinstance(target, ast.Subscript):
# scenario 2&3
symbol, indexes = self.visit(target)
self.assign_target = (symbol, indexes)
rhs = self.visit(node.value)
if isinstance(symbol, tvm.tir.Buffer):
# BufferStore
return tvm.tir.BufferStore(symbol, tvm.runtime.convert(rhs), indexes)
else:
if len(indexes) != 1:
self.report_error("Invalid Store stmt")
# Store
return tvm.tir.Store(symbol, tvm.runtime.convert(rhs), indexes[0],
tvm.runtime.convert(True))
else:
Expand Down Expand Up @@ -393,9 +402,8 @@ def visit_With(self, node):
With(withitem* items, stmt* body, string? type_comment)
withitem = (expr context_expr, expr? optional_vars)
By now 2 types of With is supported:
1. with tir.block(*axes) as block_vars:
2. with tir.allocate() as
2. with tir.let/tir.Assert()/tir.attr()/tir.allocate()/tir.realize()
1. with tir.block(*axes)/tir.allocate() as targets:
2. with tir.let()/tir.Assert()/tir.attr()//tir.realize()
"""
if not len(node.items) == 1:
self.report_error("Only one with element is supported now")
Expand All @@ -405,42 +413,27 @@ def visit_With(self, node):
func_call = node.items[0].context_expr
func_node = func_call.func
func = self.visit(func_node)
is_tir_block = (isinstance(func_node, ast.Attribute)
and isinstance(func_node.value, ast.Name)
and func_node.value.id == "tir"
and func_node.attr == "block")

if not is_tir_block and node.items[0].optional_vars is not None:
self.report_error("Now only tir.block allows optional var")
if not Registry.is_with_scope(func):
self.report_error("Function not allowed in with scope")

if is_tir_block:
# preprocess block_var definitions
self.target = []
if node.items[0].optional_vars is not None:
# preprocess optional var names
if isinstance(node.items[0].optional_vars, ast.Name):
block_vars = [node.items[0].optional_vars.id]
self.target = [node.items[0].optional_vars.id]
elif isinstance(node.items[0].optional_vars, (ast.List, ast.Tuple)):
for var in node.items[0].optional_vars.elts:
if not isinstance(var, ast.Name):
self.report_error("Invalid block var definition")
block_vars = [var.id for var in node.items[0].optional_vars.elts]
elif node.items[0].optional_vars is None:
block_vars = []
self.report_error("Invalid optional var definition")
self.target = [var.id for var in node.items[0].optional_vars.elts]
else:
self.report_error("Invalid block var definition")
# update block vars into symbol table
block_vars = [tvm.te.var(name) for name in block_vars]
self.scope_emitter.new_scope(is_block=True)
for block_var in block_vars:
self.scope_emitter.update_symbol(block_var.name, block_var)

self.report_error("Invalid optional var definition")
# parse other arguments
args = [self.visit(arg) for arg in func_call.args]
kw_args = [self.visit(keyword) for keyword in func_call.keywords]
kw_args = {kw_arg[0]: kw_arg[1] for kw_arg in kw_args}

if is_tir_block:
args = [block_vars] + args

old_lineno, old_col_offset = self.current_lineno, self.current_col_offset
self.current_lineno, self.current_col_offset = \
self.base_lineno + func_call.lineno - 1, func_call.col_offset
Expand Down
76 changes: 49 additions & 27 deletions python/tvm/hybrid/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def auto_insert_body(self, pos, body):
self.kwargs["body"] = body


def func_wrapper(func_name, func_to_register, arg_list, category, concise):
def func_wrapper(func_name, func_to_register, arg_list, category, concise=False, with_var=False):
"""Helper function to wrap a function to be registered """

def wrap_func(parser, node, args, kwargs):
Expand Down Expand Up @@ -145,25 +145,34 @@ def wrap_func(parser, node, args, kwargs):
parser.scope_emitter.loop_stack[-1].pop()
parser.scope_emitter.pop_scope()
elif category == Category.WITH_SCOPE:
# automatically parse body for with_scope handlers
if isinstance(node, ast.With):
# the scope handler is used inside with context/for context
parser.scope_emitter.new_scope()
parser.scope_emitter.node_stack[-1].extend(reversed(node.body))
body = parser.get_body()
parser.scope_emitter.pop_scope()
if not with_var:
if isinstance(node, ast.With) and node.items[0].optional_vars is not None:
parser.report_error("Function " + func_name + " expects no optional vars")
# automatically parse body for with_scope handlers without optional vars
if isinstance(node, ast.With):
parser.scope_emitter.new_scope()
parser.scope_emitter.node_stack[-1].extend(reversed(node.body))
body = parser.get_body()
parser.scope_emitter.pop_scope()
else:
body = parser.get_body()
else:
# the with scope handler is used in concise scoping
if not concise:
parser.report_error("Concise scoping is not allowed here")
body = parser.get_body()
if isinstance(node, ast.With) and node.items[0].optional_vars is None:
parser.report_error("Function " + func_name + " expects optional vars")
body = None

if not isinstance(node, ast.With) and not concise:
parser.report_error("Concise scoping is not allowed here")

reader = CallArgumentReader(func_name, args, kwargs, parser)
pos_only, kwargs, varargs = arg_list

internal_args = list()
if category == Category.WITH_SCOPE:
internal_args.extend([parser, node, body])
if not with_var:
internal_args.extend([parser, node, body])
else:
internal_args.extend([parser, node])
elif category == Category.FOR_SCOPE:
internal_args.extend([parser, node, body, loop_vars])
elif category == Category.SPECIAL_STMT:
Expand All @@ -184,15 +193,17 @@ def wrap_func(parser, node, args, kwargs):
return wrap_func


def get_arg_list(origin_func, category):
def get_arg_list(origin_func, category, with_var=False):
"""Helper function to get the argument list of Function

Parameters
----------
origin_func: function
The function to get the argument list
category: Category
The category of registerde function
The category of registered function
with_var: bool, optional
Whether the with scope handler neeeds optional vars
"""
full_arg_spec = inspect.getfullargspec(origin_func)

Expand All @@ -202,11 +213,18 @@ def get_arg_list(origin_func, category):
defaults = tuple()

if category == Category.WITH_SCOPE:
if len(args) < 3 or args[0] != "parser" or args[1] != "node" or args[2] != "body":
raise RuntimeError(
"TVM Hybrid Script register error : the first three arguments of with scope handler"
"must be parser, node, body")
args = args[3:]
if not with_var:
if len(args) < 3 or args[0] != "parser" or args[1] != "node" or args[2] != "body":
raise RuntimeError(
"TVM Hybrid Script register error : the first three arguments of "
"this with scope handler must be parser, node, body")
args = args[3:]
else:
if len(args) < 2 or args[0] != "parser" or args[1] != "node":
raise RuntimeError(
"TVM Hybrid Script register error : the first two arguments of "
"this with scope handler must be parser, node")
args = args[2:]
elif category == Category.FOR_SCOPE:
if len(args) < 4 or args[0] != "parser" or \
args[1] != "node" or args[2] != "body" or args[3] != "loop_vars":
Expand Down Expand Up @@ -259,19 +277,21 @@ def decorate(origin_func):
func_name = "tir." + origin_func.__qualname__ if name is None else name
Registry.functions[func_name] = \
func_wrapper(func_name, origin_func, get_arg_list(origin_func, Category.INTRIN),
Category.INTRIN, concise=False), Category.INTRIN
Category.INTRIN), Category.INTRIN
return origin_func

return decorate


def register_with_scope(concise=False, name=None):
def register_with_scope(concise=False, with_var=False, name=None):
"""Decorator to register function under with scope handler

Parameters
----------
concise: bool, optional
whether this scope handler is allowed in concise scoping
whether this with scope handler is allowed in concise scoping
with_var: bool, optional
whether this with scope handler neeeds optional vars
name: str, optional
registered name for the function

Expand All @@ -288,8 +308,10 @@ def decorate(origin_func):
"""Register function under category with_scope"""
func_name = "tir." + origin_func.__qualname__ if name is None else name
Registry.functions[func_name] = \
func_wrapper(func_name, origin_func, get_arg_list(origin_func, Category.WITH_SCOPE),
Category.WITH_SCOPE, concise=concise), Category.WITH_SCOPE
func_wrapper(func_name, origin_func,
get_arg_list(origin_func, Category.WITH_SCOPE, with_var),
Category.WITH_SCOPE, concise=concise, with_var=with_var), \
Category.WITH_SCOPE
return origin_func

return decorate
Expand All @@ -308,7 +330,7 @@ def decorate(origin_func):
func_name = "tir." + origin_func.__qualname__ if name is None else name
Registry.functions[func_name] = \
func_wrapper(func_name, origin_func, get_arg_list(origin_func, Category.FOR_SCOPE),
Category.FOR_SCOPE, concise=False), Category.FOR_SCOPE
Category.FOR_SCOPE), Category.FOR_SCOPE
return origin_func

return decorate
Expand Down Expand Up @@ -337,7 +359,7 @@ def decorate(origin_func):
func_name = "tir." + origin_func.__qualname__ if name is None else name
Registry.functions[func_name] = \
func_wrapper(func_name, origin_func, get_arg_list(origin_func, Category.SPECIAL_STMT),
Category.SPECIAL_STMT, concise=False), Category.SPECIAL_STMT
Category.SPECIAL_STMT), Category.SPECIAL_STMT
return origin_func

return decorate
Loading