diff --git a/python/tvm/hybrid/parser.py b/python/tvm/hybrid/parser.py index 0eaf20d813..96e11bf203 100644 --- a/python/tvm/hybrid/parser.py +++ b/python/tvm/hybrid/parser.py @@ -97,7 +97,7 @@ def __init__(self, src, base_lienno): self.meta = None self.functions = {} - self.assign_target = None + self.target = None def init_function_parsing_env(self): """Initialize function parsing environment""" @@ -304,6 +304,8 @@ def visit_Assign(self, node): 1.3 Var = tir.var() 2. (BufferStore) Buffer[PrimExpr, PrimExpr, ..., PrimExpr] = PrimExpr 3. (Store) Var[PrimExpr] = PrimExpr + 4. with scope handlers with concise scoping and var def + 4.1 var = tir.alloc_with_scope() """ if not len(node.targets) == 1: @@ -311,22 +313,29 @@ def visit_Assign(self, node): target = node.targets[0] if isinstance(target, ast.Name): - # scenario 1 - self.assign_target = target.id - rhs = self.visit(node.value) + # scenario 1&4 + self.target = [target.id] if not isinstance(node.value, ast.Call): self.report_error("Unsupported assign stmt") - self.scope_emitter.update_symbol(target.id, rhs) + func = self.visit(node.value.func) + if Registry.is_with_scope(func): + # scenario 4 + return self.visit(node.value) + else: + # scenario 1 + rhs = self.visit(node.value) + self.scope_emitter.update_symbol(target.id, rhs) elif isinstance(target, ast.Subscript): # scenario 2&3 symbol, indexes = self.visit(target) - self.assign_target = (symbol, indexes) rhs = self.visit(node.value) if isinstance(symbol, tvm.tir.Buffer): + # BufferStore return tvm.tir.BufferStore(symbol, tvm.runtime.convert(rhs), indexes) else: if len(indexes) != 1: self.report_error("Invalid Store stmt") + # Store return tvm.tir.Store(symbol, tvm.runtime.convert(rhs), indexes[0], tvm.runtime.convert(True)) else: @@ -393,9 +402,8 @@ def visit_With(self, node): With(withitem* items, stmt* body, string? type_comment) withitem = (expr context_expr, expr? optional_vars) By now 2 types of With is supported: - 1. with tir.block(*axes) as block_vars: - 2. with tir.allocate() as - 2. with tir.let/tir.Assert()/tir.attr()/tir.allocate()/tir.realize() + 1. with tir.block(*axes)/tir.allocate() as targets: + 2. with tir.let()/tir.Assert()/tir.attr()//tir.realize() """ if not len(node.items) == 1: self.report_error("Only one with element is supported now") @@ -405,42 +413,27 @@ def visit_With(self, node): func_call = node.items[0].context_expr func_node = func_call.func func = self.visit(func_node) - is_tir_block = (isinstance(func_node, ast.Attribute) - and isinstance(func_node.value, ast.Name) - and func_node.value.id == "tir" - and func_node.attr == "block") - if not is_tir_block and node.items[0].optional_vars is not None: - self.report_error("Now only tir.block allows optional var") if not Registry.is_with_scope(func): self.report_error("Function not allowed in with scope") - if is_tir_block: - # preprocess block_var definitions + self.target = [] + if node.items[0].optional_vars is not None: + # preprocess optional var names if isinstance(node.items[0].optional_vars, ast.Name): - block_vars = [node.items[0].optional_vars.id] + self.target = [node.items[0].optional_vars.id] elif isinstance(node.items[0].optional_vars, (ast.List, ast.Tuple)): for var in node.items[0].optional_vars.elts: if not isinstance(var, ast.Name): - self.report_error("Invalid block var definition") - block_vars = [var.id for var in node.items[0].optional_vars.elts] - elif node.items[0].optional_vars is None: - block_vars = [] + self.report_error("Invalid optional var definition") + self.target = [var.id for var in node.items[0].optional_vars.elts] else: - self.report_error("Invalid block var definition") - # update block vars into symbol table - block_vars = [tvm.te.var(name) for name in block_vars] - self.scope_emitter.new_scope(is_block=True) - for block_var in block_vars: - self.scope_emitter.update_symbol(block_var.name, block_var) - + self.report_error("Invalid optional var definition") + # parse other arguments args = [self.visit(arg) for arg in func_call.args] kw_args = [self.visit(keyword) for keyword in func_call.keywords] kw_args = {kw_arg[0]: kw_arg[1] for kw_arg in kw_args} - if is_tir_block: - args = [block_vars] + args - old_lineno, old_col_offset = self.current_lineno, self.current_col_offset self.current_lineno, self.current_col_offset = \ self.base_lineno + func_call.lineno - 1, func_call.col_offset diff --git a/python/tvm/hybrid/registry.py b/python/tvm/hybrid/registry.py index 03c609a347..77f0d434ef 100644 --- a/python/tvm/hybrid/registry.py +++ b/python/tvm/hybrid/registry.py @@ -117,7 +117,7 @@ def auto_insert_body(self, pos, body): self.kwargs["body"] = body -def func_wrapper(func_name, func_to_register, arg_list, category, concise): +def func_wrapper(func_name, func_to_register, arg_list, category, concise=False, with_var=False): """Helper function to wrap a function to be registered """ def wrap_func(parser, node, args, kwargs): @@ -145,25 +145,34 @@ def wrap_func(parser, node, args, kwargs): parser.scope_emitter.loop_stack[-1].pop() parser.scope_emitter.pop_scope() elif category == Category.WITH_SCOPE: - # automatically parse body for with_scope handlers - if isinstance(node, ast.With): - # the scope handler is used inside with context/for context - parser.scope_emitter.new_scope() - parser.scope_emitter.node_stack[-1].extend(reversed(node.body)) - body = parser.get_body() - parser.scope_emitter.pop_scope() + if not with_var: + if isinstance(node, ast.With) and node.items[0].optional_vars is not None: + parser.report_error("Function " + func_name + " expects no optional vars") + # automatically parse body for with_scope handlers without optional vars + if isinstance(node, ast.With): + parser.scope_emitter.new_scope() + parser.scope_emitter.node_stack[-1].extend(reversed(node.body)) + body = parser.get_body() + parser.scope_emitter.pop_scope() + else: + body = parser.get_body() else: - # the with scope handler is used in concise scoping - if not concise: - parser.report_error("Concise scoping is not allowed here") - body = parser.get_body() + if isinstance(node, ast.With) and node.items[0].optional_vars is None: + parser.report_error("Function " + func_name + " expects optional vars") + body = None + + if not isinstance(node, ast.With) and not concise: + parser.report_error("Concise scoping is not allowed here") reader = CallArgumentReader(func_name, args, kwargs, parser) pos_only, kwargs, varargs = arg_list internal_args = list() if category == Category.WITH_SCOPE: - internal_args.extend([parser, node, body]) + if not with_var: + internal_args.extend([parser, node, body]) + else: + internal_args.extend([parser, node]) elif category == Category.FOR_SCOPE: internal_args.extend([parser, node, body, loop_vars]) elif category == Category.SPECIAL_STMT: @@ -184,7 +193,7 @@ def wrap_func(parser, node, args, kwargs): return wrap_func -def get_arg_list(origin_func, category): +def get_arg_list(origin_func, category, with_var=False): """Helper function to get the argument list of Function Parameters @@ -192,7 +201,9 @@ def get_arg_list(origin_func, category): origin_func: function The function to get the argument list category: Category - The category of registerde function + The category of registered function + with_var: bool, optional + Whether the with scope handler neeeds optional vars """ full_arg_spec = inspect.getfullargspec(origin_func) @@ -202,11 +213,18 @@ def get_arg_list(origin_func, category): defaults = tuple() if category == Category.WITH_SCOPE: - if len(args) < 3 or args[0] != "parser" or args[1] != "node" or args[2] != "body": - raise RuntimeError( - "TVM Hybrid Script register error : the first three arguments of with scope handler" - "must be parser, node, body") - args = args[3:] + if not with_var: + if len(args) < 3 or args[0] != "parser" or args[1] != "node" or args[2] != "body": + raise RuntimeError( + "TVM Hybrid Script register error : the first three arguments of " + "this with scope handler must be parser, node, body") + args = args[3:] + else: + if len(args) < 2 or args[0] != "parser" or args[1] != "node": + raise RuntimeError( + "TVM Hybrid Script register error : the first two arguments of " + "this with scope handler must be parser, node") + args = args[2:] elif category == Category.FOR_SCOPE: if len(args) < 4 or args[0] != "parser" or \ args[1] != "node" or args[2] != "body" or args[3] != "loop_vars": @@ -259,19 +277,21 @@ def decorate(origin_func): func_name = "tir." + origin_func.__qualname__ if name is None else name Registry.functions[func_name] = \ func_wrapper(func_name, origin_func, get_arg_list(origin_func, Category.INTRIN), - Category.INTRIN, concise=False), Category.INTRIN + Category.INTRIN), Category.INTRIN return origin_func return decorate -def register_with_scope(concise=False, name=None): +def register_with_scope(concise=False, with_var=False, name=None): """Decorator to register function under with scope handler Parameters ---------- concise: bool, optional - whether this scope handler is allowed in concise scoping + whether this with scope handler is allowed in concise scoping + with_var: bool, optional + whether this with scope handler neeeds optional vars name: str, optional registered name for the function @@ -288,8 +308,10 @@ def decorate(origin_func): """Register function under category with_scope""" func_name = "tir." + origin_func.__qualname__ if name is None else name Registry.functions[func_name] = \ - func_wrapper(func_name, origin_func, get_arg_list(origin_func, Category.WITH_SCOPE), - Category.WITH_SCOPE, concise=concise), Category.WITH_SCOPE + func_wrapper(func_name, origin_func, + get_arg_list(origin_func, Category.WITH_SCOPE, with_var), + Category.WITH_SCOPE, concise=concise, with_var=with_var), \ + Category.WITH_SCOPE return origin_func return decorate @@ -308,7 +330,7 @@ def decorate(origin_func): func_name = "tir." + origin_func.__qualname__ if name is None else name Registry.functions[func_name] = \ func_wrapper(func_name, origin_func, get_arg_list(origin_func, Category.FOR_SCOPE), - Category.FOR_SCOPE, concise=False), Category.FOR_SCOPE + Category.FOR_SCOPE), Category.FOR_SCOPE return origin_func return decorate @@ -337,7 +359,7 @@ def decorate(origin_func): func_name = "tir." + origin_func.__qualname__ if name is None else name Registry.functions[func_name] = \ func_wrapper(func_name, origin_func, get_arg_list(origin_func, Category.SPECIAL_STMT), - Category.SPECIAL_STMT, concise=False), Category.SPECIAL_STMT + Category.SPECIAL_STMT), Category.SPECIAL_STMT return origin_func return decorate diff --git a/python/tvm/hybrid/scope_handler.py b/python/tvm/hybrid/scope_handler.py index 71682a9912..87b084bbe2 100644 --- a/python/tvm/hybrid/scope_handler.py +++ b/python/tvm/hybrid/scope_handler.py @@ -18,25 +18,67 @@ This module provides the functions registered into parser under with_scope or for_scope category. Scope handler nodes are StmtNodes with body, which are used to handle such scenarios. +1. For scope handler +When registering a for scope handler, the first 4 arguments must be parser, node, body, loop_vars +and these arguments will provided by Hybrid Script parser automatically + +.. code-block:: python + + for loop_vars in tir.xxx(): + +2. With scope handler +There are 4 subtypes of with scope handlers, classified by + 1) with or without as + 2) allow concise scoping or not +1) with as & concise +the first 2 arguments must be parser, node +Need to parse the body manually +Example : tir.alloc_with_scope + +.. code-block:: python + + target = tir.xxx() + with tir.xxx() as target: + +2) with as & not concise +the first 2 arguments must be parser, node +Need to parse the body manually +Example : tir.block + +.. code-block:: python + + with tir.xxx() as target: + +3) without as & concise +the first 3 arguments must be parser, node, body +Hybrid Script parser will parse the body automatically +Example : tir.allocate()/tir.reailze()/tir.attr() + +.. code-block:: python + + tir.xxx() + with tir.xxx(): + +4) without as & not concise +the first 3 arguments must be parser, node, body +Hybrid Script parser will parse the body automatically +Example : tir.assert()/tir.let() + .. code-block:: python - for x in tir.name(): - with tir.name(): - tir.name() # with scope handlers + concise scoping + with tir.xxx(): -When registering a with scope handler, the first three arguments must be parser, node, body -When registering a for scope handler, the first four arguments must be parser, node, body, loop_vars -These parameters will handled by Hybrid Script parser automatically """ # pylint: disable=redefined-builtin, unused-argument, invalid-name +from typed_ast import ast3 as ast import tvm.tir from .registry import register_with_scope, register_for_scope # With scope handler -@register_with_scope(concise=False) -def block(parser, node, body, block_vars, axes=None, name=""): +@register_with_scope(concise=False, with_var=True) +def block(parser, node, axes=None, name=""): """ With scope handler function block(axes, name) Example @@ -47,6 +89,15 @@ def block(parser, node, body, block_vars, axes=None, name=""): """ + # defining block vars and parse the body manually + block_vars = [tvm.te.var(name) for name in parser.target] + parser.scope_emitter.new_scope(is_block=True) + for block_var in block_vars: + parser.scope_emitter.update_symbol(block_var.name, block_var) + parser.scope_emitter.node_stack[-1].extend(reversed(node.body)) + body = parser.get_body() + block_info = parser.scope_emitter.pop_scope(is_block=True) + # create block iter vars if axes is None: axes = [] if len(axes) != len(block_vars): @@ -64,9 +115,7 @@ def block(parser, node, body, block_vars, axes=None, name=""): block_iters.append(tvm.tir.IterVar(axis.dom, block_vars[i], axis.iter_type)) else: parser.report_error("Invalid argument of tir.block()") - - block_info = parser.scope_emitter.pop_scope(is_block=True) - + # create block IO info if block_info.reads is None: reads = None else: @@ -95,7 +144,7 @@ def block(parser, node, body, block_vars, axes=None, name=""): inner = tvm.tir.Block(block_iters, reads, writes, body, block_info.allocates, block_info.annotations, name) - + # create block var binding if not block_info.binding: values = parser.scope_emitter.loop_stack[-1].copy() if len(values) == 0: @@ -112,10 +161,44 @@ def block(parser, node, body, block_vars, axes=None, name=""): return body +@register_with_scope(concise=True, with_var=True) +def allocate(parser, node, extents, dtype, scope, condition=True): + """ With scope handler function tir.alloc_with_scope(var, extents, dtype, scope, condition) """ + # defining buffer var and parse the body manually + + buffer_var = tvm.te.var(parser.target[0], "handle") + # (TODO) Uncomment this line if we have richer type info for buffer var + # buffer_var = tvm.te.var(parser.target[0], tvm.ir.PointerType(tvm.ir.PrimType(dtype))) + if isinstance(node, ast.With): + parser.scope_emitter.new_scope() + parser.scope_emitter.update_symbol(buffer_var.name, buffer_var) + parser.scope_emitter.node_stack[-1].extend(reversed(node.body)) + body = parser.get_body() + parser.scope_emitter.pop_scope() + else: + parser.scope_emitter.update_symbol(buffer_var.name, buffer_var) + body = parser.get_body() + condition = tvm.runtime.convert(condition) + scope = tvm.runtime.convert(scope) + body = tvm.tir.Allocate(buffer_var, dtype, extents, condition, body) + return tvm.tir.AttrStmt(buffer_var, "storage_scope", scope, body) + + +@register_with_scope(concise=True) +def realize(parser, node, body, buffer_bounds, scope, condition=True): + """ With scope handler function tir.realize(buffer_bounds, scope, condition) """ + buffer, bounds = buffer_bounds.buffer, buffer_bounds.region + scope = tvm.runtime.convert(scope) + return tvm.tir.AttrStmt(buffer, "realize_scope", scope, + tvm.tir.BufferRealize(buffer, bounds, condition, body)) + + @register_with_scope(concise=True) -def allocate(parser, node, body, buffer_var, dtype, extents, condition=True): - """ With scope handler function tir.allocate(buffer_var, dtype, extents, condition) """ - return tvm.tir.Allocate(buffer_var, dtype, extents, tvm.runtime.convert(condition), body) +def attr(parser, node, body, attr_node, attr_key, value): + """ With scope handler function tir.attr(attr_node, attr_key, value) """ + attr_node = tvm.runtime.convert(attr_node) + value = tvm.runtime.convert(value) + return tvm.tir.AttrStmt(attr_node, attr_key, value, body) @register_with_scope(concise=False) @@ -130,21 +213,6 @@ def let(parser, node, body, var, value): return tvm.tir.LetStmt(var, value, body) -@register_with_scope(concise=True) -def realize(parser, node, body, buffer_bounds, condition=True): - """ With scope handler function tir.realize(buffer_bounds, condition) """ - buffer, bounds = buffer_bounds.buffer, buffer_bounds.region - return tvm.tir.BufferRealize(buffer, bounds, condition, body) - - -@register_with_scope(concise=True) -def attr(parser, node, body, attr_node, attr_key, value): - """ With scope handler function tir.attr(attr_node, attr_key, value) """ - attr_node = tvm.runtime.convert(attr_node) - value = tvm.runtime.convert(value) - return tvm.tir.AttrStmt(attr_node, attr_key, value, body) - - # For scope handler @register_for_scope() def serial(parser, node, body, loop_vars, begin, end): diff --git a/python/tvm/hybrid/special_stmt.py b/python/tvm/hybrid/special_stmt.py index 23d6468592..db671fffe6 100644 --- a/python/tvm/hybrid/special_stmt.py +++ b/python/tvm/hybrid/special_stmt.py @@ -53,7 +53,7 @@ def match_buffer(parser, node, param, shape, dtype="float32", data=None, strides strides = [] align = align.value if not isinstance(align, int) else align offset_factor = offset_factor.value if not isinstance(offset_factor, int) else offset_factor - buffer = tvm.tir.decl_buffer(shape, dtype, parser.assign_target, data, strides, elem_offset, + buffer = tvm.tir.decl_buffer(shape, dtype, parser.target[0], data, strides, elem_offset, scope, align, offset_factor, buffer_type) parser.buffer_map[param] = buffer return buffer @@ -77,7 +77,7 @@ def buffer_allocate(parser, node, shape, dtype="float32", data=None, strides=Non strides = [] align = align.value if not isinstance(align, int) else align offset_factor = offset_factor.value if not isinstance(offset_factor, int) else offset_factor - buffer = tvm.tir.decl_buffer(shape, dtype, parser.assign_target, data, strides, elem_offset, + buffer = tvm.tir.decl_buffer(shape, dtype, parser.target[0], data, strides, elem_offset, scope, align, offset_factor, buffer_type) parser.scope_emitter.block_scope().allocates.append(tvm.tir.BufferAllocate(buffer, scope)) return buffer @@ -124,7 +124,7 @@ def buffer_decl(parser, node, shape, dtype="float32", data=None, strides=None, e strides = [] align = align.value if not isinstance(align, int) else align offset_factor = offset_factor.value if not isinstance(offset_factor, int) else offset_factor - buffer = tvm.tir.decl_buffer(shape, dtype, parser.assign_target, data, strides, elem_offset, + buffer = tvm.tir.decl_buffer(shape, dtype, parser.target[0], data, strides, elem_offset, scope, align, offset_factor, buffer_type) return buffer @@ -132,7 +132,7 @@ def buffer_decl(parser, node, shape, dtype="float32", data=None, strides=None, e @register_special_stmt() def var(parser, node, dtype): """ Special function for defining a Var""" - return te.var(parser.assign_target, dtype) + return te.var(parser.target[0], dtype) class HybridLambda: diff --git a/src/printer/tir_hybrid_printer.cc b/src/printer/tir_hybrid_printer.cc index 2e3aa2ec72..b1c239788c 100644 --- a/src/printer/tir_hybrid_printer.cc +++ b/src/printer/tir_hybrid_printer.cc @@ -526,6 +526,55 @@ Doc TIRHybridPrinter::VisitStmt_(const LetStmtNode* op) { Doc TIRHybridPrinter::VisitStmt_(const AttrStmtNode* op) { Doc doc; + // merge attr with allocate when possible + if (op->node->IsInstance() && op->attr_key == "storage_scope" + && op->body->IsInstance()) { + const auto* alloc = Downcast(op->body).get(); + if (alloc->buffer_var.same_as(op->node)) { + var_not_in_headers.insert(alloc->buffer_var.get()); + if (current_num_ != num_child_ - 1) { + doc << "with tir.allocate(" << Print(alloc->extents) << ", " + << PrintDType(alloc->dtype) << ", "<< Print(op->value); + if (!is_one(alloc->condition)) { + doc << ", " << Print(alloc->condition); + } + doc << ") as " << Print(op->node) << ":"; + doc << Doc::Indent(4, Doc::NewLine() << PrintBody(alloc->body)); + } else { + doc << Print(op->node) << " = tir.allocate(" << Print(alloc->extents) << ", " + << PrintDType(alloc->dtype) << ", " << Print(op->value); + if (!is_one(alloc->condition)) { + doc << ", " << Print(alloc->condition); + } + doc << ")" << Doc::NewLine() << PrintBody(alloc->body); + } + return doc; + } + } + + if (op->node->IsInstance() && op->attr_key == "realize_scope" + && op->body->IsInstance()) { + const auto* realize = Downcast(op->body).get(); + if (realize->buffer.same_as(op->node)) { + if (current_num_ != num_child_ - 1) { + doc << "with tir.realize(" << Print(realize->buffer) << Print(realize->bounds) + << ", " << Print(op->value); + if (!is_one(realize->condition)) { + doc << ", " << Print(realize->condition); + } + doc << "):" << Doc::Indent(4, Doc::NewLine() << PrintBody(realize->body)); + } else { + doc << "tir.realize(" << Print(realize->buffer) << Print(realize->bounds) + << ", " << Print(op->value); + if (!is_one(realize->condition)) { + doc << ", " << Print(realize->condition); + } + doc << ")" << Doc::NewLine() << PrintBody(realize->body); + } + return doc; + } + } + if (current_num_ != num_child_ - 1) { doc << "with tir.attr(" << Print(op->node) << ", " << Doc::StrLiteral(op->attr_key) << ", " << Print(op->value) << "):"; @@ -562,35 +611,13 @@ Doc TIRHybridPrinter::VisitStmt_(const StoreNode* op) { } Doc TIRHybridPrinter::VisitStmt_(const BufferRealizeNode* op) { - Doc doc; - if (current_num_ != num_child_ - 1) { - doc << "with tir.realize(" << Print(op->buffer) << Print(op->bounds); - if (!is_one(op->condition)) { - doc << ", " << Print(op->condition); - } - doc << "):" << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body)); - } else { - doc << "tir.realize(" << Print(op->buffer) << Print(op->bounds); - if (!is_one(op->condition)) { - doc << ", " << Print(op->condition); - } - doc << ")" << Doc::NewLine() << PrintBody(op->body); - } - return doc; + LOG(FATAL) << "Hybrid Printer Internal Error: All the BufferRealize should be folded with Attr"; + return Doc(); } Doc TIRHybridPrinter::VisitStmt_(const AllocateNode* op) { - Doc doc; - if (current_num_ != num_child_ - 1) { - doc << "with tir.allocate(" << Print(op->buffer_var) << ", " - << PrintDType(op->dtype) << ", " << Print(op->extents) << "):"; - doc << Doc::Indent(4, PrintBody(op->body)); - } else { - doc << "tir.allocate(" << Print(op->buffer_var) << ", " - << PrintDType(op->dtype) << ", " << Print(op->extents) << ")"; - doc << Doc::NewLine() << PrintBody(op->body); - } - return doc; + LOG(FATAL) << "Hybrid Printer Internal Error: All the Allocate should be folded with Attr"; + return Doc(); } Doc TIRHybridPrinter::VisitStmt_(const IfThenElseNode* op) { @@ -719,15 +746,12 @@ Doc TIRHybridPrinter::VisitStmt_(const BlockRealizeNode* op) { } doc << PrintSep(block_var_docs, Doc::Text(", ")) << "], "; doc << Doc::StrLiteral(block_op->tag) << ")"; - if (!block_op->iter_vars.empty()) { - std::vector block_var_names; - for (const auto& iter_var : block_op->iter_vars) { - var_not_in_headers.insert(iter_var->var.get()); - block_var_names.push_back(Print(iter_var->var)); - } - doc << " as [" << PrintSep(block_var_names, Doc::Text(", ")) << "]"; + std::vector block_var_names; + for (const auto& iter_var : block_op->iter_vars) { + var_not_in_headers.insert(iter_var->var.get()); + block_var_names.push_back(Print(iter_var->var)); } - doc << ":"; + doc << " as [" << PrintSep(block_var_names, Doc::Text(", ")) << "]:"; Doc block_attr_doc; // print predicate, binding, read/write tensor region, annotations if (!is_one(op->predicate)) { diff --git a/tests/python/tir/test_hybrid_roundtrip.py b/tests/python/tir/test_hybrid_roundtrip.py index 42abb21a38..591c0d760a 100644 --- a/tests/python/tir/test_hybrid_roundtrip.py +++ b/tests/python/tir/test_hybrid_roundtrip.py @@ -34,18 +34,15 @@ def mmult(A: ty.handle, B: ty.handle, C: ty.handle) -> None: B_1 = tir.match_buffer(B, [1024, 1024], elem_offset=0, align=128, offset_factor=1) C_1 = tir.match_buffer(C, [1024, 1024], elem_offset=0, align=128, offset_factor=1) # body - tir.attr(packedB, "realize_scope", "") - tir.realize(packedB[0:32, 0:1024, 0:32]) + tir.realize(packedB[0:32, 0:1024, 0:32], "") for x in tir.parallel(0, 32): for y in tir.serial(0, 1024): for z in tir.vectorized(0, 32): packedB[x, y, z] = B_1[y, ((x*32) + z)] - tir.attr(C_1, "realize_scope", "") - tir.realize(C_1[0:1024, 0:1024]) + tir.realize(C_1[0:1024, 0:1024], "") for x_outer in tir.parallel(0, 32): for y_outer in tir.serial(0, 32): - tir.attr(C_global, "realize_scope", "global") - tir.realize(C_global[(x_outer*32):((x_outer*32) + 32), (y_outer*32):((y_outer*32) + 32)]) + tir.realize(C_global[(x_outer*32):((x_outer*32) + 32), (y_outer*32):((y_outer*32) + 32)], "global") for x_c_init in tir.serial(0, 32): for y_c_init in tir.vectorized(0, 32): C_global[(x_c_init + (x_outer*32)), (y_c_init + (y_outer*32))] = tir.float32(0) @@ -70,21 +67,16 @@ class Module2: def mmult(A: ty.handle, B: ty.handle, C: ty.handle) -> None: # function attr dict tir.func_attr({"global_symbol": "mmult", "tir.noalias": True}) - # var definition - C_global = tir.var("handle") - packedB = tir.var("handle") A_1 = tir.match_buffer(A, [1024, 1024], elem_offset=0, align=128, offset_factor=1) B_1 = tir.match_buffer(B, [1024, 1024], elem_offset=0, align=128, offset_factor=1) C_1 = tir.match_buffer(C, [1024, 1024], elem_offset=0, align=128, offset_factor=1) # body - tir.attr(packedB, "storage_scope", "global") - tir.allocate(packedB, "float32x32", [32768]) + packedB = tir.allocate([32768], "float32x32", "global") for x in tir.parallel(0, 32): for y in tir.serial(0, 1024): tir.store(packedB, tir.ramp(((x*32768) + (y*32)), 1, 32), tir.load("float32x32", B_1.data, tir.ramp(((y*1024) + (x*32)), 1, 32), tir.broadcast(True, 32)), tir.broadcast(True, 32)) for x_outer in tir.parallel(0, 32): - tir.attr(C_global, "storage_scope", "global") - tir.allocate(C_global, "float32", [1024]) + C_global = tir.allocate([1024], "float32", "global") for y_outer in tir.serial(0, 32): for x_c_init in tir.serial(0, 32): tir.store(C_global, tir.ramp((x_c_init*32), 1, 32), tir.broadcast(tir.float32(0), 32), tir.broadcast(True, 32)) @@ -238,30 +230,26 @@ def opt_conv_tensorcore_normalize(A: ty.handle, W: ty.handle, Conv: ty.handle) - W_1 = tir.match_buffer(W, [3, 3, 16, 32, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1) Conv_1 = tir.match_buffer(Conv, [16, 14, 14, 32, 16, 16], elem_offset=0, align=128, offset_factor=1) # body - tir.attr(Conv_1, "realize_scope", "") - tir.realize(Conv_1[0:16, 0:14, 0:14, 0:32, 0:16, 0:16]) + tir.realize(Conv_1[0:16, 0:14, 0:14, 0:32, 0:16, 0:16], "") tir.attr(tir.iter_var(blockIdx_z, None, "ThreadIndex", "blockIdx.z"), "thread_extent", 196) tir.attr(tir.iter_var(blockIdx_x, None, "ThreadIndex", "blockIdx.x"), "thread_extent", 2) tir.attr(tir.iter_var(blockIdx_y, None, "ThreadIndex", "blockIdx.y"), "thread_extent", 4) tir.attr(tir.iter_var(threadIdx_y, None, "ThreadIndex", "threadIdx.y"), "thread_extent", 4) tir.attr(tir.iter_var(threadIdx_z, None, "ThreadIndex", "threadIdx.z"), "thread_extent", 2) - tir.attr(Conv_wmma_accumulator, "realize_scope", "wmma.accumulator") - tir.realize(Conv_wmma_accumulator[((blockIdx_x*8) + (threadIdx_y*2)):(((blockIdx_x*8) + (threadIdx_y*2)) + 2), tir.floordiv(blockIdx_z, 14):(tir.floordiv(blockIdx_z, 14) + 1), tir.floormod(blockIdx_z, 14):(tir.floormod(blockIdx_z, 14) + 1), ((blockIdx_y*8) + (threadIdx_z*4)):(((blockIdx_y*8) + (threadIdx_z*4)) + 4), 0:16, 0:16]) + tir.realize(Conv_wmma_accumulator[((blockIdx_x*8) + (threadIdx_y*2)):(((blockIdx_x*8) + (threadIdx_y*2)) + 2), tir.floordiv(blockIdx_z, 14):(tir.floordiv(blockIdx_z, 14) + 1), tir.floormod(blockIdx_z, 14):(tir.floormod(blockIdx_z, 14) + 1), ((blockIdx_y*8) + (threadIdx_z*4)):(((blockIdx_y*8) + (threadIdx_z*4)) + 4), 0:16, 0:16], "wmma.accumulator") for n_c_init in tir.serial(0, 2): for o_c_init in tir.serial(0, 4): tir.attr([BC, Conv_wmma_accumulator], "buffer_bind_scope", tir.tvm_tuple((n_c_init + ((blockIdx_x*8) + (threadIdx_y*2))), 1, tir.floordiv(blockIdx_z, 14), 1, tir.floormod(blockIdx_z, 14), 1, (o_c_init + ((blockIdx_y*8) + (threadIdx_z*4))), 1, 0, 16, 0, 16, dtype="handle")) tir.evaluate(tir.tvm_fill_fragment(BC.data, 16, 16, 16, tir.floordiv(BC.elem_offset, 256), tir.float32(0), dtype="handle")) for ic_outer in tir.serial(0, 8): for kh in tir.serial(0, 3): - tir.attr(Apad_shared, "realize_scope", "shared") - tir.realize(Apad_shared[(blockIdx_x*8):((blockIdx_x*8) + 8), (tir.floordiv(blockIdx_z, 14) + kh):((tir.floordiv(blockIdx_z, 14) + kh) + 1), tir.floormod(blockIdx_z, 14):(tir.floormod(blockIdx_z, 14) + 3), (ic_outer*2):((ic_outer*2) + 2), 0:16, 0:16]) + tir.realize(Apad_shared[(blockIdx_x*8):((blockIdx_x*8) + 8), (tir.floordiv(blockIdx_z, 14) + kh):((tir.floordiv(blockIdx_z, 14) + kh) + 1), tir.floormod(blockIdx_z, 14):(tir.floormod(blockIdx_z, 14) + 3), (ic_outer*2):((ic_outer*2) + 2), 0:16, 0:16], "shared") for ax2 in tir.serial(0, 3): for ax3 in tir.serial(0, 2): for ax4_ax5_fused_outer in tir.serial(0, 8): tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32) Apad_shared[((threadIdx_z + (threadIdx_y*2)) + (blockIdx_x*8)), (tir.floordiv(blockIdx_z, 14) + kh), (ax2 + tir.floormod(blockIdx_z, 14)), (ax3 + (ic_outer*2)), tir.floordiv((threadIdx_x + (ax4_ax5_fused_outer*32)), 16), tir.floormod((threadIdx_x + (ax4_ax5_fused_outer*32)), 16)] = tir.if_then_else((((((tir.floordiv(blockIdx_z, 14) + kh) >= 1) and (((tir.floordiv(blockIdx_z, 14) + kh) - 1) < 14)) and ((ax2 + tir.floormod(blockIdx_z, 14)) >= 1)) and (((ax2 + tir.floormod(blockIdx_z, 14)) - 1) < 14)), A_1[((threadIdx_z + (threadIdx_y*2)) + (blockIdx_x*8)), ((tir.floordiv(blockIdx_z, 14) + kh) - 1), ((ax2 + tir.floormod(blockIdx_z, 14)) - 1), (ax3 + (ic_outer*2)), tir.floordiv((threadIdx_x + (ax4_ax5_fused_outer*32)), 16), tir.floormod((threadIdx_x + (ax4_ax5_fused_outer*32)), 16)], tir.float16(0), dtype="float16") - tir.attr(W_shared, "realize_scope", "shared") - tir.realize(W_shared[kh:(kh + 1), 0:3, (ic_outer*2):((ic_outer*2) + 2), (blockIdx_y*8):((blockIdx_y*8) + 8), 0:16, 0:16]) + tir.realize(W_shared[kh:(kh + 1), 0:3, (ic_outer*2):((ic_outer*2) + 2), (blockIdx_y*8):((blockIdx_y*8) + 8), 0:16, 0:16], "shared") for ax1 in tir.serial(0, 3): for ax2_1 in tir.serial(0, 2): tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32) @@ -269,14 +257,12 @@ def opt_conv_tensorcore_normalize(A: ty.handle, W: ty.handle, Conv: ty.handle) - W_shared[kh, ax1, (ax2_1 + (ic_outer*2)), ((threadIdx_z + (threadIdx_y*2)) + (blockIdx_y*8)), tir.floordiv((ax4_ax5_fused_inner + (threadIdx_x*8)), 16), tir.floormod((ax4_ax5_fused_inner + (threadIdx_x*8)), 16)] = W_1[kh, ax1, (ax2_1 + (ic_outer*2)), ((threadIdx_z + (threadIdx_y*2)) + (blockIdx_y*8)), tir.floordiv((ax4_ax5_fused_inner + (threadIdx_x*8)), 16), tir.floormod((ax4_ax5_fused_inner + (threadIdx_x*8)), 16)] for ic_inner in tir.serial(0, 2): for kw in tir.serial(0, 3): - tir.attr(Apad_shared_wmma_matrix_a, "realize_scope", "wmma.matrix_a") - tir.realize(Apad_shared_wmma_matrix_a[((blockIdx_x*8) + (threadIdx_y*2)):(((blockIdx_x*8) + (threadIdx_y*2)) + 2), (tir.floordiv(blockIdx_z, 14) + kh):((tir.floordiv(blockIdx_z, 14) + kh) + 1), (kw + tir.floormod(blockIdx_z, 14)):((kw + tir.floormod(blockIdx_z, 14)) + 1), ((ic_outer*2) + ic_inner):(((ic_outer*2) + ic_inner) + 1), 0:16, 0:16]) + tir.realize(Apad_shared_wmma_matrix_a[((blockIdx_x*8) + (threadIdx_y*2)):(((blockIdx_x*8) + (threadIdx_y*2)) + 2), (tir.floordiv(blockIdx_z, 14) + kh):((tir.floordiv(blockIdx_z, 14) + kh) + 1), (kw + tir.floormod(blockIdx_z, 14)):((kw + tir.floormod(blockIdx_z, 14)) + 1), ((ic_outer*2) + ic_inner):(((ic_outer*2) + ic_inner) + 1), 0:16, 0:16], "wmma.matrix_a") for ax0 in tir.serial(0, 2): tir.attr([buffer, Apad_shared], "buffer_bind_scope", tir.tvm_tuple((ax0 + ((blockIdx_x*8) + (threadIdx_y*2))), 1, (tir.floordiv(blockIdx_z, 14) + kh), 1, (kw + tir.floormod(blockIdx_z, 14)), 1, ((ic_outer*2) + ic_inner), 1, 0, 16, 0, 16, dtype="handle")) tir.attr([buffer_1, Apad_shared_wmma_matrix_a], "buffer_bind_scope", tir.tvm_tuple((ax0 + ((blockIdx_x*8) + (threadIdx_y*2))), 1, (tir.floordiv(blockIdx_z, 14) + kh), 1, (kw + tir.floormod(blockIdx_z, 14)), 1, ((ic_outer*2) + ic_inner), 1, 0, 16, 0, 16, dtype="handle")) tir.evaluate(tir.tvm_load_matrix_sync(buffer_1.data, 16, 16, 16, tir.floordiv(buffer_1.elem_offset, 256), tir.tvm_access_ptr(tir.type_annotation(dtype="float16"), buffer.data, buffer.elem_offset, 256, 1, dtype="handle"), 16, "row_major", dtype="handle")) - tir.attr(W_shared_wmma_matrix_b, "realize_scope", "wmma.matrix_b") - tir.realize(W_shared_wmma_matrix_b[kh:(kh + 1), kw:(kw + 1), ((ic_outer*2) + ic_inner):(((ic_outer*2) + ic_inner) + 1), ((blockIdx_y*8) + (threadIdx_z*4)):(((blockIdx_y*8) + (threadIdx_z*4)) + 4), 0:16, 0:16]) + tir.realize(W_shared_wmma_matrix_b[kh:(kh + 1), kw:(kw + 1), ((ic_outer*2) + ic_inner):(((ic_outer*2) + ic_inner) + 1), ((blockIdx_y*8) + (threadIdx_z*4)):(((blockIdx_y*8) + (threadIdx_z*4)) + 4), 0:16, 0:16], "wmma.matrix_b") for ax3_1 in tir.serial(0, 4): tir.attr([buffer_2, W_shared], "buffer_bind_scope", tir.tvm_tuple(kh, 1, kw, 1, ((ic_outer*2) + ic_inner), 1, (ax3_1 + ((blockIdx_y*8) + (threadIdx_z*4))), 1, 0, 16, 0, 16, dtype="handle")) tir.attr([buffer_3, W_shared_wmma_matrix_b], "buffer_bind_scope", tir.tvm_tuple(kh, 1, kw, 1, ((ic_outer*2) + ic_inner), 1, (ax3_1 + ((blockIdx_y*8) + (threadIdx_z*4))), 1, 0, 16, 0, 16, dtype="handle")) @@ -293,6 +279,7 @@ def opt_conv_tensorcore_normalize(A: ty.handle, W: ty.handle, Conv: ty.handle) - tir.attr([buffer_5, Conv_1], "buffer_bind_scope", tir.tvm_tuple(((((blockIdx_x*4) + threadIdx_y)*2) + n_inner), 1, tir.floordiv(blockIdx_z, 14), 1, tir.floormod(blockIdx_z, 14), 1, ((((blockIdx_y*2) + threadIdx_z)*4) + o_inner), 1, 0, 16, 0, 16, dtype="handle")) tir.evaluate(tir.tvm_store_matrix_sync(buffer_4.data, 16, 16, 16, tir.floordiv(buffer_4.elem_offset, 256), tir.tvm_access_ptr(tir.type_annotation(dtype="float32"), buffer_5.data, buffer_5.elem_offset, 256, 2, dtype="handle"), 16, "row_major", dtype="handle")) + def test_opt_conv_tensorcore_normalize(): mod = opt_conv_tensorcore_normalize rt_mod = tvm.hybrid.from_source(tvm.hybrid.ashybrid(mod, True)) @@ -304,11 +291,6 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No # function attr dict tir.func_attr({"global_symbol": "default_function", "tir.noalias": True}) # var definition - Apad_shared = tir.var("handle") - Apad_shared_wmma_matrix_a = tir.var("handle") - Conv_wmma_accumulator = tir.var("handle") - W_shared = tir.var("handle") - W_shared_wmma_matrix_b = tir.var("handle") blockIdx_x = tir.var("int32") blockIdx_y = tir.var("int32") blockIdx_z = tir.var("int32") @@ -320,16 +302,11 @@ def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> No Conv_1 = tir.match_buffer(Conv, [16, 14, 14, 32, 16, 16], elem_offset=0, align=128, offset_factor=1) # body tir.attr(tir.iter_var(blockIdx_z, None, "ThreadIndex", "blockIdx.z"), "thread_extent", 196) - tir.attr(Conv_wmma_accumulator, "storage_scope", "wmma.accumulator") - tir.allocate(Conv_wmma_accumulator, "float32", [2048]) - tir.attr(Apad_shared, "storage_scope", "shared") - tir.allocate(Apad_shared, "float16", [12288]) - tir.attr(W_shared, "storage_scope", "shared") - tir.allocate(W_shared, "float16", [12288]) - tir.attr(Apad_shared_wmma_matrix_a, "storage_scope", "wmma.matrix_a") - tir.allocate(Apad_shared_wmma_matrix_a, "float16", [512]) - tir.attr(W_shared_wmma_matrix_b, "storage_scope", "wmma.matrix_b") - tir.allocate(W_shared_wmma_matrix_b, "float16", [1024]) + Conv_wmma_accumulator = tir.allocate([2048], "float32", "wmma.accumulator") + Apad_shared = tir.allocate([12288], "float16", "shared") + W_shared = tir.allocate([12288], "float16", "shared") + Apad_shared_wmma_matrix_a = tir.allocate([512], "float16", "wmma.matrix_a") + W_shared_wmma_matrix_b = tir.allocate([1024], "float16", "wmma.matrix_b") tir.attr(tir.iter_var(blockIdx_x, None, "ThreadIndex", "blockIdx.x"), "thread_extent", 2) tir.attr(tir.iter_var(blockIdx_y, None, "ThreadIndex", "blockIdx.y"), "thread_extent", 4) tir.attr(tir.iter_var(threadIdx_y, None, "ThreadIndex", "threadIdx.y"), "thread_extent", 4) diff --git a/tests/python/tir/test_tir_buffer_flatten.py b/tests/python/tir/test_tir_buffer_flatten.py index 90295ac621..a53dea55d4 100644 --- a/tests/python/tir/test_tir_buffer_flatten.py +++ b/tests/python/tir/test_tir_buffer_flatten.py @@ -71,17 +71,16 @@ def compute_at_element_wise(a: ty.handle, c: ty.handle) -> None: A = tir.match_buffer(a, (128, 128), "float32", name="A") C = tir.match_buffer(c, (128, 128), "float32", name="C") - with tir.block(name="root"): - B = tir.buffer_allocate((128, 128), "float32", name="B") + B = tir.buffer_allocate((128, 128), "float32", name="B") - for i in range(0, 128): - for j in range(0, 128): - with tir.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 + for i in range(0, 128): + for j in range(0, 128): + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 - for j in range(0, 128): - with tir.block([128, 128], "C") as [vi, vj]: - C[vi, vj] = B[vi, vj] + 1.0 + for j in range(0, 128): + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = B[vi, vj] + 1.0 def test_local_allocate():