diff --git a/python/tvm/script/__init__.py b/python/tvm/script/__init__.py index 4b9f07354f70..4cf7828290a7 100644 --- a/python/tvm/script/__init__.py +++ b/python/tvm/script/__init__.py @@ -16,5 +16,4 @@ # under the License. """TVM Script APIs of TVM Python Package, aimed to support TIR""" -from .utils import create_module, asscript, tir, module -from .parser import from_source +from .parser import from_source, create_module, asscript, tir, module diff --git a/python/tvm/script/_ffi_api.py b/python/tvm/script/_ffi_api.py index 92c38909f446..926d17b1667e 100644 --- a/python/tvm/script/_ffi_api.py +++ b/python/tvm/script/_ffi_api.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""FFI APIs for tvm.tvmscript""" +"""FFI APIs for tvm.script""" import tvm._ffi tvm._ffi._init_api("script", __name__) diff --git a/python/tvm/script/scope_emitter.py b/python/tvm/script/context_maintainer.py similarity index 70% rename from python/tvm/script/scope_emitter.py rename to python/tvm/script/context_maintainer.py index 69ad26731492..8ad39354e5cf 100644 --- a/python/tvm/script/scope_emitter.py +++ b/python/tvm/script/context_maintainer.py @@ -14,17 +14,24 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""TVM Script Scope Emitter for TIR""" +"""TVM Script Context Maintainer for TIR""" from tvm.te import schedule -class ScopeEmitter: - """Maintain the nodes and symbols of scopes""" +class ContextMaintainer: + """Maintain all the necessary context info""" def __init__(self, parser): - self.node_stack = [[]] # AST nodes of scopes - self.symbols = [dict()] # Symbols of scopes + # scope context + self.node_stack = [] # AST nodes of scopes + self.symbols = [] # symbols of scopes + # function context + self.func_params = [] # parameter list of function + self.func_buffer_map = {} # buffer_map of function + self.func_dict_attr = {} # func_attr of function + self.func_var_env_dict = {} # map from var to env_name + # parser self.parser = parser def pop_scope(self): @@ -32,9 +39,11 @@ def pop_scope(self): self.symbols.pop() self.node_stack.pop() - def new_scope(self): - """ Creating a new scope """ - self.node_stack.append([]) + def new_scope(self, nodes=None): + """Creating a new scope""" + if nodes is None: + nodes = [] + self.node_stack.append(list(reversed(nodes))) self.symbols.append(dict()) def update_symbol(self, name, symbol): @@ -60,3 +69,6 @@ def lookup_symbol(self, name): if name in symbols: return symbols[name] return None + + def report_error(self, message): + self.parser.report_error(message) diff --git a/python/tvm/script/intrin.py b/python/tvm/script/intrin.py index 21570b91111a..63bc676bc889 100644 --- a/python/tvm/script/intrin.py +++ b/python/tvm/script/intrin.py @@ -14,127 +14,127 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""TVM Script Parser Intrinsic Functions - -IRNodes (StmtNodes without body, PrimExprNodes and more) are called intrins -""" -# pylint: disable=redefined-builtin +"""TVM Script Parser Intrinsic Classes""" +# pylint: disable=redefined-builtin, relative-beyond-top-level import tvm.tir -from .registry import register_intrin +from .registry import register +from .utils import get_param_list + + +class Intrin: + def __init__(self, intrin, stmt=False): + self.intrin = intrin + self.stmt = stmt + + def signature(self): + return "tir." + self.intrin.__name__, get_param_list(self.intrin) + def handle(self, arg_list): + return self.intrin(*arg_list) -@register_intrin() + +@register def bool(imm): return tvm.tir.const(imm, "bool") -@register_intrin() +@register def int8(imm): return tvm.tir.const(imm, "int8") -@register_intrin() +@register def int16(imm): return tvm.tir.const(imm, "int16") -@register_intrin() +@register def int32(imm): return tvm.tir.const(imm, "int32") -@register_intrin() +@register def int64(imm): return tvm.tir.const(imm, "int64") -@register_intrin() +@register def uint8(imm): return tvm.tir.const(imm, "uint8") -@register_intrin() +@register def uint16(imm): return tvm.tir.const(imm, "uint16") -@register_intrin() +@register def uint32(imm): return tvm.tir.const(imm, "uint32") -@register_intrin() +@register def uint64(imm): return tvm.tir.const(imm, "uint64") -@register_intrin() +@register def float8(imm): return tvm.tir.const(imm, "float8") -@register_intrin() +@register def float16(imm): return tvm.tir.const(imm, "float16") -@register_intrin() +@register def float32(imm): return tvm.tir.const(imm, "float32") -@register_intrin() +@register def float64(imm): return tvm.tir.const(imm, "float64") -@register_intrin() +@register def floordiv(x, y): return tvm.tir.floordiv(x, y) -@register_intrin() +@register def floormod(x, y): return tvm.tir.floormod(x, y) -@register_intrin() +@register def load(dtype, var, index, predicate=True): return tvm.tir.Load(dtype, var, index, predicate) -@register_intrin() +@register def cast(value, dtype): return tvm.tir.Cast(dtype, value) -@register_intrin() +@register def ramp(base, stride, lanes): return tvm.tir.Ramp(base, stride, lanes) -@register_intrin() +@register def broadcast(value, lanes): return tvm.tir.Broadcast(value, lanes) -@register_intrin() -def evaluate(value): - return tvm.tir.Evaluate(value) - - -@register_intrin() -def store(var, index, value, predicate=True): - return tvm.tir.Store(var, value, index, predicate) - - -@register_intrin() +@register def iter_var(var, dom, iter_type, thread_tag): iter_type = getattr(tvm.tir.IterVar, iter_type) return tvm.tir.IterVar(dom, var, iter_type, thread_tag) -@register_intrin() +@register def max(a, b): # pylint: disable=redefined-builtin return tvm.tir.Max(a, b) @@ -148,21 +148,39 @@ def get_axis(begin, end, iter_type): return tvm.tir.IterVar(block_var_dom, "bv", iter_type_dict[iter_type]) -@register_intrin() +@register def range(begin, end): return get_axis(begin, end, "data_par") -@register_intrin() +@register def reduce_axis(begin, end): return get_axis(begin, end, "reduce") -@register_intrin() +@register def scan_axis(begin, end): return get_axis(begin, end, "scan") -@register_intrin() +@register def opaque_axis(begin, end): return get_axis(begin, end, "opaque") + + +@register +class EvaluateIntrin(Intrin): + def __init__(self): + def evaluate(value): + return tvm.tir.Evaluate(value) + + super().__init__(evaluate, stmt=True) + + +@register +class StoreIntrin(Intrin): + def __init__(self): + def store(var, index, value, predicate=True): + return tvm.tir.Store(var, value, index, predicate) + + super().__init__(store, stmt=True) diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py index 56710fc7a60f..70aa3fe34387 100644 --- a/python/tvm/script/parser.py +++ b/python/tvm/script/parser.py @@ -16,25 +16,69 @@ # under the License. """TVM Script Parser For TIR""" # pylint: disable=invalid-name, missing-docstring, inconsistent-return-statements, no-else-return -# pylint: disable=unnecessary-comprehension, unused-argument, import-outside-toplevel -# pylint: disable=unused-import +# pylint: disable=unnecessary-comprehension, unused-argument +# pylint: disable=relative-beyond-top-level import json import operator +import inspect from typed_ast import ast3 as ast -import tvm._ffi -from tvm import tir +import tvm +from tvm import IRModule from tvm._ffi.base import TVMError from tvm.ir import GlobalVar from tvm.tir import all as _all from tvm.tir import expr as _expr -from . import scope_emitter, special_stmt, scope_handler, intrin, ty +from . import context_maintainer, ty from .meta_unparser import MetaUnparser from .registry import Registry +from .intrin import Intrin +from .special_stmt import SpecialStmt +from .scope_handler import ScopeHandler, WithScopeHandler, ForScopeHandler from . import _ffi_api +class CallArgumentReader(object): + """A helper class which read required argument from passed arguments""" + + def __init__(self, func_name, args, kwargs, parser): + self.func_name = func_name + self.args = args + self.kwargs = kwargs + self.parser = parser + + def get_pos_only_arg(self, pos, name): + """Get corresponding position only function argument from argument list""" + if len(self.args) >= pos: + arg = self.args[pos - 1] + elif name not in self.kwargs: + self.parser.report_error(self.func_name + " misses argument " + name) + else: + arg = self.kwargs[name] + + return arg + + def get_kwarg(self, pos, name, default): + """Get corresponding keyword function argument from argument list + If user doesn't provide the argument, set it to default value + """ + if len(self.args) >= pos: + arg = self.args[pos - 1] + elif name in self.kwargs: + arg = self.kwargs[name] + else: + return default + + return arg + + def get_varargs(self, pos): + """Get corresponding variable argument from argument list""" + if len(self.args) >= pos and len(self.kwargs) == 0: + return self.args[pos - 1 :] + return [] + + class TVMScriptParserError(RuntimeError): """TVM script Parser Runtime Error""" @@ -58,33 +102,29 @@ class TVMScriptParser(ast.NodeVisitor): """ _binop_maker = { - ast.Add: tir.Add, - ast.Sub: tir.Sub, - ast.Mult: tir.Mul, - ast.Div: tir.Div, - ast.FloorDiv: tir.FloorDiv, - ast.Mod: tir.FloorMod, + ast.Add: tvm.tir.Add, + ast.Sub: tvm.tir.Sub, + ast.Mult: tvm.tir.Mul, + ast.Div: tvm.tir.Div, + ast.FloorDiv: tvm.tir.FloorDiv, + ast.Mod: tvm.tir.FloorMod, ast.BitOr: operator.or_, ast.BitAnd: operator.and_, ast.BitXor: operator.xor, - ast.Gt: tir.GT, - ast.GtE: tir.GE, - ast.Lt: tir.LT, - ast.LtE: tir.LE, - ast.Eq: tir.EQ, - ast.NotEq: tir.NE, - ast.And: tir.And, - ast.Or: tir.Or, + ast.Gt: tvm.tir.GT, + ast.GtE: tvm.tir.GE, + ast.Lt: tvm.tir.LT, + ast.LtE: tvm.tir.LE, + ast.Eq: tvm.tir.EQ, + ast.NotEq: tvm.tir.NE, + ast.And: tvm.tir.And, + ast.Or: tvm.tir.Or, } - _unaryop_maker = {ast.USub: operator.neg, ast.Invert: operator.invert, ast.Not: tir.Not} + _unaryop_maker = {ast.USub: operator.neg, ast.Invert: operator.invert, ast.Not: tvm.tir.Not} def __init__(self, src, base_lienno): - self.params = None - self.buffer_map = None - self.dict_attr = None - self.scope_emitter = None - self.var_env_dict = None + self.context = None self.src = src.split("\n") self.base_lineno = base_lienno @@ -93,15 +133,10 @@ def __init__(self, src, base_lienno): self.meta = None self.functions = {} - self.target = None def init_function_parsing_env(self): """Initialize function parsing environment""" - self.params = [] # parameter list - self.buffer_map = {} # buffer map - self.dict_attr = {} # dict attr - self.scope_emitter = scope_emitter.ScopeEmitter(self) # scope emitter - self.var_env_dict = {} # map from var to thread env name + self.context = context_maintainer.ContextMaintainer(self) # scope emitter @staticmethod def is_meta(node): @@ -170,15 +205,40 @@ def report_error(self, message, lineno=None, col_offset=None): col_offset = self.current_col_offset raise TVMScriptParserError(self.wrap_line_col(message, lineno, col_offset)) - def get_body(self): + def parse_body(self): body = [] - while len(self.scope_emitter.node_stack[-1]) > 0: - res = self.visit(self.scope_emitter.node_stack[-1].pop()) + while len(self.context.node_stack[-1]) > 0: + res = self.visit(self.context.node_stack[-1].pop()) if res is not None: body.append(res) return tvm.tir.SeqStmt(body) if len(body) > 1 else body[0] - def get_type(self, type_node): + def parse_arg_list(self, func, node_call): + assert isinstance(node_call, ast.Call) + # collect arguments + args = [self.visit(arg) for arg in node_call.args] + kw_args = [self.visit(keyword) for keyword in node_call.keywords] + kw_args = {kw_arg[0]: kw_arg[1] for kw_arg in kw_args} + # get the name and parameter list of func + if isinstance(func, (Intrin, ScopeHandler, SpecialStmt)): + func_name, param_list = func.signature() + else: + print(func) + raise Exception("Internal Error") + # check arguments and parameter list and get a list of arguments + reader = CallArgumentReader(func_name, args, kw_args, self) + pos_only, kwargs, varargs = param_list + internal_args = list() + for i, arg_name in enumerate(pos_only): + internal_args.append(reader.get_pos_only_arg(i + 1, arg_name)) + for i, arg_info in enumerate(kwargs): + arg_name, default = arg_info + internal_args.append(reader.get_kwarg(i + 1 + len(pos_only), arg_name, default=default)) + if varargs is not None: + internal_args.extend(reader.get_varargs(len(pos_only) + len(kwargs) + 1)) + return internal_args + + def parse_type(self, type_node): """ Parse type """ if type_node is None: self.report_error("missing type annotation") @@ -267,7 +327,6 @@ def visit_ClassDef(self, node): for body_element in node.body: if isinstance(body_element, ast.FunctionDef): self.visit(body_element) - from .utils import create_module return create_module(self.functions) @@ -282,69 +341,76 @@ def visit_FunctionDef(self, node): """ self.init_function_parsing_env() + self.context.new_scope(nodes=node.body) + # add parameters of function for arg in node.args.args: - arg_var = tvm.te.var(arg.arg, self.get_type(arg.annotation)) - self.scope_emitter.update_symbol(arg.arg, arg_var) - self.params.append(arg_var) - - # visit the body of function - self.scope_emitter.node_stack[-1].extend(reversed(node.body)) + arg_var = tvm.te.var(arg.arg, self.parse_type(arg.annotation)) + self.context.update_symbol(arg.arg, arg_var) + self.context.func_params.append(arg_var) # fetch the body and return a tir.PrimFunc func = tvm.tir.PrimFunc( - self.params, - self.get_body(), - ret_type=self.get_type(node.returns), - buffer_map=self.buffer_map, - attrs=tvm.ir.make_node("DictAttrs", **self.dict_attr), + self.context.func_params, + self.parse_body(), + ret_type=self.parse_type(node.returns), + buffer_map=self.context.func_buffer_map, + attrs=tvm.ir.make_node("DictAttrs", **self.context.func_dict_attr), ) self.functions[GlobalVar(node.name)] = func + + self.context.pop_scope() return func def visit_Assign(self, node): """Assign visitor AST abstract grammar: Assign(expr* targets, expr value, string? type_comment) - By now only 3 types of Assign is supported: + + By now 3 patterns of Assign is supported: 1. special stmts with return value - 1.1 Buffer = tir.buffer_bind()/tir.buffer_decl() + 1.1 Buffer = tir.match_buffer()/tir.buffer_decl() 1.2 Var = tir.var() 1.3 Var = tir.env_thread() 2. (BufferStore) Buffer[PrimExpr, PrimExpr, ..., PrimExpr] = PrimExpr 3. (Store) Var[PrimExpr] = PrimExpr 4. with scope handlers with concise scoping and var def - 4.1 var = tir.alloc_with_scope() + 4.1 var = tir.allocate() """ if not len(node.targets) == 1: self.report_error("Only one-valued assignment is supported now") - target = node.targets[0] - if isinstance(target, ast.Name): - # scenario 1&4 - self.target = [target.id] - if not isinstance(node.value, ast.Call): - self.report_error("Unsupported assign stmt") + if isinstance(node.targets[0], ast.Name) and isinstance(node.value, ast.Call): + # Pattern 1 & Pattern 4 func = self.visit(node.value.func) - if Registry.is_with_scope(func): - # scenario 4 - return self.visit(node.value) + arg_list = self.parse_arg_list(func, node.value) + if isinstance(func, WithScopeHandler): + if not func.concise_scope or not func.def_symbol: + self.report_error( + "with scope handler " + func.signature()[0] + " is not suitable here" + ) + # Pattern 4 + func.enter_scope(node, self.context) + arg_list = self.parse_arg_list(func, node.value) + func.body = self.parse_body() + return func.exit_scope(node, self.context, arg_list) + elif isinstance(func, SpecialStmt): + # Pattern 1 + func.handle(node, self.context, arg_list) else: - # scenario 1 - rhs = self.visit(node.value) - self.scope_emitter.update_symbol(target.id, rhs) - elif isinstance(target, ast.Subscript): - # scenario 2&3 - symbol, indexes = self.visit(target) + self.report_error("Unsupported Assign stmt") + elif isinstance(node.targets[0], ast.Subscript): + # Pattern 2 & Pattern 3 + symbol, indexes = self.visit(node.targets[0]) rhs = self.visit(node.value) if isinstance(symbol, tvm.tir.Buffer): - # BufferStore + # Pattern 2 return tvm.tir.BufferStore(symbol, tvm.runtime.convert(rhs), indexes) else: if len(indexes) != 1: self.report_error("Invalid Store stmt") - # Store + # Pattern 3 return tvm.tir.Store( symbol, tvm.runtime.convert(rhs), indexes[0], tvm.runtime.convert(True) ) @@ -355,14 +421,17 @@ def visit_AnnAssign(self, node): """AnnAssign visitor AST abstract grammar: AnnAssign(expr target, expr annotation, expr? value, int simple) - Corresponds to concise mode of with tir.let() + + Pattern corresponds to concise mode of with tir.let() """ if isinstance(node.target, ast.Name): value = self.visit(node.value) - var = tvm.te.var(node.target.id, self.get_type(node.annotation)) - self.scope_emitter.update_symbol(var.name, var) - return tvm.tir.LetStmt(var, value, self.visit(self.scope_emitter.node_stack[-1].pop())) + var = tvm.te.var(node.target.id, self.parse_type(node.annotation)) + self.context.update_symbol(var.name, var) + body = self.parse_body() + self.context.remove_symbol(var.name) + return tvm.tir.LetStmt(var, value, body) else: self.report_error("Unsupported AnnAssign stmt") @@ -370,40 +439,45 @@ def visit_Assert(self, node): """Assert visitor AST abstract grammar: Assert(expr test, expr? msg) - Corresponds to concise mode of with tir.assert() + + Pattern corresponds to concise mode of with tir.Assert() """ condition = self.visit(node.test) if node.msg is None: self.report_error("Message of AssertStmt can't be None") message = self.visit(node.msg) - return tvm.tir.AssertStmt(condition, tvm.runtime.convert(message), self.get_body()) + body = self.parse_body() + return tvm.tir.AssertStmt(condition, tvm.runtime.convert(message), body) def visit_For(self, node): """For visitor AST abstract grammar: For(expr target, expr iter, stmt* body, stmt* orelse, string? type_comment) - By now only 1 type of For is supported: - 1. for name in tir.serial/parallel/vectorized/unroll(begin, end) + By now 1 pattern of For is supported: + 1. for scope handler + for name in tir.serial()/tir.parallel()/tir.vectorized()/tir.unroll() """ - # check node.iter, which is a Call if not isinstance(node.iter, ast.Call): self.report_error("The loop iter should be a Call") func = self.visit(node.iter.func) - if not Registry.is_for_scope(func): - self.report_error("Function not allowed in for scope") - # collect arguments - args = [self.visit(arg) for arg in node.iter.args] - kw_args = [self.visit(keyword) for keyword in node.iter.keywords] - kw_args = {kw_arg[0]: kw_arg[1] for kw_arg in kw_args} - + if not isinstance(func, ForScopeHandler): + self.report_error("Only for scope handlers can be used in for stmt") + # prepare for new for scope old_lineno, old_col_offset = self.current_lineno, self.current_col_offset self.current_lineno, self.current_col_offset = ( self.base_lineno + node.iter.lineno - 1, node.iter.col_offset, ) - res = func(self, node, args, kw_args) + self.context.new_scope(nodes=node.body) + # for scope handler process the scope + func.enter_scope(node, self.context) + func.body = self.parse_body() + arg_list = self.parse_arg_list(func, node.iter) + res = func.exit_scope(node, self.context, arg_list) + # exit the scope + self.context.pop_scope() self.current_lineno, self.current_col_offset = old_lineno, old_col_offset return res @@ -412,10 +486,13 @@ def visit_With(self, node): AST abstract grammar: With(withitem* items, stmt* body, string? type_comment) withitem = (expr context_expr, expr? optional_vars) - By now 2 types of With is supported: - 1. with tir.allocate() as targets: - 2. with tir.let()/tir.Assert()/tir.attr()//tir.realize() + By now 2 patterns of With is supported: + 1. with scope handler with symbol def + with tir.allocate() as targets: + 2. with scope handler without symbol def + with tir.let()/tir.Assert()/tir.attr()//tir.realize() """ + if not len(node.items) == 1: self.report_error("Only one with element is supported now") if not isinstance(node.items[0].context_expr, ast.Call): @@ -425,32 +502,22 @@ def visit_With(self, node): func_node = func_call.func func = self.visit(func_node) - if not Registry.is_with_scope(func): + if not isinstance(func, WithScopeHandler): self.report_error("Function not allowed in with scope") - - self.target = [] - if node.items[0].optional_vars is not None: - # preprocess optional var names - if isinstance(node.items[0].optional_vars, ast.Name): - self.target = [node.items[0].optional_vars.id] - elif isinstance(node.items[0].optional_vars, (ast.List, ast.Tuple)): - for var in node.items[0].optional_vars.elts: - if not isinstance(var, ast.Name): - self.report_error("Invalid optional var definition") - self.target = [var.id for var in node.items[0].optional_vars.elts] - else: - self.report_error("Invalid optional var definition") - # parse other arguments - args = [self.visit(arg) for arg in func_call.args] - kw_args = [self.visit(keyword) for keyword in func_call.keywords] - kw_args = {kw_arg[0]: kw_arg[1] for kw_arg in kw_args} - + # prepare for new block scope old_lineno, old_col_offset = self.current_lineno, self.current_col_offset self.current_lineno, self.current_col_offset = ( self.base_lineno + func_call.lineno - 1, func_call.col_offset, ) - res = func(self, node, args, kw_args) + self.context.new_scope(nodes=node.body) + # with scope handler process the scope + func.enter_scope(node, self.context) + func.body = self.parse_body() + arg_list = self.parse_arg_list(func, func_call) + res = func.exit_scope(node, self.context, arg_list) + # exit the scope + self.context.pop_scope() self.current_lineno, self.current_col_offset = old_lineno, old_col_offset return res @@ -462,19 +529,18 @@ def visit_If(self, node): condition = self.visit(node.test) # then body - self.scope_emitter.new_scope() - self.scope_emitter.node_stack[-1].extend(reversed(node.body)) - then_body = self.get_body() - self.scope_emitter.pop_scope() + self.context.new_scope(nodes=node.body) + then_body = self.parse_body() + self.context.pop_scope() # else body if len(node.orelse) > 0: - self.scope_emitter.new_scope() - self.scope_emitter.node_stack[-1].extend(reversed(node.orelse)) - else_body = self.get_body() - self.scope_emitter.pop_scope() + self.context.new_scope(nodes=node.orelse) + else_body = self.parse_body() + self.context.pop_scope() else: else_body = None + return tvm.tir.IfThenElse(condition, then_body, else_body) def visit_Call(self, node): @@ -482,22 +548,30 @@ def visit_Call(self, node): AST abstract grammar: Call(expr func, expr* args, keyword* keywords) keyword = (identifier? arg, expr value) - All the functions used outside With and For are registered in special_stmt or intrin + + By now 3 patterns of Call is allowed + 1. Intrin representing PrimExpr/IterVar + 1.1 tir.int/uint/float8/16/32/64/floormod/floordiv/load/cast/ramp/broadcast/max + 1.2 tir.range/reduce_axis/scan_axis/opaque_axis + 2. tir.Op(dtype, ...) + 3. other callable functions """ func = self.visit(node.func) - # collect arguments - args = [self.visit(arg) for arg in node.args] - kw_args = [self.visit(keyword) for keyword in node.keywords] - kw_args = {kw_arg[0]: kw_arg[1] for kw_arg in kw_args} - - if callable(func): - if Registry.is_registered(func): - return func(self, node, args, kw_args) - else: + if isinstance(func, Intrin) and not func.stmt: + # pattern 1 + arg_list = self.parse_arg_list(func, node) + return func.handle(arg_list) + else: + args = [self.visit(arg) for arg in node.args] + kw_args = [self.visit(keyword) for keyword in node.keywords] + kw_args = {kw_arg[0]: kw_arg[1] for kw_arg in kw_args} + if isinstance(func, tvm.tir.op.Op): + # pattern 2 + return tvm.tir.Call(kw_args["dtype"], func, args) + elif callable(func): + # pattern 3 return func(*args, **kw_args) - elif isinstance(func, tvm.tir.op.Op): - return tvm.tir.Call(kw_args["dtype"], func, args) self.report_error("Unsupported function call") @@ -505,17 +579,35 @@ def visit_Expr(self, node): """Expr visitor AST abstract grammar: Expr(expr value) - Now only 3 types of `Expr` stmt is allowed: - 1. reducer.step()/tir.store() - 2. tir.attr()/tir.assert()/tir.allocate()/tir.realize() - 3. tir.set_func_attr() + + Now only 3 types of Expr stmt is allowed: + 1. Intrin representing Stmt without body + tir.store()/tir.evaluate() + 2. with scope handlers with concise scoping without var def + tir.attr()/tir.assert()/tir.allocate()/tir.realize() + 3. special stmt without var def + tir.func_attr() """ if not isinstance(node.value, ast.Call): self.report_error("Unsupported Expr stmt") - res = self.visit(node.value) - if res is None or isinstance(res, tvm.tir.Stmt): - return res + + func = self.visit(node.value.func) + arg_list = self.parse_arg_list(func, node.value) + + if isinstance(func, Intrin) and func.stmt: + # pattern 1 + return func.handle(arg_list) + elif isinstance(func, WithScopeHandler) and func.concise_scope and not func.def_symbol: + # pattern 2 + func.enter_scope(node, self.context) + func.body = self.parse_body() + return func.exit_scope(node, self.context, arg_list) + elif isinstance(func, SpecialStmt) and not func.def_symbol: + # pattern 3 + func.handle(node, self.context, arg_list) + return + self.report_error("Invalid Expr stmt") def visit_BinOp(self, node): @@ -572,7 +664,7 @@ def visit_Subscript(self, node): slice = Slice(expr? lower, expr? upper, expr? step) | ExtSlice(slice* dims) | Index(expr value) - By now only 2 types of Subscript are supported: + By now 2 patterns of Subscript are supported: 1. Buffer[index, index, ...], Buffer element access(BufferLoad & BufferStore) Var[index] Buffer element access() 2. meta[type_key][index], Meta info access @@ -587,7 +679,7 @@ def visit_Subscript(self, node): indexes = self.visit(node.slice.value) indexes = list(indexes) if isinstance(indexes, tuple) else [indexes] if isinstance(node.ctx, ast.Load): - if isinstance(symbol, tir.expr.Var): + if isinstance(symbol, tvm.tir.expr.Var): return tvm.tir.Load("float32", symbol, indexes, True) else: return tvm.tir.BufferLoad(symbol, indexes) @@ -632,7 +724,7 @@ def visit_Attribute(self, node): if isinstance(node.value, ast.Name): if node.value.id == "tir": func_name = "tir." + node.attr - res = Registry.look_up_function(func_name) + res = Registry.lookup(func_name) if res is not None: return res try: @@ -696,10 +788,10 @@ def visit_Name(self, node): name = node.id if name == "meta": return self.meta - symbol = Registry.look_up_function(name) + symbol = Registry.lookup(name) if symbol is not None: return symbol - symbol = self.scope_emitter.lookup_symbol(name) + symbol = self.context.lookup_symbol(name) if symbol is not None: return symbol self.report_error("Unknown identifier %s" % name) @@ -749,10 +841,95 @@ def from_source(src, func_lineno=0): parser.wrap_line_col(msg, parser.current_lineno, parser.current_col_offset).split("\n") ) inject_e[-1] = "TVM" + inject_e[-1][6:] - raise TVMError("\n".join(inject_e)) + raise TVMError("\n".join(inject_e)) from e except Exception as e: inject_e = parser.wrap_line_col(str(e), parser.current_lineno, parser.current_col_offset) - raise TVMScriptParserError(inject_e) + raise TVMScriptParserError(inject_e) from e + + +def _parse(script_in): + """Helper function to parse TVM script into TIR""" + return from_source(inspect.getsource(script_in), inspect.getsourcelines(script_in)[1]) + + +def create_module(functions=None): + """Construct a module from list of functions. + + Parameters + ----------- + functions: Optional[dict]. + Map of GlobalVar or str to PrimFunc + + Returns + ------- + mod : IRModule + An IRModule containing the passed definitions + """ + + return IRModule(functions=functions) + + +def asscript(input_ir, show_meta=False): + """Transform a PrimFunc or IRModule to python syntax script + + Parameters + ---------- + input_ir : Union[PrimFunc, IRModule] + The PrimFunc or IRModule to be dumped + + show_meta : bool + Whether show meta + + Returns + ------- + script : str + The Python script + """ + + return _ffi_api.AsTVMScript(input_ir, show_meta) + + +def tir(script_in): + """Decorate a python function or class as tvm script. + + The tvm function or parsing support parsing to the internal TIR. + + Returns + ------- + output : Union[Function, Module] + The Function or Module in IR. + """ + + if inspect.isfunction(script_in): + result = _parse(script_in) + elif inspect.isclass(script_in): + result = TVMScriptClass(script_in) + else: + raise TypeError("Only function and class are supported") + result.__name__ = script_in.__name__ + result.__qualname__ = script_in.__qualname__ + return result + + +def module(script_in): + """Decorate a python function or class as tvm script. + + Alias for tvm.script.tir for now. + + Returns + ------- + output : Union[Function, Module] + The Function or Module in IR. + """ + return tir(script_in) + + +class TVMScriptClass: + """Helper class for decorating a class""" + def __init__(self, script_in): + self.script = script_in -tvm._ffi._init_api("script", __name__) + def __call__(self, *args, **kwargs): + # call the parser to transform tvm script into TIR + return _parse(self.script) diff --git a/python/tvm/script/registry.py b/python/tvm/script/registry.py index acbc444a4190..389570115935 100644 --- a/python/tvm/script/registry.py +++ b/python/tvm/script/registry.py @@ -15,19 +15,8 @@ # specific language governing permissions and limitations # under the License. """TVM Script Parser Function Registry """ -# pylint: disable=inconsistent-return-statements +# pylint: disable=inconsistent-return-statements, relative-beyond-top-level, import-outside-toplevel import inspect -from enum import Enum -from typed_ast import ast3 as ast - -import tvm - - -class Category(Enum): - INTRIN = 0 - WITH_SCOPE = 1 - FOR_SCOPE = 2 - SPECIAL_STMT = 3 class Registry(object): @@ -35,355 +24,35 @@ class Registry(object): All these maps are static """ - functions = dict() + registrations = dict() @staticmethod - def look_up_function(func_name): - """look up a registered function by name""" - if func_name in Registry.functions: - return Registry.functions[func_name][0] + def lookup(name): + if name in Registry.registrations: + # every time we create a new handler + # since we may want to keep some local info inside it + return Registry.registrations[name]() return None - @staticmethod - def is_intrin(func): - """check whether a function belongs to intrin""" - return (func, Category.INTRIN) in Registry.functions.values() - - @staticmethod - def is_with_scope(func): - """check whether a function belongs to with scope handlers""" - return (func, Category.WITH_SCOPE) in Registry.functions.values() - - @staticmethod - def is_for_scope(func): - """check whether a function belongs to for scope handlers""" - return (func, Category.FOR_SCOPE) in Registry.functions.values() - - @staticmethod - def is_special_stmt(func): - """check whether a function belongs to special stmts""" - return (func, Category.SPECIAL_STMT) in Registry.functions.values() - - @staticmethod - def is_registered(func): - """check whether a function is registered""" - return ( - Registry.is_intrin(func) - or Registry.is_with_scope(func) - or Registry.is_for_scope(func) - or Registry.is_special_stmt(func) - ) - - -class CallArgumentReader(object): - """A helper class which read required argument from passed arguments""" - - def __init__(self, func_name, args, kwargs, parser): - self.func_name = func_name - self.args = args - self.kwargs = kwargs - self.parser = parser - - def get_pos_only_arg(self, pos, name): - """Get corresponding position only function argument from argument list""" - if len(self.args) >= pos: - arg = self.args[pos - 1] - elif name not in self.kwargs: - self.parser.report_error(self.func_name + " misses argument " + name) - else: - arg = self.kwargs[name] - - return arg - - def get_kwarg(self, pos, name, default): - """Get corresponding keyword function argument from argument list - If user doesn't provide the argument, set it to default value - """ - if len(self.args) >= pos: - arg = self.args[pos - 1] - elif name in self.kwargs: - arg = self.kwargs[name] - else: - return default - - return arg - - def get_varargs(self, pos): - """Get corresponding variable argument from argument list""" - if len(self.args) >= pos and len(self.kwargs) == 0: - return self.args[pos - 1 :] - return [] - - def auto_insert_body(self, pos, body): - """Automatically provide body as function call argument""" - if len(self.args) >= pos: - self.args.insert(pos - 1, body) - else: - self.kwargs["body"] = body - - -def func_wrapper(func_name, func_to_register, arg_list, category, concise=False, with_var=False): - """Helper function to wrap a function to be registered """ - - def wrap_func(parser, node, args, kwargs): - if category == Category.FOR_SCOPE: - # automatically parse loop vars and body for for_scope handlers - loop_var_names = list() - if isinstance(node.target, ast.Name): - loop_var_names.append(node.target.id) - elif isinstance(node.target, ast.Tuple): - for elt in node.target.elts: - if not isinstance(elt, ast.Name): - parser.report_error("Invalid loop var") - loop_var_names.append(elt.id) - else: - parser.report_error("Invalid loop var") - loop_vars = [tvm.te.var(name, dtype="int32") for name in loop_var_names] - - parser.scope_emitter.new_scope() - parser.scope_emitter.node_stack[-1].extend(reversed(node.body)) - for loop_var in loop_vars: - parser.scope_emitter.update_symbol(loop_var.name, loop_var) - body = parser.get_body() - parser.scope_emitter.pop_scope() - elif category == Category.WITH_SCOPE: - if not with_var: - if isinstance(node, ast.With) and node.items[0].optional_vars is not None: - parser.report_error("Function " + func_name + " expects no optional vars") - # automatically parse body for with_scope handlers without optional vars - if isinstance(node, ast.With): - parser.scope_emitter.new_scope() - parser.scope_emitter.node_stack[-1].extend(reversed(node.body)) - body = parser.get_body() - parser.scope_emitter.pop_scope() - else: - body = parser.get_body() - else: - if isinstance(node, ast.With) and node.items[0].optional_vars is None: - parser.report_error("Function " + func_name + " expects optional vars") - body = None - - if not isinstance(node, ast.With) and not concise: - parser.report_error("Concise scoping is not allowed here") - - reader = CallArgumentReader(func_name, args, kwargs, parser) - pos_only, kwargs, varargs = arg_list - - internal_args = list() - if category == Category.WITH_SCOPE: - if not with_var: - internal_args.extend([parser, node, body]) - else: - internal_args.extend([parser, node]) - elif category == Category.FOR_SCOPE: - internal_args.extend([parser, node, body, loop_vars]) - elif category == Category.SPECIAL_STMT: - internal_args.extend([parser, node]) - - for i, arg_name in enumerate(pos_only): - internal_args.append(reader.get_pos_only_arg(i + 1, arg_name)) - for i, arg_info in enumerate(kwargs): - arg_name, default = arg_info - internal_args.append(reader.get_kwarg(i + 1 + len(pos_only), arg_name, default=default)) +def register(inputs): + """Register Intrin/ScopeHandler/SpecialStmt""" + if inspect.isfunction(inputs): + from .intrin import Intrin - if varargs is not None: - internal_args.extend(reader.get_varargs(len(pos_only) + len(kwargs) + 1)) + def create_new_intrin(func): + class NewIntrin(Intrin): + def __init__(self): + super().__init__(func) - return func_to_register(*internal_args) - - return wrap_func - - -def get_arg_list(origin_func, category, with_var=False): - """Helper function to get the argument list of Function - Parameters - ---------- - origin_func: function - The function to get the argument list - category: Category - The category of registered function - with_var: bool, optional - Whether the with scope handler neeeds optional vars - """ - full_arg_spec = inspect.getfullargspec(origin_func) - - args, defaults = full_arg_spec.args, full_arg_spec.defaults - - if defaults is None: - defaults = tuple() - - if category == Category.WITH_SCOPE: - if not with_var: - if len(args) < 3 or args[0] != "parser" or args[1] != "node" or args[2] != "body": - raise RuntimeError( - "TVM Script register error : the first three arguments of " - "this with scope handler must be parser, node, body" - ) - args = args[3:] - else: - if len(args) < 2 or args[0] != "parser" or args[1] != "node": - raise RuntimeError( - "TVM Script register error : the first two arguments of " - "this with scope handler must be parser, node" - ) - args = args[2:] - elif category == Category.FOR_SCOPE: - if ( - len(args) < 4 - or args[0] != "parser" - or args[1] != "node" - or args[2] != "body" - or args[3] != "loop_vars" - ): - raise RuntimeError( - "TVM Script register error : the first three arguments of for scope handler" - "must be parser, node, body, loop_vars" - ) - args = args[4:] - elif category == Category.SPECIAL_STMT: - if len(args) < 2 or args[0] != "parser" or args[1] != "node": - raise RuntimeError( - "TVM Script register error : the first three arguments of special stmt" - "must be parser, node" - ) - args = args[2:] - - if full_arg_spec.varkw is not None: - raise RuntimeError( - "TVM Script register error : variable keyword argument is not supported now" - ) - if not len(full_arg_spec.kwonlyargs) == 0: - raise RuntimeError("TVM Script register error : keyword only argument is not supported now") - - pos_only = list() - for arg in args[: len(args) - len(defaults)]: - pos_only.append(arg) - kwargs = list() - for default, arg in zip(defaults, args[len(args) - len(defaults) :]): - kwargs.append((arg, default)) - - return pos_only, kwargs, full_arg_spec.varargs - - -def register_intrin(name=None): - """Decorator to register function under category intrin - Parameters - ---------- - name: str, optional - registered name for the function - Example - ------ - .. code-block:: python - @register_intrin - def broadcast(value, lanes): - lanes = lanes.value if not isinstance(lanes, int) else lanes - return tvm.tir.Broadcast(value, lanes) - """ - - def decorate(origin_func): - func_name = "tir." + origin_func.__qualname__ if name is None else name - Registry.functions[func_name] = ( - func_wrapper( - func_name, origin_func, get_arg_list(origin_func, Category.INTRIN), Category.INTRIN - ), - Category.INTRIN, - ) - return origin_func - - return decorate - - -def register_with_scope(concise=False, with_var=False, name=None): - """Decorator to register function under with scope handler - Parameters - ---------- - concise: bool, optional - whether this with scope handler is allowed in concise scoping - with_var: bool, optional - whether this with scope handler neeeds optional vars - name: str, optional - registered name for the function - Example - ------ - .. code-block:: python - @register_scope_handler(concise=True) - def attr(parser, node, attr_node, attr_key, value, body): - return tvm.tir.AttrStmt(attr_node, attr_key, tvm.runtime.convert(value), body) - """ - - def decorate(origin_func): - """Register function under category with_scope""" - func_name = "tir." + origin_func.__qualname__ if name is None else name - Registry.functions[func_name] = ( - func_wrapper( - func_name, - origin_func, - get_arg_list(origin_func, Category.WITH_SCOPE, with_var), - Category.WITH_SCOPE, - concise=concise, - with_var=with_var, - ), - Category.WITH_SCOPE, - ) - return origin_func - - return decorate - - -def register_for_scope(name=None): - """Decorator to register function under for scope handler - Parameters - ---------- - name: str, optional - registered name for the function - """ - - def decorate(origin_func): - func_name = "tir." + origin_func.__qualname__ if name is None else name - Registry.functions[func_name] = ( - func_wrapper( - func_name, - origin_func, - get_arg_list(origin_func, Category.FOR_SCOPE), - Category.FOR_SCOPE, - ), - Category.FOR_SCOPE, - ) - return origin_func - - return decorate - - -def register_special_stmt(name=None): - """Decorator to register function under category special_stmt - Parameters - ---------- - name: str, optional - registered name for the function - Example - ------- - @register_special_stmt - def buffer_decl(parser, node, shape, dtype="float32", data=None, strides=[], elem_offset=None, - scope="global", align=-1, offset_factor=0, buffer_type="default"): - align = align.value if not isinstance(align, int) else align - offset_factor = offset_factor.value if not isinstance(offset_factor, int) else offset_factor - buffer = tvm.tir.decl_buffer(shape, dtype, parser.assign_target, data, strides, - elem_offset, scope, align, offset_factor, buffer_type) - return buffer - """ + return NewIntrin - def decorate(origin_func): - func_name = "tir." + origin_func.__qualname__ if name is None else name - Registry.functions[func_name] = ( - func_wrapper( - func_name, - origin_func, - get_arg_list(origin_func, Category.SPECIAL_STMT), - Category.SPECIAL_STMT, - ), - Category.SPECIAL_STMT, - ) - return origin_func + registration = create_new_intrin(inputs) + elif inspect.isclass(inputs): + registration = inputs + else: + raise ValueError() - return decorate + key = registration().signature()[0] + Registry.registrations[key] = registration + return registration diff --git a/python/tvm/script/scope_handler.py b/python/tvm/script/scope_handler.py index 08cd7ca84eb9..251df8c6d6cb 100644 --- a/python/tvm/script/scope_handler.py +++ b/python/tvm/script/scope_handler.py @@ -14,182 +14,248 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""TVM Script Parser Scope Handler Functions -This module provides the functions registered into parser under with_scope or for_scope category. -Scope handler nodes are StmtNodes with body, which are used to handle such scenarios. -1. For scope handler -When registering a for scope handler, the first 4 arguments must be parser, node, body, loop_vars -and these arguments will provided by TVM Script parser automatically -.. code-block:: python - for loop_vars in tir.xxx(): -2. With scope handler -There are 4 subtypes of with scope handlers, classified by - 1) with or without as - 2) allow concise scoping or not -1) with as & concise -the first 2 arguments must be parser, node -Need to parse the body manually -Example : tir.alloc_with_scope -.. code-block:: python - target = tir.xxx() - with tir.xxx() as target: -2) with as & not concise -the first 2 arguments must be parser, node -Need to parse the body manually -Example : None atm -.. code-block:: python - with tir.xxx() as target: -3) without as & concise -the first 3 arguments must be parser, node, body -TVM Script parser will parse the body automatically -Example : tir.allocate()/tir.realize()/tir.attr() -.. code-block:: python - tir.xxx() - with tir.xxx(): -4) without as & not concise -the first 3 arguments must be parser, node, body -TVM Script parser will parse the body automatically -Example : tir.assert()/tir.let() -.. code-block:: python - with tir.xxx(): -""" -# pylint: disable=redefined-builtin, unused-argument, invalid-name +"""TVM Script Parser Scope Handler Classes""" +# pylint: disable=redefined-builtin, unused-argument, invalid-name, relative-beyond-top-level from typed_ast import ast3 as ast import tvm.tir -from .registry import register_with_scope, register_for_scope - - -# With scope handler -@register_with_scope(concise=True, with_var=True) -def allocate(parser, node, extents, dtype, scope, condition=True): - """ With scope handler function tir.alloc_with_scope(var, extents, dtype, scope, condition) """ - # defining buffer var and parse the body manually - - buffer_var = tvm.te.var(parser.target[0], "handle") - # (TODO) Uncomment this line if we have richer type info for buffer var - # buffer_var = tvm.te.var(parser.target[0], tvm.ir.PointerType(tvm.ir.PrimType(dtype))) - if isinstance(node, ast.With): - parser.scope_emitter.new_scope() - parser.scope_emitter.update_symbol(buffer_var.name, buffer_var) - parser.scope_emitter.node_stack[-1].extend(reversed(node.body)) - body = parser.get_body() - parser.scope_emitter.pop_scope() - else: - parser.scope_emitter.update_symbol(buffer_var.name, buffer_var) - body = parser.get_body() - condition = tvm.runtime.convert(condition) - scope = tvm.runtime.convert(scope) - body = tvm.tir.Allocate(buffer_var, dtype, extents, condition, body) - return tvm.tir.AttrStmt(buffer_var, "storage_scope", scope, body) - - -@register_with_scope(concise=True) -def launch_thread(parser, node, body, env_var, extent): - extent = tvm.runtime.convert(extent) - return tvm.tir.AttrStmt( - tvm.tir.IterVar( - None, env_var, getattr(tvm.tir.IterVar, "ThreadIndex"), parser.var_env_dict[env_var] - ), - "thread_extent", - extent, - body, - ) - - -@register_with_scope(concise=True) -def realize(parser, node, body, buffer_bounds, scope, condition=True): - """ With scope handler function tir.realize(buffer_bounds, scope, condition) """ - buffer, bounds = buffer_bounds - scope = tvm.runtime.convert(scope) - return tvm.tir.AttrStmt( - buffer, "realize_scope", scope, tvm.tir.BufferRealize(buffer, bounds, condition, body) - ) - - -@register_with_scope(concise=True) -def attr(parser, node, body, attr_node, attr_key, value): - """ With scope handler function tir.attr(attr_node, attr_key, value) """ - attr_node = tvm.runtime.convert(attr_node) - value = tvm.runtime.convert(value) - return tvm.tir.AttrStmt(attr_node, attr_key, value, body) - - -@register_with_scope(concise=False) -def Assert(parser, node, body, condition, message): - """ With scope handler function tir.Assert(condition, message) """ - return tvm.tir.AssertStmt(condition, tvm.runtime.convert(message), body) - - -@register_with_scope(concise=False) -def let(parser, node, body, var, value): - """ With scope handler function tir.let(var, value) """ - return tvm.tir.LetStmt(var, value, body) - - -# For scope handler -@register_for_scope() -def serial(parser, node, body, loop_vars, begin, end): - """ For scope handler function tir.serial(begin, end)""" - if len(loop_vars) != 1: - parser.report_error("Expect exact 1 loop var") - ana = tvm.arith.Analyzer() - extent = end if begin == 0 else ana.simplify(end - begin) - return tvm.tir.For(loop_vars[0], begin, extent, 0, 0, body) - - -@register_for_scope() -def parallel(parser, node, body, loop_vars, begin, end): - """ For scope handler function tir.parallel(begin, end)""" - if len(loop_vars) != 1: - parser.report_error("Expect exact 1 loop var") - ana = tvm.arith.Analyzer() - extent = end if begin == 0 else ana.simplify(end - begin) - return tvm.tir.For(loop_vars[0], begin, extent, 1, 0, body) - - -@register_for_scope() -def vectorized(parser, node, body, loop_vars, begin, end): - """ For scope handler function tir.vectorized(begin, end)""" - if len(loop_vars) != 1: - parser.report_error("Expect exact 1 loop var") - ana = tvm.arith.Analyzer() - extent = end if begin == 0 else ana.simplify(end - begin) - return tvm.tir.For(loop_vars[0], begin, extent, 2, 0, body) - - -@register_for_scope() -def unroll(parser, node, body, loop_vars, begin, end): - """ For scope handler function tir.unroll(begin, end)""" - if len(loop_vars) != 1: - parser.report_error("Expect exact 1 loop var") - ana = tvm.arith.Analyzer() - extent = end if begin == 0 else ana.simplify(end - begin) - return tvm.tir.For(loop_vars[0], begin, extent, 3, 0, body) - - -@register_for_scope(name="range") -def Range(parser, node, body, loop_vars, begin, end, annotation=None): - """ For scope handler function range(begin, end, annotation)""" - if len(loop_vars) != 1: - parser.report_error("Expect exact 1 loop var") - ana = tvm.arith.Analyzer() - extent = end if begin == 0 else ana.simplify(end - begin) - if annotation is None: - annotation = [] - else: - annotation = [ - tvm.tir.Annotation(key, tvm.runtime.convert(val) if isinstance(val, str) else val) - for key, val in annotation.items() - ] - return tvm.tir.Loop(loop_vars[0], begin, extent, annotation, body) - - -@register_for_scope() -def grid(parser, node, body, loop_vars, *extents): - """ For scope handler function tir.grid(*extents) """ - if len(loop_vars) != len(extents): - parser.report_error("Inconsitent number of loop vars and extents") - for loop_var, extent in zip(reversed(loop_vars), reversed(extents)): - body = tvm.tir.Loop(loop_var, 0, extent, [], body) - return body +from .utils import get_param_list +from .registry import register + + +class ScopeHandler: + """Base class for all scope handlers""" + + def __init__(self, func): + self.func = func + self.body = None + self.node = None + self.context = None + + def signature(self): + return "tir." + self.func.__name__, get_param_list(self.func) + + def enter_scope(self, node, context): + pass + + def exit_scope(self, node, context, arg_list): + self.node = node + self.context = context + return self.func(*arg_list) + + +class WithScopeHandler(ScopeHandler): + """Base class for all with scope handlers""" + + def __init__(self, func, concise_scope, def_symbol): + super().__init__(func) + self.concise_scope = concise_scope + self.def_symbol = def_symbol + + @staticmethod + def get_optional_var_names(node, context): + """Get list of names from ast.With's optional_vars""" + assert isinstance(node, ast.With) + + var_names = None + if isinstance(node.items[0].optional_vars, ast.Name): + var_names = [node.items[0].optional_vars.id] + elif isinstance(node.items[0].optional_vars, (ast.List, ast.Tuple)): + for var in node.items[0].optional_vars.elts: + if not isinstance(var, ast.Name): + context.report_error("Invalid optional var definition") + var_names = [var.id for var in node.items[0].optional_vars.elts] + else: + context.report_error("Invalid optional var definition") + return var_names + + +@register +class Allocate(WithScopeHandler): + """ With scope handler tir.alloc_with_scope(var, extents, dtype, scope, condition) """ + + def __init__(self): + def allocate(extents, dtype, scope, condition=True): + condition = tvm.runtime.convert(condition) + scope = tvm.runtime.convert(scope) + body = tvm.tir.Allocate(self.buffer_var, dtype, extents, condition, self.body) + return tvm.tir.AttrStmt(self.buffer_var, "storage_scope", scope, body) + + super().__init__(allocate, concise_scope=True, def_symbol=True) + self.buffer_var = None + + def enter_scope(self, node, context): + # define buffer vars in symbol table + if isinstance(node, ast.With): + names = WithScopeHandler.get_optional_var_names(node, context) + if len(names) != 1: + context.report_error("Unexpected number of vars") + name = names[0] + elif isinstance(node, ast.Assign): + name = node.targets[0].id + else: + raise Exception("Internal Bug") + + self.buffer_var = tvm.te.var(name, "handle") + context.update_symbol(name, self.buffer_var) + + +@register +class LaunchThread(WithScopeHandler): + """ With scope handler tir.launch_thread(env_var, extent) """ + + def __init__(self): + def launch_thread(env_var, extent): + extent = tvm.runtime.convert(extent) + return tvm.tir.AttrStmt( + tvm.tir.IterVar( + None, + env_var, + getattr(tvm.tir.IterVar, "ThreadIndex"), + self.context.func_var_env_dict[env_var], + ), + "thread_extent", + extent, + self.body, + ) + + super().__init__(launch_thread, concise_scope=True, def_symbol=False) + + +@register +class Realize(WithScopeHandler): + """ With scope handler tir.realize(buffer_bounds, scope, condition) """ + + def __init__(self): + def realize(buffer_bounds, scope, condition=True): + buffer, bounds = buffer_bounds + scope = tvm.runtime.convert(scope) + return tvm.tir.AttrStmt( + buffer, + "realize_scope", + scope, + tvm.tir.BufferRealize(buffer, bounds, condition, self.body), + ) + + super().__init__(realize, concise_scope=True, def_symbol=False) + + +@register +class Attr(WithScopeHandler): + """ With scope handler tir.attr(attr_node, attr_key, value) """ + + def __init__(self): + def attr(attr_node, attr_key, value): + attr_node = tvm.runtime.convert(attr_node) + value = tvm.runtime.convert(value) + return tvm.tir.AttrStmt(attr_node, attr_key, value, self.body) + + super().__init__(attr, concise_scope=True, def_symbol=False) + + +@register +class AssertHandler(WithScopeHandler): + """ With scope handler tir.Assert(condition, message) """ + + def __init__(self): + def Assert(condition, message): + return tvm.tir.AssertStmt(condition, tvm.runtime.convert(message), self.body) + + super().__init__(Assert, concise_scope=True, def_symbol=False) + + +@register +class Let(WithScopeHandler): + """ With scope handler tir.let(var, value) """ + + def __init__(self): + def let(var, value): + return tvm.tir.LetStmt(var, value, self.body) + + super().__init__(let, concise_scope=False, def_symbol=False) + + +class ForScopeHandler(ScopeHandler): + """Base class for all for scope handlers""" + + def __init__(self, func): + super().__init__(func) + self.loop_vars = None + + def enter_scope(self, node, context): + assert isinstance(node, ast.For) + + loop_var_names = list() + if isinstance(node.target, ast.Name): + loop_var_names.append(node.target.id) + elif isinstance(node.target, ast.Tuple): + for elt in node.target.elts: + if not isinstance(elt, ast.Name): + context.report_error("Invalid loop var") + loop_var_names.append(elt.id) + else: + context.report_error("Invalid loop var") + + self.loop_vars = [tvm.te.var(name, dtype="int32") for name in loop_var_names] + for loop_var in self.loop_vars: + context.update_symbol(loop_var.name, loop_var) + + +@register +class Serial(ForScopeHandler): + """ For scope handler tir.serial(begin, end)""" + + def __init__(self): + def serial(begin, end): + if len(self.loop_vars) != 1: + self.context.report_error("Expect exact 1 loop var") + ana = tvm.arith.Analyzer() + extent = end if begin == 0 else ana.simplify(end - begin) + return tvm.tir.For(self.loop_vars[0], begin, extent, 0, 0, self.body) + + super().__init__(serial) + + +@register +class Parallel(ForScopeHandler): + """ For scope handler tir.parallel(begin, end)""" + + def __init__(self): + def parallel(begin, end): + if len(self.loop_vars) != 1: + self.context.report_error("Expect exact 1 loop var") + ana = tvm.arith.Analyzer() + extent = end if begin == 0 else ana.simplify(end - begin) + return tvm.tir.For(self.loop_vars[0], begin, extent, 1, 0, self.body) + + super().__init__(parallel) + + +@register +class Vectorized(ForScopeHandler): + """ For scope handler tir.vectorized(begin, end)""" + + def __init__(self): + def vectorized(begin, end): + if len(self.loop_vars) != 1: + self.context.report_error("Expect exact 1 loop var") + ana = tvm.arith.Analyzer() + extent = end if begin == 0 else ana.simplify(end - begin) + return tvm.tir.For(self.loop_vars[0], begin, extent, 2, 0, self.body) + + super().__init__(vectorized) + + +@register +class Unroll(ForScopeHandler): + """ For scope handler tir.unroll(begin, end)""" + + def __init__(self): + def unroll(begin, end): + if len(self.loop_vars) != 1: + self.context.report_error("Expect exact 1 loop var") + ana = tvm.arith.Analyzer() + extent = end if begin == 0 else ana.simplify(end - begin) + return tvm.tir.For(self.loop_vars[0], begin, extent, 3, 0, self.body) + + super().__init__(unroll) diff --git a/python/tvm/script/special_stmt.py b/python/tvm/script/special_stmt.py index 53c01d49d371..31fe0ed7cebf 100644 --- a/python/tvm/script/special_stmt.py +++ b/python/tvm/script/special_stmt.py @@ -14,130 +14,172 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""TVM Script Parser Special Stmt Functions -This module provides the functions registered into parser under special_stmt category. -special_stmt functions don't correspond to an IRNode in the AST directly. It is usually -used for some information that is not suitable to be printed directly. -special_stmt can appear as 2 formats -.. code-block:: python - target = tir.name(): - tir.name() -When registering a special stmt, the first two arguments must be parser, node -""" +"""TVM Script Parser Special Stmt Classes""" # pylint: disable=unused-argument, no-self-argument, inconsistent-return-statements +# pylint: disable=relative-beyond-top-level +from typed_ast import ast3 as ast import tvm.tir from tvm import te -from .registry import register_special_stmt - - -@register_special_stmt() -def match_buffer( - parser, - node, - param, - shape, - dtype="float32", - data=None, - strides=None, - elem_offset=None, - scope="global", - align=-1, - offset_factor=0, - buffer_type="default", -): - """Special function match_buffer(var, shape, dtype, data, strides, elem_offset, scope, align, - offset_factor, buffer_type) +from .utils import get_param_list +from .registry import register + + +class SpecialStmt: + """Base class for all Special Stmts""" + + def __init__(self, func, def_symbol): + self.func = func + self.def_symbol = def_symbol + self.node = None + self.context = None + + def signature(self): + return "tir." + self.func.__name__, get_param_list(self.func) + + def handle(self, node, context, arg_list): + self.node = node + self.context = context + return self.func(*arg_list) + + +@register +class MatchBuffer(SpecialStmt): + """Special Stmt match_buffer(var, shape, dtype, data, strides, elem_offset, scope, align, + offset_factor, buffer_type) Example ------- .. code-block:: python A = tir.match_buffer(a, (128, 128), dtype="float32") """ - if param not in parser.params: - parser.report_error("Can not bind non-input param to buffer") - if strides is None: - strides = [] - align = align.value if not isinstance(align, int) else align - offset_factor = offset_factor.value if not isinstance(offset_factor, int) else offset_factor - buffer = tvm.tir.decl_buffer( - shape, - dtype, - parser.target[0], - data, - strides, - elem_offset, - scope, - align, - offset_factor, - buffer_type, - ) - parser.buffer_map[param] = buffer - return buffer - - -@register_special_stmt() -def buffer_decl( - parser, - node, - shape, - dtype="float32", - data=None, - strides=None, - elem_offset=None, - scope="global", - align=-1, - offset_factor=0, - buffer_type="default", -): - """Special function buffer_decl(shape, dtype, data, strides, elem_offset, scope, align, - offset_factor, buffer_type) + def __init__(self): + def match_buffer( + param, + shape, + dtype="float32", + data=None, + strides=None, + elem_offset=None, + scope="global", + align=-1, + offset_factor=0, + buffer_type="default", + ): + assert isinstance(self.node, ast.Assign) + + if param not in self.context.func_params: + self.context.report_error("Can not bind non-input param to buffer") + if strides is None: + strides = [] + align = align.value if not isinstance(align, int) else align + offset_factor = ( + offset_factor.value if not isinstance(offset_factor, int) else offset_factor + ) + buffer = tvm.tir.decl_buffer( + shape, + dtype, + self.node.targets[0].id, + data, + strides, + elem_offset, + scope, + align, + offset_factor, + buffer_type, + ) + self.context.func_buffer_map[param] = buffer + self.context.update_symbol(self.node.targets[0].id, buffer) + + super().__init__(match_buffer, def_symbol=True) + + +@register +class BufferDeclare(SpecialStmt): + """Special Stmt buffer_decl(shape, dtype, data, strides, elem_offset, scope, align, + offset_factor, buffer_type) Example ------- .. code-block:: python A = tir.buffer_decl((128, 128), dtype="float32") """ - if strides is None: - strides = [] - align = align.value if not isinstance(align, int) else align - offset_factor = offset_factor.value if not isinstance(offset_factor, int) else offset_factor - buffer = tvm.tir.decl_buffer( - shape, - dtype, - parser.target[0], - data, - strides, - elem_offset, - scope, - align, - offset_factor, - buffer_type, - ) - return buffer - - -@register_special_stmt() -def var(parser, node, dtype): + def __init__(self): + def buffer_decl( + shape, + dtype="float32", + data=None, + strides=None, + elem_offset=None, + scope="global", + align=-1, + offset_factor=0, + buffer_type="default", + ): + assert isinstance(self.node, ast.Assign) + + if strides is None: + strides = [] + align = align.value if not isinstance(align, int) else align + offset_factor = ( + offset_factor.value if not isinstance(offset_factor, int) else offset_factor + ) + buffer = tvm.tir.decl_buffer( + shape, + dtype, + self.node.targets[0].id, + data, + strides, + elem_offset, + scope, + align, + offset_factor, + buffer_type, + ) + self.context.update_symbol(self.node.targets[0].id, buffer) + return buffer + + super().__init__(buffer_decl, def_symbol=True) + + +@register +class VarDef(SpecialStmt): """ Special function for defining a Var""" - return te.var(parser.target[0], dtype) + def __init__(self): + def var(dtype): + assert isinstance(self.node, ast.Assign) + v = te.var(self.node.targets[0].id, dtype) + self.context.update_symbol(v.name, v) -@register_special_stmt() -def env_thread(parser, node, env_name): + super().__init__(var, def_symbol=True) + + +@register +class EnvThread(SpecialStmt): """ Bind a var to thread env """ - v = te.var(parser.target[0]) - parser.var_env_dict[v] = env_name - return v + def __init__(self): + def env_thread(env_name): + assert isinstance(self.node, ast.Assign) + v = te.var(self.node.targets[0].id) + self.context.func_var_env_dict[v] = env_name + self.context.update_symbol(v.name, v) + + super().__init__(env_thread, def_symbol=True) -@register_special_stmt() -def func_attr(parser, node, dict_attr): - """Special function for declaring the DictAttr of PrimFunc + +@register +class FuncAttr(SpecialStmt): + """Special Stmt for declaring the DictAttr of PrimFunc Example ------- .. code-block:: python tir.func_attr({"tir.noalias": True, "global_symbol"}) """ - parser.dict_attr = dict_attr + def __init__(self): + def func_attr(dict_attr): + self.context.func_dict_attr = dict_attr + + super().__init__(func_attr, def_symbol=False) diff --git a/python/tvm/script/ty.py b/python/tvm/script/ty.py index 430a746fff40..1d7871624eb5 100644 --- a/python/tvm/script/ty.py +++ b/python/tvm/script/ty.py @@ -23,14 +23,15 @@ import tvm -class TypeGeneric: +class TypeGeneric: # pylint: disable=too-few-public-methods """Base class for all the TVM script typing class""" def evaluate(self): + """Return an actual ir.Type Object that this Generic class wraps""" raise TypeError("Cannot get tvm.Type from a generic type") -class ConcreteType(TypeGeneric): +class ConcreteType(TypeGeneric): # pylint: disable=too-few-public-methods """TVM script typing class for uniform Type objects""" def __init__(self, vtype): diff --git a/python/tvm/script/utils.py b/python/tvm/script/utils.py index f510ddb906aa..ef6736f3e98b 100644 --- a/python/tvm/script/utils.py +++ b/python/tvm/script/utils.py @@ -17,93 +17,29 @@ """Helper functions in TVM Script Parser""" import inspect -from tvm import IRModule -from . import _ffi_api -from .parser import from_source +def get_param_list(func): + """Get the parameter list from definition of function""" + full_arg_spec = inspect.getfullargspec(func) -def create_module(functions=None): - """Construct a module from list of functions. + args, defaults = full_arg_spec.args, full_arg_spec.defaults - Parameters - ----------- - functions: Optional[dict]. - Map of GlobalVar or str to PrimFunc + if defaults is None: + defaults = tuple() - Returns - ------- - mod : IRModule - An IRModule containing the passed definitions - """ + if full_arg_spec.varkw is not None: + raise RuntimeError( + "TVM Script register error : variable keyword argument is not supported now" + ) + if not len(full_arg_spec.kwonlyargs) == 0: + raise RuntimeError("TVM Script register error : keyword only argument is not supported now") - return IRModule(functions=functions) + pos_only = list() + for arg in args[: len(args) - len(defaults)]: + pos_only.append(arg) + kwargs = list() + for default, arg in zip(defaults, args[len(args) - len(defaults) :]): + kwargs.append((arg, default)) - -def asscript(input_ir, show_meta=False): - """Transform a PrimFunc or IRModule to python syntax script - - Parameters - ---------- - input_ir : Union[PrimFunc, IRModule] - The PrimFunc or IRModule to be dumped - - show_meta : bool - Whether show meta - - Returns - ------- - script : str - The Python script - """ - - return _ffi_api.AsTVMScript(input_ir, show_meta) - - -def tir(script_in): - """Decorate a python function or class as tvm script. - - The tvm function or parsing support parsing to the internal TIR. - - Returns - ------- - output : Union[Function, Module] - The Function or Module in IR. - """ - - if inspect.isfunction(script_in): - return _parse(script_in) - - if inspect.isclass(script_in): - return TVMScriptClass(script_in) - - raise TypeError("Only function and class are supported") - - -def module(script_in): - """Decorate a python function or class as tvm script. - - Alias for tvm.script.tir for now. - - Returns - ------- - output : Union[Function, Module] - The Function or Module in IR. - """ - return tir(script_in) - - -class TVMScriptClass: - """Helper class for decorating a class""" - - def __init__(self, script_in): - self.script = script_in - - def __call__(self, *args, **kwargs): - # call the parser to transform tvm script into TIR - return _parse(self.script) - - -def _parse(script_in): - """Helper function to parse TVM script into TIR""" - return from_source(inspect.getsource(script_in), inspect.getsourcelines(script_in)[1]) + return pos_only, kwargs, full_arg_spec.varargs