diff --git a/python/tvm/hybrid/_ffi_api.py b/python/tvm/hybrid/_ffi_api.py index d59302a95dd1..929a65c03049 100644 --- a/python/tvm/hybrid/_ffi_api.py +++ b/python/tvm/hybrid/_ffi_api.py @@ -17,5 +17,4 @@ """FFI APIs for tvm.hybrid""" import tvm._ffi - -tvm._ffi._init_api("tir.hybrid", __name__) +tvm._ffi._init_api("hybrid", __name__) diff --git a/python/tvm/hybrid/intrin.py b/python/tvm/hybrid/intrin.py index 3dc46a280b72..fdd48f37d1bd 100644 --- a/python/tvm/hybrid/intrin.py +++ b/python/tvm/hybrid/intrin.py @@ -23,114 +23,146 @@ from .registry import register_intrin -@register_intrin +@register_intrin() def bool(imm): - return tvm.tir.const(imm.value, "bool") + return tvm.tir.const(imm, "bool") -@register_intrin +@register_intrin() def int8(imm): - return tvm.tir.const(imm.value, "int8") + return tvm.tir.const(imm, "int8") -@register_intrin +@register_intrin() def int16(imm): - return tvm.tir.const(imm.value, "int16") + return tvm.tir.const(imm, "int16") -@register_intrin +@register_intrin() def int32(imm): - return tvm.tir.const(imm.value, "int32") + return tvm.tir.const(imm, "int32") -@register_intrin +@register_intrin() def int64(imm): - return tvm.tir.const(imm.value, "int64") + return tvm.tir.const(imm, "int64") -@register_intrin +@register_intrin() def uint8(imm): - return tvm.tir.const(imm.value, "uint8") + return tvm.tir.const(imm, "uint8") -@register_intrin +@register_intrin() def uint16(imm): - return tvm.tir.const(imm.value, "uint16") + return tvm.tir.const(imm, "uint16") -@register_intrin +@register_intrin() def uint32(imm): - return tvm.tir.const(imm.value, "uint32") + return tvm.tir.const(imm, "uint32") -@register_intrin +@register_intrin() def uint64(imm): - return tvm.tir.const(imm.value, "uint64") + return tvm.tir.const(imm, "uint64") -@register_intrin +@register_intrin() def float8(imm): - return tvm.tir.const(imm.value, "float8") + return tvm.tir.const(imm, "float8") -@register_intrin +@register_intrin() def float16(imm): - return tvm.tir.const(imm.value, "float16") + return tvm.tir.const(imm, "float16") -@register_intrin +@register_intrin() def float32(imm): - return tvm.tir.const(imm.value, "float32") + return tvm.tir.const(imm, "float32") -@register_intrin +@register_intrin() def float64(imm): - return tvm.tir.const(imm.value, "float64") + return tvm.tir.const(imm, "float64") -@register_intrin +@register_intrin() def floordiv(x, y): return tvm.tir.floordiv(x, y) -@register_intrin +@register_intrin() def floormod(x, y): return tvm.tir.floormod(x, y) -@register_intrin +@register_intrin() def load(dtype, var, index, predicate=True): return tvm.tir.Load(dtype, var, index, predicate) -@register_intrin -def cast(dtype, value): +@register_intrin() +def cast(value, dtype): return tvm.tir.Cast(dtype, value) -@register_intrin +@register_intrin() def ramp(base, stride, lanes): - lanes = lanes.value if not isinstance(lanes, int) else lanes return tvm.tir.Ramp(base, stride, lanes) -@register_intrin +@register_intrin() def broadcast(value, lanes): - lanes = lanes.value if not isinstance(lanes, int) else lanes return tvm.tir.Broadcast(value, lanes) -@register_intrin +@register_intrin() def evaluate(value): return tvm.tir.Evaluate(value) -@register_intrin +@register_intrin() def store(var, index, value, predicate=True): return tvm.tir.Store(var, value, index, predicate) -@register_intrin +@register_intrin() def iter_var(var, dom, iter_type, thread_tag): iter_type = getattr(tvm.tir.IterVar, iter_type) return tvm.tir.IterVar(dom, var, iter_type, thread_tag) + + +@register_intrin() +def max(a, b): # pylint: disable=redefined-builtin + return tvm.tir.Max(a, b) + + +def get_axis(begin, end, iter_type): + ana = tvm.arith.Analyzer() + extent = ana.simplify(end - begin) + block_var_dom = tvm.ir.Range.from_min_extent(begin, extent) + + iter_type_dict = {"data_par": 0, "reduce": 2, "scan": 3, "opaque": 4} + return tvm.tir.IterVar(block_var_dom, "bv", iter_type_dict[iter_type]) + + +@register_intrin() +def range(begin, end): + return get_axis(begin, end, "data_par") + + +@register_intrin() +def reduce_axis(begin, end): + return get_axis(begin, end, "reduce") + + +@register_intrin() +def scan_axis(begin, end): + return get_axis(begin, end, "scan") + + +@register_intrin() +def opaque_axis(begin, end): + return get_axis(begin, end, "opaque") diff --git a/python/tvm/hybrid/parser.py b/python/tvm/hybrid/parser.py index b3b042a6de7d..a1aa652c79d9 100644 --- a/python/tvm/hybrid/parser.py +++ b/python/tvm/hybrid/parser.py @@ -19,7 +19,6 @@ # pylint: disable=unnecessary-comprehension, unused-argument, import-outside-toplevel # pylint: disable=unused-import import json -import numbers import operator from typed_ast import ast3 as ast @@ -30,9 +29,10 @@ from tvm.tir import all as _all from tvm.tir import expr as _expr -from . import scope_emitter, special_stmt, scope_handler, intrin +from . import scope_emitter, special_stmt, scope_handler, intrin, ty from .meta_unparser import MetaUnparser from .registry import Registry +from . import _ffi_api class HybridParserError(RuntimeError): @@ -45,13 +45,14 @@ class HybridParser(ast.NodeVisitor): 1. To support new types of AST nodes. Add a function visit_xxx(). 2. To support new functions We divide allowed function calls in hybrid script into 3 categories, - which is scope_handler, intrin and special_stmt. - 1) scope_handler: scope_handler functions correspond to StmtNodes without body, which can be - further classified into 2 categories: with scope handler can for scope handlers - 2) intrin: intrin functions corresponds to the remaining IRNodes (StmtNodes without body, - PrimExprNodes and more) - 3) special_stmt: special_stmt functions don't correspond to an IRNode in the AST directly. - It is usually used for some information that is not suitable to be printed directly. + which is intrin, scope_handler and special_stmt. + 1) intrin functions ought to have return value. + User can also register intrin category function into parser. + 2) scope_handler functions have no return value and accepts parser and AST node + as its arguments, which is used in for scope and with scope. + 3) special_stmt functions have return value and accepts parser and AST node as its arguments + When visiting Call node, we check special_stmt registry at first. If no registered function + is found, we then check intrin. When visiting With node, we check with_scope registry. When visiting For node, we check for_scope registry. """ @@ -83,6 +84,7 @@ def __init__(self, src, base_lienno): self.buffer_map = None self.dict_attr = None self.scope_emitter = None + self.var_env_dict = None self.src = src.split("\n") self.base_lineno = base_lienno @@ -91,9 +93,7 @@ def __init__(self, src, base_lienno): self.meta = None self.functions = {} - - self._in_with_func_arg = False - self._assign_target = None + self.target = None def init_function_parsing_env(self): """Initialize function parsing environment""" @@ -101,6 +101,7 @@ def init_function_parsing_env(self): self.buffer_map = {} # buffer map self.dict_attr = {} # dict attr self.scope_emitter = scope_emitter.ScopeEmitter(self) # scope emitter + self.var_env_dict = {} # map from var to thread env name @staticmethod def is_meta(node): @@ -169,15 +170,6 @@ def report_error(self, message, lineno=None, col_offset=None): col_offset = self.current_col_offset raise HybridParserError(self.wrap_line_col(message, lineno, col_offset)) - def get_type_name(self, vtype): - if ( - isinstance(vtype, ast.Attribute) - and isinstance(vtype.value, ast.Name) - and vtype.value.id == "ty" - ): - return vtype.attr - self.report_error("invalid type annotation") - def get_body(self): body = [] while len(self.scope_emitter.node_stack[-1]) > 0: @@ -186,31 +178,19 @@ def get_body(self): body.append(res) return tvm.tir.SeqStmt(body) if len(body) > 1 else body[0] - def parse_type(self, vtype): - """ Parse type annotation AST into Type object """ - if isinstance(vtype, ast.NameConstant) and vtype.value is None: - return tvm.ir.TupleType([]) - elif isinstance(vtype, ast.Attribute): - return tvm.ir.PrimType(self.get_type_name(vtype)) - elif isinstance(vtype, ast.Subscript) and isinstance(vtype.slice, ast.Index): - type_name = self.get_type_name(vtype.value) - if isinstance(vtype.slice.value, ast.Tuple): - args = [self.parse_type(element) for element in vtype.slice.value.elts] - else: - args = [self.parse_type(vtype.slice.value)] - if type_name == "Ptr": - return tvm.ir.PointerType(*args) - elif type_name == "Tuple": - return tvm.ir.TupleType(args) - - self.report_error("invalid type annotation") + def get_type(self, type_node): + """ Parse type """ + if type_node is None: + self.report_error("missing type annotation") + res_type = self.visit(type_node) + return tvm.ir.TupleType([]) if res_type is None else res_type.evaluate() def generic_visit(self, node): """Override method in ast.NodeVisitor. To directly filter out invalidate type of stmt. """ - self.report_error(type(node).__name__ + " stmt is not supported now") + self.report_error(type(node).__name__ + " AST node is not supported now") def visit_Module(self, node): """Module visitor @@ -304,7 +284,7 @@ def visit_FunctionDef(self, node): self.init_function_parsing_env() # add parameters of function for arg in node.args.args: - arg_var = tvm.te.var(arg.arg, self.parse_type(arg.annotation)) + arg_var = tvm.te.var(arg.arg, self.get_type(arg.annotation)) self.scope_emitter.update_symbol(arg.arg, arg_var) self.params.append(arg_var) @@ -315,7 +295,7 @@ def visit_FunctionDef(self, node): func = tvm.tir.PrimFunc( self.params, self.get_body(), - ret_type=self.parse_type(node.returns), + ret_type=self.get_type(node.returns), buffer_map=self.buffer_map, attrs=tvm.ir.make_node("DictAttrs", **self.dict_attr), ) @@ -326,12 +306,15 @@ def visit_Assign(self, node): """Assign visitor AST abstract grammar: Assign(expr* targets, expr value, string? type_comment) - By now only 2 types of Assign is supported: - 1. special stmts that appear as assign stmt + By now only 3 types of Assign is supported: + 1. special stmts with return value 1.1 Buffer = tir.buffer_bind()/tir.buffer_decl() 1.2 Var = tir.var() + 1.3 Var = tir.env_thread() 2. (BufferStore) Buffer[PrimExpr, PrimExpr, ..., PrimExpr] = PrimExpr - 3. (Store) Var[PrimExpr] = PrimExpr + 3. (Store) Var[PrimExpr] = PrimExpr + 4. with scope handlers with concise scoping and var def + 4.1 var = tir.alloc_with_scope() """ if not len(node.targets) == 1: @@ -339,22 +322,29 @@ def visit_Assign(self, node): target = node.targets[0] if isinstance(target, ast.Name): - # scenario 1 - self._assign_target = target.id - rhs = self.visit(node.value) + # scenario 1&4 + self.target = [target.id] if not isinstance(node.value, ast.Call): - self.report_error("Unsupported Assign stmt") - self.scope_emitter.update_symbol(target.id, rhs) + self.report_error("Unsupported assign stmt") + func = self.visit(node.value.func) + if Registry.is_with_scope(func): + # scenario 4 + return self.visit(node.value) + else: + # scenario 1 + rhs = self.visit(node.value) + self.scope_emitter.update_symbol(target.id, rhs) elif isinstance(target, ast.Subscript): # scenario 2&3 symbol, indexes = self.visit(target) - self._assign_target = (symbol, indexes) rhs = self.visit(node.value) if isinstance(symbol, tvm.tir.Buffer): + # BufferStore return tvm.tir.BufferStore(symbol, tvm.runtime.convert(rhs), indexes) else: if len(indexes) != 1: self.report_error("Invalid Store stmt") + # Store return tvm.tir.Store( symbol, tvm.runtime.convert(rhs), indexes[0], tvm.runtime.convert(True) ) @@ -370,7 +360,7 @@ def visit_AnnAssign(self, node): if isinstance(node.target, ast.Name): value = self.visit(node.value) - var = tvm.te.var(node.target.id, self.parse_type(node.annotation)) + var = tvm.te.var(node.target.id, self.get_type(node.annotation)) self.scope_emitter.update_symbol(var.name, var) return tvm.tir.LetStmt(var, value, self.visit(self.scope_emitter.node_stack[-1].pop())) else: @@ -394,40 +384,26 @@ def visit_For(self, node): AST abstract grammar: For(expr target, expr iter, stmt* body, stmt* orelse, string? type_comment) By now only 1 type of For is supported: - 1. for name in tir.range(begin, end, for_type) + 1. for name in tir.serial/parallel/vectorized/unroll(begin, end) """ - if not isinstance(node.target, ast.Name): - self.report_error("The loop variable should be a name variable") - # check node.iter, which is a tir Call + # check node.iter, which is a Call if not isinstance(node.iter, ast.Call): self.report_error("The loop iter should be a Call") - if ( - not isinstance(node.iter.func, ast.Attribute) - or not isinstance(node.iter.func.value, ast.Name) - or node.iter.func.value.id != "tir" - ): - self.report_error("The loop iter Call should be tir.name()") - - func_name = node.iter.func.attr + func = self.visit(node.iter.func) + if not Registry.is_for_scope(func): + self.report_error("Function not allowed in for scope") # collect arguments args = [self.visit(arg) for arg in node.iter.args] kw_args = [self.visit(keyword) for keyword in node.iter.keywords] kw_args = {kw_arg[0]: kw_arg[1] for kw_arg in kw_args} - # All the functions supported in For stmt are registered in scope_handler.ForScope - if func_name not in Registry.for_scope: - self.report_error( - "Function " + func_name + " used in For stmt is not supported now", - self.current_lineno, - node.iter.col_offset, - ) old_lineno, old_col_offset = self.current_lineno, self.current_col_offset self.current_lineno, self.current_col_offset = ( self.base_lineno + node.iter.lineno - 1, node.iter.col_offset, ) - res = Registry.for_scope.get(func_name)(self, node, args, kw_args) + res = func(self, node, args, kw_args) self.current_lineno, self.current_col_offset = old_lineno, old_col_offset return res @@ -436,37 +412,45 @@ def visit_With(self, node): AST abstract grammar: With(withitem* items, stmt* body, string? type_comment) withitem = (expr context_expr, expr? optional_vars) - By now only 1 type of With is supported: - 1. with tir.let/tir.Assert()/tir.attr()/tir.allocate()/tir.realize() + By now 2 types of With is supported: + 1. with tir.allocate() as targets: + 2. with tir.let()/tir.Assert()/tir.attr()//tir.realize() """ - - if len(node.items) != 1: + if not len(node.items) == 1: self.report_error("Only one with element is supported now") if not isinstance(node.items[0].context_expr, ast.Call): self.report_error("The context expression of with should be a Call") + func_call = node.items[0].context_expr - if ( - not isinstance(func_call.func, ast.Attribute) - or not isinstance(func_call.func.value, ast.Name) - or func_call.func.value.id != "tir" - ): - self.report_error("The context expression of with should be tir.name()") - - func_name = func_call.func.attr - # collect arguments + func_node = func_call.func + func = self.visit(func_node) + + if not Registry.is_with_scope(func): + self.report_error("Function not allowed in with scope") + + self.target = [] + if node.items[0].optional_vars is not None: + # preprocess optional var names + if isinstance(node.items[0].optional_vars, ast.Name): + self.target = [node.items[0].optional_vars.id] + elif isinstance(node.items[0].optional_vars, (ast.List, ast.Tuple)): + for var in node.items[0].optional_vars.elts: + if not isinstance(var, ast.Name): + self.report_error("Invalid optional var definition") + self.target = [var.id for var in node.items[0].optional_vars.elts] + else: + self.report_error("Invalid optional var definition") + # parse other arguments args = [self.visit(arg) for arg in func_call.args] kw_args = [self.visit(keyword) for keyword in func_call.keywords] kw_args = {kw_arg[0]: kw_arg[1] for kw_arg in kw_args} - if func_name not in Registry.with_scope: - self.report_error("Function " + func_name + " used in With stmt is not supported now") - # All the functions supported in With stmt are registered in scope_handler.WithScope old_lineno, old_col_offset = self.current_lineno, self.current_col_offset self.current_lineno, self.current_col_offset = ( self.base_lineno + func_call.lineno - 1, func_call.col_offset, ) - res = Registry.with_scope.get(func_name)(self, node, args, kw_args) + res = func(self, node, args, kw_args) self.current_lineno, self.current_col_offset = old_lineno, old_col_offset return res @@ -498,49 +482,41 @@ def visit_Call(self, node): AST abstract grammar: Call(expr func, expr* args, keyword* keywords) keyword = (identifier? arg, expr value) + All the functions used outside With and For are registered in special_stmt or intrin """ + func = self.visit(node.func) # collect arguments args = [self.visit(arg) for arg in node.args] kw_args = [self.visit(keyword) for keyword in node.keywords] kw_args = {kw_arg[0]: kw_arg[1] for kw_arg in kw_args} - maybe_intrin = False - if isinstance(node.func, ast.Attribute) and isinstance(node.func.value, ast.Name): - if node.func.value.id == "tir": - func_name = node.func.attr - maybe_intrin = True + if callable(func): + if Registry.is_registered(func): + return func(self, node, args, kw_args) else: - self.report_error("Unsupported Attribute typed function call") - else: - self.report_error("Unsupported function call") + return func(*args, **kw_args) + elif isinstance(func, tvm.tir.op.Op): + return tvm.tir.Call(kw_args["dtype"], func, args) - if func_name in Registry.special_stmt: - return Registry.special_stmt.get(func_name)(self, node, args, kw_args) - if func_name in Registry.intrin: - return Registry.intrin.get(func_name)(self, node, args, kw_args) - if func_name in Registry.with_scope: - return Registry.with_scope.get(func_name)(self, node, args, kw_args) - if maybe_intrin: - return tvm.tir.Call(kw_args["dtype"], tvm.ir.op.Op.get("tir." + func_name), args) - - self.report_error("Function " + func_name + " is not supported now") + self.report_error("Unsupported function call") def visit_Expr(self, node): """Expr visitor AST abstract grammar: Expr(expr value) - - Now only 2 types of Expr stmt is allowed: - 1. Concise mode of with scope handlers - tir.attr()/tir.assert()/tir.allocate()/tir.realize() - 2. special stmts appear as a call - tir.set_func_attr() + Now only 3 types of `Expr` stmt is allowed: + 1. reducer.step()/tir.store() + 2. tir.attr()/tir.assert()/tir.allocate()/tir.realize() + 3. tir.set_func_attr() """ if not isinstance(node.value, ast.Call): self.report_error("Unsupported Expr stmt") - return self.visit(node.value) + res = self.visit(node.value) + if res is None or isinstance(res, tvm.tir.Stmt): + return res + self.report_error("Invalid Expr stmt") def visit_BinOp(self, node): """BinOp visitor @@ -602,16 +578,14 @@ def visit_Subscript(self, node): 2. meta[type_key][index], Meta info access """ - if isinstance(node.value, (ast.Name, ast.Attribute)): - symbol = self.visit(node.value) + symbol = self.visit(node.value) + if symbol is None: + self.report_error(node.value.id + " is not defined") + if isinstance(symbol, (tvm.tir.expr.Var, tvm.tir.Buffer)): if isinstance(node.slice, ast.Index): - # BufferLoad & BufferStore - if isinstance(node.slice.value, ast.Tuple): - # Buffer/Var[index, index, ...] - indexes = [self.visit(element) for element in node.slice.value.elts] - else: - # Buffer/Var[index] - indexes = [self.visit(node.slice.value)] + # BufferLoad & BufferStore, Buffer/Var[index, index, ...] + indexes = self.visit(node.slice.value) + indexes = list(indexes) if isinstance(indexes, tuple) else [indexes] if isinstance(node.ctx, ast.Load): if isinstance(symbol, tir.expr.Var): return tvm.tir.Load("float32", symbol, indexes, True) @@ -643,42 +617,11 @@ def visit_Subscript(self, node): extent = ana.simplify(extent) doms.append(tvm.ir.Range.from_min_extent(lower, extent)) return symbol, doms - - elif ( - isinstance(node.value, ast.Subscript) - and isinstance(node.value.value, ast.Name) - and node.value.value.id == "meta" - ): - # meta[type_key][index] - if not ( - isinstance(node.slice, ast.Index) and isinstance(node.slice.value, ast.Num) - ) or not ( - isinstance(node.value.slice, ast.Index) - and isinstance(node.value.slice.value, ast.Name) - ): - self.report_error("The meta access format ought to be meta[type_key][index]") - type_key = node.value.slice.value.id - index = node.slice.value.n - node_list = self.meta[type_key] - if node_list is None: - self.report_error("type_key " + type_key + " in meta not found") - if len(node_list) <= index: - self.report_error("index " + index + " out of range " + len(node_list)) - return node_list[index] else: - self.report_error("Only buffer variable and meta can be subscriptable") - - def visit_Name(self, node): - """Name visitor - AST abstract grammar: - Name(identifier id, expr_context ctx) - """ - - name = node.id - symbol = self.scope_emitter.lookup_symbol(name) - if symbol is None: - self.report_error("Unknown symbol %s" % name) - return symbol + res = symbol[self.visit(slice)] + if res is None: + self.report_error("Only buffer variable and meta can be subscriptable") + return res def visit_Attribute(self, node): """Attribute visitor @@ -686,15 +629,28 @@ def visit_Attribute(self, node): Attribute(expr value, identifier attr, expr_context ctx) """ - if not isinstance(node.value, ast.Name): - self.report_error("The value of Attribute ought to a Name") - name = node.value.id - symbol = self.scope_emitter.lookup_symbol(name) - if symbol is None or not isinstance(symbol, tvm.tir.Buffer): + if isinstance(node.value, ast.Name): + if node.value.id == "tir": + func_name = "tir." + node.attr + res = Registry.look_up_function(func_name) + if res is not None: + return res + try: + return tvm.ir.op.Op.get(func_name) + except AttributeError: + self.report_error("Unregistered function tir." + node.attr) + elif node.value.id == "ty": + if not hasattr(ty, node.attr): + self.report_error("invalid type annotation ty." + node.attr) + return getattr(ty, node.attr) + + symbol = self.visit(node.value) + if symbol is None: self.report_error("Unsupported Attribute expression") if not hasattr(symbol, node.attr): self.report_error("Type " + type(symbol) + " has not attr " + node.attr) - return getattr(symbol, node.attr) + res = getattr(symbol, node.attr) + return res def visit_Dict(self, node): """Dict visitor @@ -731,20 +687,32 @@ def visit_keyword(self, node): return node.arg, self.visit(node.value) - def visit_NameConstant(self, node): - return tvm.runtime.convert(node.value) + def visit_Name(self, node): + """Name visitor + AST abstract grammar: + Name(identifier id, expr_context ctx) + """ + + name = node.id + if name == "meta": + return self.meta + symbol = Registry.look_up_function(name) + if symbol is not None: + return symbol + symbol = self.scope_emitter.lookup_symbol(name) + if symbol is not None: + return symbol + self.report_error("Unknown identifier %s" % name) + # note that after Python3.8, ast.NameConstant, ast.Num, ast.Str are no longer used def visit_Constant(self, node): - return tvm.runtime.convert(node.value) + return node.value + + def visit_NameConstant(self, node): + return node.value def visit_Num(self, node): - if isinstance(node.n, numbers.Integral): - dtype = "int32" - elif isinstance(node.n, float): - dtype = "float32" - else: - self.report_error("The data type should be one of (int, float)") - return tvm.tir.const(node.n, dtype) + return node.n def visit_Str(self, node): return node.s @@ -787,4 +755,4 @@ def from_source(src, func_lineno=0): raise HybridParserError(inject_e) -tvm._ffi._init_api("tvm.hybrid.parser") +tvm._ffi._init_api("hybrid", __name__) diff --git a/python/tvm/hybrid/registry.py b/python/tvm/hybrid/registry.py index 9f5c39161150..a1b2b3cd4e39 100644 --- a/python/tvm/hybrid/registry.py +++ b/python/tvm/hybrid/registry.py @@ -17,18 +17,62 @@ """Hybrid Script Parser Function Registry """ # pylint: disable=inconsistent-return-statements import inspect +from enum import Enum from typed_ast import ast3 as ast +import tvm + + +class Category(Enum): + INTRIN = 0 + WITH_SCOPE = 1 + FOR_SCOPE = 2 + SPECIAL_STMT = 3 + class Registry(object): """Registration map All these maps are static """ - intrin = dict() - with_scope = dict() - for_scope = dict() - special_stmt = dict() + functions = dict() + + @staticmethod + def look_up_function(func_name): + """look up a registered function by name""" + if func_name in Registry.functions: + return Registry.functions[func_name][0] + return None + + @staticmethod + def is_intrin(func): + """check whether a function belongs to intrin""" + return (func, Category.INTRIN) in Registry.functions.values() + + @staticmethod + def is_with_scope(func): + """check whether a function belongs to with scope handlers""" + return (func, Category.WITH_SCOPE) in Registry.functions.values() + + @staticmethod + def is_for_scope(func): + """check whether a function belongs to for scope handlers""" + return (func, Category.FOR_SCOPE) in Registry.functions.values() + + @staticmethod + def is_special_stmt(func): + """check whether a function belongs to special stmts""" + return (func, Category.SPECIAL_STMT) in Registry.functions.values() + + @staticmethod + def is_registered(func): + """check whether a function is registered""" + return ( + Registry.is_intrin(func) + or Registry.is_with_scope(func) + or Registry.is_for_scope(func) + or Registry.is_special_stmt(func) + ) class CallArgumentReader(object): @@ -40,99 +84,171 @@ def __init__(self, func_name, args, kwargs, parser): self.kwargs = kwargs self.parser = parser - def get_func_compulsory_arg(self, pos, name): - """Get corresponding function argument from argument list which is compulsory""" - + def get_pos_only_arg(self, pos, name): + """Get corresponding position only function argument from argument list""" if len(self.args) >= pos: arg = self.args[pos - 1] - elif name not in self.kwargs.keys(): + elif name not in self.kwargs: self.parser.report_error(self.func_name + " misses argument " + name) else: arg = self.kwargs[name] return arg - def get_func_optional_arg(self, pos, name, default): - """Get corresponding function argument from argument list which is optional. + def get_kwarg(self, pos, name, default): + """Get corresponding keyword function argument from argument list If user doesn't provide the argument, set it to default value """ - if len(self.args) >= pos: arg = self.args[pos - 1] - elif name in self.kwargs.keys(): + elif name in self.kwargs: arg = self.kwargs[name] else: return default return arg + def get_varargs(self, pos): + """Get corresponding variable argument from argument list""" + if len(self.args) >= pos and len(self.kwargs) == 0: + return self.args[pos - 1 :] + return [] -def func_wrapper(func_name, func_to_register, arg_list, need_parser_and_node, need_body, concise): + def auto_insert_body(self, pos, body): + """Automatically provide body as function call argument""" + if len(self.args) >= pos: + self.args.insert(pos - 1, body) + else: + self.kwargs["body"] = body + + +def func_wrapper(func_name, func_to_register, arg_list, category, concise=False, with_var=False): """Helper function to wrap a function to be registered """ def wrap_func(parser, node, args, kwargs): - reader = CallArgumentReader(func_name, args, kwargs, parser) - internal_args = list() - - if need_body and not isinstance(node, ast.For): - # automatically parse body for with scope handlers - if isinstance(node, ast.With): - # the with scope handler is used inside with context - parser.scope_emitter.new_scope() - parser.scope_emitter.node_stack[-1].extend(reversed(node.body)) - body = parser.get_body() - parser.scope_emitter.pop_scope() + if category == Category.FOR_SCOPE: + # automatically parse loop vars and body for for_scope handlers + loop_var_names = list() + if isinstance(node.target, ast.Name): + loop_var_names.append(node.target.id) + elif isinstance(node.target, ast.Tuple): + for elt in node.target.elts: + if not isinstance(elt, ast.Name): + parser.report_error("Invalid loop var") + loop_var_names.append(elt.id) else: - # the with scope handler is used in concise scoping - if not concise: - parser.report_error("Concise scoping is not allowed here") - body = parser.get_body() - - if need_parser_and_node: - internal_args.append(parser) - internal_args.append(node) - - for i, arg_info in enumerate(arg_list): - if len(arg_info) == 1: - (arg_name,) = arg_info - if need_body and arg_name == "body": - internal_args.append(body) + parser.report_error("Invalid loop var") + loop_vars = [tvm.te.var(name, dtype="int32") for name in loop_var_names] + + parser.scope_emitter.new_scope() + parser.scope_emitter.node_stack[-1].extend(reversed(node.body)) + for loop_var in loop_vars: + parser.scope_emitter.update_symbol(loop_var.name, loop_var) + body = parser.get_body() + parser.scope_emitter.pop_scope() + elif category == Category.WITH_SCOPE: + if not with_var: + if isinstance(node, ast.With) and node.items[0].optional_vars is not None: + parser.report_error("Function " + func_name + " expects no optional vars") + # automatically parse body for with_scope handlers without optional vars + if isinstance(node, ast.With): + parser.scope_emitter.new_scope() + parser.scope_emitter.node_stack[-1].extend(reversed(node.body)) + body = parser.get_body() + parser.scope_emitter.pop_scope() else: - internal_args.append(reader.get_func_compulsory_arg(i + 1, arg_name)) + body = parser.get_body() else: - arg_name, default = arg_info - internal_args.append(reader.get_func_optional_arg(i + 1, arg_name, default=default)) + if isinstance(node, ast.With) and node.items[0].optional_vars is None: + parser.report_error("Function " + func_name + " expects optional vars") + body = None + + if not isinstance(node, ast.With) and not concise: + parser.report_error("Concise scoping is not allowed here") + + reader = CallArgumentReader(func_name, args, kwargs, parser) + pos_only, kwargs, varargs = arg_list + + internal_args = list() + if category == Category.WITH_SCOPE: + if not with_var: + internal_args.extend([parser, node, body]) + else: + internal_args.extend([parser, node]) + elif category == Category.FOR_SCOPE: + internal_args.extend([parser, node, body, loop_vars]) + elif category == Category.SPECIAL_STMT: + internal_args.extend([parser, node]) + + for i, arg_name in enumerate(pos_only): + internal_args.append(reader.get_pos_only_arg(i + 1, arg_name)) + + for i, arg_info in enumerate(kwargs): + arg_name, default = arg_info + internal_args.append(reader.get_kwarg(i + 1 + len(pos_only), arg_name, default=default)) + + if varargs is not None: + internal_args.extend(reader.get_varargs(len(pos_only) + len(kwargs) + 1)) return func_to_register(*internal_args) return wrap_func -def get_arg_list(origin_func, need_parser_and_node): +def get_arg_list(origin_func, category, with_var=False): """Helper function to get the argument list of Function - Parameters ---------- origin_func: function The function to get the argument list - - need_parser_and_node: bool - Whether the function need parser and node in its arguments + category: Category + The category of registered function + with_var: bool, optional + Whether the with scope handler neeeds optional vars """ - full_arg_spec = inspect.getfullargspec(origin_func) args, defaults = full_arg_spec.args, full_arg_spec.defaults if defaults is None: defaults = tuple() - if need_parser_and_node: + + if category == Category.WITH_SCOPE: + if not with_var: + if len(args) < 3 or args[0] != "parser" or args[1] != "node" or args[2] != "body": + raise RuntimeError( + "TVM Hybrid Script register error : the first three arguments of " + "this with scope handler must be parser, node, body" + ) + args = args[3:] + else: + if len(args) < 2 or args[0] != "parser" or args[1] != "node": + raise RuntimeError( + "TVM Hybrid Script register error : the first two arguments of " + "this with scope handler must be parser, node" + ) + args = args[2:] + elif category == Category.FOR_SCOPE: + if ( + len(args) < 4 + or args[0] != "parser" + or args[1] != "node" + or args[2] != "body" + or args[3] != "loop_vars" + ): + raise RuntimeError( + "TVM Hybrid Script register error : the first three arguments of for scope handler" + "must be parser, node, body, loop_vars" + ) + args = args[4:] + elif category == Category.SPECIAL_STMT: + if len(args) < 2 or args[0] != "parser" or args[1] != "node": + raise RuntimeError( + "TVM Hybrid Script register error : the first three arguments of special stmt" + "must be parser, node" + ) args = args[2:] - if full_arg_spec.varargs is not None: - raise RuntimeError( - "TVM Hybrid Script register error : variable argument is not supported now" - ) if full_arg_spec.varkw is not None: raise RuntimeError( "TVM Hybrid Script register error : variable keyword argument is not supported now" @@ -142,95 +258,111 @@ def get_arg_list(origin_func, need_parser_and_node): "TVM Hybrid Script register error : keyword only argument is not supported now" ) - arg_list = list() + pos_only = list() for arg in args[: len(args) - len(defaults)]: - arg_list.append((arg,)) + pos_only.append(arg) + kwargs = list() for default, arg in zip(defaults, args[len(args) - len(defaults) :]): - arg_list.append((arg, default)) + kwargs.append((arg, default)) - return arg_list + return pos_only, kwargs, full_arg_spec.varargs -def register_intrin(origin_func): +def register_intrin(name=None): """Decorator to register function under category intrin - + Parameters + ---------- + name: str, optional + registered name for the function Example ------ .. code-block:: python - @register_intrin def broadcast(value, lanes): lanes = lanes.value if not isinstance(lanes, int) else lanes return tvm.tir.Broadcast(value, lanes) """ - func_name = origin_func.__qualname__ - Registry.intrin[func_name] = func_wrapper( - func_name, - origin_func, - get_arg_list(origin_func, False), - need_parser_and_node=False, - need_body=False, - concise=False, - ) - return origin_func - - -def register_with_scope(concise=False): - """Decorator to register function under with scope handler + def decorate(origin_func): + func_name = "tir." + origin_func.__qualname__ if name is None else name + Registry.functions[func_name] = ( + func_wrapper( + func_name, origin_func, get_arg_list(origin_func, Category.INTRIN), Category.INTRIN + ), + Category.INTRIN, + ) + return origin_func + + return decorate + + +def register_with_scope(concise=False, with_var=False, name=None): + """Decorator to register function under with scope handler Parameters ---------- - concise: bool - whether this scope handler is allowed in concise scoping - + concise: bool, optional + whether this with scope handler is allowed in concise scoping + with_var: bool, optional + whether this with scope handler neeeds optional vars + name: str, optional + registered name for the function Example ------ .. code-block:: python - @register_scope_handler(concise=True) def attr(parser, node, attr_node, attr_key, value, body): - return tvm.tir.AttrStmt(attr_node, attr_key, tvm.runtime.convert(value), body) """ def decorate(origin_func): """Register function under category with_scope""" - func_name = origin_func.__qualname__ - Registry.with_scope[func_name] = func_wrapper( - func_name, - origin_func, - get_arg_list(origin_func, True), - need_parser_and_node=True, - need_body=True, - concise=concise, + func_name = "tir." + origin_func.__qualname__ if name is None else name + Registry.functions[func_name] = ( + func_wrapper( + func_name, + origin_func, + get_arg_list(origin_func, Category.WITH_SCOPE, with_var), + Category.WITH_SCOPE, + concise=concise, + with_var=with_var, + ), + Category.WITH_SCOPE, ) return origin_func return decorate -def register_for_scope(): - """Decorator to register function under for scope handler""" +def register_for_scope(name=None): + """Decorator to register function under for scope handler + Parameters + ---------- + name: str, optional + registered name for the function + """ def decorate(origin_func): - """Register function under category for_scope""" - func_name = origin_func.__qualname__ - Registry.for_scope[func_name] = func_wrapper( - func_name, - origin_func, - get_arg_list(origin_func, True), - need_parser_and_node=True, - need_body=True, - concise=False, + func_name = "tir." + origin_func.__qualname__ if name is None else name + Registry.functions[func_name] = ( + func_wrapper( + func_name, + origin_func, + get_arg_list(origin_func, Category.FOR_SCOPE), + Category.FOR_SCOPE, + ), + Category.FOR_SCOPE, ) return origin_func return decorate -def register_special_stmt(origin_func): +def register_special_stmt(name=None): """Decorator to register function under category special_stmt - + Parameters + ---------- + name: str, optional + registered name for the function Example ------- @register_special_stmt @@ -238,19 +370,22 @@ def buffer_decl(parser, node, shape, dtype="float32", data=None, strides=[], ele scope="global", align=-1, offset_factor=0, buffer_type="default"): align = align.value if not isinstance(align, int) else align offset_factor = offset_factor.value if not isinstance(offset_factor, int) else offset_factor - buffer = tvm.tir.decl_buffer(shape, dtype, parser._assign_target, data, strides, + buffer = tvm.tir.decl_buffer(shape, dtype, parser.assign_target, data, strides, elem_offset, scope, align, offset_factor, buffer_type) return buffer - """ - func_name = origin_func.__qualname__ - Registry.special_stmt[func_name] = func_wrapper( - func_name, - origin_func, - get_arg_list(origin_func, True), - need_parser_and_node=True, - need_body=False, - concise=False, - ) - return origin_func + def decorate(origin_func): + func_name = "tir." + origin_func.__qualname__ if name is None else name + Registry.functions[func_name] = ( + func_wrapper( + func_name, + origin_func, + get_arg_list(origin_func, Category.SPECIAL_STMT), + Category.SPECIAL_STMT, + ), + Category.SPECIAL_STMT, + ) + return origin_func + + return decorate diff --git a/python/tvm/hybrid/scope_handler.py b/python/tvm/hybrid/scope_handler.py index 3b1b7a2c5987..126a3dc2432e 100644 --- a/python/tvm/hybrid/scope_handler.py +++ b/python/tvm/hybrid/scope_handler.py @@ -15,75 +15,181 @@ # specific language governing permissions and limitations # under the License. """Hybrid Script Parser Scope Handler Functions - This module provides the functions registered into parser under with_scope or for_scope category. Scope handler nodes are StmtNodes with body, which are used to handle such scenarios. - +1. For scope handler +When registering a for scope handler, the first 4 arguments must be parser, node, body, loop_vars +and these arguments will provided by Hybrid Script parser automatically .. code-block:: python - - for x in tir.name(): - with tir.name(): - tir.name() # with scope handlers + concise scoping - + for loop_vars in tir.xxx(): +2. With scope handler +There are 4 subtypes of with scope handlers, classified by + 1) with or without as + 2) allow concise scoping or not +1) with as & concise +the first 2 arguments must be parser, node +Need to parse the body manually +Example : tir.alloc_with_scope +.. code-block:: python + target = tir.xxx() + with tir.xxx() as target: +2) with as & not concise +the first 2 arguments must be parser, node +Need to parse the body manually +Example : None atm +.. code-block:: python + with tir.xxx() as target: +3) without as & concise +the first 3 arguments must be parser, node, body +Hybrid Script parser will parse the body automatically +Example : tir.allocate()/tir.realize()/tir.attr() +.. code-block:: python + tir.xxx() + with tir.xxx(): +4) without as & not concise +the first 3 arguments must be parser, node, body +Hybrid Script parser will parse the body automatically +Example : tir.assert()/tir.let() +.. code-block:: python + with tir.xxx(): """ # pylint: disable=redefined-builtin, unused-argument, invalid-name + +from typed_ast import ast3 as ast import tvm.tir from .registry import register_with_scope, register_for_scope # With scope handler -@register_with_scope(concise=False) -def Assert(parser, node, condition, message, body): - """ With scope handler function assert(condition, message, body) """ - - return tvm.tir.AssertStmt(condition, tvm.runtime.convert(message), body) - +@register_with_scope(concise=True, with_var=True) +def allocate(parser, node, extents, dtype, scope, condition=True): + """ With scope handler function tir.alloc_with_scope(var, extents, dtype, scope, condition) """ + # defining buffer var and parse the body manually + + buffer_var = tvm.te.var(parser.target[0], "handle") + # (TODO) Uncomment this line if we have richer type info for buffer var + # buffer_var = tvm.te.var(parser.target[0], tvm.ir.PointerType(tvm.ir.PrimType(dtype))) + if isinstance(node, ast.With): + parser.scope_emitter.new_scope() + parser.scope_emitter.update_symbol(buffer_var.name, buffer_var) + parser.scope_emitter.node_stack[-1].extend(reversed(node.body)) + body = parser.get_body() + parser.scope_emitter.pop_scope() + else: + parser.scope_emitter.update_symbol(buffer_var.name, buffer_var) + body = parser.get_body() + condition = tvm.runtime.convert(condition) + scope = tvm.runtime.convert(scope) + body = tvm.tir.Allocate(buffer_var, dtype, extents, condition, body) + return tvm.tir.AttrStmt(buffer_var, "storage_scope", scope, body) -@register_with_scope(concise=False) -def let(parser, node, var, value, body): - """ With scope handler function let(var, value, body) """ - return tvm.tir.LetStmt(var, value, body) +@register_with_scope(concise=True) +def launch_thread(parser, node, body, env_var, extent): + extent = tvm.runtime.convert(extent) + return tvm.tir.AttrStmt( + tvm.tir.IterVar( + None, env_var, getattr(tvm.tir.IterVar, "ThreadIndex"), parser.var_env_dict[env_var] + ), + "thread_extent", + extent, + body, + ) @register_with_scope(concise=True) -def realize(parser, node, buffer_bounds, body, condition=True): - """ With scope handler function realize(buffer_bounds, condition, body) """ - +def realize(parser, node, body, buffer_bounds, scope, condition=True): + """ With scope handler function tir.realize(buffer_bounds, scope, condition) """ buffer, bounds = buffer_bounds - return tvm.tir.BufferRealize(buffer, bounds, condition, body) + scope = tvm.runtime.convert(scope) + return tvm.tir.AttrStmt( + buffer, "realize_scope", scope, tvm.tir.BufferRealize(buffer, bounds, condition, body) + ) @register_with_scope(concise=True) -def attr(parser, node, attr_node, attr_key, value, body): - """ With scope handler function attr(attr_node, attr_key, value, body) """ +def attr(parser, node, body, attr_node, attr_key, value): + """ With scope handler function tir.attr(attr_node, attr_key, value) """ + attr_node = tvm.runtime.convert(attr_node) + value = tvm.runtime.convert(value) + return tvm.tir.AttrStmt(attr_node, attr_key, value, body) - return tvm.tir.AttrStmt(attr_node, attr_key, tvm.runtime.convert(value), body) +@register_with_scope(concise=False) +def Assert(parser, node, body, condition, message): + """ With scope handler function tir.Assert(condition, message) """ + return tvm.tir.AssertStmt(condition, tvm.runtime.convert(message), body) -@register_with_scope(concise=True) -def allocate(parser, node, buffer_var, dtype, extents, body, condition=True): - """ With scope handler function allocate(buffer_var, dtype, extents, condition, body) """ - return tvm.tir.Allocate(buffer_var, dtype, extents, tvm.runtime.convert(condition), body) +@register_with_scope(concise=False) +def let(parser, node, body, var, value): + """ With scope handler function tir.let(var, value) """ + return tvm.tir.LetStmt(var, value, body) # For scope handler @register_for_scope() -def range(parser, node, begin, end, for_type="serial"): +def serial(parser, node, body, loop_vars, begin, end): + """ For scope handler function tir.serial(begin, end)""" + if len(loop_vars) != 1: + parser.report_error("Expect exact 1 loop var") + ana = tvm.arith.Analyzer() + extent = end if begin == 0 else ana.simplify(end - begin) + return tvm.tir.For(loop_vars[0], begin, extent, 0, 0, body) + + +@register_for_scope() +def parallel(parser, node, body, loop_vars, begin, end): + """ For scope handler function tir.parallel(begin, end)""" + if len(loop_vars) != 1: + parser.report_error("Expect exact 1 loop var") + ana = tvm.arith.Analyzer() + extent = end if begin == 0 else ana.simplify(end - begin) + return tvm.tir.For(loop_vars[0], begin, extent, 1, 0, body) + + +@register_for_scope() +def vectorized(parser, node, body, loop_vars, begin, end): + """ For scope handler function tir.vectorized(begin, end)""" + if len(loop_vars) != 1: + parser.report_error("Expect exact 1 loop var") + ana = tvm.arith.Analyzer() + extent = end if begin == 0 else ana.simplify(end - begin) + return tvm.tir.For(loop_vars[0], begin, extent, 2, 0, body) + + +@register_for_scope() +def unroll(parser, node, body, loop_vars, begin, end): + """ For scope handler function tir.unroll(begin, end)""" + if len(loop_vars) != 1: + parser.report_error("Expect exact 1 loop var") + ana = tvm.arith.Analyzer() + extent = end if begin == 0 else ana.simplify(end - begin) + return tvm.tir.For(loop_vars[0], begin, extent, 3, 0, body) + + +@register_for_scope(name="range") +def Range(parser, node, body, loop_vars, begin, end, annotation=None): """ For scope handler function range(begin, end, annotation)""" + if len(loop_vars) != 1: + parser.report_error("Expect exact 1 loop var") ana = tvm.arith.Analyzer() extent = end if begin == 0 else ana.simplify(end - begin) - loop_var_name = node.target.id - loop_var = tvm.te.var(loop_var_name, dtype="int32") - - parser.scope_emitter.new_scope() - parser.scope_emitter.update_symbol(loop_var_name, loop_var) - parser.scope_emitter.node_stack[-1].extend(reversed(node.body)) - body = parser.get_body() - parser.scope_emitter.pop_scope() - - for_type_dict = {"serial": 0, "parallel": 1, "vectorized": 2, "unroll": 3} - if for_type not in for_type_dict: - parser.report_error("unknown for type " + for_type) - return tvm.tir.For(loop_var, begin, extent, for_type_dict[for_type], 0, body) + if annotation is None: + annotation = [] + else: + annotation = [ + tvm.tir.Annotation(key, tvm.runtime.convert(val) if isinstance(val, str) else val) + for key, val in annotation.items() + ] + return tvm.tir.Loop(loop_vars[0], begin, extent, annotation, body) + + +@register_for_scope() +def grid(parser, node, body, loop_vars, *extents): + """ For scope handler function tir.grid(*extents) """ + if len(loop_vars) != len(extents): + parser.report_error("Inconsitent number of loop vars and extents") + for loop_var, extent in zip(reversed(loop_vars), reversed(extents)): + body = tvm.tir.Loop(loop_var, 0, extent, [], body) + return body diff --git a/python/tvm/hybrid/special_stmt.py b/python/tvm/hybrid/special_stmt.py index 129354db1443..f080071666f2 100644 --- a/python/tvm/hybrid/special_stmt.py +++ b/python/tvm/hybrid/special_stmt.py @@ -15,27 +15,24 @@ # specific language governing permissions and limitations # under the License. """Hybrid Script Parser Special Stmt Functions - This module provides the functions registered into parser under special_stmt category. special_stmt functions don't correspond to an IRNode in the AST directly. It is usually used for some information that is not suitable to be printed directly. - special_stmt can appear as 2 formats - .. code-block:: python - target = tir.name(): tir.name() - +When registering a special stmt, the first two arguments must be parser, node """ -# pylint: disable=unused-argument +# pylint: disable=unused-argument, no-self-argument, inconsistent-return-statements + import tvm.tir from tvm import te from .registry import register_special_stmt -@register_special_stmt -def buffer_bind( +@register_special_stmt() +def match_buffer( parser, node, param, @@ -49,15 +46,12 @@ def buffer_bind( offset_factor=0, buffer_type="default", ): - """Special function buffer_bind(var, shape, dtype, data, strides, elem_offset, scope, align, - offset_factor, buffer_type) - + """Special function match_buffer(var, shape, dtype, data, strides, elem_offset, scope, align, + offset_factor, buffer_type) Example ------- .. code-block:: python - - A = tir.buffer_bind(a, (128, 128), dtype="float32") - + A = tir.match_buffer(a, (128, 128), dtype="float32") """ if param not in parser.params: @@ -69,7 +63,7 @@ def buffer_bind( buffer = tvm.tir.decl_buffer( shape, dtype, - parser._assign_target, + parser.target[0], data, strides, elem_offset, @@ -82,7 +76,7 @@ def buffer_bind( return buffer -@register_special_stmt +@register_special_stmt() def buffer_decl( parser, node, @@ -98,14 +92,12 @@ def buffer_decl( ): """Special function buffer_decl(shape, dtype, data, strides, elem_offset, scope, align, offset_factor, buffer_type) - Example ------- .. code-block:: python - A = tir.buffer_decl((128, 128), dtype="float32") - """ + if strides is None: strides = [] align = align.value if not isinstance(align, int) else align @@ -113,7 +105,7 @@ def buffer_decl( buffer = tvm.tir.decl_buffer( shape, dtype, - parser._assign_target, + parser.target[0], data, strides, elem_offset, @@ -125,20 +117,26 @@ def buffer_decl( return buffer -@register_special_stmt +@register_special_stmt() def var(parser, node, dtype): """ Special function for defining a Var""" - return te.var(parser._assign_target, dtype) + return te.var(parser.target[0], dtype) -@register_special_stmt +@register_special_stmt() +def env_thread(parser, node, env_name): + """ Bind a var to thread env """ + v = te.var(parser.target[0]) + parser.var_env_dict[v] = env_name + return v + + +@register_special_stmt() def func_attr(parser, node, dict_attr): """Special function for declaring the DictAttr of PrimFunc - Example ------- .. code-block:: python - tir.func_attr({"tir.noalias": True, "global_symbol"}) """ diff --git a/python/tvm/hybrid/ty.py b/python/tvm/hybrid/ty.py index a3319474825a..c309fbe9104e 100644 --- a/python/tvm/hybrid/ty.py +++ b/python/tvm/hybrid/ty.py @@ -37,7 +37,7 @@ def __init__(self, vtype): self.type = vtype def evaluate(self): - return self.type + return tvm.ir.PrimType(self.type) class GenericPtrType(TypeGeneric): @@ -61,7 +61,7 @@ def __getitem__(self, vtypes): return ConcreteType(tvm.ir.TupleType([vtype.evaluate() for vtype in vtypes])) -int32 = ConcreteType(tvm.ir.PrimType("int32")) -handle = ConcreteType(tvm.ir.PrimType("handle")) +int32 = ConcreteType("int32") +handle = ConcreteType("handle") Ptr = GenericPtrType() Tuple = GenericTupleType() diff --git a/src/printer/tir_hybrid_printer.cc b/src/printer/tir_hybrid_printer.cc index 8f6b37a2bbc4..b58e5fc2e9a0 100644 --- a/src/printer/tir_hybrid_printer.cc +++ b/src/printer/tir_hybrid_printer.cc @@ -19,10 +19,9 @@ /*! * \file printer/tir_hybrid_printer.cc - * \brief Printer class to print Te IR to python syntax script + * \brief Printer class to print Tensor IR to python syntax script */ -#include #include #include #include @@ -34,6 +33,7 @@ #include #include +#include #include "doc.h" #include "meta_data.h" @@ -48,7 +48,7 @@ class TIRHybridPrinter : public StmtFunctor, public: explicit TIRHybridPrinter(bool show_meta, runtime::TypedPackedFunc annotate = nullptr) - : show_meta_(show_meta), annotate_(annotate), meta_collector_(&meta_) {} + : show_meta_(show_meta), annotate_(std::move(annotate)), meta_collector_(&meta_) {} /*! \brief Print the node */ TVM_DLL Doc Print(const ObjectRef& node); @@ -68,6 +68,8 @@ class TIRHybridPrinter : public StmtFunctor, std::unordered_set var_not_in_headers; /*! \brief buffer collector (buffer defined in BufferMap and BufferAllocation)*/ std::unordered_set buf_not_in_headers; + /*! \breif Map from Var to thread env name */ + std::unordered_map var_env_map_; /*! \brief Map from Var to Doc */ std::unordered_map memo_var_; /*! \brief Map from Buffer to Doc */ @@ -356,7 +358,11 @@ Doc TIRHybridPrinter::VisitExpr_(const StringImmNode* op) { return Doc::StrLiter Doc TIRHybridPrinter::VisitExpr_(const CastNode* op) { Doc doc; - doc << "tir.cast(" << PrintDType(op->dtype) << ", " << Print(op->value) << ")"; + if (cast(op->dtype, op->value)->IsInstance()) { + doc << Print(op->value) << ".astype(" << PrintDType(op->dtype) << ")"; + } else { + doc << "tir.cast(" << Print(op->value) << ", " << PrintDType(op->dtype) << ")"; + } return doc; } @@ -509,6 +515,70 @@ Doc TIRHybridPrinter::VisitStmt_(const LetStmtNode* op) { Doc TIRHybridPrinter::VisitStmt_(const AttrStmtNode* op) { Doc doc; + // merge attr with allocate when possible + if (op->node->IsInstance() && op->attr_key == "storage_scope" && + op->body->IsInstance()) { + const auto* alloc = Downcast(op->body).get(); + if (alloc->buffer_var.same_as(op->node)) { + var_not_in_headers.insert(alloc->buffer_var.get()); + if (current_num_ != num_child_ - 1) { + doc << "with tir.allocate(" << Print(alloc->extents) << ", " << PrintDType(alloc->dtype) + << ", " << Print(op->value); + if (!is_one(alloc->condition)) { + doc << ", " << Print(alloc->condition); + } + doc << ") as " << Print(op->node) << ":"; + doc << Doc::Indent(4, Doc::NewLine() << PrintBody(alloc->body)); + } else { + doc << Print(op->node) << " = tir.allocate(" << Print(alloc->extents) << ", " + << PrintDType(alloc->dtype) << ", " << Print(op->value); + if (!is_one(alloc->condition)) { + doc << ", " << Print(alloc->condition); + } + doc << ")" << Doc::NewLine() << PrintBody(alloc->body); + } + return doc; + } + } + // merge attr with realize when possible + if (op->node->IsInstance() && op->attr_key == "realize_scope" && + op->body->IsInstance()) { + const auto* realize = Downcast(op->body).get(); + if (realize->buffer.same_as(op->node)) { + if (current_num_ != num_child_ - 1) { + doc << "with tir.realize(" << Print(realize->buffer) << Print(realize->bounds) << ", " + << Print(op->value); + if (!is_one(realize->condition)) { + doc << ", " << Print(realize->condition); + } + doc << "):" << Doc::Indent(4, Doc::NewLine() << PrintBody(realize->body)); + } else { + doc << "tir.realize(" << Print(realize->buffer) << Print(realize->bounds) << ", " + << Print(op->value); + if (!is_one(realize->condition)) { + doc << ", " << Print(realize->condition); + } + doc << ")" << Doc::NewLine() << PrintBody(realize->body); + } + return doc; + } + } + // concise thread env + if (op->node->IsInstance() && op->attr_key == "thread_extent") { + const auto* iter_var = Downcast(op->node).get(); + CHECK(!iter_var->dom.defined()); + var_not_in_headers.insert(iter_var->var.get()); + var_env_map_[iter_var->var] = iter_var->thread_tag; + if (current_num_ != num_child_ - 1) { + doc << "with tir.launch_thread(" << Print(iter_var->var) << ", " << Print(op->value) << "):"; + doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body)); + } else { + doc << "tir.launch_thread(" << Print(iter_var->var) << ", " << Print(op->value) << ")"; + doc << Doc::NewLine() << PrintBody(op->body); + } + return doc; + } + // default if (current_num_ != num_child_ - 1) { doc << "with tir.attr(" << Print(op->node) << ", " << Doc::StrLiteral(op->attr_key) << ", " << Print(op->value) << "):"; @@ -545,35 +615,13 @@ Doc TIRHybridPrinter::VisitStmt_(const StoreNode* op) { } Doc TIRHybridPrinter::VisitStmt_(const BufferRealizeNode* op) { - Doc doc; - if (current_num_ != num_child_ - 1) { - doc << "with tir.realize(" << Print(op->buffer) << Print(op->bounds); - if (!is_one(op->condition)) { - doc << ", " << Print(op->condition); - } - doc << "):" << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body)); - } else { - doc << "tir.realize(" << Print(op->buffer) << Print(op->bounds); - if (!is_one(op->condition)) { - doc << ", " << Print(op->condition); - } - doc << ")" << Doc::NewLine() << PrintBody(op->body); - } - return doc; + LOG(FATAL) << "Hybrid Printer Internal Error: All the BufferRealize should be folded with Attr"; + return Doc(); } Doc TIRHybridPrinter::VisitStmt_(const AllocateNode* op) { - Doc doc; - if (current_num_ != num_child_ - 1) { - doc << "with tir.allocate(" << Print(op->buffer_var) << ", " << PrintDType(op->dtype) << ", " - << Print(op->extents) << "):"; - doc << Doc::Indent(4, PrintBody(op->body)); - } else { - doc << "tir.allocate(" << Print(op->buffer_var) << ", " << PrintDType(op->dtype) << ", " - << Print(op->extents) << ")"; - doc << Doc::NewLine() << PrintBody(op->body); - } - return doc; + LOG(FATAL) << "Hybrid Printer Internal Error: All the Allocate should be folded with Attr"; + return Doc(); } Doc TIRHybridPrinter::VisitStmt_(const IfThenElseNode* op) { @@ -618,12 +666,10 @@ inline const char* ForType2String(ForType t) { Doc TIRHybridPrinter::VisitStmt_(const ForNode* op) { Doc doc; var_not_in_headers.insert(op->loop_var.get()); - doc << "for " << Print(op->loop_var) << " in tir.range(" << Print(op->min) << ", " - << Print(op->min + op->extent); - if (op->for_type != ForType::Serial) { - doc << ", " << Doc::StrLiteral(ForType2String(op->for_type)); - } - doc << "):" << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body)); + doc << "for " << Print(op->loop_var) + << " in tir." + std::string(ForType2String(op->for_type)) + "(" << Print(op->min) << ", " + << Print(op->min + op->extent) + << "):" << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body)); return doc; } @@ -734,7 +780,7 @@ Doc TIRHybridPrinter::PrintPrimFunc(const PrimFunc& primFunc) { // print buffer_bind for (const auto& it : op->buffer_map) { buf_not_in_headers.insert(it.second.get()); - body << Print(it.second) << " = tir.buffer_bind("; + body << Print(it.second) << " = tir.match_buffer("; body << Print(it.first) << ", " << memo_buf_decl_[it.second]; body << ")" << Doc::NewLine(); } @@ -785,8 +831,14 @@ Doc TIRHybridPrinter::PrintPrimFunc(const PrimFunc& primFunc) { vars.push_back(it.first.get()); } } - if (!vars.empty()) { + if (!var_env_map_.empty()) { header_var << Doc::NewLine() << "# var definition"; + for (const auto& it : var_env_map_) { + header_var << Doc::NewLine() << Print(it.first) << " = tir.env_thread(" + << Doc::StrLiteral(it.second) << ")"; + } + } + if (!vars.empty()) { std::sort(vars.begin(), vars.end(), [&](const VarNode* a, const VarNode* b) { return memo_var_[GetRef(a)].str() < memo_var_[GetRef(b)].str(); }); @@ -834,7 +886,7 @@ Doc TIRHybridPrinter::PrintBuffer(const BufferNode* op) { return meta_.InMeta(buffer) ? meta_.GetMetaNode(buffer) : AllocBuf(buffer); } -TVM_REGISTER_GLOBAL("tir.hybrid.AsHybrid") +TVM_REGISTER_GLOBAL("hybrid.AsHybrid") .set_body_typed([](const ObjectRef& functions, bool show_meta) { CHECK(functions.as() != nullptr || functions.as() != nullptr); diff --git a/tests/python/unittest/test_hybrid_error_report.py b/tests/python/unittest/test_hybrid_error_report.py index 39b8bfc0251b..9b3c5bb663eb 100644 --- a/tests/python/unittest/test_hybrid_error_report.py +++ b/tests/python/unittest/test_hybrid_error_report.py @@ -26,30 +26,30 @@ @tvm.hybrid.script class Module1: def buffer_bind_missing_args(a: ty.handle) -> None: - A = tir.buffer_bind((16, 16), "float32") + A = tir.match_buffer((16, 16), "float32") @tvm.hybrid.script class Module2: def range_missing_args(a: ty.handle) -> None: - A = tir.buffer_bind(a, (16, 16), "float32") + A = tir.match_buffer(a, (16, 16), "float32") tir.attr(A, "realize_scope", "") tir.realize(A[0:16, 0:16]) - for i in tir.range(16): - for j in tir.range(0, 16): + for i in tir.serial(16): + for j in tir.serial(0, 16): A[i, j] = 0.0 @tvm.hybrid.script class Module3: def undefined_buffer(a: ty.handle) -> None: - A = tir.buffer_bind(a, (16, 16), "float32") + A = tir.match_buffer(a, (16, 16), "float32") tir.attr(A, "realize_scope", "") tir.realize(C[0:16, 0:16]) - for i in tir.range(16): - for j in tir.range(0, 16): + for i in tir.serial(16): + for j in tir.serial(0, 16): A[i, j] = 0.0 @@ -63,12 +63,12 @@ def unsupported_stmt(a: ty.int32) -> None: @tvm.hybrid.script class Module5: def unsupported_function_call(a: ty.handle) -> None: - A = tir.buffer_bind(a, (16, 16), "float32") + A = tir.match_buffer(a, (16, 16), "float32") tir.attr(A, "realize_scope", "") tir.realize(A[0:16, 0:16]) for i in tir.const_range(16): - for j in tir.range(0, 16): + for j in tir.serial(0, 16): A[i, j] = 0.0 @@ -85,6 +85,31 @@ def invalid_concise_scoping() -> None: tir.evaluate(0.0) +@tvm.hybrid.script +class Module8: + def invalid_expr_stmt() -> None: + tir.max(1, 2) + + +@tvm.hybrid.script +class Module9: + def invalid_for_function(a: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16), "float32") + + for i in tir.evaluate(0.0): + for j in tir.serial(0, 16): + A[i, j] = 0.0 + + +@tvm.hybrid.script +class Module10: + def invalid_block_function(a: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16), "float32") + + with tir.evaluate(0.0): + pass + + def wrap_error(module, lineno): with pytest.raises(HybridParserError) as error: mod = module() @@ -103,3 +128,6 @@ def wrap_error(module, lineno): wrap_error(Module5, 70) wrap_error(Module6, 77) wrap_error(Module7, 84) + wrap_error(Module8, 91) + wrap_error(Module9, 99) + wrap_error(Module10, 109) diff --git a/tests/python/unittest/test_hybrid_roundtrip.py b/tests/python/unittest/test_hybrid_roundtrip.py index 90d76a2b9875..ea67a4e6f549 100644 --- a/tests/python/unittest/test_hybrid_roundtrip.py +++ b/tests/python/unittest/test_hybrid_roundtrip.py @@ -28,36 +28,34 @@ def mmult(A: ty.handle, B: ty.handle, C: ty.handle) -> None: # buffer definition C_global = tir.buffer_decl([1024, 1024], elem_offset=0, align=128, offset_factor=1) packedB = tir.buffer_decl([32, 1024, 32], elem_offset=0, align=128, offset_factor=1) - A_1 = tir.buffer_bind(A, [1024, 1024], elem_offset=0, align=128, offset_factor=1) - B_1 = tir.buffer_bind(B, [1024, 1024], elem_offset=0, align=128, offset_factor=1) - C_1 = tir.buffer_bind(C, [1024, 1024], elem_offset=0, align=128, offset_factor=1) + A_1 = tir.match_buffer(A, [1024, 1024], elem_offset=0, align=128, offset_factor=1) + B_1 = tir.match_buffer(B, [1024, 1024], elem_offset=0, align=128, offset_factor=1) + C_1 = tir.match_buffer(C, [1024, 1024], elem_offset=0, align=128, offset_factor=1) # body - tir.attr(packedB, "realize_scope", "") - tir.realize(packedB[0:32, 0:1024, 0:32]) - for x in tir.range(0, 32, "parallel"): - for y in tir.range(0, 1024): - for z in tir.range(0, 32, "vectorized"): + tir.realize(packedB[0:32, 0:1024, 0:32], "") + for x in tir.parallel(0, 32): + for y in tir.serial(0, 1024): + for z in tir.vectorized(0, 32): packedB[x, y, z] = B_1[y, ((x * 32) + z)] - tir.attr(C_1, "realize_scope", "") - tir.realize(C_1[0:1024, 0:1024]) - for x_outer in tir.range(0, 32, "parallel"): - for y_outer in tir.range(0, 32): - tir.attr(C_global, "realize_scope", "global") + tir.realize(C_1[0:1024, 0:1024], "") + for x_outer in tir.parallel(0, 32): + for y_outer in tir.serial(0, 32): tir.realize( C_global[ (x_outer * 32) : ((x_outer * 32) + 32), (y_outer * 32) : ((y_outer * 32) + 32), - ] + ], + "global", ) - for x_c_init in tir.range(0, 32): - for y_c_init in tir.range(0, 32, "vectorized"): + for x_c_init in tir.serial(0, 32): + for y_c_init in tir.vectorized(0, 32): C_global[ (x_c_init + (x_outer * 32)), (y_c_init + (y_outer * 32)) ] = tir.float32(0) - for k_outer in tir.range(0, 256): - for x_c in tir.range(0, 32): - for k_inner in tir.range(0, 4, "unroll"): - for y_c in tir.range(0, 32, "vectorized"): + for k_outer in tir.serial(0, 256): + for x_c in tir.serial(0, 32): + for k_inner in tir.unroll(0, 4): + for y_c in tir.vectorized(0, 32): C_global[(x_c + (x_outer * 32)), (y_c + (y_outer * 32))] = C_global[ (x_c + (x_outer * 32)), (y_c + (y_outer * 32)) ] + ( @@ -68,8 +66,8 @@ def mmult(A: ty.handle, B: ty.handle, C: ty.handle) -> None: tir.floormod((y_c + (y_outer * 32)), 32), ] ) - for x_inner in tir.range(0, 32): - for y_inner in tir.range(0, 32): + for x_inner in tir.serial(0, 32): + for y_inner in tir.serial(0, 32): C_1[(x_inner + (x_outer * 32)), (y_inner + (y_outer * 32))] = C_global[ (x_inner + (x_outer * 32)), (y_inner + (y_outer * 32)) ] @@ -86,19 +84,13 @@ class Module2: def mmult(A: ty.handle, B: ty.handle, C: ty.handle) -> None: # function attr dict tir.func_attr({"global_symbol": "mmult", "tir.noalias": True}) - # var definition - C_global = tir.var("handle") - packedB = tir.var("handle") - A_1 = tir.buffer_bind(A, [1024, 1024], elem_offset=0, align=128, offset_factor=1) - B_1 = tir.buffer_bind(B, [1024, 1024], elem_offset=0, align=128, offset_factor=1) - C_1 = tir.buffer_bind(C, [1024, 1024], elem_offset=0, align=128, offset_factor=1) + A_1 = tir.match_buffer(A, [1024, 1024], elem_offset=0, align=128, offset_factor=1) + B_1 = tir.match_buffer(B, [1024, 1024], elem_offset=0, align=128, offset_factor=1) + C_1 = tir.match_buffer(C, [1024, 1024], elem_offset=0, align=128, offset_factor=1) # body - tir.attr(packedB, "storage_scope", "global") - tir.allocate(packedB, "float32x32", [32768]) - tir.attr(C_global, "storage_scope", "global") - tir.allocate(C_global, "float32", [1024]) - for x in tir.range(0, 32, "parallel"): - for y in tir.range(0, 1024): + packedB = tir.allocate([32768], "float32x32", "global") + for x in tir.parallel(0, 32): + for y in tir.serial(0, 1024): tir.store( packedB, tir.ramp(((x * 32768) + (y * 32)), 1, 32), @@ -110,17 +102,18 @@ def mmult(A: ty.handle, B: ty.handle, C: ty.handle) -> None: ), tir.broadcast(True, 32), ) - for x_outer in tir.range(0, 32): - for y_outer in tir.range(0, 32): - for x_c_init in tir.range(0, 32): + for x_outer in tir.parallel(0, 32): + C_global = tir.allocate([1024], "float32", "global") + for y_outer in tir.serial(0, 32): + for x_c_init in tir.serial(0, 32): tir.store( C_global, tir.ramp((x_c_init * 32), 1, 32), tir.broadcast(tir.float32(0), 32), tir.broadcast(True, 32), ) - for k_outer in tir.range(0, 256): - for x_c in tir.range(0, 32): + for k_outer in tir.serial(0, 256): + for x_c in tir.serial(0, 32): tir.store( C_global, tir.ramp((x_c * 32), 1, 32), @@ -252,8 +245,8 @@ def mmult(A: ty.handle, B: ty.handle, C: ty.handle) -> None: ), tir.broadcast(True, 32), ) - for x_inner in tir.range(0, 32): - for y_inner in tir.range(0, 32): + for x_inner in tir.serial(0, 32): + for y_inner in tir.serial(0, 32): C_1.data[ ((((x_outer * 32768) + (x_inner * 1024)) + (y_outer * 32)) + y_inner) ] = tir.load("float32", C_global, ((x_inner * 32) + y_inner)) @@ -329,14 +322,14 @@ def mmult( tir.tvm_struct_get(arg0, 0, 7, dtype="uint16") == tir.uint16(1) ), "arg0.dtype is expected to be float32" assert 1024 == tir.cast( - "int32", tir.load("int64", arg0_shape, 0) + tir.load("int64", arg0_shape, 0), "int32" ), "Argument arg0.shape[0] has an unsatisfied constraint" assert 1024 == tir.cast( - "int32", tir.load("int64", arg0_shape, 1) + tir.load("int64", arg0_shape, 1), "int32" ), "Argument arg0.shape[1] has an unsatisfied constraint" if not (tir.isnullptr(arg0_strides, dtype="bool")): - assert (1 == tir.cast("int32", tir.load("int64", arg0_strides, 1))) and ( - 1024 == tir.cast("int32", tir.load("int64", arg0_strides, 0)) + assert (1 == tir.cast(tir.load("int64", arg0_strides, 1), "int32")) and ( + 1024 == tir.cast(tir.load("int64", arg0_strides, 0), "int32") ), "arg0.strides: expected to be compact array" tir.evaluate(0) assert tir.uint64(0) == tir.tvm_struct_get( @@ -358,14 +351,14 @@ def mmult( tir.tvm_struct_get(arg1, 0, 7, dtype="uint16") == tir.uint16(1) ), "arg1.dtype is expected to be float32" assert 1024 == tir.cast( - "int32", tir.load("int64", arg1_shape, 0) + tir.load("int64", arg1_shape, 0), "int32" ), "Argument arg1.shape[0] has an unsatisfied constraint" assert 1024 == tir.cast( - "int32", tir.load("int64", arg1_shape, 1) + tir.load("int64", arg1_shape, 1), "int32" ), "Argument arg1.shape[1] has an unsatisfied constraint" if not (tir.isnullptr(arg1_strides, dtype="bool")): - assert (1 == tir.cast("int32", tir.load("int64", arg1_strides, 1))) and ( - 1024 == tir.cast("int32", tir.load("int64", arg1_strides, 0)) + assert (1 == tir.cast(tir.load("int64", arg1_strides, 1), "int32")) and ( + 1024 == tir.cast(tir.load("int64", arg1_strides, 0), "int32") ), "arg1.strides: expected to be compact array" tir.evaluate(0) assert tir.uint64(0) == tir.tvm_struct_get( @@ -390,14 +383,14 @@ def mmult( tir.tvm_struct_get(arg2, 0, 7, dtype="uint16") == tir.uint16(1) ), "arg2.dtype is expected to be float32" assert 1024 == tir.cast( - "int32", tir.load("int64", arg2_shape, 0) + tir.load("int64", arg2_shape, 0), "int32" ), "Argument arg2.shape[0] has an unsatisfied constraint" assert 1024 == tir.cast( - "int32", tir.load("int64", arg2_shape, 1) + tir.load("int64", arg2_shape, 1), "int32" ), "Argument arg2.shape[1] has an unsatisfied constraint" if not (tir.isnullptr(arg2_strides, dtype="bool")): - assert (1 == tir.cast("int32", tir.load("int64", arg2_strides, 1))) and ( - 1024 == tir.cast("int32", tir.load("int64", arg2_strides, 0)) + assert (1 == tir.cast(tir.load("int64", arg2_strides, 1), "int32")) and ( + 1024 == tir.cast(tir.load("int64", arg2_strides, 0), "int32") ), "arg2.strides: expected to be compact array" tir.evaluate(0) assert tir.uint64(0) == tir.tvm_struct_get( @@ -418,8 +411,8 @@ def mmult( ): if tir.isnullptr(packedB, dtype="bool"): tir.evaluate(tir.tvm_throw_last_error(dtype="int32")) - for x in tir.range(0, 32, "parallel"): - for y in tir.range(0, 1024): + for x in tir.parallel(0, 32): + for y in tir.serial(0, 1024): tir.store( packedB, tir.ramp(((x * 32768) + (y * 32)), 1, 32), @@ -431,7 +424,7 @@ def mmult( ), tir.broadcast(True, 32), ) - for x_outer in tir.range(0, 32, "parallel"): + for x_outer in tir.parallel(0, 32): tir.attr(C_global, "storage_scope", "global") tir.attr(C_global, "storage_alignment", 128) with tir.let( @@ -442,16 +435,16 @@ def mmult( ): if tir.isnullptr(C_global, dtype="bool"): tir.evaluate(tir.tvm_throw_last_error(dtype="int32")) - for y_outer in tir.range(0, 32): - for x_c_init in tir.range(0, 32): + for y_outer in tir.serial(0, 32): + for x_c_init in tir.serial(0, 32): tir.store( C_global, tir.ramp((x_c_init * 32), 1, 32), tir.broadcast(tir.float32(0), 32), tir.broadcast(True, 32), ) - for k_outer in tir.range(0, 256): - for x_c in tir.range(0, 32): + for k_outer in tir.serial(0, 256): + for x_c in tir.serial(0, 32): tir.store( C_global, tir.ramp((x_c * 32), 1, 32), @@ -599,8 +592,8 @@ def mmult( ), tir.broadcast(True, 32), ) - for x_inner in tir.range(0, 32): - for y_inner in tir.range(0, 32): + for x_inner in tir.serial(0, 32): + for y_inner in tir.serial(0, 32): C[ ( (((x_outer * 32768) + (x_inner * 1024)) + (y_outer * 32)) @@ -624,12 +617,12 @@ def opt_conv_tensorcore_normalize(A: ty.handle, W: ty.handle, Conv: ty.handle) - # function attr dict tir.func_attr({"global_symbol": "default_function", "tir.noalias": True}) # var definition - blockIdx_x = tir.var("int32") - blockIdx_y = tir.var("int32") - blockIdx_z = tir.var("int32") - threadIdx_x = tir.var("int32") - threadIdx_y = tir.var("int32") - threadIdx_z = tir.var("int32") + bx = tir.env_thread("blockIdx.x") + by = tir.env_thread("blockIdx.y") + bz = tir.env_thread("blockIdx.z") + tx = tir.env_thread("threadIdx.x") + ty = tir.env_thread("threadIdx.y") + tz = tir.env_thread("threadIdx.z") # buffer definition Apad_shared = tir.buffer_decl( [16, 16, 16, 16, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1 @@ -665,47 +658,46 @@ def opt_conv_tensorcore_normalize(A: ty.handle, W: ty.handle, Conv: ty.handle) - ) buffer_4 = tir.buffer_decl([16, 16], scope="wmma.accumulator", align=32, offset_factor=256) buffer_5 = tir.buffer_decl([16, 16], align=32, offset_factor=256) - A_1 = tir.buffer_bind( + A_1 = tir.match_buffer( A, [16, 14, 14, 16, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1 ) - W_1 = tir.buffer_bind( + W_1 = tir.match_buffer( W, [3, 3, 16, 32, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1 ) - Conv_1 = tir.buffer_bind( + Conv_1 = tir.match_buffer( Conv, [16, 14, 14, 32, 16, 16], elem_offset=0, align=128, offset_factor=1 ) # body - tir.attr(Conv_1, "realize_scope", "") - tir.realize(Conv_1[0:16, 0:14, 0:14, 0:32, 0:16, 0:16]) - tir.attr(tir.iter_var(blockIdx_z, None, "ThreadIndex", "blockIdx.z"), "thread_extent", 196) - tir.attr(tir.iter_var(blockIdx_x, None, "ThreadIndex", "blockIdx.x"), "thread_extent", 2) - tir.attr(tir.iter_var(blockIdx_y, None, "ThreadIndex", "blockIdx.y"), "thread_extent", 4) - tir.attr(tir.iter_var(threadIdx_y, None, "ThreadIndex", "threadIdx.y"), "thread_extent", 4) - tir.attr(tir.iter_var(threadIdx_z, None, "ThreadIndex", "threadIdx.z"), "thread_extent", 2) - tir.attr(Conv_wmma_accumulator, "realize_scope", "wmma.accumulator") + tir.realize(Conv_1[0:16, 0:14, 0:14, 0:32, 0:16, 0:16], "") + tir.launch_thread(bz, 196) + tir.launch_thread(bx, 2) + tir.launch_thread(by, 4) + tir.launch_thread(ty, 4) + tir.launch_thread(tz, 2) tir.realize( Conv_wmma_accumulator[ - ((blockIdx_x * 8) + (threadIdx_y * 2)) : (((blockIdx_x * 8) + (threadIdx_y * 2)) + 2), - tir.floordiv(blockIdx_z, 14) : (tir.floordiv(blockIdx_z, 14) + 1), - tir.floormod(blockIdx_z, 14) : (tir.floormod(blockIdx_z, 14) + 1), - ((blockIdx_y * 8) + (threadIdx_z * 4)) : (((blockIdx_y * 8) + (threadIdx_z * 4)) + 4), + ((bx * 8) + (ty * 2)) : (((bx * 8) + (ty * 2)) + 2), + tir.floordiv(bz, 14) : (tir.floordiv(bz, 14) + 1), + tir.floormod(bz, 14) : (tir.floormod(bz, 14) + 1), + ((by * 8) + (tz * 4)) : (((by * 8) + (tz * 4)) + 4), 0:16, 0:16, - ] + ], + "wmma.accumulator", ) - for n_c_init in tir.range(0, 2): - for o_c_init in tir.range(0, 4): + for n_c_init in tir.serial(0, 2): + for o_c_init in tir.serial(0, 4): tir.attr( [BC, Conv_wmma_accumulator], "buffer_bind_scope", tir.tvm_tuple( - (n_c_init + ((blockIdx_x * 8) + (threadIdx_y * 2))), + (n_c_init + ((bx * 8) + (ty * 2))), 1, - tir.floordiv(blockIdx_z, 14), + tir.floordiv(bz, 14), 1, - tir.floormod(blockIdx_z, 14), + tir.floormod(bz, 14), 1, - (o_c_init + ((blockIdx_y * 8) + (threadIdx_z * 4))), + (o_c_init + ((by * 8) + (tz * 4))), 1, 0, 16, @@ -725,119 +717,105 @@ def opt_conv_tensorcore_normalize(A: ty.handle, W: ty.handle, Conv: ty.handle) - dtype="handle", ) ) - for ic_outer in tir.range(0, 8): - for kh in tir.range(0, 3): - tir.attr(Apad_shared, "realize_scope", "shared") + for ic_outer in tir.serial(0, 8): + for kh in tir.serial(0, 3): tir.realize( Apad_shared[ - (blockIdx_x * 8) : ((blockIdx_x * 8) + 8), - (tir.floordiv(blockIdx_z, 14) + kh) : ((tir.floordiv(blockIdx_z, 14) + kh) + 1), - tir.floormod(blockIdx_z, 14) : (tir.floormod(blockIdx_z, 14) + 3), + (bx * 8) : ((bx * 8) + 8), + (tir.floordiv(bz, 14) + kh) : ((tir.floordiv(bz, 14) + kh) + 1), + tir.floormod(bz, 14) : (tir.floormod(bz, 14) + 3), (ic_outer * 2) : ((ic_outer * 2) + 2), 0:16, 0:16, - ] + ], + "shared", ) - for ax2 in tir.range(0, 3): - for ax3 in tir.range(0, 2): - for ax4_ax5_fused_outer in tir.range(0, 8): - tir.attr( - tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), - "thread_extent", - 32, - ) + for ax2 in tir.serial(0, 3): + for ax3 in tir.serial(0, 2): + for ax4_ax5_fused_outer in tir.serial(0, 8): + tir.launch_thread(tx, 32) Apad_shared[ - ((threadIdx_z + (threadIdx_y * 2)) + (blockIdx_x * 8)), - (tir.floordiv(blockIdx_z, 14) + kh), - (ax2 + tir.floormod(blockIdx_z, 14)), + ((tz + (ty * 2)) + (bx * 8)), + (tir.floordiv(bz, 14) + kh), + (ax2 + tir.floormod(bz, 14)), (ax3 + (ic_outer * 2)), - tir.floordiv((threadIdx_x + (ax4_ax5_fused_outer * 32)), 16), - tir.floormod((threadIdx_x + (ax4_ax5_fused_outer * 32)), 16), + tir.floordiv((tx + (ax4_ax5_fused_outer * 32)), 16), + tir.floormod((tx + (ax4_ax5_fused_outer * 32)), 16), ] = tir.if_then_else( ( ( ( - ((tir.floordiv(blockIdx_z, 14) + kh) >= 1) - and (((tir.floordiv(blockIdx_z, 14) + kh) - 1) < 14) + ((tir.floordiv(bz, 14) + kh) >= 1) + and (((tir.floordiv(bz, 14) + kh) - 1) < 14) ) - and ((ax2 + tir.floormod(blockIdx_z, 14)) >= 1) + and ((ax2 + tir.floormod(bz, 14)) >= 1) ) - and (((ax2 + tir.floormod(blockIdx_z, 14)) - 1) < 14) + and (((ax2 + tir.floormod(bz, 14)) - 1) < 14) ), A_1[ - ((threadIdx_z + (threadIdx_y * 2)) + (blockIdx_x * 8)), - ((tir.floordiv(blockIdx_z, 14) + kh) - 1), - ((ax2 + tir.floormod(blockIdx_z, 14)) - 1), + ((tz + (ty * 2)) + (bx * 8)), + ((tir.floordiv(bz, 14) + kh) - 1), + ((ax2 + tir.floormod(bz, 14)) - 1), (ax3 + (ic_outer * 2)), - tir.floordiv((threadIdx_x + (ax4_ax5_fused_outer * 32)), 16), - tir.floormod((threadIdx_x + (ax4_ax5_fused_outer * 32)), 16), + tir.floordiv((tx + (ax4_ax5_fused_outer * 32)), 16), + tir.floormod((tx + (ax4_ax5_fused_outer * 32)), 16), ], tir.float16(0), dtype="float16", ) - tir.attr(W_shared, "realize_scope", "shared") tir.realize( W_shared[ kh : (kh + 1), 0:3, (ic_outer * 2) : ((ic_outer * 2) + 2), - (blockIdx_y * 8) : ((blockIdx_y * 8) + 8), + (by * 8) : ((by * 8) + 8), 0:16, 0:16, - ] + ], + "shared", ) - for ax1 in tir.range(0, 3): - for ax2_1 in tir.range(0, 2): - tir.attr( - tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), - "thread_extent", - 32, - ) - for ax4_ax5_fused_inner in tir.range(0, 8, "vectorized"): + for ax1 in tir.serial(0, 3): + for ax2_1 in tir.serial(0, 2): + tir.launch_thread(tx, 32) + for ax4_ax5_fused_inner in tir.vectorized(0, 8): W_shared[ kh, ax1, (ax2_1 + (ic_outer * 2)), - ((threadIdx_z + (threadIdx_y * 2)) + (blockIdx_y * 8)), - tir.floordiv((ax4_ax5_fused_inner + (threadIdx_x * 8)), 16), - tir.floormod((ax4_ax5_fused_inner + (threadIdx_x * 8)), 16), + ((tz + (ty * 2)) + (by * 8)), + tir.floordiv((ax4_ax5_fused_inner + (tx * 8)), 16), + tir.floormod((ax4_ax5_fused_inner + (tx * 8)), 16), ] = W_1[ kh, ax1, (ax2_1 + (ic_outer * 2)), - ((threadIdx_z + (threadIdx_y * 2)) + (blockIdx_y * 8)), - tir.floordiv((ax4_ax5_fused_inner + (threadIdx_x * 8)), 16), - tir.floormod((ax4_ax5_fused_inner + (threadIdx_x * 8)), 16), + ((tz + (ty * 2)) + (by * 8)), + tir.floordiv((ax4_ax5_fused_inner + (tx * 8)), 16), + tir.floormod((ax4_ax5_fused_inner + (tx * 8)), 16), ] - for ic_inner in tir.range(0, 2): - for kw in tir.range(0, 3): - tir.attr(Apad_shared_wmma_matrix_a, "realize_scope", "wmma.matrix_a") + for ic_inner in tir.serial(0, 2): + for kw in tir.serial(0, 3): tir.realize( Apad_shared_wmma_matrix_a[ - ((blockIdx_x * 8) + (threadIdx_y * 2)) : ( - ((blockIdx_x * 8) + (threadIdx_y * 2)) + 2 - ), - (tir.floordiv(blockIdx_z, 14) + kh) : ( - (tir.floordiv(blockIdx_z, 14) + kh) + 1 - ), - (kw + tir.floormod(blockIdx_z, 14)) : ( - (kw + tir.floormod(blockIdx_z, 14)) + 1 - ), + ((bx * 8) + (ty * 2)) : (((bx * 8) + (ty * 2)) + 2), + (tir.floordiv(bz, 14) + kh) : ((tir.floordiv(bz, 14) + kh) + 1), + (kw + tir.floormod(bz, 14)) : ((kw + tir.floormod(bz, 14)) + 1), ((ic_outer * 2) + ic_inner) : (((ic_outer * 2) + ic_inner) + 1), 0:16, 0:16, - ] + ], + "wmma.matrix_a", ) - for ax0 in tir.range(0, 2): + for ax0 in tir.serial(0, 2): tir.attr( [buffer, Apad_shared], "buffer_bind_scope", tir.tvm_tuple( - (ax0 + ((blockIdx_x * 8) + (threadIdx_y * 2))), + (ax0 + ((bx * 8) + (ty * 2))), 1, - (tir.floordiv(blockIdx_z, 14) + kh), + (tir.floordiv(bz, 14) + kh), 1, - (kw + tir.floormod(blockIdx_z, 14)), + (kw + tir.floormod(bz, 14)), 1, ((ic_outer * 2) + ic_inner), 1, @@ -852,11 +830,11 @@ def opt_conv_tensorcore_normalize(A: ty.handle, W: ty.handle, Conv: ty.handle) - [buffer_1, Apad_shared_wmma_matrix_a], "buffer_bind_scope", tir.tvm_tuple( - (ax0 + ((blockIdx_x * 8) + (threadIdx_y * 2))), + (ax0 + ((bx * 8) + (ty * 2))), 1, - (tir.floordiv(blockIdx_z, 14) + kh), + (tir.floordiv(bz, 14) + kh), 1, - (kw + tir.floormod(blockIdx_z, 14)), + (kw + tir.floormod(bz, 14)), 1, ((ic_outer * 2) + ic_inner), 1, @@ -887,20 +865,18 @@ def opt_conv_tensorcore_normalize(A: ty.handle, W: ty.handle, Conv: ty.handle) - dtype="handle", ) ) - tir.attr(W_shared_wmma_matrix_b, "realize_scope", "wmma.matrix_b") tir.realize( W_shared_wmma_matrix_b[ kh : (kh + 1), kw : (kw + 1), ((ic_outer * 2) + ic_inner) : (((ic_outer * 2) + ic_inner) + 1), - ((blockIdx_y * 8) + (threadIdx_z * 4)) : ( - ((blockIdx_y * 8) + (threadIdx_z * 4)) + 4 - ), + ((by * 8) + (tz * 4)) : (((by * 8) + (tz * 4)) + 4), 0:16, 0:16, - ] + ], + "wmma.matrix_b", ) - for ax3_1 in tir.range(0, 4): + for ax3_1 in tir.serial(0, 4): tir.attr( [buffer_2, W_shared], "buffer_bind_scope", @@ -911,7 +887,7 @@ def opt_conv_tensorcore_normalize(A: ty.handle, W: ty.handle, Conv: ty.handle) - 1, ((ic_outer * 2) + ic_inner), 1, - (ax3_1 + ((blockIdx_y * 8) + (threadIdx_z * 4))), + (ax3_1 + ((by * 8) + (tz * 4))), 1, 0, 16, @@ -930,7 +906,7 @@ def opt_conv_tensorcore_normalize(A: ty.handle, W: ty.handle, Conv: ty.handle) - 1, ((ic_outer * 2) + ic_inner), 1, - (ax3_1 + ((blockIdx_y * 8) + (threadIdx_z * 4))), + (ax3_1 + ((by * 8) + (tz * 4))), 1, 0, 16, @@ -959,17 +935,17 @@ def opt_conv_tensorcore_normalize(A: ty.handle, W: ty.handle, Conv: ty.handle) - dtype="handle", ) ) - for n_c in tir.range(0, 2): - for o_c in tir.range(0, 4): + for n_c in tir.serial(0, 2): + for o_c in tir.serial(0, 4): tir.attr( [BA, Apad_shared_wmma_matrix_a], "buffer_bind_scope", tir.tvm_tuple( - (n_c + ((blockIdx_x * 8) + (threadIdx_y * 2))), + (n_c + ((bx * 8) + (ty * 2))), 1, - (tir.floordiv(blockIdx_z, 14) + kh), + (tir.floordiv(bz, 14) + kh), 1, - (tir.floormod(blockIdx_z, 14) + kw), + (tir.floormod(bz, 14) + kw), 1, ((ic_outer * 2) + ic_inner), 1, @@ -990,7 +966,7 @@ def opt_conv_tensorcore_normalize(A: ty.handle, W: ty.handle, Conv: ty.handle) - 1, ((ic_outer * 2) + ic_inner), 1, - (o_c + ((blockIdx_y * 8) + (threadIdx_z * 4))), + (o_c + ((by * 8) + (tz * 4))), 1, 0, 16, @@ -1003,13 +979,13 @@ def opt_conv_tensorcore_normalize(A: ty.handle, W: ty.handle, Conv: ty.handle) - [BC, Conv_wmma_accumulator], "buffer_bind_scope", tir.tvm_tuple( - (n_c + ((blockIdx_x * 8) + (threadIdx_y * 2))), + (n_c + ((bx * 8) + (ty * 2))), 1, - tir.floordiv(blockIdx_z, 14), + tir.floordiv(bz, 14), 1, - tir.floormod(blockIdx_z, 14), + tir.floormod(bz, 14), 1, - (o_c + ((blockIdx_y * 8) + (threadIdx_z * 4))), + (o_c + ((by * 8) + (tz * 4))), 1, 0, 16, @@ -1031,19 +1007,19 @@ def opt_conv_tensorcore_normalize(A: ty.handle, W: ty.handle, Conv: ty.handle) - dtype="handle", ) ) - for n_inner in tir.range(0, 2): - for o_inner in tir.range(0, 4): + for n_inner in tir.serial(0, 2): + for o_inner in tir.serial(0, 4): tir.attr( [buffer_4, Conv_wmma_accumulator], "buffer_bind_scope", tir.tvm_tuple( - ((((blockIdx_x * 4) + threadIdx_y) * 2) + n_inner), + ((((bx * 4) + ty) * 2) + n_inner), 1, - tir.floordiv(blockIdx_z, 14), + tir.floordiv(bz, 14), 1, - tir.floormod(blockIdx_z, 14), + tir.floormod(bz, 14), 1, - ((((blockIdx_y * 2) + threadIdx_z) * 4) + o_inner), + ((((by * 2) + tz) * 4) + o_inner), 1, 0, 16, @@ -1056,13 +1032,13 @@ def opt_conv_tensorcore_normalize(A: ty.handle, W: ty.handle, Conv: ty.handle) - [buffer_5, Conv_1], "buffer_bind_scope", tir.tvm_tuple( - ((((blockIdx_x * 4) + threadIdx_y) * 2) + n_inner), + ((((bx * 4) + ty) * 2) + n_inner), 1, - tir.floordiv(blockIdx_z, 14), + tir.floordiv(bz, 14), 1, - tir.floormod(blockIdx_z, 14), + tir.floormod(bz, 14), 1, - ((((blockIdx_y * 2) + threadIdx_z) * 4) + o_inner), + ((((by * 2) + tz) * 4) + o_inner), 1, 0, 16, @@ -1103,43 +1079,32 @@ def test_opt_conv_tensorcore_normalize(): def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> None: # function attr dict tir.func_attr({"global_symbol": "default_function", "tir.noalias": True}) - # var definition - Apad_shared = tir.var("handle") - Apad_shared_wmma_matrix_a = tir.var("handle") - Conv_wmma_accumulator = tir.var("handle") - W_shared = tir.var("handle") - W_shared_wmma_matrix_b = tir.var("handle") - blockIdx_x = tir.var("int32") - blockIdx_y = tir.var("int32") - blockIdx_z = tir.var("int32") - threadIdx_x = tir.var("int32") - threadIdx_y = tir.var("int32") - threadIdx_z = tir.var("int32") - A_1 = tir.buffer_bind( + # body + A_1 = tir.match_buffer( A, [16, 14, 14, 16, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1 ) - W_1 = tir.buffer_bind( + W_1 = tir.match_buffer( W, [3, 3, 16, 32, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1 ) - Conv_1 = tir.buffer_bind( + Conv_1 = tir.match_buffer( Conv, [16, 14, 14, 32, 16, 16], elem_offset=0, align=128, offset_factor=1 ) - # body - tir.attr(tir.iter_var(blockIdx_z, None, "ThreadIndex", "blockIdx.z"), "thread_extent", 196) - tir.attr(Conv_wmma_accumulator, "storage_scope", "wmma.accumulator") - tir.allocate(Conv_wmma_accumulator, "float32", [2048]) - tir.attr(Apad_shared, "storage_scope", "shared") - tir.allocate(Apad_shared, "float16", [12288]) - tir.attr(W_shared, "storage_scope", "shared") - tir.allocate(W_shared, "float16", [12288]) - tir.attr(Apad_shared_wmma_matrix_a, "storage_scope", "wmma.matrix_a") - tir.allocate(Apad_shared_wmma_matrix_a, "float16", [512]) - tir.attr(W_shared_wmma_matrix_b, "storage_scope", "wmma.matrix_b") - tir.allocate(W_shared_wmma_matrix_b, "float16", [1024]) - tir.attr(tir.iter_var(blockIdx_x, None, "ThreadIndex", "blockIdx.x"), "thread_extent", 2) - tir.attr(tir.iter_var(blockIdx_y, None, "ThreadIndex", "blockIdx.y"), "thread_extent", 4) - tir.attr(tir.iter_var(threadIdx_y, None, "ThreadIndex", "threadIdx.y"), "thread_extent", 4) - tir.attr(tir.iter_var(threadIdx_z, None, "ThreadIndex", "threadIdx.z"), "thread_extent", 2) + bx = tir.env_thread("blockIdx.x") + by = tir.env_thread("blockIdx.y") + bz = tir.env_thread("blockIdx.z") + tx = tir.env_thread("threadIdx.x") + ty = tir.env_thread("threadIdx.y") + tz = tir.env_thread("threadIdx.z") + tir.launch_thread(bz, 196) + Conv_wmma_accumulator = tir.allocate([2048], "float32", "wmma.accumulator") + Apad_shared = tir.allocate([12288], "float16", "shared") + W_shared = tir.allocate([12288], "float16", "shared") + Apad_shared_wmma_matrix_a = tir.allocate([512], "float16", "wmma.matrix_a") + W_shared_wmma_matrix_b = tir.allocate([1024], "float16", "wmma.matrix_b") + tir.launch_thread(bx, 2) + tir.launch_thread(by, 4) + tir.launch_thread(ty, 4) + tir.launch_thread(tz, 2) tir.evaluate( tir.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 0, tir.float32(0), dtype="handle") ) @@ -1164,29 +1129,22 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No tir.evaluate( tir.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 7, tir.float32(0), dtype="handle") ) - for ic_outer in tir.range(0, 8): - for kh in tir.range(0, 3): - for ax2 in tir.range(0, 3): - with tir.attr( - tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), - "thread_extent", - 32, - ): + for ic_outer in tir.serial(0, 8): + for kh in tir.serial(0, 3): + for ax2 in tir.serial(0, 3): + with tir.launch_thread(tx, 32): Apad_shared[ - ( - (((threadIdx_y * 3072) + (threadIdx_z * 1536)) + (ax2 * 512)) - + threadIdx_x - ) + ((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) ] = tir.if_then_else( ( ( ( - (1 <= (tir.floordiv(blockIdx_z, 14) + kh)) - and ((tir.floordiv(blockIdx_z, 14) + kh) < 15) + (1 <= (tir.floordiv(bz, 14) + kh)) + and ((tir.floordiv(bz, 14) + kh) < 15) ) - and (1 <= (ax2 + tir.floormod(blockIdx_z, 14))) + and (1 <= (ax2 + tir.floormod(bz, 14))) ) - and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15) + and ((ax2 + tir.floormod(bz, 14)) < 15) ), tir.load( "float16", @@ -1198,21 +1156,18 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No ( ( ( - ( - (blockIdx_x * 6422528) - + (threadIdx_y * 1605632) - ) - + (threadIdx_z * 802816) + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) ) + (kh * 57344) ) - + (blockIdx_z * 4096) + + (bz * 4096) ) + (ax2 * 4096) ) + (ic_outer * 512) ) - + threadIdx_x + + tx ) - 61440 ), @@ -1220,29 +1175,19 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No tir.float16(0), dtype="float16", ) - with tir.attr( - tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), - "thread_extent", - 32, - ): + with tir.launch_thread(tx, 32): Apad_shared[ - ( - ( - (((threadIdx_y * 3072) + (threadIdx_z * 1536)) + (ax2 * 512)) - + threadIdx_x - ) - + 32 - ) + (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 32) ] = tir.if_then_else( ( ( ( - (1 <= (tir.floordiv(blockIdx_z, 14) + kh)) - and ((tir.floordiv(blockIdx_z, 14) + kh) < 15) + (1 <= (tir.floordiv(bz, 14) + kh)) + and ((tir.floordiv(bz, 14) + kh) < 15) ) - and (1 <= (ax2 + tir.floormod(blockIdx_z, 14))) + and (1 <= (ax2 + tir.floormod(bz, 14))) ) - and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15) + and ((ax2 + tir.floormod(bz, 14)) < 15) ), tir.load( "float16", @@ -1254,21 +1199,18 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No ( ( ( - ( - (blockIdx_x * 6422528) - + (threadIdx_y * 1605632) - ) - + (threadIdx_z * 802816) + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) ) + (kh * 57344) ) - + (blockIdx_z * 4096) + + (bz * 4096) ) + (ax2 * 4096) ) + (ic_outer * 512) ) - + threadIdx_x + + tx ) - 61408 ), @@ -1276,29 +1218,19 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No tir.float16(0), dtype="float16", ) - with tir.attr( - tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), - "thread_extent", - 32, - ): + with tir.launch_thread(tx, 32): Apad_shared[ - ( - ( - (((threadIdx_y * 3072) + (threadIdx_z * 1536)) + (ax2 * 512)) - + threadIdx_x - ) - + 64 - ) + (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 64) ] = tir.if_then_else( ( ( ( - (1 <= (tir.floordiv(blockIdx_z, 14) + kh)) - and ((tir.floordiv(blockIdx_z, 14) + kh) < 15) + (1 <= (tir.floordiv(bz, 14) + kh)) + and ((tir.floordiv(bz, 14) + kh) < 15) ) - and (1 <= (ax2 + tir.floormod(blockIdx_z, 14))) + and (1 <= (ax2 + tir.floormod(bz, 14))) ) - and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15) + and ((ax2 + tir.floormod(bz, 14)) < 15) ), tir.load( "float16", @@ -1310,21 +1242,18 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No ( ( ( - ( - (blockIdx_x * 6422528) - + (threadIdx_y * 1605632) - ) - + (threadIdx_z * 802816) + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) ) + (kh * 57344) ) - + (blockIdx_z * 4096) + + (bz * 4096) ) + (ax2 * 4096) ) + (ic_outer * 512) ) - + threadIdx_x + + tx ) - 61376 ), @@ -1332,29 +1261,19 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No tir.float16(0), dtype="float16", ) - with tir.attr( - tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), - "thread_extent", - 32, - ): + with tir.launch_thread(tx, 32): Apad_shared[ - ( - ( - (((threadIdx_y * 3072) + (threadIdx_z * 1536)) + (ax2 * 512)) - + threadIdx_x - ) - + 96 - ) + (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 96) ] = tir.if_then_else( ( ( ( - (1 <= (tir.floordiv(blockIdx_z, 14) + kh)) - and ((tir.floordiv(blockIdx_z, 14) + kh) < 15) + (1 <= (tir.floordiv(bz, 14) + kh)) + and ((tir.floordiv(bz, 14) + kh) < 15) ) - and (1 <= (ax2 + tir.floormod(blockIdx_z, 14))) + and (1 <= (ax2 + tir.floormod(bz, 14))) ) - and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15) + and ((ax2 + tir.floormod(bz, 14)) < 15) ), tir.load( "float16", @@ -1366,21 +1285,18 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No ( ( ( - ( - (blockIdx_x * 6422528) - + (threadIdx_y * 1605632) - ) - + (threadIdx_z * 802816) + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) ) + (kh * 57344) ) - + (blockIdx_z * 4096) + + (bz * 4096) ) + (ax2 * 4096) ) + (ic_outer * 512) ) - + threadIdx_x + + tx ) - 61344 ), @@ -1388,29 +1304,19 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No tir.float16(0), dtype="float16", ) - with tir.attr( - tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), - "thread_extent", - 32, - ): + with tir.launch_thread(tx, 32): Apad_shared[ - ( - ( - (((threadIdx_y * 3072) + (threadIdx_z * 1536)) + (ax2 * 512)) - + threadIdx_x - ) - + 128 - ) + (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 128) ] = tir.if_then_else( ( ( ( - (1 <= (tir.floordiv(blockIdx_z, 14) + kh)) - and ((tir.floordiv(blockIdx_z, 14) + kh) < 15) + (1 <= (tir.floordiv(bz, 14) + kh)) + and ((tir.floordiv(bz, 14) + kh) < 15) ) - and (1 <= (ax2 + tir.floormod(blockIdx_z, 14))) + and (1 <= (ax2 + tir.floormod(bz, 14))) ) - and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15) + and ((ax2 + tir.floormod(bz, 14)) < 15) ), tir.load( "float16", @@ -1422,21 +1328,18 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No ( ( ( - ( - (blockIdx_x * 6422528) - + (threadIdx_y * 1605632) - ) - + (threadIdx_z * 802816) + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) ) + (kh * 57344) ) - + (blockIdx_z * 4096) + + (bz * 4096) ) + (ax2 * 4096) ) + (ic_outer * 512) ) - + threadIdx_x + + tx ) - 61312 ), @@ -1444,29 +1347,19 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No tir.float16(0), dtype="float16", ) - with tir.attr( - tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), - "thread_extent", - 32, - ): + with tir.launch_thread(tx, 32): Apad_shared[ - ( - ( - (((threadIdx_y * 3072) + (threadIdx_z * 1536)) + (ax2 * 512)) - + threadIdx_x - ) - + 160 - ) + (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 160) ] = tir.if_then_else( ( ( ( - (1 <= (tir.floordiv(blockIdx_z, 14) + kh)) - and ((tir.floordiv(blockIdx_z, 14) + kh) < 15) + (1 <= (tir.floordiv(bz, 14) + kh)) + and ((tir.floordiv(bz, 14) + kh) < 15) ) - and (1 <= (ax2 + tir.floormod(blockIdx_z, 14))) + and (1 <= (ax2 + tir.floormod(bz, 14))) ) - and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15) + and ((ax2 + tir.floormod(bz, 14)) < 15) ), tir.load( "float16", @@ -1478,21 +1371,18 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No ( ( ( - ( - (blockIdx_x * 6422528) - + (threadIdx_y * 1605632) - ) - + (threadIdx_z * 802816) + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) ) + (kh * 57344) ) - + (blockIdx_z * 4096) + + (bz * 4096) ) + (ax2 * 4096) ) + (ic_outer * 512) ) - + threadIdx_x + + tx ) - 61280 ), @@ -1500,29 +1390,19 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No tir.float16(0), dtype="float16", ) - with tir.attr( - tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), - "thread_extent", - 32, - ): + with tir.launch_thread(tx, 32): Apad_shared[ - ( - ( - (((threadIdx_y * 3072) + (threadIdx_z * 1536)) + (ax2 * 512)) - + threadIdx_x - ) - + 192 - ) + (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 192) ] = tir.if_then_else( ( ( ( - (1 <= (tir.floordiv(blockIdx_z, 14) + kh)) - and ((tir.floordiv(blockIdx_z, 14) + kh) < 15) + (1 <= (tir.floordiv(bz, 14) + kh)) + and ((tir.floordiv(bz, 14) + kh) < 15) ) - and (1 <= (ax2 + tir.floormod(blockIdx_z, 14))) + and (1 <= (ax2 + tir.floormod(bz, 14))) ) - and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15) + and ((ax2 + tir.floormod(bz, 14)) < 15) ), tir.load( "float16", @@ -1534,21 +1414,18 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No ( ( ( - ( - (blockIdx_x * 6422528) - + (threadIdx_y * 1605632) - ) - + (threadIdx_z * 802816) + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) ) + (kh * 57344) ) - + (blockIdx_z * 4096) + + (bz * 4096) ) + (ax2 * 4096) ) + (ic_outer * 512) ) - + threadIdx_x + + tx ) - 61248 ), @@ -1556,29 +1433,19 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No tir.float16(0), dtype="float16", ) - with tir.attr( - tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), - "thread_extent", - 32, - ): + with tir.launch_thread(tx, 32): Apad_shared[ - ( - ( - (((threadIdx_y * 3072) + (threadIdx_z * 1536)) + (ax2 * 512)) - + threadIdx_x - ) - + 224 - ) + (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 224) ] = tir.if_then_else( ( ( ( - (1 <= (tir.floordiv(blockIdx_z, 14) + kh)) - and ((tir.floordiv(blockIdx_z, 14) + kh) < 15) + (1 <= (tir.floordiv(bz, 14) + kh)) + and ((tir.floordiv(bz, 14) + kh) < 15) ) - and (1 <= (ax2 + tir.floormod(blockIdx_z, 14))) + and (1 <= (ax2 + tir.floormod(bz, 14))) ) - and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15) + and ((ax2 + tir.floormod(bz, 14)) < 15) ), tir.load( "float16", @@ -1590,21 +1457,18 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No ( ( ( - ( - (blockIdx_x * 6422528) - + (threadIdx_y * 1605632) - ) - + (threadIdx_z * 802816) + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) ) + (kh * 57344) ) - + (blockIdx_z * 4096) + + (bz * 4096) ) + (ax2 * 4096) ) + (ic_outer * 512) ) - + threadIdx_x + + tx ) - 61216 ), @@ -1612,29 +1476,19 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No tir.float16(0), dtype="float16", ) - with tir.attr( - tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), - "thread_extent", - 32, - ): + with tir.launch_thread(tx, 32): Apad_shared[ - ( - ( - (((threadIdx_y * 3072) + (threadIdx_z * 1536)) + (ax2 * 512)) - + threadIdx_x - ) - + 256 - ) + (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 256) ] = tir.if_then_else( ( ( ( - (1 <= (tir.floordiv(blockIdx_z, 14) + kh)) - and ((tir.floordiv(blockIdx_z, 14) + kh) < 15) + (1 <= (tir.floordiv(bz, 14) + kh)) + and ((tir.floordiv(bz, 14) + kh) < 15) ) - and (1 <= (ax2 + tir.floormod(blockIdx_z, 14))) + and (1 <= (ax2 + tir.floormod(bz, 14))) ) - and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15) + and ((ax2 + tir.floormod(bz, 14)) < 15) ), tir.load( "float16", @@ -1646,21 +1500,18 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No ( ( ( - ( - (blockIdx_x * 6422528) - + (threadIdx_y * 1605632) - ) - + (threadIdx_z * 802816) + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) ) + (kh * 57344) ) - + (blockIdx_z * 4096) + + (bz * 4096) ) + (ax2 * 4096) ) + (ic_outer * 512) ) - + threadIdx_x + + tx ) - 61184 ), @@ -1668,29 +1519,19 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No tir.float16(0), dtype="float16", ) - with tir.attr( - tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), - "thread_extent", - 32, - ): + with tir.launch_thread(tx, 32): Apad_shared[ - ( - ( - (((threadIdx_y * 3072) + (threadIdx_z * 1536)) + (ax2 * 512)) - + threadIdx_x - ) - + 288 - ) + (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 288) ] = tir.if_then_else( ( ( ( - (1 <= (tir.floordiv(blockIdx_z, 14) + kh)) - and ((tir.floordiv(blockIdx_z, 14) + kh) < 15) + (1 <= (tir.floordiv(bz, 14) + kh)) + and ((tir.floordiv(bz, 14) + kh) < 15) ) - and (1 <= (ax2 + tir.floormod(blockIdx_z, 14))) + and (1 <= (ax2 + tir.floormod(bz, 14))) ) - and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15) + and ((ax2 + tir.floormod(bz, 14)) < 15) ), tir.load( "float16", @@ -1702,21 +1543,18 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No ( ( ( - ( - (blockIdx_x * 6422528) - + (threadIdx_y * 1605632) - ) - + (threadIdx_z * 802816) + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) ) + (kh * 57344) ) - + (blockIdx_z * 4096) + + (bz * 4096) ) + (ax2 * 4096) ) + (ic_outer * 512) ) - + threadIdx_x + + tx ) - 61152 ), @@ -1724,29 +1562,19 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No tir.float16(0), dtype="float16", ) - with tir.attr( - tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), - "thread_extent", - 32, - ): + with tir.launch_thread(tx, 32): Apad_shared[ - ( - ( - (((threadIdx_y * 3072) + (threadIdx_z * 1536)) + (ax2 * 512)) - + threadIdx_x - ) - + 320 - ) + (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 320) ] = tir.if_then_else( ( ( ( - (1 <= (tir.floordiv(blockIdx_z, 14) + kh)) - and ((tir.floordiv(blockIdx_z, 14) + kh) < 15) + (1 <= (tir.floordiv(bz, 14) + kh)) + and ((tir.floordiv(bz, 14) + kh) < 15) ) - and (1 <= (ax2 + tir.floormod(blockIdx_z, 14))) + and (1 <= (ax2 + tir.floormod(bz, 14))) ) - and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15) + and ((ax2 + tir.floormod(bz, 14)) < 15) ), tir.load( "float16", @@ -1758,21 +1586,18 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No ( ( ( - ( - (blockIdx_x * 6422528) - + (threadIdx_y * 1605632) - ) - + (threadIdx_z * 802816) + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) ) + (kh * 57344) ) - + (blockIdx_z * 4096) + + (bz * 4096) ) + (ax2 * 4096) ) + (ic_outer * 512) ) - + threadIdx_x + + tx ) - 61120 ), @@ -1780,29 +1605,19 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No tir.float16(0), dtype="float16", ) - with tir.attr( - tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), - "thread_extent", - 32, - ): + with tir.launch_thread(tx, 32): Apad_shared[ - ( - ( - (((threadIdx_y * 3072) + (threadIdx_z * 1536)) + (ax2 * 512)) - + threadIdx_x - ) - + 352 - ) + (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 352) ] = tir.if_then_else( ( ( ( - (1 <= (tir.floordiv(blockIdx_z, 14) + kh)) - and ((tir.floordiv(blockIdx_z, 14) + kh) < 15) + (1 <= (tir.floordiv(bz, 14) + kh)) + and ((tir.floordiv(bz, 14) + kh) < 15) ) - and (1 <= (ax2 + tir.floormod(blockIdx_z, 14))) + and (1 <= (ax2 + tir.floormod(bz, 14))) ) - and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15) + and ((ax2 + tir.floormod(bz, 14)) < 15) ), tir.load( "float16", @@ -1814,21 +1629,18 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No ( ( ( - ( - (blockIdx_x * 6422528) - + (threadIdx_y * 1605632) - ) - + (threadIdx_z * 802816) + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) ) + (kh * 57344) ) - + (blockIdx_z * 4096) + + (bz * 4096) ) + (ax2 * 4096) ) + (ic_outer * 512) ) - + threadIdx_x + + tx ) - 61088 ), @@ -1836,29 +1648,19 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No tir.float16(0), dtype="float16", ) - with tir.attr( - tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), - "thread_extent", - 32, - ): + with tir.launch_thread(tx, 32): Apad_shared[ - ( - ( - (((threadIdx_y * 3072) + (threadIdx_z * 1536)) + (ax2 * 512)) - + threadIdx_x - ) - + 384 - ) + (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 384) ] = tir.if_then_else( ( ( ( - (1 <= (tir.floordiv(blockIdx_z, 14) + kh)) - and ((tir.floordiv(blockIdx_z, 14) + kh) < 15) + (1 <= (tir.floordiv(bz, 14) + kh)) + and ((tir.floordiv(bz, 14) + kh) < 15) ) - and (1 <= (ax2 + tir.floormod(blockIdx_z, 14))) + and (1 <= (ax2 + tir.floormod(bz, 14))) ) - and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15) + and ((ax2 + tir.floormod(bz, 14)) < 15) ), tir.load( "float16", @@ -1870,21 +1672,18 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No ( ( ( - ( - (blockIdx_x * 6422528) - + (threadIdx_y * 1605632) - ) - + (threadIdx_z * 802816) + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) ) + (kh * 57344) ) - + (blockIdx_z * 4096) + + (bz * 4096) ) + (ax2 * 4096) ) + (ic_outer * 512) ) - + threadIdx_x + + tx ) - 61056 ), @@ -1892,29 +1691,19 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No tir.float16(0), dtype="float16", ) - with tir.attr( - tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), - "thread_extent", - 32, - ): + with tir.launch_thread(tx, 32): Apad_shared[ - ( - ( - (((threadIdx_y * 3072) + (threadIdx_z * 1536)) + (ax2 * 512)) - + threadIdx_x - ) - + 416 - ) + (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 416) ] = tir.if_then_else( ( ( ( - (1 <= (tir.floordiv(blockIdx_z, 14) + kh)) - and ((tir.floordiv(blockIdx_z, 14) + kh) < 15) + (1 <= (tir.floordiv(bz, 14) + kh)) + and ((tir.floordiv(bz, 14) + kh) < 15) ) - and (1 <= (ax2 + tir.floormod(blockIdx_z, 14))) + and (1 <= (ax2 + tir.floormod(bz, 14))) ) - and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15) + and ((ax2 + tir.floormod(bz, 14)) < 15) ), tir.load( "float16", @@ -1926,21 +1715,18 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No ( ( ( - ( - (blockIdx_x * 6422528) - + (threadIdx_y * 1605632) - ) - + (threadIdx_z * 802816) + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) ) + (kh * 57344) ) - + (blockIdx_z * 4096) + + (bz * 4096) ) + (ax2 * 4096) ) + (ic_outer * 512) ) - + threadIdx_x + + tx ) - 61024 ), @@ -1948,29 +1734,19 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No tir.float16(0), dtype="float16", ) - with tir.attr( - tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), - "thread_extent", - 32, - ): + with tir.launch_thread(tx, 32): Apad_shared[ - ( - ( - (((threadIdx_y * 3072) + (threadIdx_z * 1536)) + (ax2 * 512)) - + threadIdx_x - ) - + 448 - ) + (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 448) ] = tir.if_then_else( ( ( ( - (1 <= (tir.floordiv(blockIdx_z, 14) + kh)) - and ((tir.floordiv(blockIdx_z, 14) + kh) < 15) + (1 <= (tir.floordiv(bz, 14) + kh)) + and ((tir.floordiv(bz, 14) + kh) < 15) ) - and (1 <= (ax2 + tir.floormod(blockIdx_z, 14))) + and (1 <= (ax2 + tir.floormod(bz, 14))) ) - and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15) + and ((ax2 + tir.floormod(bz, 14)) < 15) ), tir.load( "float16", @@ -1982,21 +1758,18 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No ( ( ( - ( - (blockIdx_x * 6422528) - + (threadIdx_y * 1605632) - ) - + (threadIdx_z * 802816) + ((bx * 6422528) + (ty * 1605632)) + + (tz * 802816) ) + (kh * 57344) ) - + (blockIdx_z * 4096) + + (bz * 4096) ) + (ax2 * 4096) ) + (ic_outer * 512) ) - + threadIdx_x + + tx ) - 60992 ), @@ -2004,29 +1777,19 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No tir.float16(0), dtype="float16", ) - tir.attr( - tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), - "thread_extent", - 32, - ) + tir.launch_thread(tx, 32) Apad_shared[ - ( - ( - (((threadIdx_y * 3072) + (threadIdx_z * 1536)) + (ax2 * 512)) - + threadIdx_x - ) - + 480 - ) + (((((ty * 3072) + (tz * 1536)) + (ax2 * 512)) + tx) + 480) ] = tir.if_then_else( ( ( ( - (1 <= (tir.floordiv(blockIdx_z, 14) + kh)) - and ((tir.floordiv(blockIdx_z, 14) + kh) < 15) + (1 <= (tir.floordiv(bz, 14) + kh)) + and ((tir.floordiv(bz, 14) + kh) < 15) ) - and (1 <= (ax2 + tir.floormod(blockIdx_z, 14))) + and (1 <= (ax2 + tir.floormod(bz, 14))) ) - and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15) + and ((ax2 + tir.floormod(bz, 14)) < 15) ), tir.load( "float16", @@ -2037,22 +1800,16 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No ( ( ( - ( - ( - (blockIdx_x * 6422528) - + (threadIdx_y * 1605632) - ) - + (threadIdx_z * 802816) - ) + (((bx * 6422528) + (ty * 1605632)) + (tz * 802816)) + (kh * 57344) ) - + (blockIdx_z * 4096) + + (bz * 4096) ) + (ax2 * 4096) ) + (ic_outer * 512) ) - + threadIdx_x + + tx ) - 60960 ), @@ -2060,14 +1817,10 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No tir.float16(0), dtype="float16", ) - with tir.attr( - tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32 - ): + with tir.launch_thread(tx, 32): tir.store( W_shared, - tir.ramp( - (((threadIdx_y * 512) + (threadIdx_z * 256)) + (threadIdx_x * 8)), 1, 8 - ), + tir.ramp((((ty * 512) + (tz * 256)) + (tx * 8)), 1, 8), tir.load( "float16x8", W_1.data, @@ -2075,12 +1828,12 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No ( ( ( - (((kh * 393216) + (ic_outer * 16384)) + (blockIdx_y * 2048)) - + (threadIdx_y * 512) + (((kh * 393216) + (ic_outer * 16384)) + (by * 2048)) + + (ty * 512) ) - + (threadIdx_z * 256) + + (tz * 256) ) - + (threadIdx_x * 8) + + (tx * 8) ), 1, 8, @@ -2089,16 +1842,10 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No ), tir.broadcast(True, 8), ) - with tir.attr( - tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32 - ): + with tir.launch_thread(tx, 32): tir.store( W_shared, - tir.ramp( - ((((threadIdx_y * 512) + (threadIdx_z * 256)) + (threadIdx_x * 8)) + 2048), - 1, - 8, - ), + tir.ramp(((((ty * 512) + (tz * 256)) + (tx * 8)) + 2048), 1, 8), tir.load( "float16x8", W_1.data, @@ -2107,15 +1854,12 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No ( ( ( - ( - ((kh * 393216) + (ic_outer * 16384)) - + (blockIdx_y * 2048) - ) - + (threadIdx_y * 512) + (((kh * 393216) + (ic_outer * 16384)) + (by * 2048)) + + (ty * 512) ) - + (threadIdx_z * 256) + + (tz * 256) ) - + (threadIdx_x * 8) + + (tx * 8) ) + 8192 ), @@ -2126,16 +1870,10 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No ), tir.broadcast(True, 8), ) - with tir.attr( - tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32 - ): + with tir.launch_thread(tx, 32): tir.store( W_shared, - tir.ramp( - ((((threadIdx_y * 512) + (threadIdx_z * 256)) + (threadIdx_x * 8)) + 4096), - 1, - 8, - ), + tir.ramp(((((ty * 512) + (tz * 256)) + (tx * 8)) + 4096), 1, 8), tir.load( "float16x8", W_1.data, @@ -2144,15 +1882,12 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No ( ( ( - ( - ((kh * 393216) + (ic_outer * 16384)) - + (blockIdx_y * 2048) - ) - + (threadIdx_y * 512) + (((kh * 393216) + (ic_outer * 16384)) + (by * 2048)) + + (ty * 512) ) - + (threadIdx_z * 256) + + (tz * 256) ) - + (threadIdx_x * 8) + + (tx * 8) ) + 131072 ), @@ -2163,16 +1898,10 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No ), tir.broadcast(True, 8), ) - with tir.attr( - tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32 - ): + with tir.launch_thread(tx, 32): tir.store( W_shared, - tir.ramp( - ((((threadIdx_y * 512) + (threadIdx_z * 256)) + (threadIdx_x * 8)) + 6144), - 1, - 8, - ), + tir.ramp(((((ty * 512) + (tz * 256)) + (tx * 8)) + 6144), 1, 8), tir.load( "float16x8", W_1.data, @@ -2181,15 +1910,12 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No ( ( ( - ( - ((kh * 393216) + (ic_outer * 16384)) - + (blockIdx_y * 2048) - ) - + (threadIdx_y * 512) + (((kh * 393216) + (ic_outer * 16384)) + (by * 2048)) + + (ty * 512) ) - + (threadIdx_z * 256) + + (tz * 256) ) - + (threadIdx_x * 8) + + (tx * 8) ) + 139264 ), @@ -2200,16 +1926,10 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No ), tir.broadcast(True, 8), ) - with tir.attr( - tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32 - ): + with tir.launch_thread(tx, 32): tir.store( W_shared, - tir.ramp( - ((((threadIdx_y * 512) + (threadIdx_z * 256)) + (threadIdx_x * 8)) + 8192), - 1, - 8, - ), + tir.ramp(((((ty * 512) + (tz * 256)) + (tx * 8)) + 8192), 1, 8), tir.load( "float16x8", W_1.data, @@ -2218,15 +1938,12 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No ( ( ( - ( - ((kh * 393216) + (ic_outer * 16384)) - + (blockIdx_y * 2048) - ) - + (threadIdx_y * 512) + (((kh * 393216) + (ic_outer * 16384)) + (by * 2048)) + + (ty * 512) ) - + (threadIdx_z * 256) + + (tz * 256) ) - + (threadIdx_x * 8) + + (tx * 8) ) + 262144 ), @@ -2237,16 +1954,10 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No ), tir.broadcast(True, 8), ) - with tir.attr( - tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32 - ): + with tir.launch_thread(tx, 32): tir.store( W_shared, - tir.ramp( - ((((threadIdx_y * 512) + (threadIdx_z * 256)) + (threadIdx_x * 8)) + 10240), - 1, - 8, - ), + tir.ramp(((((ty * 512) + (tz * 256)) + (tx * 8)) + 10240), 1, 8), tir.load( "float16x8", W_1.data, @@ -2255,15 +1966,12 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No ( ( ( - ( - ((kh * 393216) + (ic_outer * 16384)) - + (blockIdx_y * 2048) - ) - + (threadIdx_y * 512) + (((kh * 393216) + (ic_outer * 16384)) + (by * 2048)) + + (ty * 512) ) - + (threadIdx_z * 256) + + (tz * 256) ) - + (threadIdx_x * 8) + + (tx * 8) ) + 270336 ), @@ -2274,8 +1982,8 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No ), tir.broadcast(True, 8), ) - for ic_inner in tir.range(0, 2): - for kw in tir.range(0, 3): + for ic_inner in tir.serial(0, 2): + for kw in tir.serial(0, 3): tir.evaluate( tir.tvm_load_matrix_sync( Apad_shared_wmma_matrix_a, @@ -2286,7 +1994,7 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No tir.tvm_access_ptr( tir.type_annotation(dtype="float16"), Apad_shared, - (((threadIdx_y * 3072) + (kw * 512)) + (ic_inner * 256)), + (((ty * 3072) + (kw * 512)) + (ic_inner * 256)), 256, 1, dtype="handle", @@ -2306,7 +2014,7 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No tir.tvm_access_ptr( tir.type_annotation(dtype="float16"), Apad_shared, - ((((threadIdx_y * 3072) + (kw * 512)) + (ic_inner * 256)) + 1536), + ((((ty * 3072) + (kw * 512)) + (ic_inner * 256)) + 1536), 256, 1, dtype="handle", @@ -2326,7 +2034,7 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No tir.tvm_access_ptr( tir.type_annotation(dtype="float16"), W_shared, - (((kw * 4096) + (ic_inner * 2048)) + (threadIdx_z * 1024)), + (((kw * 4096) + (ic_inner * 2048)) + (tz * 1024)), 256, 1, dtype="handle", @@ -2346,7 +2054,7 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No tir.tvm_access_ptr( tir.type_annotation(dtype="float16"), W_shared, - ((((kw * 4096) + (ic_inner * 2048)) + (threadIdx_z * 1024)) + 256), + ((((kw * 4096) + (ic_inner * 2048)) + (tz * 1024)) + 256), 256, 1, dtype="handle", @@ -2366,7 +2074,7 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No tir.tvm_access_ptr( tir.type_annotation(dtype="float16"), W_shared, - ((((kw * 4096) + (ic_inner * 2048)) + (threadIdx_z * 1024)) + 512), + ((((kw * 4096) + (ic_inner * 2048)) + (tz * 1024)) + 512), 256, 1, dtype="handle", @@ -2386,7 +2094,7 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No tir.tvm_access_ptr( tir.type_annotation(dtype="float16"), W_shared, - ((((kw * 4096) + (ic_inner * 2048)) + (threadIdx_z * 1024)) + 768), + ((((kw * 4096) + (ic_inner * 2048)) + (tz * 1024)) + 768), 256, 1, dtype="handle", @@ -2510,13 +2218,7 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No tir.tvm_access_ptr( tir.type_annotation(dtype="float32"), Conv_1.data, - ( - ( - (((blockIdx_x * 12845056) + (threadIdx_y * 3211264)) + (blockIdx_z * 8192)) - + (blockIdx_y * 2048) - ) - + (threadIdx_z * 1024) - ), + (((((bx * 12845056) + (ty * 3211264)) + (bz * 8192)) + (by * 2048)) + (tz * 1024)), 256, 2, dtype="handle", @@ -2538,14 +2240,8 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No Conv_1.data, ( ( - ( - ( - ((blockIdx_x * 12845056) + (threadIdx_y * 3211264)) - + (blockIdx_z * 8192) - ) - + (blockIdx_y * 2048) - ) - + (threadIdx_z * 1024) + ((((bx * 12845056) + (ty * 3211264)) + (bz * 8192)) + (by * 2048)) + + (tz * 1024) ) + 256 ), @@ -2570,14 +2266,8 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No Conv_1.data, ( ( - ( - ( - ((blockIdx_x * 12845056) + (threadIdx_y * 3211264)) - + (blockIdx_z * 8192) - ) - + (blockIdx_y * 2048) - ) - + (threadIdx_z * 1024) + ((((bx * 12845056) + (ty * 3211264)) + (bz * 8192)) + (by * 2048)) + + (tz * 1024) ) + 512 ), @@ -2602,14 +2292,8 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No Conv_1.data, ( ( - ( - ( - ((blockIdx_x * 12845056) + (threadIdx_y * 3211264)) - + (blockIdx_z * 8192) - ) - + (blockIdx_y * 2048) - ) - + (threadIdx_z * 1024) + ((((bx * 12845056) + (ty * 3211264)) + (bz * 8192)) + (by * 2048)) + + (tz * 1024) ) + 768 ), @@ -2634,14 +2318,8 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No Conv_1.data, ( ( - ( - ( - ((blockIdx_x * 12845056) + (threadIdx_y * 3211264)) - + (blockIdx_z * 8192) - ) - + (blockIdx_y * 2048) - ) - + (threadIdx_z * 1024) + ((((bx * 12845056) + (ty * 3211264)) + (bz * 8192)) + (by * 2048)) + + (tz * 1024) ) + 1605632 ), @@ -2666,14 +2344,8 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No Conv_1.data, ( ( - ( - ( - ((blockIdx_x * 12845056) + (threadIdx_y * 3211264)) - + (blockIdx_z * 8192) - ) - + (blockIdx_y * 2048) - ) - + (threadIdx_z * 1024) + ((((bx * 12845056) + (ty * 3211264)) + (bz * 8192)) + (by * 2048)) + + (tz * 1024) ) + 1605888 ), @@ -2698,14 +2370,8 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No Conv_1.data, ( ( - ( - ( - ((blockIdx_x * 12845056) + (threadIdx_y * 3211264)) - + (blockIdx_z * 8192) - ) - + (blockIdx_y * 2048) - ) - + (threadIdx_z * 1024) + ((((bx * 12845056) + (ty * 3211264)) + (bz * 8192)) + (by * 2048)) + + (tz * 1024) ) + 1606144 ), @@ -2730,14 +2396,8 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No Conv_1.data, ( ( - ( - ( - ((blockIdx_x * 12845056) + (threadIdx_y * 3211264)) - + (blockIdx_z * 8192) - ) - + (blockIdx_y * 2048) - ) - + (threadIdx_z * 1024) + ((((bx * 12845056) + (ty * 3211264)) + (bz * 8192)) + (by * 2048)) + + (tz * 1024) ) + 1606400 ), @@ -2817,38 +2477,38 @@ def opt_conv_tensorcore_mod_host( tir.tvm_struct_get(arg0, 0, 7, dtype="uint16") == tir.uint16(1) ), "arg0.dtype is expected to be float16" assert 16 == tir.cast( - "int32", tir.load("int64", arg0_shape, 0) + tir.load("int64", arg0_shape, 0), "int32" ), "Argument arg0.shape[0] has an unsatisfied constraint" assert 14 == tir.cast( - "int32", tir.load("int64", arg0_shape, 1) + tir.load("int64", arg0_shape, 1), "int32" ), "Argument arg0.shape[1] has an unsatisfied constraint" assert 14 == tir.cast( - "int32", tir.load("int64", arg0_shape, 2) + tir.load("int64", arg0_shape, 2), "int32" ), "Argument arg0.shape[2] has an unsatisfied constraint" assert 16 == tir.cast( - "int32", tir.load("int64", arg0_shape, 3) + tir.load("int64", arg0_shape, 3), "int32" ), "Argument arg0.shape[3] has an unsatisfied constraint" assert 16 == tir.cast( - "int32", tir.load("int64", arg0_shape, 4) + tir.load("int64", arg0_shape, 4), "int32" ), "Argument arg0.shape[4] has an unsatisfied constraint" assert 16 == tir.cast( - "int32", tir.load("int64", arg0_shape, 5) + tir.load("int64", arg0_shape, 5), "int32" ), "Argument arg0.shape[5] has an unsatisfied constraint" if not (tir.isnullptr(arg0_strides, dtype="bool")): assert ( ( ( ( - (1 == tir.cast("int32", tir.load("int64", arg0_strides, 5))) - and (16 == tir.cast("int32", tir.load("int64", arg0_strides, 4))) + (1 == tir.cast(tir.load("int64", arg0_strides, 5), "int32")) + and (16 == tir.cast(tir.load("int64", arg0_strides, 4), "int32")) ) - and (256 == tir.cast("int32", tir.load("int64", arg0_strides, 3))) + and (256 == tir.cast(tir.load("int64", arg0_strides, 3), "int32")) ) - and (4096 == tir.cast("int32", tir.load("int64", arg0_strides, 2))) + and (4096 == tir.cast(tir.load("int64", arg0_strides, 2), "int32")) ) - and (57344 == tir.cast("int32", tir.load("int64", arg0_strides, 1))) + and (57344 == tir.cast(tir.load("int64", arg0_strides, 1), "int32")) ) and ( - 802816 == tir.cast("int32", tir.load("int64", arg0_strides, 0)) + 802816 == tir.cast(tir.load("int64", arg0_strides, 0), "int32") ), "arg0.strides: expected to be compact array" tir.evaluate(0) assert tir.uint64(0) == tir.tvm_struct_get( @@ -2866,38 +2526,38 @@ def opt_conv_tensorcore_mod_host( tir.tvm_struct_get(arg1, 0, 7, dtype="uint16") == tir.uint16(1) ), "arg1.dtype is expected to be float16" assert 3 == tir.cast( - "int32", tir.load("int64", arg1_shape, 0) + tir.load("int64", arg1_shape, 0), "int32" ), "Argument arg1.shape[0] has an unsatisfied constraint" assert 3 == tir.cast( - "int32", tir.load("int64", arg1_shape, 1) + tir.load("int64", arg1_shape, 1), "int32" ), "Argument arg1.shape[1] has an unsatisfied constraint" assert 16 == tir.cast( - "int32", tir.load("int64", arg1_shape, 2) + tir.load("int64", arg1_shape, 2), "int32" ), "Argument arg1.shape[2] has an unsatisfied constraint" assert 32 == tir.cast( - "int32", tir.load("int64", arg1_shape, 3) + tir.load("int64", arg1_shape, 3), "int32" ), "Argument arg1.shape[3] has an unsatisfied constraint" assert 16 == tir.cast( - "int32", tir.load("int64", arg1_shape, 4) + tir.load("int64", arg1_shape, 4), "int32" ), "Argument arg1.shape[4] has an unsatisfied constraint" assert 16 == tir.cast( - "int32", tir.load("int64", arg1_shape, 5) + tir.load("int64", arg1_shape, 5), "int32" ), "Argument arg1.shape[5] has an unsatisfied constraint" if not (tir.isnullptr(arg1_strides, dtype="bool")): assert ( ( ( ( - (1 == tir.cast("int32", tir.load("int64", arg1_strides, 5))) - and (16 == tir.cast("int32", tir.load("int64", arg1_strides, 4))) + (1 == tir.cast(tir.load("int64", arg1_strides, 5), "int32")) + and (16 == tir.cast(tir.load("int64", arg1_strides, 4), "int32")) ) - and (256 == tir.cast("int32", tir.load("int64", arg1_strides, 3))) + and (256 == tir.cast(tir.load("int64", arg1_strides, 3), "int32")) ) - and (8192 == tir.cast("int32", tir.load("int64", arg1_strides, 2))) + and (8192 == tir.cast(tir.load("int64", arg1_strides, 2), "int32")) ) - and (131072 == tir.cast("int32", tir.load("int64", arg1_strides, 1))) + and (131072 == tir.cast(tir.load("int64", arg1_strides, 1), "int32")) ) and ( - 393216 == tir.cast("int32", tir.load("int64", arg1_strides, 0)) + 393216 == tir.cast(tir.load("int64", arg1_strides, 0), "int32") ), "arg1.strides: expected to be compact array" tir.evaluate(0) assert tir.uint64(0) == tir.tvm_struct_get( @@ -2918,38 +2578,38 @@ def opt_conv_tensorcore_mod_host( tir.tvm_struct_get(arg2, 0, 7, dtype="uint16") == tir.uint16(1) ), "arg2.dtype is expected to be float32" assert 16 == tir.cast( - "int32", tir.load("int64", arg2_shape, 0) + tir.load("int64", arg2_shape, 0), "int32" ), "Argument arg2.shape[0] has an unsatisfied constraint" assert 14 == tir.cast( - "int32", tir.load("int64", arg2_shape, 1) + tir.load("int64", arg2_shape, 1), "int32" ), "Argument arg2.shape[1] has an unsatisfied constraint" assert 14 == tir.cast( - "int32", tir.load("int64", arg2_shape, 2) + tir.load("int64", arg2_shape, 2), "int32" ), "Argument arg2.shape[2] has an unsatisfied constraint" assert 32 == tir.cast( - "int32", tir.load("int64", arg2_shape, 3) + tir.load("int64", arg2_shape, 3), "int32" ), "Argument arg2.shape[3] has an unsatisfied constraint" assert 16 == tir.cast( - "int32", tir.load("int64", arg2_shape, 4) + tir.load("int64", arg2_shape, 4), "int32" ), "Argument arg2.shape[4] has an unsatisfied constraint" assert 16 == tir.cast( - "int32", tir.load("int64", arg2_shape, 5) + tir.load("int64", arg2_shape, 5), "int32" ), "Argument arg2.shape[5] has an unsatisfied constraint" if not (tir.isnullptr(arg2_strides, dtype="bool")): assert ( ( ( ( - (1 == tir.cast("int32", tir.load("int64", arg2_strides, 5))) - and (16 == tir.cast("int32", tir.load("int64", arg2_strides, 4))) + (1 == tir.cast(tir.load("int64", arg2_strides, 5), "int32")) + and (16 == tir.cast(tir.load("int64", arg2_strides, 4), "int32")) ) - and (256 == tir.cast("int32", tir.load("int64", arg2_strides, 3))) + and (256 == tir.cast(tir.load("int64", arg2_strides, 3), "int32")) ) - and (8192 == tir.cast("int32", tir.load("int64", arg2_strides, 2))) + and (8192 == tir.cast(tir.load("int64", arg2_strides, 2), "int32")) ) - and (114688 == tir.cast("int32", tir.load("int64", arg2_strides, 1))) + and (114688 == tir.cast(tir.load("int64", arg2_strides, 1), "int32")) ) and ( - 1605632 == tir.cast("int32", tir.load("int64", arg2_strides, 0)) + 1605632 == tir.cast(tir.load("int64", arg2_strides, 0), "int32") ), "arg2.strides: expected to be compact array" tir.evaluate(0) assert tir.uint64(0) == tir.tvm_struct_get( @@ -2961,9 +2621,9 @@ def opt_conv_tensorcore_mod_host( assert dev_id == tir.tvm_struct_get( arg2, 0, 9, dtype="int32" ), "Argument arg2.device_id has an unsatisfied constraint" - tir.evaluate(tir.tvm_struct_set(stack_value, 0, 12, tir.cast("int64", 2), dtype="int32")) + tir.evaluate(tir.tvm_struct_set(stack_value, 0, 12, tir.cast(2, "int64"), dtype="int32")) stack_tcode[0] = 0 - tir.evaluate(tir.tvm_struct_set(stack_value, 1, 12, tir.cast("int64", dev_id), dtype="int32")) + tir.evaluate(tir.tvm_struct_set(stack_value, 1, 12, tir.cast(dev_id, "int64"), dtype="int32")) stack_tcode[1] = 0 tir.evaluate( tir.tvm_call_packed_lowered( @@ -2977,17 +2637,17 @@ def opt_conv_tensorcore_mod_host( stack_tcode[1] = 3 tir.evaluate(tir.tvm_struct_set(stack_value, 2, 12, Conv, dtype="int32")) stack_tcode[2] = 3 - tir.evaluate(tir.tvm_struct_set(stack_value, 3, 12, tir.cast("int64", 196), dtype="int32")) + tir.evaluate(tir.tvm_struct_set(stack_value, 3, 12, tir.cast(196, "int64"), dtype="int32")) stack_tcode[3] = 0 - tir.evaluate(tir.tvm_struct_set(stack_value, 4, 12, tir.cast("int64", 2), dtype="int32")) + tir.evaluate(tir.tvm_struct_set(stack_value, 4, 12, tir.cast(2, "int64"), dtype="int32")) stack_tcode[4] = 0 - tir.evaluate(tir.tvm_struct_set(stack_value, 5, 12, tir.cast("int64", 4), dtype="int32")) + tir.evaluate(tir.tvm_struct_set(stack_value, 5, 12, tir.cast(4, "int64"), dtype="int32")) stack_tcode[5] = 0 - tir.evaluate(tir.tvm_struct_set(stack_value, 6, 12, tir.cast("int64", 4), dtype="int32")) + tir.evaluate(tir.tvm_struct_set(stack_value, 6, 12, tir.cast(4, "int64"), dtype="int32")) stack_tcode[6] = 0 - tir.evaluate(tir.tvm_struct_set(stack_value, 7, 12, tir.cast("int64", 2), dtype="int32")) + tir.evaluate(tir.tvm_struct_set(stack_value, 7, 12, tir.cast(2, "int64"), dtype="int32")) stack_tcode[7] = 0 - tir.evaluate(tir.tvm_struct_set(stack_value, 8, 12, tir.cast("int64", 32), dtype="int32")) + tir.evaluate(tir.tvm_struct_set(stack_value, 8, 12, tir.cast(32, "int64"), dtype="int32")) stack_tcode[8] = 0 tir.evaluate( tir.tvm_call_packed_lowered(