From 8ab60b9e8b657b65d41e0d34f2627d4a4a1a5dac Mon Sep 17 00:00:00 2001 From: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Date: Mon, 10 Aug 2020 23:25:01 +0800 Subject: [PATCH] [TIR][Hybrid] Hybrid Script Support for TIR (#6227) --- python/tvm/__init__.py | 3 + python/tvm/hybrid/__init__.py | 20 + python/tvm/hybrid/_ffi_api.py | 21 + python/tvm/hybrid/intrin.py | 136 +++ python/tvm/hybrid/meta_unparser.py | 50 ++ python/tvm/hybrid/parser.py | 755 ++++++++++++++++ python/tvm/hybrid/registry.py | 231 +++++ python/tvm/hybrid/scope_emitter.py | 62 ++ python/tvm/hybrid/scope_handler.py | 89 ++ python/tvm/hybrid/special_stmt.py | 102 +++ python/tvm/hybrid/ty.py | 63 ++ python/tvm/hybrid/utils.py | 96 ++ src/printer/tir_hybrid_printer.cc | 845 ++++++++++++++++++ .../unittest/test_hybrid_error_report.py | 105 +++ .../python/unittest/test_hybrid_roundtrip.py | 536 +++++++++++ 15 files changed, 3114 insertions(+) create mode 100644 python/tvm/hybrid/__init__.py create mode 100644 python/tvm/hybrid/_ffi_api.py create mode 100644 python/tvm/hybrid/intrin.py create mode 100644 python/tvm/hybrid/meta_unparser.py create mode 100644 python/tvm/hybrid/parser.py create mode 100644 python/tvm/hybrid/registry.py create mode 100644 python/tvm/hybrid/scope_emitter.py create mode 100644 python/tvm/hybrid/scope_handler.py create mode 100644 python/tvm/hybrid/special_stmt.py create mode 100644 python/tvm/hybrid/ty.py create mode 100644 python/tvm/hybrid/utils.py create mode 100644 src/printer/tir_hybrid_printer.cc create mode 100644 tests/python/unittest/test_hybrid_error_report.py create mode 100644 tests/python/unittest/test_hybrid_roundtrip.py diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py index cb1f4d2c20a0..2474ae8b5b1a 100644 --- a/python/tvm/__init__.py +++ b/python/tvm/__init__.py @@ -60,6 +60,9 @@ # tvm.parser from . import parser +# tvm tir hybrid script +from . import hybrid + # others from . import arith diff --git a/python/tvm/hybrid/__init__.py b/python/tvm/hybrid/__init__.py new file mode 100644 index 000000000000..7c3ef758d34f --- /dev/null +++ b/python/tvm/hybrid/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Hybrid Script APIs of TVM Python Package, aimed to support TIR""" + +from .utils import create_module, ashybrid, script +from .parser import from_source diff --git a/python/tvm/hybrid/_ffi_api.py b/python/tvm/hybrid/_ffi_api.py new file mode 100644 index 000000000000..d59302a95dd1 --- /dev/null +++ b/python/tvm/hybrid/_ffi_api.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI APIs for tvm.hybrid""" +import tvm._ffi + + +tvm._ffi._init_api("tir.hybrid", __name__) diff --git a/python/tvm/hybrid/intrin.py b/python/tvm/hybrid/intrin.py new file mode 100644 index 000000000000..3dc46a280b72 --- /dev/null +++ b/python/tvm/hybrid/intrin.py @@ -0,0 +1,136 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Hybrid Script Parser Intrinsic Functions + +IRNodes (StmtNodes without body, PrimExprNodes and more) are called intrins +""" +# pylint: disable=redefined-builtin +import tvm.tir +from .registry import register_intrin + + +@register_intrin +def bool(imm): + return tvm.tir.const(imm.value, "bool") + + +@register_intrin +def int8(imm): + return tvm.tir.const(imm.value, "int8") + + +@register_intrin +def int16(imm): + return tvm.tir.const(imm.value, "int16") + + +@register_intrin +def int32(imm): + return tvm.tir.const(imm.value, "int32") + + +@register_intrin +def int64(imm): + return tvm.tir.const(imm.value, "int64") + + +@register_intrin +def uint8(imm): + return tvm.tir.const(imm.value, "uint8") + + +@register_intrin +def uint16(imm): + return tvm.tir.const(imm.value, "uint16") + + +@register_intrin +def uint32(imm): + return tvm.tir.const(imm.value, "uint32") + + +@register_intrin +def uint64(imm): + return tvm.tir.const(imm.value, "uint64") + + +@register_intrin +def float8(imm): + return tvm.tir.const(imm.value, "float8") + + +@register_intrin +def float16(imm): + return tvm.tir.const(imm.value, "float16") + + +@register_intrin +def float32(imm): + return tvm.tir.const(imm.value, "float32") + + +@register_intrin +def float64(imm): + return tvm.tir.const(imm.value, "float64") + + +@register_intrin +def floordiv(x, y): + return tvm.tir.floordiv(x, y) + + +@register_intrin +def floormod(x, y): + return tvm.tir.floormod(x, y) + + +@register_intrin +def load(dtype, var, index, predicate=True): + return tvm.tir.Load(dtype, var, index, predicate) + + +@register_intrin +def cast(dtype, value): + return tvm.tir.Cast(dtype, value) + + +@register_intrin +def ramp(base, stride, lanes): + lanes = lanes.value if not isinstance(lanes, int) else lanes + return tvm.tir.Ramp(base, stride, lanes) + + +@register_intrin +def broadcast(value, lanes): + lanes = lanes.value if not isinstance(lanes, int) else lanes + return tvm.tir.Broadcast(value, lanes) + + +@register_intrin +def evaluate(value): + return tvm.tir.Evaluate(value) + + +@register_intrin +def store(var, index, value, predicate=True): + return tvm.tir.Store(var, value, index, predicate) + + +@register_intrin +def iter_var(var, dom, iter_type, thread_tag): + iter_type = getattr(tvm.tir.IterVar, iter_type) + return tvm.tir.IterVar(dom, var, iter_type, thread_tag) diff --git a/python/tvm/hybrid/meta_unparser.py b/python/tvm/hybrid/meta_unparser.py new file mode 100644 index 000000000000..d56fbad3d1e3 --- /dev/null +++ b/python/tvm/hybrid/meta_unparser.py @@ -0,0 +1,50 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Unparse meta AST node into a dict""" +# pylint: disable=invalid-name + +from typed_ast import ast3 as ast + + +class MetaUnparser(ast.NodeVisitor): + """Python AST Visitor to unparse meta AST node into a dict""" + + def visit_Dict(self, node): + keys = [self.visit(key) for key in node.keys] + values = [self.visit(value) for value in node.values] + return dict(zip(keys, values)) + + def visit_Tuple(self, node): + return tuple(self.visit(element) for element in node.elts) + + def visit_List(self, node): + return [self.visit(element) for element in node.elts] + + def visit_keyword(self, node): + return node.arg, self.visit(node.value) + + def visit_NameConstant(self, node): + return node.value + + def visit_Constant(self, node): + return node.value + + def visit_Num(self, node): + return node.n + + def visit_Str(self, node): + return node.s diff --git a/python/tvm/hybrid/parser.py b/python/tvm/hybrid/parser.py new file mode 100644 index 000000000000..bf8466f2e126 --- /dev/null +++ b/python/tvm/hybrid/parser.py @@ -0,0 +1,755 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Hybrid Script Parser For TIR""" +# pylint: disable=invalid-name, missing-docstring, inconsistent-return-statements, no-else-return +# pylint: disable=unnecessary-comprehension, unused-argument, import-outside-toplevel +# pylint: disable=unused-import +import json +import numbers +import operator +from typed_ast import ast3 as ast + +import tvm._ffi +from tvm import tir +from tvm._ffi.base import TVMError +from tvm.ir import GlobalVar +from tvm.tir import all as _all +from tvm.tir import expr as _expr + +from . import scope_emitter, special_stmt, scope_handler, intrin +from .meta_unparser import MetaUnparser +from .registry import Registry + + +class HybridParserError(RuntimeError): + """Hybrid Parser Runtime Error""" + + +class HybridParser(ast.NodeVisitor): + """Python AST visitor pass which finally lowers it to TIR + Notes for extension: + 1. To support new types of AST nodes. Add a function visit_xxx(). + 2. To support new functions + We divide allowed function calls in hybrid script into 3 categories, + which is scope_handler, intrin and special_stmt. + 1) scope_handler: scope_handler functions correspond to StmtNodes without body, which can be + further classified into 2 categories: with scope handler can for scope handlers + 2) intrin: intrin functions corresponds to the remaining IRNodes (StmtNodes without body, + PrimExprNodes and more) + 3) special_stmt: special_stmt functions don't correspond to an IRNode in the AST directly. + It is usually used for some information that is not suitable to be printed directly. + When visiting With node, we check with_scope registry. + When visiting For node, we check for_scope registry. + """ + + _binop_maker = { + ast.Add: tir.Add, + ast.Sub: tir.Sub, + ast.Mult: tir.Mul, + ast.Div: tir.Div, + ast.FloorDiv: tir.FloorDiv, + ast.Mod: tir.FloorMod, + ast.BitOr: operator.or_, + ast.BitAnd: operator.and_, + ast.BitXor: operator.xor, + ast.Gt: tir.GT, + ast.GtE: tir.GE, + ast.Lt: tir.LT, + ast.LtE: tir.LE, + ast.Eq: tir.EQ, + ast.NotEq: tir.NE, + ast.And: tir.And, + ast.Or: tir.Or, + } + + _unaryop_maker = { + ast.USub: operator.neg, + ast.Invert: operator.invert, + ast.Not: tir.Not + } + + def __init__(self, src, base_lienno): + self.params = None + self.buffer_map = None + self.dict_attr = None + self.scope_emitter = None + + self.src = src.split('\n') + self.base_lineno = base_lienno + self.current_lineno = 0 + self.current_col_offset = 0 + self.meta = None + + self.functions = {} + + self._in_with_func_arg = False + self._assign_target = None + + def init_function_parsing_env(self): + """Initialize function parsing environment""" + self.params = [] # parameter list + self.buffer_map = {} # buffer map + self.dict_attr = {} # dict attr + self.scope_emitter = scope_emitter.ScopeEmitter(self) # scope emitter + + @staticmethod + def is_meta(node): + """Judge whether an AST node is META""" + return isinstance(node, ast.Assign) and len(node.targets) == 1 \ + and isinstance(node.targets[0], ast.Name) and node.targets[0].id == "__tvm_meta__" + + def init_meta(self, meta_dict): + if meta_dict is not None: + self.meta = tvm.ir.load_json(json.dumps(meta_dict)) + + def visit(self, node): + """Override method in ast.NodeVisitor""" + old_lineno, old_col_offset = self.current_lineno, self.current_col_offset + + if hasattr(node, "lineno"): + self.current_lineno = self.base_lineno + node.lineno - 1 + if hasattr(node, "col_offset"): + self.current_col_offset = node.col_offset + + method = 'visit_' + node.__class__.__name__ + visitor = getattr(self, method, self.generic_visit) + visit_res = visitor(node) + + self.current_lineno, self.current_col_offset = old_lineno, old_col_offset + + return visit_res + + def wrap_line_col(self, message, lineno, col_offset): + """Wrap the message with line number and column offset""" + src_line = self.src[lineno - self.base_lineno] + leading_space = len(src_line) - len(src_line.lstrip(' ')) + col_offset = col_offset - leading_space + src_line = src_line[leading_space:] + return "\n " + src_line + "\n " + " " * col_offset + "^\n" + "ParserError in line " \ + + str(lineno) + " : " + message + + def report_error(self, message, lineno=None, col_offset=None): + """ Report an error occur in line lineno and column col_offset + Parameters + ---------- + message : str + Error message + lineno : int + Line number of error line + col_offset : int + Column offset of error line + """ + + if lineno is None: + lineno = self.current_lineno + if col_offset is None: + col_offset = self.current_col_offset + raise HybridParserError(self.wrap_line_col(message, lineno, col_offset)) + + def get_type_name(self, vtype): + if isinstance(vtype, ast.Attribute) \ + and isinstance(vtype.value, ast.Name) and vtype.value.id == 'ty': + return vtype.attr + self.report_error("invalid type annotation") + + def get_body(self): + body = [] + while len(self.scope_emitter.node_stack[-1]) > 0: + res = self.visit(self.scope_emitter.node_stack[-1].pop()) + if res is not None: + body.append(res) + return tvm.tir.SeqStmt(body) if len(body) > 1 else body[0] + + def parse_type(self, vtype): + """ Parse type annotation AST into Type object """ + if isinstance(vtype, ast.NameConstant) and vtype.value is None: + return tvm.ir.TupleType([]) + elif isinstance(vtype, ast.Attribute): + return tvm.ir.PrimType(self.get_type_name(vtype)) + elif isinstance(vtype, ast.Subscript) and isinstance(vtype.slice, ast.Index): + type_name = self.get_type_name(vtype.value) + if isinstance(vtype.slice.value, ast.Tuple): + args = [self.parse_type(element) for element in vtype.slice.value.elts] + else: + args = [self.parse_type(vtype.slice.value)] + if type_name == "Ptr": + return tvm.ir.PointerType(*args) + elif type_name == "Tuple": + return tvm.ir.TupleType(args) + + self.report_error("invalid type annotation") + + def generic_visit(self, node): + """ Override method in ast.NodeVisitor. + To directly filter out invalidate type of stmt. + """ + + self.report_error(type(node).__name__ + " stmt is not supported now") + + def visit_Module(self, node): + """ Module visitor + AST abstract grammar: + Module(stmt* body, type_ignore* type_ignore) + By now we support two format of hybrid script shown below. + + Example + ------- + 1. Generate a Function(If the code is printed, then it may bring meta) + .. code-block:: python + + import tvm + + @tvm.hybrid.script + def A(...): + ... + + # call hybrid parser when call this function, get a Function + func = A + + 2. Generate an IRModule + .. code-block:: python + + import tvm + + @tvm.hybrid.script + class MyMod(): + def A(...): + ... + + def B(...): + ... + + __tvm_meta__ = ... + + # call hybrid parser during construction, get an IRModule + mod = MyMod() + """ + + if len(node.body) == 1 and isinstance(node.body[0], (ast.ClassDef, ast.FunctionDef)): + # class or single function + return self.visit(node.body[0]) + elif len(node.body) == 2: + if isinstance(node.body[0], ast.Assign): + node.body[0], node.body[1] = node.body[1], node.body[0] + if isinstance(node.body[0], ast.FunctionDef) and HybridParser.is_meta(node.body[1]): + # function with meta + self.init_meta(MetaUnparser().visit(node.body[1].value)) + return self.visit(node.body[0]) + self.report_error( + "Only one-function, one-class or function-with-meta source code is allowed") + + def visit_ClassDef(self, node): + """ ClassDef visitor + AST abstract grammar: + ClassDef(identifier name, expr* bases, keyword* keywords, stmt* body, + expr* decorator_list) + """ + + # parse meta + count = False + for body_element in node.body: + if isinstance(body_element, ast.FunctionDef): + pass + elif HybridParser.is_meta(body_element) and not count: + count = True + self.init_meta(MetaUnparser().visit(body_element.value)) + else: + self.report_error("invalid class member") + + # parse member functions + for body_element in node.body: + if isinstance(body_element, ast.FunctionDef): + self.visit(body_element) + from .utils import create_module + return create_module(self.functions) + + def visit_FunctionDef(self, node): + """ FunctionDef visitor + AST abstract grammar: + FunctionDef(identifier name, arguments args, stmt* body, expr* decorator_list, + expr? returns, string? type_comment) + arguments = (arg* posonlyargs, arg* args, arg? vararg, arg* kwonlyargs, + expr* kw_defaults, arg? kwarg, expr* defaults) + arg = (identifier arg, expr? annotation, string? type_comment) + """ + + self.init_function_parsing_env() + # add parameters of function + for arg in node.args.args: + arg_var = tvm.te.var(arg.arg, self.parse_type(arg.annotation)) + self.scope_emitter.update_symbol(arg.arg, arg_var) + self.params.append(arg_var) + + # visit the body of function + self.scope_emitter.node_stack[-1].extend(reversed(node.body)) + + # fetch the body and return a tir.PrimFunc + func = tvm.tir.PrimFunc(self.params, self.get_body(), + ret_type=self.parse_type(node.returns), + buffer_map=self.buffer_map, + attrs=tvm.ir.make_node("DictAttrs", **self.dict_attr)) + self.functions[GlobalVar(node.name)] = func + return func + + def visit_Assign(self, node): + """ Assign visitor + AST abstract grammar: + Assign(expr* targets, expr value, string? type_comment) + By now only 2 types of Assign is supported: + 1. special stmts that appear as assign stmt + 1.1 Buffer = tir.buffer_bind()/tir.buffer_decl() + 1.2 Var = tir.var() + 2. (BufferStore) Buffer[PrimExpr, PrimExpr, ..., PrimExpr] = PrimExpr + 3. (Store) Var[PrimExpr] = PrimExpr + """ + + if not len(node.targets) == 1: + self.report_error("Only one-valued assignment is supported now") + target = node.targets[0] + + if isinstance(target, ast.Name): + # scenario 1 + self._assign_target = target.id + rhs = self.visit(node.value) + if not isinstance(node.value, ast.Call): + self.report_error("Unsupported Assign stmt") + self.scope_emitter.update_symbol(target.id, rhs) + elif isinstance(target, ast.Subscript): + # scenario 2&3 + symbol, indexes = self.visit(target) + self._assign_target = (symbol, indexes) + rhs = self.visit(node.value) + if isinstance(symbol, tvm.tir.Buffer): + return tvm.tir.BufferStore(symbol, tvm.runtime.convert(rhs), indexes) + else: + if len(indexes) != 1: + self.report_error("Invalid Store stmt") + return tvm.tir.Store(symbol, tvm.runtime.convert(rhs), indexes[0], + tvm.runtime.convert(True)) + else: + self.report_error("Unsupported Assign stmt") + + def visit_AnnAssign(self, node): + """ AnnAssign visitor + AST abstract grammar: + AnnAssign(expr target, expr annotation, expr? value, int simple) + Corresponds to concise mode of with tir.let() + """ + + if isinstance(node.target, ast.Name): + value = self.visit(node.value) + var = tvm.te.var(node.target.id, self.parse_type(node.annotation)) + self.scope_emitter.update_symbol(var.name, var) + return tvm.tir.LetStmt(var, value, self.visit(self.scope_emitter.node_stack[-1].pop())) + else: + self.report_error("Unsupported AnnAssign stmt") + + def visit_Assert(self, node): + """ Assert visitor + AST abstract grammar: + Assert(expr test, expr? msg) + Corresponds to concise mode of with tir.assert() + """ + + condition = self.visit(node.test) + if node.msg is None: + self.report_error("Message of AssertStmt can't be None") + message = self.visit(node.msg) + return tvm.tir.AssertStmt(condition, tvm.runtime.convert(message), self.get_body()) + + def visit_For(self, node): + """ For visitor + AST abstract grammar: + For(expr target, expr iter, stmt* body, stmt* orelse, string? type_comment) + By now only 1 type of For is supported: + 1. for name in tir.range(begin, end, for_type) + """ + + if not isinstance(node.target, ast.Name): + self.report_error("The loop variable should be a name variable") + # check node.iter, which is a tir Call + if not isinstance(node.iter, ast.Call): + self.report_error("The loop iter should be a Call") + if not isinstance(node.iter.func, ast.Attribute) \ + or not isinstance(node.iter.func.value, ast.Name) \ + or node.iter.func.value.id != "tir": + self.report_error("The loop iter Call should be tir.name()") + + func_name = node.iter.func.attr + # collect arguments + args = [self.visit(arg) for arg in node.iter.args] + kw_args = [self.visit(keyword) for keyword in node.iter.keywords] + kw_args = {kw_arg[0]: kw_arg[1] for kw_arg in kw_args} + # All the functions supported in For stmt are registered in scope_handler.ForScope + if func_name not in Registry.for_scope: + self.report_error("Function " + func_name + " used in For stmt is not supported now", + self.current_lineno, + node.iter.col_offset) + + old_lineno, old_col_offset = self.current_lineno, self.current_col_offset + self.current_lineno, self.current_col_offset = \ + self.base_lineno + node.iter.lineno - 1, node.iter.col_offset + res = Registry.for_scope.get(func_name)(self, node, args, kw_args) + self.current_lineno, self.current_col_offset = old_lineno, old_col_offset + return res + + def visit_With(self, node): + """ With visitor + AST abstract grammar: + With(withitem* items, stmt* body, string? type_comment) + withitem = (expr context_expr, expr? optional_vars) + By now only 1 type of With is supported: + 1. with tir.let/tir.Assert()/tir.attr()/tir.allocate()/tir.realize() + """ + + if len(node.items) != 1: + self.report_error("Only one with element is supported now") + if not isinstance(node.items[0].context_expr, ast.Call): + self.report_error("The context expression of with should be a Call") + func_call = node.items[0].context_expr + if not isinstance(func_call.func, ast.Attribute) \ + or not isinstance(func_call.func.value, ast.Name) \ + or func_call.func.value.id != "tir": + self.report_error("The context expression of with should be tir.name()") + + func_name = func_call.func.attr + # collect arguments + args = [self.visit(arg) for arg in func_call.args] + kw_args = [self.visit(keyword) for keyword in func_call.keywords] + kw_args = {kw_arg[0]: kw_arg[1] for kw_arg in kw_args} + if func_name not in Registry.with_scope: + self.report_error("Function " + func_name + " used in With stmt is not supported now") + + # All the functions supported in With stmt are registered in scope_handler.WithScope + old_lineno, old_col_offset = self.current_lineno, self.current_col_offset + self.current_lineno, self.current_col_offset = \ + self.base_lineno + func_call.lineno - 1, func_call.col_offset + res = Registry.with_scope.get(func_name)(self, node, args, kw_args) + self.current_lineno, self.current_col_offset = old_lineno, old_col_offset + return res + + def visit_If(self, node): + """ If visitor + AST abstract grammar: + If(expr test, stmt* body, stmt* orelse) + """ + + condition = self.visit(node.test) + # then body + self.scope_emitter.new_scope() + self.scope_emitter.node_stack[-1].extend(reversed(node.body)) + then_body = self.get_body() + self.scope_emitter.pop_scope() + + # else body + if len(node.orelse) > 0: + self.scope_emitter.new_scope() + self.scope_emitter.node_stack[-1].extend(reversed(node.orelse)) + else_body = self.get_body() + self.scope_emitter.pop_scope() + else: + else_body = None + return tvm.tir.IfThenElse(condition, then_body, else_body) + + def visit_Call(self, node): + """ Call visitor + AST abstract grammar: + Call(expr func, expr* args, keyword* keywords) + keyword = (identifier? arg, expr value) + """ + + # collect arguments + args = [self.visit(arg) for arg in node.args] + kw_args = [self.visit(keyword) for keyword in node.keywords] + kw_args = {kw_arg[0]: kw_arg[1] for kw_arg in kw_args} + + maybe_intrin = False + if isinstance(node.func, ast.Attribute) and isinstance(node.func.value, ast.Name): + if node.func.value.id == "tir": + func_name = node.func.attr + maybe_intrin = True + else: + self.report_error("Unsupported Attribute typed function call") + else: + self.report_error("Unsupported function call") + + if func_name in Registry.special_stmt: + return Registry.special_stmt.get(func_name)(self, node, args, kw_args) + if func_name in Registry.intrin: + return Registry.intrin.get(func_name)(self, node, args, kw_args) + if func_name in Registry.with_scope: + return Registry.with_scope.get(func_name)(self, node, args, kw_args) + if maybe_intrin: + return tvm.tir.Call(kw_args["dtype"], tvm.ir.op.Op.get("tir." + func_name), args) + + self.report_error("Function " + func_name + " is not supported now") + + def visit_Expr(self, node): + """ Expr visitor + AST abstract grammar: + Expr(expr value) + + Now only 2 types of Expr stmt is allowed: + 1. Concise mode of with scope handlers + tir.attr()/tir.assert()/tir.allocate()/tir.realize() + 2. special stmts appear as a call + tir.set_func_attr() + """ + + if not isinstance(node.value, ast.Call): + self.report_error("Unsupported Expr stmt") + return self.visit(node.value) + + def visit_BinOp(self, node): + """ BinOp visitor + AST abstract grammar: + BinOp(expr left, operator op, expr right) + """ + + lhs = self.visit(node.left) + rhs = self.visit(node.right) + if not isinstance(node.op, tuple(HybridParser._binop_maker.keys())): + self.report_error("BinOp " + str(type(node.op)) + " is not supported now") + return HybridParser._binop_maker[type(node.op)](lhs, rhs) + + def visit_Compare(self, node): + """ Compare visitor + AST abstract grammar: + Compare(expr left, expr right, ops=) + """ + + ops = [self.visit(node.left)] + ops += [self.visit(comparator) for comparator in node.comparators] + res = [] + for i in range(len(node.ops)): + lhs = ops[i] + rhs = ops[i + 1] + res.append(HybridParser._binop_maker[type(node.ops[i])](lhs, rhs)) + return _all(*res) + + def visit_BoolOp(self, node): + """ BoolOp visitor + AST abstract grammar: + BoolOp(boolop op, expr* values) + """ + + values = [self.visit(value) for value in node.values] + return HybridParser._binop_maker[type(node.op)](*values) + + def visit_UnaryOp(self, node): + """ UnaryOp visitor + AST abstract grammar: + UnaryOp(unaryop op, expr operand) + """ + + operand = self.visit(node.operand) + if not isinstance(node.op, tuple(HybridParser._unaryop_maker.keys())): + self.report_error("UnaryOp " + str(type(node.op)) + " is not supported now") + return HybridParser._unaryop_maker[type(node.op)](operand) + + def visit_Subscript(self, node): + """ Subscript visitor + AST abstract grammar: + Subscript(expr value, slice slice, expr_context ctx) + slice = Slice(expr? lower, expr? upper, expr? step) + | ExtSlice(slice* dims) + | Index(expr value) + By now only 2 types of Subscript are supported: + 1. Buffer[index, index, ...], Buffer element access(BufferLoad & BufferStore) + Var[index] Buffer element access() + 2. meta[type_key][index], Meta info access + """ + + if isinstance(node.value, (ast.Name, ast.Attribute)): + symbol = self.visit(node.value) + if isinstance(node.slice, ast.Index): + # BufferLoad & BufferStore + if isinstance(node.slice.value, ast.Tuple): + # Buffer/Var[index, index, ...] + indexes = [self.visit(element) for element in node.slice.value.elts] + else: + # Buffer/Var[index] + indexes = [self.visit(node.slice.value)] + if isinstance(node.ctx, ast.Load): + if isinstance(symbol, tir.expr.Var): + return tvm.tir.Load("float32", symbol, indexes, True) + else: + return tvm.tir.BufferLoad(symbol, indexes) + else: + return symbol, indexes + else: + # Buffer Region, now used in tir.realize(buffer[bounds]) + doms = [] + slice_nodes = [] + if isinstance(node.slice, ast.Slice): + # Buffer[begin:end] + slice_nodes.append(node.slice) + elif isinstance(node.slice, ast.ExtSlice): + # Buffer[begin:end, begin:end] + slice_nodes.extend(node.slice.dims) + + for dim in slice_nodes: + if not hasattr(dim, "step"): + self.report_error("slice of Buffer Region ought to be begin:end") + if dim.step is not None: + self.report_error("step is not allowed in Buffer Region") + upper = self.visit(dim.upper) + lower = self.visit(dim.lower) + extent = upper - lower + if isinstance(extent, _expr.PrimExpr): + ana = tvm.arith.Analyzer() + extent = ana.simplify(extent) + doms.append(tvm.ir.Range.from_min_extent(lower, extent)) + return symbol, doms + + elif isinstance(node.value, ast.Subscript) and isinstance(node.value.value, ast.Name) \ + and node.value.value.id == 'meta': + # meta[type_key][index] + if not (isinstance(node.slice, ast.Index) and isinstance(node.slice.value, ast.Num)) \ + or not (isinstance(node.value.slice, ast.Index) \ + and isinstance(node.value.slice.value, ast.Name)): + self.report_error("The meta access format ought to be meta[type_key][index]") + type_key = node.value.slice.value.id + index = node.slice.value.n + node_list = self.meta[type_key] + if node_list is None: + self.report_error("type_key " + type_key + " in meta not found") + if len(node_list) <= index: + self.report_error("index " + index + " out of range " + len(node_list)) + return node_list[index] + else: + self.report_error("Only buffer variable and meta can be subscriptable") + + def visit_Name(self, node): + """ Name visitor + AST abstract grammar: + Name(identifier id, expr_context ctx) + """ + + name = node.id + symbol = self.scope_emitter.lookup_symbol(name) + if symbol is None: + self.report_error("Unknown symbol %s" % name) + return symbol + + def visit_Attribute(self, node): + """ Attribute visitor + AST abstract grammar: + Attribute(expr value, identifier attr, expr_context ctx) + """ + + if not isinstance(node.value, ast.Name): + self.report_error("The value of Attribute ought to a Name") + name = node.value.id + symbol = self.scope_emitter.lookup_symbol(name) + if symbol is None or not isinstance(symbol, tvm.tir.Buffer): + self.report_error("Unsupported Attribute expression") + if not hasattr(symbol, node.attr): + self.report_error("Type " + type(symbol) + " has not attr " + node.attr) + return getattr(symbol, node.attr) + + def visit_Dict(self, node): + """ Dict visitor + AST abstract grammar: + Dict(expr* keys, expr* values) + """ + + keys = [self.visit(key) for key in node.keys] + values = [self.visit(value) for value in node.values] + + return {key: value for key, value in zip(keys, values)} + + def visit_Tuple(self, node): + """ Tuple visitor + AST abstract grammar: + Tuple(expr* elts, expr_context ctx) + """ + + return tuple(self.visit(element) for element in node.elts) + + def visit_List(self, node): + """ List visitor + AST abstract grammar: + List(expr* elts, expr_context ctx) + """ + + return [self.visit(element) for element in node.elts] + + def visit_keyword(self, node): + """ Keyword visitor + AST abstract grammar: + keyword = (identifier? arg, expr value) + """ + + return node.arg, self.visit(node.value) + + def visit_NameConstant(self, node): + return tvm.runtime.convert(node.value) + + def visit_Constant(self, node): + return tvm.runtime.convert(node.value) + + def visit_Num(self, node): + if isinstance(node.n, numbers.Integral): + dtype = "int32" + elif isinstance(node.n, float): + dtype = "float32" + else: + self.report_error("The data type should be one of (int, float)") + return tvm.tir.const(node.n, dtype) + + def visit_Str(self, node): + return node.s + + +def from_source(src, func_lineno=0): + """ Parse the src into TIR + + Parameters + ---------- + src : str + Pruned source of original script + func_lineno : Optional[int] + The line number of the first line of the script to be parsed + Returns + ------- + functions : PrimFunc or IRModule + The PrimFunc or IRModule in IR. + """ + + root = ast.parse(src) + parser = HybridParser(src, func_lineno) + + try: + return parser.visit(root) + except HybridParserError as e: + raise e + except TVMError as e: + # TVM internal c++ error, we have to process the error message and inject line info + inject_e = str(e).split('\n') + msg = inject_e[-1].split(':', maxsplit=1)[1].strip() + inject_e = inject_e[:-1] + inject_e.extend( + parser.wrap_line_col(msg, parser.current_lineno, parser.current_col_offset).split('\n')) + inject_e[-1] = "TVM" + inject_e[-1][6:] + raise TVMError('\n'.join(inject_e)) + except Exception as e: + inject_e = parser.wrap_line_col(str(e), parser.current_lineno, parser.current_col_offset) + raise HybridParserError(inject_e) + + +tvm._ffi._init_api("tvm.hybrid.parser") diff --git a/python/tvm/hybrid/registry.py b/python/tvm/hybrid/registry.py new file mode 100644 index 000000000000..f33e03d11470 --- /dev/null +++ b/python/tvm/hybrid/registry.py @@ -0,0 +1,231 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Hybrid Script Parser Function Registry """ +# pylint: disable=inconsistent-return-statements +import inspect +from typed_ast import ast3 as ast + + +class Registry(object): + """Registration map + All these maps are static + """ + intrin = dict() + with_scope = dict() + for_scope = dict() + special_stmt = dict() + + +class CallArgumentReader(object): + """A helper class which read required argument from passed arguments""" + + def __init__(self, func_name, args, kwargs, parser): + self.func_name = func_name + self.args = args + self.kwargs = kwargs + self.parser = parser + + def get_func_compulsory_arg(self, pos, name): + """Get corresponding function argument from argument list which is compulsory""" + + if len(self.args) >= pos: + arg = self.args[pos - 1] + elif name not in self.kwargs.keys(): + self.parser.report_error(self.func_name + " misses argument " + name) + else: + arg = self.kwargs[name] + + return arg + + def get_func_optional_arg(self, pos, name, default): + """Get corresponding function argument from argument list which is optional. + If user doesn't provide the argument, set it to default value + """ + + if len(self.args) >= pos: + arg = self.args[pos - 1] + elif name in self.kwargs.keys(): + arg = self.kwargs[name] + else: + return default + + return arg + + +def func_wrapper(func_name, func_to_register, arg_list, need_parser_and_node, need_body, concise): + """Helper function to wrap a function to be registered """ + + def wrap_func(parser, node, args, kwargs): + reader = CallArgumentReader(func_name, args, kwargs, parser) + internal_args = list() + + if need_body and not isinstance(node, ast.For): + # automatically parse body for with scope handlers + if isinstance(node, ast.With): + # the with scope handler is used inside with context + parser.scope_emitter.new_scope() + parser.scope_emitter.node_stack[-1].extend(reversed(node.body)) + body = parser.get_body() + parser.scope_emitter.pop_scope() + else: + # the with scope handler is used in concise scoping + if not concise: + parser.report_error("Concise scoping is not allowed here") + body = parser.get_body() + + if need_parser_and_node: + internal_args.append(parser) + internal_args.append(node) + + for i, arg_info in enumerate(arg_list): + if len(arg_info) == 1: + arg_name, = arg_info + if need_body and arg_name == "body": + internal_args.append(body) + else: + internal_args.append(reader.get_func_compulsory_arg(i + 1, arg_name)) + else: + arg_name, default = arg_info + internal_args.append(reader.get_func_optional_arg(i + 1, arg_name, default=default)) + + return func_to_register(*internal_args) + + return wrap_func + + +def get_arg_list(origin_func, need_parser_and_node): + """Helper function to get the argument list of Function + + Parameters + ---------- + origin_func: function + The function to get the argument list + + need_parser_and_node: bool + Whether the function need parser and node in its arguments + """ + + full_arg_spec = inspect.getfullargspec(origin_func) + + args, defaults = full_arg_spec.args, full_arg_spec.defaults + + if defaults is None: + defaults = tuple() + if need_parser_and_node: + args = args[2:] + + if full_arg_spec.varargs is not None: + raise RuntimeError( + "TVM Hybrid Script register error : variable argument is not supported now") + if full_arg_spec.varkw is not None: + raise RuntimeError( + "TVM Hybrid Script register error : variable keyword argument is not supported now") + if not len(full_arg_spec.kwonlyargs) == 0: + raise RuntimeError( + "TVM Hybrid Script register error : keyword only argument is not supported now") + + arg_list = list() + for arg in args[: len(args) - len(defaults)]: + arg_list.append((arg,)) + for default, arg in zip(defaults, args[len(args) - len(defaults):]): + arg_list.append((arg, default)) + + return arg_list + + +def register_intrin(origin_func): + """ Decorator to register function under category intrin + + Example + ------ + .. code-block:: python + + @register_intrin + def broadcast(value, lanes): + lanes = lanes.value if not isinstance(lanes, int) else lanes + return tvm.tir.Broadcast(value, lanes) + """ + func_name = origin_func.__qualname__ + Registry.intrin[func_name] = \ + func_wrapper(func_name, origin_func, get_arg_list(origin_func, False), + need_parser_and_node=False, need_body=False, concise=False) + return origin_func + + +def register_with_scope(concise=False): + """Decorator to register function under with scope handler + + Parameters + ---------- + concise: bool + whether this scope handler is allowed in concise scoping + + Example + ------ + .. code-block:: python + + @register_scope_handler(concise=True) + def attr(parser, node, attr_node, attr_key, value, body): + + return tvm.tir.AttrStmt(attr_node, attr_key, tvm.runtime.convert(value), body) + """ + + def decorate(origin_func): + """Register function under category with_scope""" + func_name = origin_func.__qualname__ + Registry.with_scope[func_name] = \ + func_wrapper(func_name, origin_func, get_arg_list(origin_func, True), + need_parser_and_node=True, need_body=True, concise=concise) + return origin_func + + return decorate + + +def register_for_scope(): + """Decorator to register function under for scope handler""" + def decorate(origin_func): + """Register function under category for_scope""" + func_name = origin_func.__qualname__ + Registry.for_scope[func_name] = \ + func_wrapper(func_name, origin_func, get_arg_list(origin_func, True), + need_parser_and_node=True, need_body=True, concise=False) + return origin_func + + return decorate + + +def register_special_stmt(origin_func): + """ Decorator to register function under category special_stmt + + Example + ------- + @register_special_stmt + def buffer_decl(parser, node, shape, dtype="float32", data=None, strides=[], elem_offset=None, + scope="global", align=-1, offset_factor=0, buffer_type="default"): + align = align.value if not isinstance(align, int) else align + offset_factor = offset_factor.value if not isinstance(offset_factor, int) else offset_factor + buffer = tvm.tir.decl_buffer(shape, dtype, parser._assign_target, data, strides, + elem_offset, scope, align, offset_factor, buffer_type) + return buffer + + """ + + func_name = origin_func.__qualname__ + Registry.special_stmt[func_name] = \ + func_wrapper(func_name, origin_func, get_arg_list(origin_func, True), + need_parser_and_node=True, need_body=False, concise=False) + return origin_func diff --git a/python/tvm/hybrid/scope_emitter.py b/python/tvm/hybrid/scope_emitter.py new file mode 100644 index 000000000000..629f44ba5473 --- /dev/null +++ b/python/tvm/hybrid/scope_emitter.py @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Hybrid Script Scope Emitter for TIR""" + +from tvm.te import schedule + + +class ScopeEmitter: + """Maintain the nodes and symbols of scopes""" + + def __init__(self, parser): + self.node_stack = [[]] # AST nodes of scopes + self.symbols = [dict()] # Symbols of scopes + self.parser = parser + + def pop_scope(self): + """Pop the inner most scope""" + self.symbols.pop() + self.node_stack.pop() + + def new_scope(self): + """ Creating a new scope """ + self.node_stack.append([]) + self.symbols.append(dict()) + + def update_symbol(self, name, symbol): + """Append a symbol into current scope""" + if isinstance(symbol, schedule.Buffer): + if name in self.symbols[0]: + self.parser.report_error("Duplicate Buffer name") + self.symbols[0][name] = symbol + else: + self.symbols[-1][name] = symbol + + def remove_symbol(self, name): + """Remove a symbol""" + for symbols in reversed(self.symbols): + if name in symbols: + symbols.pop(name) + return + raise RuntimeError("Internal error of hybrid parser: no symbol named" + name) + + def lookup_symbol(self, name): + """Look up symbol by name""" + for symbols in reversed(self.symbols): + if name in symbols: + return symbols[name] + return None diff --git a/python/tvm/hybrid/scope_handler.py b/python/tvm/hybrid/scope_handler.py new file mode 100644 index 000000000000..3b1b7a2c5987 --- /dev/null +++ b/python/tvm/hybrid/scope_handler.py @@ -0,0 +1,89 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Hybrid Script Parser Scope Handler Functions + +This module provides the functions registered into parser under with_scope or for_scope category. +Scope handler nodes are StmtNodes with body, which are used to handle such scenarios. + +.. code-block:: python + + for x in tir.name(): + with tir.name(): + tir.name() # with scope handlers + concise scoping + +""" +# pylint: disable=redefined-builtin, unused-argument, invalid-name +import tvm.tir +from .registry import register_with_scope, register_for_scope + + +# With scope handler +@register_with_scope(concise=False) +def Assert(parser, node, condition, message, body): + """ With scope handler function assert(condition, message, body) """ + + return tvm.tir.AssertStmt(condition, tvm.runtime.convert(message), body) + + +@register_with_scope(concise=False) +def let(parser, node, var, value, body): + """ With scope handler function let(var, value, body) """ + + return tvm.tir.LetStmt(var, value, body) + + +@register_with_scope(concise=True) +def realize(parser, node, buffer_bounds, body, condition=True): + """ With scope handler function realize(buffer_bounds, condition, body) """ + + buffer, bounds = buffer_bounds + return tvm.tir.BufferRealize(buffer, bounds, condition, body) + + +@register_with_scope(concise=True) +def attr(parser, node, attr_node, attr_key, value, body): + """ With scope handler function attr(attr_node, attr_key, value, body) """ + + return tvm.tir.AttrStmt(attr_node, attr_key, tvm.runtime.convert(value), body) + + +@register_with_scope(concise=True) +def allocate(parser, node, buffer_var, dtype, extents, body, condition=True): + """ With scope handler function allocate(buffer_var, dtype, extents, condition, body) """ + + return tvm.tir.Allocate(buffer_var, dtype, extents, tvm.runtime.convert(condition), body) + + +# For scope handler +@register_for_scope() +def range(parser, node, begin, end, for_type="serial"): + """ For scope handler function range(begin, end, annotation)""" + ana = tvm.arith.Analyzer() + extent = end if begin == 0 else ana.simplify(end - begin) + loop_var_name = node.target.id + loop_var = tvm.te.var(loop_var_name, dtype="int32") + + parser.scope_emitter.new_scope() + parser.scope_emitter.update_symbol(loop_var_name, loop_var) + parser.scope_emitter.node_stack[-1].extend(reversed(node.body)) + body = parser.get_body() + parser.scope_emitter.pop_scope() + + for_type_dict = {"serial": 0, "parallel": 1, "vectorized": 2, "unroll": 3} + if for_type not in for_type_dict: + parser.report_error("unknown for type " + for_type) + return tvm.tir.For(loop_var, begin, extent, for_type_dict[for_type], 0, body) diff --git a/python/tvm/hybrid/special_stmt.py b/python/tvm/hybrid/special_stmt.py new file mode 100644 index 000000000000..03b3cca20f65 --- /dev/null +++ b/python/tvm/hybrid/special_stmt.py @@ -0,0 +1,102 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Hybrid Script Parser Special Stmt Functions + +This module provides the functions registered into parser under special_stmt category. +special_stmt functions don't correspond to an IRNode in the AST directly. It is usually +used for some information that is not suitable to be printed directly. + +special_stmt can appear as 2 formats + +.. code-block:: python + + target = tir.name(): + tir.name() + +""" +# pylint: disable=unused-argument +import tvm.tir +from tvm import te +from .registry import register_special_stmt + + +@register_special_stmt +def buffer_bind(parser, node, param, shape, dtype="float32", data=None, strides=None, + elem_offset=None, scope="global", align=-1, offset_factor=0, buffer_type="default"): + """ Special function buffer_bind(var, shape, dtype, data, strides, elem_offset, scope, align, + offset_factor, buffer_type) + + Example + ------- + .. code-block:: python + + A = tir.buffer_bind(a, (128, 128), dtype="float32") + + """ + + if param not in parser.params: + parser.report_error("Can not bind non-input param to buffer") + if strides is None: + strides = [] + align = align.value if not isinstance(align, int) else align + offset_factor = offset_factor.value if not isinstance(offset_factor, int) else offset_factor + buffer = tvm.tir.decl_buffer(shape, dtype, parser._assign_target, data, strides, elem_offset, + scope, align, offset_factor, buffer_type) + parser.buffer_map[param] = buffer + return buffer + + +@register_special_stmt +def buffer_decl(parser, node, shape, dtype="float32", data=None, strides=None, elem_offset=None, + scope="global", align=-1, offset_factor=0, buffer_type="default"): + """ Special function buffer_decl(shape, dtype, data, strides, elem_offset, scope, align, + offset_factor, buffer_type) + + Example + ------- + .. code-block:: python + + A = tir.buffer_decl((128, 128), dtype="float32") + + """ + if strides is None: + strides = [] + align = align.value if not isinstance(align, int) else align + offset_factor = offset_factor.value if not isinstance(offset_factor, int) else offset_factor + buffer = tvm.tir.decl_buffer(shape, dtype, parser._assign_target, data, strides, elem_offset, + scope, align, offset_factor, buffer_type) + return buffer + + +@register_special_stmt +def var(parser, node, dtype): + """ Special function for defining a Var""" + return te.var(parser._assign_target, dtype) + + +@register_special_stmt +def func_attr(parser, node, dict_attr): + """ Special function for declaring the DictAttr of PrimFunc + + Example + ------- + .. code-block:: python + + tir.func_attr({"tir.noalias": True, "global_symbol"}) + """ + + parser.dict_attr = dict_attr diff --git a/python/tvm/hybrid/ty.py b/python/tvm/hybrid/ty.py new file mode 100644 index 000000000000..ee33805aa3b2 --- /dev/null +++ b/python/tvm/hybrid/ty.py @@ -0,0 +1,63 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Hybrid Script Parser Typing Class + +This module provides typing class for hybrid script type annotation usage, it can be viewed as +a wrapper for uniform Type system in IR +""" +# pylint: disable=invalid-name +import tvm + + +class TypeGeneric: + """Base class for all the hybrid script typing class""" + def evaluate(self): + raise TypeError("Cannot get tvm.Type from a generic type") + + +class ConcreteType(TypeGeneric): + """Hybrid script typing class for uniform Type objects""" + def __init__(self, vtype): + self.type = vtype + + def evaluate(self): + return self.type + + +class GenericPtrType(TypeGeneric): + """Hybrid script typing class generator for PtrType + + [] operator is overloaded, accepts a ConcreteType and returns a ConcreteType wrapping PtrType + """ + def __getitem__(self, vtype): + return ConcreteType(tvm.ir.PointerType(vtype.evaluate())) + + +class GenericTupleType(TypeGeneric): + """Hybrid script typing class generator for TupleType + + [] operator is overloaded, accepts a list of ConcreteType and returns a ConcreteType + wrapping TupleType + """ + def __getitem__(self, vtypes): + return ConcreteType(tvm.ir.TupleType([vtype.evaluate() for vtype in vtypes])) + + +int32 = ConcreteType(tvm.ir.PrimType("int32")) +handle = ConcreteType(tvm.ir.PrimType("handle")) +Ptr = GenericPtrType() +Tuple = GenericTupleType() diff --git a/python/tvm/hybrid/utils.py b/python/tvm/hybrid/utils.py new file mode 100644 index 000000000000..7880fd7c90cf --- /dev/null +++ b/python/tvm/hybrid/utils.py @@ -0,0 +1,96 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Helper functions in Hybrid Script Parser""" + +import inspect +from tvm import IRModule + +from . import _ffi_api +from .parser import from_source + + +def create_module(functions=None): + """Construct a module from list of functions. + + Parameters + ----------- + functions: Optional[dict]. + Map of GlobalVar or str to PrimFunc + + Returns + ------- + mod : IRModule + An IRModule containing the passed definitions + """ + + return IRModule(functions=functions) + + +def ashybrid(input_ir, show_meta=False): + """Transform a PrimFunc or IRModule to python syntax script + + Parameters + ---------- + input_ir : Union[PrimFunc, IRModule] + The PrimFunc or IRModule to be dumped + + show_meta : bool + Whether show meta + + Returns + ------- + script : str + The Python script + """ + + return _ffi_api.AsHybrid(input_ir, show_meta) + + +def script(script_in): + """Decorate a python function or class as hybrid script. + + The hybrid function or parsing support parsing to the internal TIR. + + Returns + ------- + output : Union[Function, Module] + The Function or Module in IR. + """ + + if inspect.isfunction(script_in): + return _parse(script_in) + + if inspect.isclass(script_in): + return HybridClass(script_in) + + raise TypeError("Only function and class are supported") + + +class HybridClass: + """Helper class for decorating a class""" + + def __init__(self, script_in): + self.script = script_in + + def __call__(self, *args, **kwargs): + # call the parser to transform hybrid script into TIR + return _parse(self.script) + + +def _parse(script_in): + """Helper function to parse hybrid_script into TIR""" + return from_source(inspect.getsource(script_in), inspect.getsourcelines(script_in)[1]) diff --git a/src/printer/tir_hybrid_printer.cc b/src/printer/tir_hybrid_printer.cc new file mode 100644 index 000000000000..8f6b37a2bbc4 --- /dev/null +++ b/src/printer/tir_hybrid_printer.cc @@ -0,0 +1,845 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file printer/tir_hybrid_printer.cc + * \brief Printer class to print Te IR to python syntax script + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "doc.h" +#include "meta_data.h" +#include "text_printer.h" + +namespace tvm { +namespace tir { + +class TIRHybridPrinter : public StmtFunctor, + public ExprFunctor, + public TypeFunctor { + public: + explicit TIRHybridPrinter(bool show_meta, + runtime::TypedPackedFunc annotate = nullptr) + : show_meta_(show_meta), annotate_(annotate), meta_collector_(&meta_) {} + + /*! \brief Print the node */ + TVM_DLL Doc Print(const ObjectRef& node); + + private: + /*! \brief whether show meta data */ + bool show_meta_; + /*! \brief additional comment function */ + runtime::TypedPackedFunc annotate_; + /*! \brief meta data context */ + TextMetaDataContext meta_; + /*! \brief meta collector */ + MetaCollector meta_collector_; + /*! \brief map from Function to GlobalVar */ + std::unordered_map func2var_; + /*! \brief var collector (var defined by For/Loop/Block) */ + std::unordered_set var_not_in_headers; + /*! \brief buffer collector (buffer defined in BufferMap and BufferAllocation)*/ + std::unordered_set buf_not_in_headers; + /*! \brief Map from Var to Doc */ + std::unordered_map memo_var_; + /*! \brief Map from Buffer to Doc */ + std::unordered_map memo_buf_; + /*! \brief Map from Buffer to Declaration Doc */ + std::unordered_map memo_buf_decl_; + /*! \brief Map from CommReducer to Doc */ + std::unordered_map memo_reducer_; + /*! \brief name allocation map */ + std::unordered_map name_alloc_map_; + /*! \brief number of children of current node's parent */ + int num_child_; + /*! \brief the number of current node */ + int current_num_; + + Doc VisitExpr_(const CastNode* op) override; + Doc VisitExpr_(const VarNode* op) override; + Doc VisitExpr_(const AddNode* op) override; + Doc VisitExpr_(const SubNode* op) override; + Doc VisitExpr_(const MulNode* op) override; + Doc VisitExpr_(const DivNode* op) override; + Doc VisitExpr_(const ModNode* op) override; + Doc VisitExpr_(const FloorDivNode* op) override; + Doc VisitExpr_(const FloorModNode* op) override; + Doc VisitExpr_(const MinNode* op) override; + Doc VisitExpr_(const MaxNode* op) override; + Doc VisitExpr_(const EQNode* op) override; + Doc VisitExpr_(const NENode* op) override; + Doc VisitExpr_(const LTNode* op) override; + Doc VisitExpr_(const LENode* op) override; + Doc VisitExpr_(const GTNode* op) override; + Doc VisitExpr_(const GENode* op) override; + Doc VisitExpr_(const AndNode* op) override; + Doc VisitExpr_(const OrNode* op) override; + Doc VisitExpr_(const NotNode* op) override; + Doc VisitExpr_(const SelectNode* op) override; + Doc VisitExpr_(const IntImmNode* op) override; + Doc VisitExpr_(const FloatImmNode* op) override; + Doc VisitExpr_(const StringImmNode* op) override; + Doc VisitExpr_(const BufferLoadNode* op) override; + Doc VisitExpr_(const LoadNode* op) override; + Doc VisitExpr_(const RampNode* op) override; + Doc VisitExpr_(const BroadcastNode* op) override; + Doc VisitExpr_(const LetNode* op) override; + Doc VisitExpr_(const CallNode* op) override; + Doc VisitExpr_(const ShuffleNode* op) override; + Doc VisitExpr_(const ReduceNode* op) override; + Doc VisitExprDefault_(const Object* op) override; + + Doc VisitStmt_(const LetStmtNode* op) override; + Doc VisitStmt_(const AttrStmtNode* op) override; + Doc VisitStmt_(const AssertStmtNode* op) override; + Doc VisitStmt_(const StoreNode* op) override; + Doc VisitStmt_(const BufferStoreNode* op) override; + Doc VisitStmt_(const BufferRealizeNode* op) override; + Doc VisitStmt_(const AllocateNode* op) override; + Doc VisitStmt_(const IfThenElseNode* op) override; + Doc VisitStmt_(const SeqStmtNode* op) override; + Doc VisitStmt_(const ForNode* op) override; + Doc VisitStmt_(const PrefetchNode* op) override; + Doc VisitStmt_(const EvaluateNode* op) override; + Doc VisitStmtDefault_(const Object* op) override; + + Doc VisitType_(const PrimTypeNode* node) override; + Doc VisitType_(const PointerTypeNode* node) override; + Doc VisitType_(const TupleTypeNode* node) override; + + Doc PrintBody(const Stmt& body); + Doc PrintIRModule(const IRModule& module); + Doc PrintPrimFunc(const PrimFunc& primFunc); + Doc PrintIterVar(const IterVarNode* op); + Doc PrintRange(const RangeNode* op); + Doc PrintArray(const ArrayNode* op); + Doc PrintBuffer(const BufferNode* op); + Doc AllocBufferDeclaration(const Buffer& buf); + static Doc PrintString(const StringObj* op) { return Doc::StrLiteral(op->data); } + + Doc GetUniqueName(std::string prefix); + Doc AllocVar(const Var& var); + Doc AllocBuf(const Buffer& buffer); + + /*! + * \brief Print additional info about expr in comment. + * \param expr The expression. + */ + Doc PrintOptionalInfo(const Stmt& stmt) { + Doc doc; + // default annotations + if (annotate_ != nullptr) { + std::string annotated_stmt = annotate_(stmt); + if (!annotated_stmt.empty()) { + doc << "# " << annotated_stmt << Doc::NewLine(); + } + } + return doc; + } + + /*! + * \brief special method to render vectors of docs with a separator + * \param vec vector of docs + * \param sep separator + */ + static Doc PrintSep(const std::vector& vec, const Doc& sep) { + Doc seq; + if (vec.size() != 0) { + seq = vec[0]; + for (size_t i = 1; i < vec.size(); i++) { + seq << sep << vec[i]; + } + } + return seq; + } + + /*! + * \brief dump meta info + * \return Doc with meta info + */ + Doc DumpMeta() { + if (show_meta_) { + return Doc::Text("__tvm_meta__ = ") + << (meta_.empty() ? Doc::Text("None") : meta_.GetMetaSection()); + } else { + return Doc::Text(""); + } + } + + /*! + * \brief special method to print out data type + * \param dtype The data type + */ + static Doc PrintDType(DataType dtype) { + return Doc::StrLiteral(runtime::DLDataType2String(dtype)); + } + + /*! + * \brief special method to print out const scalar + * \param dtype The data type + * \param data The pointer to hold the data. + */ + template + static Doc PrintConstScalar(DataType dtype, const T* data) { + Doc doc; + std::ostringstream os; + os << data[0]; + if (dtype == DataType::Int(32)) { + doc << Doc::Text(os.str()); + } else if (dtype == DataType::Bool()) { + doc << Doc::Text(data[0] ? "True" : "False"); + } else { + doc << "tir." << runtime::DLDataType2String(dtype) << "(" << Doc::Text(os.str()) << ")"; + } + return doc; + } +}; + +Doc TIRHybridPrinter::GetUniqueName(std::string prefix) { + std::replace(prefix.begin(), prefix.end(), '.', '_'); + std::string unique_prefix = prefix; + auto it = name_alloc_map_.find(prefix); + if (it != name_alloc_map_.end()) { + while (name_alloc_map_.count(unique_prefix = prefix + "_" + std::to_string(++it->second)) > 0) { + } + } + name_alloc_map_[unique_prefix] = 0; + return Doc::Text(unique_prefix); +} + +Doc TIRHybridPrinter::AllocVar(const Var& var) { + const auto& it = memo_var_.find(var); + if (it != memo_var_.end()) { + return it->second; + } + std::string name = var->name_hint.operator std::string(); + if (name.length() == 0 || !std::isalpha(name[0])) { + name = "v" + name; + } + Doc val = GetUniqueName(name); + memo_var_[var] = val; + return val; +} + +Doc TIRHybridPrinter::AllocBufferDeclaration(const Buffer& buf) { + Doc doc = Print(buf->shape); + if (!runtime::TypeEqual(buf->dtype, DataType::Float(32))) { + doc << ", dtype=" << PrintDType(buf->dtype); + } + if (memo_var_.find(buf->data) != memo_var_.end()) { + doc << ", data=" << Print(buf->data); + } else { + // implicitly define data + memo_var_[buf->data] = Doc::Text(memo_buf_[buf].str() + ".data"); + var_not_in_headers.insert(buf->data.get()); + } + if (!buf->strides.empty()) { + doc << ", strides=" << Print(buf->strides); + } + if (buf->offset_factor != 0 && buf->elem_offset->IsInstance()) { + Var elem_offset = Downcast(buf->elem_offset); + if (memo_var_.find(elem_offset) != memo_var_.end()) { + doc << ", elem_offset=" << Print(buf->elem_offset); + } else { + // implicitly define elem_offset + memo_var_[elem_offset] = Doc::Text(memo_buf_[buf].str() + ".elem_offset"); + var_not_in_headers.insert(elem_offset.get()); + } + } else { + doc << ", elem_offset=" << Print(buf->elem_offset); + } + if (buf->scope != "global") { + doc << ", scope=" << Doc::StrLiteral(buf->scope); + } + if (buf->data_alignment != -1) { + doc << ", align=" << buf->data_alignment; + } + if (buf->offset_factor != 0) { + doc << ", offset_factor=" << buf->offset_factor; + } + if (buf->buffer_type != 1) { + doc << ", type=" << Doc::StrLiteral("auto"); + } + return doc; +} + +Doc TIRHybridPrinter::AllocBuf(const Buffer& buffer) { + const auto& it = memo_buf_.find(buffer); + if (it != memo_buf_.end()) { + return it->second; + } + std::string name = buffer->name; + if (name.length() == 0 || !std::isalpha(name[0])) { + name = "buf_" + name; + } + Doc val = GetUniqueName(name); + memo_buf_[buffer] = val; + memo_buf_decl_[buffer] = AllocBufferDeclaration(buffer); + return val; +} + +Doc TIRHybridPrinter::Print(const ObjectRef& node) { + if (!node.defined()) return Doc::Text("None"); + if (node->IsInstance()) { + return PrintOptionalInfo(Downcast(node)) << VisitStmt(Downcast(node)); + } else if (node->IsInstance()) { + return VisitExpr(Downcast(node)); + } else if (node->IsInstance()) { + return VisitType(Downcast(node)); + } else if (node->IsInstance()) { + return PrintPrimFunc(Downcast(node)); + } else if (node->IsInstance()) { + return PrintIRModule(Downcast(node)); + } else if (node->IsInstance()) { + return PrintArray(node.as()); + } else if (node->IsInstance()) { + return PrintBuffer(node.as()); + } else if (node->IsInstance()) { + return PrintString(node.as()); + } else if (node->IsInstance()) { + return PrintIterVar(node.as()); + } else if (node->IsInstance()) { + return PrintRange(node.as()); + } else { + meta_collector_.Collect(node); + return this->meta_.GetMetaNode(node); + } +} + +Doc TIRHybridPrinter::VisitExprDefault_(const Object* op) { + meta_collector_.Collect(GetRef(op)); + return this->meta_.GetMetaNode(GetRef(op)); +} + +Doc TIRHybridPrinter::VisitStmtDefault_(const Object* op) { + meta_collector_.Collect(GetRef(op)); + return this->meta_.GetMetaNode(GetRef(op)); +} + +Doc TIRHybridPrinter::VisitExpr_(const IntImmNode* op) { + return PrintConstScalar(op->dtype, &(op->value)); +} + +Doc TIRHybridPrinter::VisitExpr_(const FloatImmNode* op) { + return PrintConstScalar(op->dtype, &(op->value)); +} + +Doc TIRHybridPrinter::VisitExpr_(const StringImmNode* op) { return Doc::StrLiteral(op->value); } + +Doc TIRHybridPrinter::VisitExpr_(const CastNode* op) { + Doc doc; + doc << "tir.cast(" << PrintDType(op->dtype) << ", " << Print(op->value) << ")"; + return doc; +} + +Doc TIRHybridPrinter::VisitExpr_(const VarNode* op) { + const Var& var = GetRef(op); + return meta_.InMeta(var) ? meta_.GetMetaNode(var) : AllocVar(GetRef(op)); +} + +#define TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(OpName, OpString) \ + Doc TIRHybridPrinter::VisitExpr_(const OpName* op) { \ + Doc doc; \ + doc << '(' << Print(op->a) << OpString << Print(op->b) << ")"; \ + return doc; \ + } + +TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(AddNode, " + ") +TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(SubNode, " - ") +TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(MulNode, "*") +TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(DivNode, " / ") +TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(ModNode, " % ") +TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(EQNode, " == ") +TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(NENode, " != ") +TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(LTNode, " < ") +TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(LENode, " <= ") +TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(GTNode, " > ") +TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(GENode, " >= ") +TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(AndNode, " and ") +TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(OrNode, " or ") + +Doc TIRHybridPrinter::VisitExpr_(const FloorDivNode* op) { + Doc doc; + doc << "tir.floordiv(" << Print(op->a) << ", " << Print(op->b) << ")"; + return doc; +} + +Doc TIRHybridPrinter::VisitExpr_(const FloorModNode* op) { + Doc doc; + doc << "tir.floormod(" << Print(op->a) << ", " << Print(op->b) << ")"; + return doc; +} + +Doc TIRHybridPrinter::VisitExpr_(const MinNode* op) { + Doc doc; + doc << "tir.min(" << Print(op->a) << ", " << Print(op->b) << ")"; + return doc; +} + +Doc TIRHybridPrinter::VisitExpr_(const MaxNode* op) { + Doc doc; + doc << "tir.max(" << Print(op->a) << ", " << Print(op->b) << ")"; + return doc; +} + +Doc TIRHybridPrinter::VisitExpr_(const NotNode* op) { + Doc doc; + doc << "not (" << Print(op->a) << ")"; + return doc; +} + +Doc TIRHybridPrinter::VisitExpr_(const SelectNode* op) { + Doc doc; + doc << "tir.select(" << Print(op->condition) << ", " << Print(op->true_value) << ", " + << Print(op->false_value) << ")"; + return doc; +} + +Doc TIRHybridPrinter::VisitExpr_(const BufferLoadNode* op) { + Doc doc; + doc << Print(op->buffer) << Print(op->indices); + return doc; +} + +Doc TIRHybridPrinter::VisitExpr_(const LoadNode* op) { + Doc doc; + if (op->dtype == DataType::Float(32) && is_one(op->predicate) && + op->buffer_var->dtype == DataType::Float(32)) { + doc << Print(op->buffer_var) << "[" << Print(op->index) << "]"; + } else { + doc << "tir.load(" << PrintDType(op->dtype) << ", " << Print(op->buffer_var) << ", " + << Print(op->index); + if (!is_one(op->predicate) || op->dtype.lanes() != 1) { + doc << ", " << Print(op->predicate); + } + doc << ")"; + } + return doc; +} + +Doc TIRHybridPrinter::VisitExpr_(const RampNode* op) { + Doc doc; + doc << "tir.ramp(" << Print(op->base) << ", " << Print(op->stride) << ", " << op->lanes << ")"; + return doc; +} + +Doc TIRHybridPrinter::VisitExpr_(const BroadcastNode* op) { + Doc doc; + doc << "tir.broadcast(" << Print(op->value) << ", " << op->lanes << ")"; + return doc; +} + +Doc TIRHybridPrinter::VisitExpr_(const LetNode* op) { + Doc doc; + doc << "tir.let(" << Print(op->var) << ", " << Print(op->value) << ", " << Print(op->body) << ")"; + return doc; +} + +Doc TIRHybridPrinter::VisitExpr_(const CallNode* op) { + Doc doc; + if (auto* ptr_op = op->op.as()) { + doc << Doc::Text(ptr_op->name) << "("; + } else { + auto* op_gvar = op->op.as(); + CHECK(op_gvar != nullptr); + doc << Doc::Text(op_gvar->name_hint) << "("; + } + std::vector args; + for (const auto& arg : op->args) { + args.push_back(Print(arg)); + } + args.push_back(Doc::Text("dtype=") << PrintDType(op->dtype)); + doc << PrintSep(args, Doc::Text(", ")) << ")"; + return doc; +} + +Doc TIRHybridPrinter::VisitExpr_(const ShuffleNode* op) { + Doc doc; + doc << "tir.shuffle(" << Print(op->vectors) << ", " << Print(op->indices) << ")"; + return doc; +} + +Doc TIRHybridPrinter::VisitExpr_(const ReduceNode* op) { + Doc doc; + doc << "tir.reduce(" << Print(op->combiner) << ", " << Print(op->source) << ", " + << Print(op->axis) << ", " << op->value_index << ")"; + return doc; +} + +Doc TIRHybridPrinter::VisitStmt_(const LetStmtNode* op) { + Doc doc; + if (current_num_ != num_child_ - 1) { + doc << "with tir.let(" << Print(op->var) << ", " << Print(op->value) << "):"; + doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body)); + } else { + if (memo_var_.find(op->var) == memo_var_.end()) var_not_in_headers.insert(op->var.get()); + doc << Print(op->var) << ": " << Print(GetType(op->var)) << " = " << Print(op->value) + << Doc::NewLine() << PrintBody(op->body); + } + return doc; +} + +Doc TIRHybridPrinter::VisitStmt_(const AttrStmtNode* op) { + Doc doc; + if (current_num_ != num_child_ - 1) { + doc << "with tir.attr(" << Print(op->node) << ", " << Doc::StrLiteral(op->attr_key) << ", " + << Print(op->value) << "):"; + doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body)); + } else { + doc << "tir.attr(" << Print(op->node) << ", " << Doc::StrLiteral(op->attr_key) << ", " + << Print(op->value) << ")"; + doc << Doc::NewLine() << PrintBody(op->body); + } + return doc; +} + +Doc TIRHybridPrinter::VisitStmt_(const AssertStmtNode* op) { + Doc doc; + if (current_num_ != num_child_ - 1) { + doc << "with tir.Assert(" << Print(op->condition) << ", " << Print(op->message) << "):"; + doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body)); + } else { + doc << "assert " << Print(op->condition) << ", " << Print(op->message); + doc << Doc::NewLine() << PrintBody(op->body); + } + return doc; +} + +Doc TIRHybridPrinter::VisitStmt_(const StoreNode* op) { + Doc doc; + if (!is_one(op->predicate) || op->value.dtype().lanes() != 1) { + doc << "tir.store(" << Print(op->buffer_var) << ", " << Print(op->index) << ", " + << Print(op->value) << ", " << Print(op->predicate) << ")"; + } else { + doc << Print(op->buffer_var) << "[" << Print(op->index) << "] = " << Print(op->value); + } + return doc; +} + +Doc TIRHybridPrinter::VisitStmt_(const BufferRealizeNode* op) { + Doc doc; + if (current_num_ != num_child_ - 1) { + doc << "with tir.realize(" << Print(op->buffer) << Print(op->bounds); + if (!is_one(op->condition)) { + doc << ", " << Print(op->condition); + } + doc << "):" << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body)); + } else { + doc << "tir.realize(" << Print(op->buffer) << Print(op->bounds); + if (!is_one(op->condition)) { + doc << ", " << Print(op->condition); + } + doc << ")" << Doc::NewLine() << PrintBody(op->body); + } + return doc; +} + +Doc TIRHybridPrinter::VisitStmt_(const AllocateNode* op) { + Doc doc; + if (current_num_ != num_child_ - 1) { + doc << "with tir.allocate(" << Print(op->buffer_var) << ", " << PrintDType(op->dtype) << ", " + << Print(op->extents) << "):"; + doc << Doc::Indent(4, PrintBody(op->body)); + } else { + doc << "tir.allocate(" << Print(op->buffer_var) << ", " << PrintDType(op->dtype) << ", " + << Print(op->extents) << ")"; + doc << Doc::NewLine() << PrintBody(op->body); + } + return doc; +} + +Doc TIRHybridPrinter::VisitStmt_(const IfThenElseNode* op) { + Doc doc; + doc << "if " << Print(op->condition) << ":"; + doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->then_case)); + if (!is_one(op->condition) && op->else_case.defined()) { + doc << "else:" << Doc::Indent(4, Doc::NewLine() << PrintBody(op->else_case)); + } + return doc; +} + +Doc TIRHybridPrinter::VisitStmt_(const SeqStmtNode* op) { + std::vector stmts; + for (Stmt stmt : op->seq) { + stmts.push_back(Print(stmt)); + } + return PrintSep(stmts, Doc::NewLine()); +} + +Doc TIRHybridPrinter::VisitStmt_(const EvaluateNode* op) { + Doc doc; + doc << "tir.evaluate(" << Print(op->value) << ")"; + return doc; +} + +inline const char* ForType2String(ForType t) { + switch (t) { + case ForType::Serial: + return "serial"; + case ForType::Parallel: + return "parallel"; + case ForType::Vectorized: + return "vectorized"; + case ForType::Unrolled: + return "unroll"; + } + LOG(FATAL) << "Unknown ForType"; + return "Unknown"; +} + +Doc TIRHybridPrinter::VisitStmt_(const ForNode* op) { + Doc doc; + var_not_in_headers.insert(op->loop_var.get()); + doc << "for " << Print(op->loop_var) << " in tir.range(" << Print(op->min) << ", " + << Print(op->min + op->extent); + if (op->for_type != ForType::Serial) { + doc << ", " << Doc::StrLiteral(ForType2String(op->for_type)); + } + doc << "):" << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body)); + return doc; +} + +Doc TIRHybridPrinter::VisitStmt_(const PrefetchNode* op) { + Doc doc; + doc << "tir.prefetch(" << Print(op->buffer) << ", " << Print(op->bounds) << ")"; + return doc; +} + +Doc TIRHybridPrinter::VisitType_(const PrimTypeNode* node) { + Doc doc; + doc << "ty." << runtime::DLDataType2String(node->dtype); + return doc; +} + +Doc TIRHybridPrinter::VisitType_(const PointerTypeNode* node) { + Doc doc; + doc << "ty.Ptr[" << Print(node->element_type) << "]"; + return doc; +} + +Doc TIRHybridPrinter::VisitType_(const TupleTypeNode* node) { + if (node->fields.empty()) { + return Doc::Text("None"); + } else { + std::vector fields; + for (Type field : node->fields) { + fields.push_back(Print(field)); + } + return Doc::Text("ty.Tuple[") << Doc::Concat(fields) << "]"; + } +} + +Doc TIRHybridPrinter::VisitStmt_(const BufferStoreNode* op) { + Doc doc; + doc << Print(op->buffer) << Print(op->indices) << " = " << Print(op->value); + return doc; +} + +Doc TIRHybridPrinter::PrintBody(const Stmt& body) { + int memo_num_child, memo_current_num; + std::swap(memo_num_child, num_child_); + std::swap(memo_current_num, current_num_); + + Doc doc; + if (body->IsInstance()) { + const auto& op = Downcast(body); + num_child_ = op->seq.size(); + current_num_ = 0; + std::vector stmts; + for (Stmt stmt : op->seq) { + stmts.push_back(Print(stmt)); + current_num_++; + } + doc = PrintSep(stmts, Doc::NewLine()); + } else { + num_child_ = 1; + current_num_ = 0; + doc = Print(body); + } + + std::swap(memo_num_child, num_child_); + std::swap(memo_current_num, current_num_); + return doc; +} + +Doc TIRHybridPrinter::PrintIRModule(const IRModule& module) { + auto* op = module.operator->(); + Doc doc; + doc << "class Module:"; + for (const auto& x : op->functions) { + func2var_[x.second.operator->()] = x.first; + } + Doc body = Doc::NewLine(); + std::vector functions; + for (auto it = op->functions.begin(); it != op->functions.end(); ++it) { + if ((*it).second.as()) { + functions.push_back(Print((*it).second)); + } + } + body << TIRHybridPrinter::PrintSep(functions, Doc::NewLine() << Doc::NewLine()); + body << Doc::NewLine() << DumpMeta(); + doc << Doc::Indent(4, body); + return doc; +} + +Doc TIRHybridPrinter::PrintPrimFunc(const PrimFunc& primFunc) { + auto* op = primFunc.operator->(); + // clear renaming map + memo_var_.clear(); + memo_buf_.clear(); + memo_buf_decl_.clear(); + memo_reducer_.clear(); + var_not_in_headers.clear(); + buf_not_in_headers.clear(); + // print signature + Doc doc; + doc << "def " << (func2var_.find(op) == func2var_.end() ? "func" : func2var_[op]->name_hint) + << "("; + std::vector params; + for (const auto& param : op->params) { + var_not_in_headers.insert(param.get()); + params.push_back(Print(param) << ": " << Print(GetType(param))); + } + doc << PrintSep(params, Doc::Text(", ")) << ") -> " << Print(primFunc->ret_type) << ":"; + + Doc body = Doc::NewLine(); + // print buffer_bind + for (const auto& it : op->buffer_map) { + buf_not_in_headers.insert(it.second.get()); + body << Print(it.second) << " = tir.buffer_bind("; + body << Print(it.first) << ", " << memo_buf_decl_[it.second]; + body << ")" << Doc::NewLine(); + } + // print comm_reducer + for (const auto& it : memo_reducer_) { + body << it.second << " = tir.comm_reducer("; + var_not_in_headers.insert(it.first->lhs[0].get()); + var_not_in_headers.insert(it.first->rhs[0].get()); + body << "lambda " << Print(it.first->lhs[0]) << ", " << Print(it.first->rhs[0]) << ": " + << Print(it.first->result[0]) << ", " << Print(it.first->identity_element[0]); + body << ")" << Doc::NewLine(); + } + // print body + body << "# body" << Doc::NewLine() << PrintBody(op->body); + // print func attrs + Doc header_attr; + if (primFunc->attrs.defined()) { + header_attr << Doc::NewLine() << "# function attr dict" << Doc::NewLine() << "tir.func_attr({"; + std::vector attrs; + for (const auto& it : op->attrs->dict) { + attrs.push_back(Doc::StrLiteral(it.first) << ": " << Print(it.second)); + } + header_attr << PrintSep(attrs, Doc::Text(", ")) << "})"; + } + // print buffer declarations(buffers not defined by buffer_bind or buffer_allocate) + Doc header_buf; + std::vector bufs; + for (const auto& it : memo_buf_) { + if (buf_not_in_headers.find(it.first.get()) == buf_not_in_headers.end()) { + bufs.push_back(it.first.get()); + } + } + if (!bufs.empty()) { + header_buf << Doc::NewLine() << "# buffer definition"; + std::sort(bufs.begin(), bufs.end(), [&](const BufferNode* a, const BufferNode* b) { + return memo_buf_[GetRef(a)].str() < memo_buf_[GetRef(b)].str(); + }); + for (const auto& buf : bufs) { + header_buf << Doc::NewLine() << Print(GetRef(buf)) << " = tir.buffer_decl("; + header_buf << memo_buf_decl_[GetRef(buf)] << ")"; + } + } + // print var declaration + Doc header_var; + std::vector vars; + for (const auto& it : memo_var_) { + if (var_not_in_headers.find(it.first.get()) == var_not_in_headers.end()) { + vars.push_back(it.first.get()); + } + } + if (!vars.empty()) { + header_var << Doc::NewLine() << "# var definition"; + std::sort(vars.begin(), vars.end(), [&](const VarNode* a, const VarNode* b) { + return memo_var_[GetRef(a)].str() < memo_var_[GetRef(b)].str(); + }); + for (const auto& var : vars) { + header_var << Doc::NewLine() << Print(GetRef(var)) << " = tir.var("; + header_var << PrintDType(var->dtype) << ")"; + } + } + doc << Doc::Indent(4, header_attr << header_var << header_buf << body); + return doc; +} + +Doc TIRHybridPrinter::PrintArray(const ArrayNode* op) { + Doc doc; + doc << '['; + for (size_t i = 0; i < op->size(); ++i) { + if (i != 0) { + doc << ", "; + } + doc << Print(op->at(i)); + } + doc << ']'; + return doc; +} + +Doc TIRHybridPrinter::PrintIterVar(const IterVarNode* op) { + Doc doc; + doc << "tir.iter_var(" << Print(op->var); + if (op->dom.defined()) { + doc << ", [" << Print(op->dom) << "], "; + } else { + doc << ", None, "; + } + doc << Doc::StrLiteral(IterVarType2String(op->iter_type)) << ", "; + doc << Doc::StrLiteral(op->thread_tag) << ")"; + return doc; +} + +Doc TIRHybridPrinter::PrintRange(const RangeNode* op) { + return Print(op->min) << ":" << Print(op->min + op->extent); +} + +Doc TIRHybridPrinter::PrintBuffer(const BufferNode* op) { + const Buffer& buffer = GetRef(op); + return meta_.InMeta(buffer) ? meta_.GetMetaNode(buffer) : AllocBuf(buffer); +} + +TVM_REGISTER_GLOBAL("tir.hybrid.AsHybrid") + .set_body_typed([](const ObjectRef& functions, + bool show_meta) { + CHECK(functions.as() != nullptr || functions.as() != nullptr); + return "@tvm.hybrid.script\n" + TIRHybridPrinter(show_meta).Print(functions).str() + "\n"; + }); + +} // namespace tir +} // namespace tvm diff --git a/tests/python/unittest/test_hybrid_error_report.py b/tests/python/unittest/test_hybrid_error_report.py new file mode 100644 index 000000000000..dd5d70840943 --- /dev/null +++ b/tests/python/unittest/test_hybrid_error_report.py @@ -0,0 +1,105 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest + +import tvm +from tvm import tir +from tvm.hybrid import ty +from tvm.hybrid.parser import HybridParserError + + +@tvm.hybrid.script +class Module1: + def buffer_bind_missing_args(a: ty.handle) -> None: + A = tir.buffer_bind((16, 16), "float32") + + +@tvm.hybrid.script +class Module2: + def range_missing_args(a: ty.handle) -> None: + A = tir.buffer_bind(a, (16, 16), "float32") + + tir.attr(A, "realize_scope", "") + tir.realize(A[0:16, 0:16]) + for i in tir.range(16): + for j in tir.range(0, 16): + A[i, j] = 0.0 + + +@tvm.hybrid.script +class Module3: + def undefined_buffer(a: ty.handle) -> None: + A = tir.buffer_bind(a, (16, 16), "float32") + + tir.attr(A, "realize_scope", "") + tir.realize(C[0:16, 0:16]) + for i in tir.range(16): + for j in tir.range(0, 16): + A[i, j] = 0.0 + + +@tvm.hybrid.script +class Module4: + def unsupported_stmt(a: ty.int32) -> None: + if a > 0: + print("I love tvm") + + +@tvm.hybrid.script +class Module5: + def unsupported_function_call(a: ty.handle) -> None: + A = tir.buffer_bind(a, (16, 16), "float32") + + tir.attr(A, "realize_scope", "") + tir.realize(A[0:16, 0:16]) + for i in tir.const_range(16): + for j in tir.range(0, 16): + A[i, j] = 0.0 + + +@tvm.hybrid.script +class Module6: + def missing_type_annotation(a) -> None: + pass + + +@tvm.hybrid.script +class Module7: + def invalid_concise_scoping() -> None: + tir.Assert(1.0 > 0.0, "aaaa") + tir.evaluate(0.0) + + +def wrap_error(module, lineno): + with pytest.raises(HybridParserError) as error: + mod = module() + assert error is not None + e = error.value + print(e) + msg = str(e).split('\n')[-1].split(':', maxsplit=1)[0].strip().split(' ')[-1].strip() + assert int(msg) == lineno + + +if __name__ == '__main__': + wrap_error(Module1, 29) + wrap_error(Module2, 39) + wrap_error(Module3, 50) + wrap_error(Module4, 60) + wrap_error(Module5, 70) + wrap_error(Module6, 77) + wrap_error(Module7, 84) \ No newline at end of file diff --git a/tests/python/unittest/test_hybrid_roundtrip.py b/tests/python/unittest/test_hybrid_roundtrip.py new file mode 100644 index 000000000000..7b706bdbd92a --- /dev/null +++ b/tests/python/unittest/test_hybrid_roundtrip.py @@ -0,0 +1,536 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm import tir +from tvm.hybrid import ty + + +@tvm.hybrid.script +class Module1: + def mmult(A: ty.handle, B: ty.handle, C: ty.handle) -> None: + # function attr dict + tir.func_attr({"global_symbol": "mmult", "tir.noalias": True}) + # buffer definition + C_global = tir.buffer_decl([1024, 1024], elem_offset=0, align=128, offset_factor=1) + packedB = tir.buffer_decl([32, 1024, 32], elem_offset=0, align=128, offset_factor=1) + A_1 = tir.buffer_bind(A, [1024, 1024], elem_offset=0, align=128, offset_factor=1) + B_1 = tir.buffer_bind(B, [1024, 1024], elem_offset=0, align=128, offset_factor=1) + C_1 = tir.buffer_bind(C, [1024, 1024], elem_offset=0, align=128, offset_factor=1) + # body + tir.attr(packedB, "realize_scope", "") + tir.realize(packedB[0:32, 0:1024, 0:32]) + for x in tir.range(0, 32, "parallel"): + for y in tir.range(0, 1024): + for z in tir.range(0, 32, "vectorized"): + packedB[x, y, z] = B_1[y, ((x*32) + z)] + tir.attr(C_1, "realize_scope", "") + tir.realize(C_1[0:1024, 0:1024]) + for x_outer in tir.range(0, 32, "parallel"): + for y_outer in tir.range(0, 32): + tir.attr(C_global, "realize_scope", "global") + tir.realize(C_global[(x_outer*32):((x_outer*32) + 32), (y_outer*32):((y_outer*32) + 32)]) + for x_c_init in tir.range(0, 32): + for y_c_init in tir.range(0, 32, "vectorized"): + C_global[(x_c_init + (x_outer*32)), (y_c_init + (y_outer*32))] = tir.float32(0) + for k_outer in tir.range(0, 256): + for x_c in tir.range(0, 32): + for k_inner in tir.range(0, 4, "unroll"): + for y_c in tir.range(0, 32, "vectorized"): + C_global[(x_c + (x_outer*32)), (y_c + (y_outer*32))] = (C_global[(x_c + (x_outer*32)), (y_c + (y_outer*32))] + (A_1[(x_c + (x_outer*32)), (k_inner + (k_outer*4))]*packedB[tir.floordiv((y_c + (y_outer*32)), 32), (k_inner + (k_outer*4)), tir.floormod((y_c + (y_outer*32)), 32)])) + for x_inner in tir.range(0, 32): + for y_inner in tir.range(0, 32): + C_1[(x_inner + (x_outer*32)), (y_inner + (y_outer*32))] = C_global[(x_inner + (x_outer*32)), (y_inner + (y_outer*32))] + + +def test_opt_gemm_normalize(): + mod = Module1() + rt_mod = tvm.hybrid.from_source(tvm.hybrid.ashybrid(mod, True)) + tvm.ir.assert_structural_equal(mod, rt_mod, True) + + +@tvm.hybrid.script +class Module2: + def mmult(A: ty.handle, B: ty.handle, C: ty.handle) -> None: + # function attr dict + tir.func_attr({"global_symbol": "mmult", "tir.noalias": True}) + # var definition + C_global = tir.var("handle") + packedB = tir.var("handle") + A_1 = tir.buffer_bind(A, [1024, 1024], elem_offset=0, align=128, offset_factor=1) + B_1 = tir.buffer_bind(B, [1024, 1024], elem_offset=0, align=128, offset_factor=1) + C_1 = tir.buffer_bind(C, [1024, 1024], elem_offset=0, align=128, offset_factor=1) + # body + tir.attr(packedB, "storage_scope", "global") + tir.allocate(packedB, "float32x32", [32768]) + tir.attr(C_global, "storage_scope", "global") + tir.allocate(C_global, "float32", [1024]) + for x in tir.range(0, 32, "parallel"): + for y in tir.range(0, 1024): + tir.store(packedB, tir.ramp(((x*32768) + (y*32)), 1, 32), tir.load("float32x32", B_1.data, tir.ramp(((y*1024) + (x*32)), 1, 32), tir.broadcast(True, 32)), tir.broadcast(True, 32)) + for x_outer in tir.range(0, 32): + for y_outer in tir.range(0, 32): + for x_c_init in tir.range(0, 32): + tir.store(C_global, tir.ramp((x_c_init*32), 1, 32), tir.broadcast(tir.float32(0), 32), tir.broadcast(True, 32)) + for k_outer in tir.range(0, 256): + for x_c in tir.range(0, 32): + tir.store(C_global, tir.ramp((x_c*32), 1, 32), (tir.load("float32x32", C_global, tir.ramp((x_c*32), 1, 32), tir.broadcast(True, 32)) + (tir.broadcast(tir.load("float32", A_1.data, (((x_outer*32768) + (x_c*1024)) + (k_outer*4))), 32)*tir.load("float32x32", packedB, tir.ramp(((y_outer*32768) + (k_outer*128)), 1, 32), tir.broadcast(True, 32)))), tir.broadcast(True, 32)) + tir.store(C_global, tir.ramp((x_c*32), 1, 32), (tir.load("float32x32", C_global, tir.ramp((x_c*32), 1, 32), tir.broadcast(True, 32)) + (tir.broadcast(tir.load("float32", A_1.data, ((((x_outer*32768) + (x_c*1024)) + (k_outer*4)) + 1)), 32)*tir.load("float32x32", packedB, tir.ramp((((y_outer*32768) + (k_outer*128)) + 32), 1, 32), tir.broadcast(True, 32)))), tir.broadcast(True, 32)) + tir.store(C_global, tir.ramp((x_c*32), 1, 32), (tir.load("float32x32", C_global, tir.ramp((x_c*32), 1, 32), tir.broadcast(True, 32)) + (tir.broadcast(tir.load("float32", A_1.data, ((((x_outer*32768) + (x_c*1024)) + (k_outer*4)) + 2)), 32)*tir.load("float32x32", packedB, tir.ramp((((y_outer*32768) + (k_outer*128)) + 64), 1, 32), tir.broadcast(True, 32)))), tir.broadcast(True, 32)) + tir.store(C_global, tir.ramp((x_c*32), 1, 32), (tir.load("float32x32", C_global, tir.ramp((x_c*32), 1, 32), tir.broadcast(True, 32)) + (tir.broadcast(tir.load("float32", A_1.data, ((((x_outer*32768) + (x_c*1024)) + (k_outer*4)) + 3)), 32)*tir.load("float32x32", packedB, tir.ramp((((y_outer*32768) + (k_outer*128)) + 96), 1, 32), tir.broadcast(True, 32)))), tir.broadcast(True, 32)) + for x_inner in tir.range(0, 32): + for y_inner in tir.range(0, 32): + C_1.data[((((x_outer*32768) + (x_inner*1024)) + (y_outer*32)) + y_inner)] = tir.load("float32", C_global, ((x_inner*32) + y_inner)) + + +def test_opt_gemm_lower(): + mod = Module2() + rt_mod = tvm.hybrid.from_source(tvm.hybrid.ashybrid(mod, True)) + tvm.ir.assert_structural_equal(mod, rt_mod, True) + + +@tvm.hybrid.script +class Module3: + def mmult(args: ty.handle, arg_type_ids: ty.handle, num_args: ty.int32, out_ret_value: ty.handle, out_ret_tcode: ty.handle) -> ty.int32: + # function attr dict + tir.func_attr({"tir.noalias": True, "global_symbol": "mmult", "tir.is_entry_func": True, "calling_conv": 1}) + # var definition + C_global = tir.var("handle") + packedB = tir.var("handle") + # body + assert (num_args == 3), "mmult: num_args should be 3" + arg0: ty.handle = tir.tvm_struct_get(args, 0, 12, dtype="handle") + arg0_code: ty.int32 = tir.load("int32", arg_type_ids, 0) + arg1: ty.handle = tir.tvm_struct_get(args, 1, 12, dtype="handle") + arg1_code: ty.int32 = tir.load("int32", arg_type_ids, 1) + arg2: ty.handle = tir.tvm_struct_get(args, 2, 12, dtype="handle") + arg2_code: ty.int32 = tir.load("int32", arg_type_ids, 2) + A: ty.handle = tir.tvm_struct_get(arg0, 0, 1, dtype="handle") + tir.attr(A, "storage_alignment", 128) + arg0_shape: ty.handle = tir.tvm_struct_get(arg0, 0, 2, dtype="handle") + arg0_strides: ty.handle = tir.tvm_struct_get(arg0, 0, 3, dtype="handle") + dev_id: ty.int32 = tir.tvm_struct_get(arg0, 0, 9, dtype="int32") + B: ty.handle = tir.tvm_struct_get(arg1, 0, 1, dtype="handle") + tir.attr(B, "storage_alignment", 128) + arg1_shape: ty.handle = tir.tvm_struct_get(arg1, 0, 2, dtype="handle") + arg1_strides: ty.handle = tir.tvm_struct_get(arg1, 0, 3, dtype="handle") + C: ty.handle = tir.tvm_struct_get(arg2, 0, 1, dtype="handle") + tir.attr(C, "storage_alignment", 128) + arg2_shape: ty.handle = tir.tvm_struct_get(arg2, 0, 2, dtype="handle") + arg2_strides: ty.handle = tir.tvm_struct_get(arg2, 0, 3, dtype="handle") + assert ((((arg0_code == 3) or (arg0_code == 13)) or (arg0_code == 7)) or (arg0_code == 4)), "mmult: Expect arg[0] to be pointer" + assert ((((arg1_code == 3) or (arg1_code == 13)) or (arg1_code == 7)) or (arg1_code == 4)), "mmult: Expect arg[1] to be pointer" + assert ((((arg2_code == 3) or (arg2_code == 13)) or (arg2_code == 7)) or (arg2_code == 4)), "mmult: Expect arg[2] to be pointer" + assert (2 == tir.tvm_struct_get(arg0, 0, 4, dtype="int32")), "arg0.ndim is expected to equal 2" + assert (2 == tir.tvm_struct_get(arg0, 0, 4, dtype="int32")), "arg0.ndim is expected to equal 2" + assert (((tir.tvm_struct_get(arg0, 0, 5, dtype="uint8") == tir.uint8(2)) and (tir.tvm_struct_get(arg0, 0, 6, dtype="uint8") == tir.uint8(32))) and (tir.tvm_struct_get(arg0, 0, 7, dtype="uint16") == tir.uint16(1))), "arg0.dtype is expected to be float32" + assert (1024 == tir.cast("int32", tir.load("int64", arg0_shape, 0))), "Argument arg0.shape[0] has an unsatisfied constraint" + assert (1024 == tir.cast("int32", tir.load("int64", arg0_shape, 1))), "Argument arg0.shape[1] has an unsatisfied constraint" + if not (tir.isnullptr(arg0_strides, dtype="bool")): + assert ((1 == tir.cast("int32", tir.load("int64", arg0_strides, 1))) and (1024 == tir.cast("int32", tir.load("int64", arg0_strides, 0)))), "arg0.strides: expected to be compact array" + tir.evaluate(0) + assert (tir.uint64(0) == tir.tvm_struct_get(arg0, 0, 8, dtype="uint64")), "Argument arg0.byte_offset has an unsatisfied constraint" + assert (1 == tir.tvm_struct_get(arg0, 0, 10, dtype="int32")), "Argument arg0.device_type has an unsatisfied constraint" + assert (2 == tir.tvm_struct_get(arg1, 0, 4, dtype="int32")), "arg1.ndim is expected to equal 2" + assert (2 == tir.tvm_struct_get(arg1, 0, 4, dtype="int32")), "arg1.ndim is expected to equal 2" + assert (((tir.tvm_struct_get(arg1, 0, 5, dtype="uint8") == tir.uint8(2)) and (tir.tvm_struct_get(arg1, 0, 6, dtype="uint8") == tir.uint8(32))) and (tir.tvm_struct_get(arg1, 0, 7, dtype="uint16") == tir.uint16(1))), "arg1.dtype is expected to be float32" + assert (1024 == tir.cast("int32", tir.load("int64", arg1_shape, 0))), "Argument arg1.shape[0] has an unsatisfied constraint" + assert (1024 == tir.cast("int32", tir.load("int64", arg1_shape, 1))), "Argument arg1.shape[1] has an unsatisfied constraint" + if not (tir.isnullptr(arg1_strides, dtype="bool")): + assert ((1 == tir.cast("int32", tir.load("int64", arg1_strides, 1))) and (1024 == tir.cast("int32", tir.load("int64", arg1_strides, 0)))), "arg1.strides: expected to be compact array" + tir.evaluate(0) + assert (tir.uint64(0) == tir.tvm_struct_get(arg1, 0, 8, dtype="uint64")), "Argument arg1.byte_offset has an unsatisfied constraint" + assert (1 == tir.tvm_struct_get(arg1, 0, 10, dtype="int32")), "Argument arg1.device_type has an unsatisfied constraint" + assert (dev_id == tir.tvm_struct_get(arg1, 0, 9, dtype="int32")), "Argument arg1.device_id has an unsatisfied constraint" + assert (2 == tir.tvm_struct_get(arg2, 0, 4, dtype="int32")), "arg2.ndim is expected to equal 2" + assert (2 == tir.tvm_struct_get(arg2, 0, 4, dtype="int32")), "arg2.ndim is expected to equal 2" + assert (((tir.tvm_struct_get(arg2, 0, 5, dtype="uint8") == tir.uint8(2)) and (tir.tvm_struct_get(arg2, 0, 6, dtype="uint8") == tir.uint8(32))) and (tir.tvm_struct_get(arg2, 0, 7, dtype="uint16") == tir.uint16(1))), "arg2.dtype is expected to be float32" + assert (1024 == tir.cast("int32", tir.load("int64", arg2_shape, 0))), "Argument arg2.shape[0] has an unsatisfied constraint" + assert (1024 == tir.cast("int32", tir.load("int64", arg2_shape, 1))), "Argument arg2.shape[1] has an unsatisfied constraint" + if not (tir.isnullptr(arg2_strides, dtype="bool")): + assert ((1 == tir.cast("int32", tir.load("int64", arg2_strides, 1))) and (1024 == tir.cast("int32", tir.load("int64", arg2_strides, 0)))), "arg2.strides: expected to be compact array" + tir.evaluate(0) + assert (tir.uint64(0) == tir.tvm_struct_get(arg2, 0, 8, dtype="uint64")), "Argument arg2.byte_offset has an unsatisfied constraint" + assert (1 == tir.tvm_struct_get(arg2, 0, 10, dtype="int32")), "Argument arg2.device_type has an unsatisfied constraint" + assert (dev_id == tir.tvm_struct_get(arg2, 0, 9, dtype="int32")), "Argument arg2.device_id has an unsatisfied constraint" + tir.attr(0, "compute_scope", "mmult_compute_") + tir.attr(packedB, "storage_scope", "global") + tir.attr(packedB, "storage_alignment", 128) + with tir.let(packedB, tir.TVMBackendAllocWorkspace(1, dev_id, tir.uint64(4194304), 2, 32, dtype="handle")): + if tir.isnullptr(packedB, dtype="bool"): + tir.evaluate(tir.tvm_throw_last_error(dtype="int32")) + for x in tir.range(0, 32, "parallel"): + for y in tir.range(0, 1024): + tir.store(packedB, tir.ramp(((x*32768) + (y*32)), 1, 32), tir.load("float32x32", B, tir.ramp(((y*1024) + (x*32)), 1, 32), tir.broadcast(True, 32)), tir.broadcast(True, 32)) + for x_outer in tir.range(0, 32, "parallel"): + tir.attr(C_global, "storage_scope", "global") + tir.attr(C_global, "storage_alignment", 128) + with tir.let(C_global, tir.TVMBackendAllocWorkspace(1, dev_id, tir.uint64(4096), 2, 32, dtype="handle")): + if tir.isnullptr(C_global, dtype="bool"): + tir.evaluate(tir.tvm_throw_last_error(dtype="int32")) + for y_outer in tir.range(0, 32): + for x_c_init in tir.range(0, 32): + tir.store(C_global, tir.ramp((x_c_init*32), 1, 32), tir.broadcast(tir.float32(0), 32), tir.broadcast(True, 32)) + for k_outer in tir.range(0, 256): + for x_c in tir.range(0, 32): + tir.store(C_global, tir.ramp((x_c*32), 1, 32), tir.call_llvm_pure_intrin(tir.uint32(97), tir.uint32(3), tir.broadcast(tir.load("float32", A, (((x_outer*32768) + (x_c*1024)) + (k_outer*4))), 32), tir.load("float32x32", packedB, tir.ramp(((y_outer*32768) + (k_outer*128)), 1, 32), tir.broadcast(True, 32)), tir.load("float32x32", C_global, tir.ramp((x_c*32), 1, 32), tir.broadcast(True, 32)), dtype="float32x32"), tir.broadcast(True, 32)) + tir.store(C_global, tir.ramp((x_c*32), 1, 32), tir.call_llvm_pure_intrin(tir.uint32(97), tir.uint32(3), tir.broadcast(tir.load("float32", A, ((((x_outer*32768) + (x_c*1024)) + (k_outer*4)) + 1)), 32), tir.load("float32x32", packedB, tir.ramp((((y_outer*32768) + (k_outer*128)) + 32), 1, 32), tir.broadcast(True, 32)), tir.load("float32x32", C_global, tir.ramp((x_c*32), 1, 32), tir.broadcast(True, 32)), dtype="float32x32"), tir.broadcast(True, 32)) + tir.store(C_global, tir.ramp((x_c*32), 1, 32), tir.call_llvm_pure_intrin(tir.uint32(97), tir.uint32(3), tir.broadcast(tir.load("float32", A, ((((x_outer*32768) + (x_c*1024)) + (k_outer*4)) + 2)), 32), tir.load("float32x32", packedB, tir.ramp((((y_outer*32768) + (k_outer*128)) + 64), 1, 32), tir.broadcast(True, 32)), tir.load("float32x32", C_global, tir.ramp((x_c*32), 1, 32), tir.broadcast(True, 32)), dtype="float32x32"), tir.broadcast(True, 32)) + tir.store(C_global, tir.ramp((x_c*32), 1, 32), tir.call_llvm_pure_intrin(tir.uint32(97), tir.uint32(3), tir.broadcast(tir.load("float32", A, ((((x_outer*32768) + (x_c*1024)) + (k_outer*4)) + 3)), 32), tir.load("float32x32", packedB, tir.ramp((((y_outer*32768) + (k_outer*128)) + 96), 1, 32), tir.broadcast(True, 32)), tir.load("float32x32", C_global, tir.ramp((x_c*32), 1, 32), tir.broadcast(True, 32)), dtype="float32x32"), tir.broadcast(True, 32)) + for x_inner in tir.range(0, 32): + for y_inner in tir.range(0, 32): + C[((((x_outer*32768) + (x_inner*1024)) + (y_outer*32)) + y_inner)] = tir.load("float32", C_global, ((x_inner*32) + y_inner)) + if (tir.TVMBackendFreeWorkspace(1, dev_id, C_global, dtype="int32") != 0): + tir.evaluate(tir.tvm_throw_last_error(dtype="int32")) + if (tir.TVMBackendFreeWorkspace(1, dev_id, packedB, dtype="int32") != 0): + tir.evaluate(tir.tvm_throw_last_error(dtype="int32")) + + +def test_opt_gemm_mod_host(): + mod = Module3() + rt_mod = tvm.hybrid.from_source(tvm.hybrid.ashybrid(mod, True)) + tvm.ir.assert_structural_equal(mod, rt_mod, True) + + +@tvm.hybrid.script +def opt_conv_tensorcore_normalize(A: ty.handle, W: ty.handle, Conv: ty.handle) -> None: + # function attr dict + tir.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + # var definition + blockIdx_x = tir.var("int32") + blockIdx_y = tir.var("int32") + blockIdx_z = tir.var("int32") + threadIdx_x = tir.var("int32") + threadIdx_y = tir.var("int32") + threadIdx_z = tir.var("int32") + # buffer definition + Apad_shared = tir.buffer_decl([16, 16, 16, 16, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1) + Apad_shared_wmma_matrix_a = tir.buffer_decl([16, 16, 16, 16, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1) + BA = tir.buffer_decl([16, 16], dtype="float16", scope="wmma.matrix_a", align=32, offset_factor=256) + BB = tir.buffer_decl([16, 16], dtype="float16", scope="wmma.matrix_b", align=32, offset_factor=256) + BC = tir.buffer_decl([16, 16], scope="wmma.accumulator", align=32, offset_factor=256) + Conv_wmma_accumulator = tir.buffer_decl([16, 14, 14, 32, 16, 16], elem_offset=0, align=128, offset_factor=1) + W_shared = tir.buffer_decl([3, 3, 16, 32, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1) + W_shared_wmma_matrix_b = tir.buffer_decl([3, 3, 16, 32, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1) + buffer = tir.buffer_decl([16, 16], dtype="float16", scope="shared", align=32, offset_factor=256) + buffer_1 = tir.buffer_decl([16, 16], dtype="float16", scope="wmma.matrix_a", align=32, offset_factor=256) + buffer_2 = tir.buffer_decl([16, 16], dtype="float16", scope="shared", align=32, offset_factor=256) + buffer_3 = tir.buffer_decl([16, 16], dtype="float16", scope="wmma.matrix_b", align=32, offset_factor=256) + buffer_4 = tir.buffer_decl([16, 16], scope="wmma.accumulator", align=32, offset_factor=256) + buffer_5 = tir.buffer_decl([16, 16], align=32, offset_factor=256) + A_1 = tir.buffer_bind(A, [16, 14, 14, 16, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1) + W_1 = tir.buffer_bind(W, [3, 3, 16, 32, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1) + Conv_1 = tir.buffer_bind(Conv, [16, 14, 14, 32, 16, 16], elem_offset=0, align=128, offset_factor=1) + # body + tir.attr(Conv_1, "realize_scope", "") + tir.realize(Conv_1[0:16, 0:14, 0:14, 0:32, 0:16, 0:16]) + tir.attr(tir.iter_var(blockIdx_z, None, "ThreadIndex", "blockIdx.z"), "thread_extent", 196) + tir.attr(tir.iter_var(blockIdx_x, None, "ThreadIndex", "blockIdx.x"), "thread_extent", 2) + tir.attr(tir.iter_var(blockIdx_y, None, "ThreadIndex", "blockIdx.y"), "thread_extent", 4) + tir.attr(tir.iter_var(threadIdx_y, None, "ThreadIndex", "threadIdx.y"), "thread_extent", 4) + tir.attr(tir.iter_var(threadIdx_z, None, "ThreadIndex", "threadIdx.z"), "thread_extent", 2) + tir.attr(Conv_wmma_accumulator, "realize_scope", "wmma.accumulator") + tir.realize(Conv_wmma_accumulator[((blockIdx_x*8) + (threadIdx_y*2)):(((blockIdx_x*8) + (threadIdx_y*2)) + 2), tir.floordiv(blockIdx_z, 14):(tir.floordiv(blockIdx_z, 14) + 1), tir.floormod(blockIdx_z, 14):(tir.floormod(blockIdx_z, 14) + 1), ((blockIdx_y*8) + (threadIdx_z*4)):(((blockIdx_y*8) + (threadIdx_z*4)) + 4), 0:16, 0:16]) + for n_c_init in tir.range(0, 2): + for o_c_init in tir.range(0, 4): + tir.attr([BC, Conv_wmma_accumulator], "buffer_bind_scope", tir.tvm_tuple((n_c_init + ((blockIdx_x*8) + (threadIdx_y*2))), 1, tir.floordiv(blockIdx_z, 14), 1, tir.floormod(blockIdx_z, 14), 1, (o_c_init + ((blockIdx_y*8) + (threadIdx_z*4))), 1, 0, 16, 0, 16, dtype="handle")) + tir.evaluate(tir.tvm_fill_fragment(BC.data, 16, 16, 16, tir.floordiv(BC.elem_offset, 256), tir.float32(0), dtype="handle")) + for ic_outer in tir.range(0, 8): + for kh in tir.range(0, 3): + tir.attr(Apad_shared, "realize_scope", "shared") + tir.realize(Apad_shared[(blockIdx_x*8):((blockIdx_x*8) + 8), (tir.floordiv(blockIdx_z, 14) + kh):((tir.floordiv(blockIdx_z, 14) + kh) + 1), tir.floormod(blockIdx_z, 14):(tir.floormod(blockIdx_z, 14) + 3), (ic_outer*2):((ic_outer*2) + 2), 0:16, 0:16]) + for ax2 in tir.range(0, 3): + for ax3 in tir.range(0, 2): + for ax4_ax5_fused_outer in tir.range(0, 8): + tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32) + Apad_shared[((threadIdx_z + (threadIdx_y*2)) + (blockIdx_x*8)), (tir.floordiv(blockIdx_z, 14) + kh), (ax2 + tir.floormod(blockIdx_z, 14)), (ax3 + (ic_outer*2)), tir.floordiv((threadIdx_x + (ax4_ax5_fused_outer*32)), 16), tir.floormod((threadIdx_x + (ax4_ax5_fused_outer*32)), 16)] = tir.if_then_else((((((tir.floordiv(blockIdx_z, 14) + kh) >= 1) and (((tir.floordiv(blockIdx_z, 14) + kh) - 1) < 14)) and ((ax2 + tir.floormod(blockIdx_z, 14)) >= 1)) and (((ax2 + tir.floormod(blockIdx_z, 14)) - 1) < 14)), A_1[((threadIdx_z + (threadIdx_y*2)) + (blockIdx_x*8)), ((tir.floordiv(blockIdx_z, 14) + kh) - 1), ((ax2 + tir.floormod(blockIdx_z, 14)) - 1), (ax3 + (ic_outer*2)), tir.floordiv((threadIdx_x + (ax4_ax5_fused_outer*32)), 16), tir.floormod((threadIdx_x + (ax4_ax5_fused_outer*32)), 16)], tir.float16(0), dtype="float16") + tir.attr(W_shared, "realize_scope", "shared") + tir.realize(W_shared[kh:(kh + 1), 0:3, (ic_outer*2):((ic_outer*2) + 2), (blockIdx_y*8):((blockIdx_y*8) + 8), 0:16, 0:16]) + for ax1 in tir.range(0, 3): + for ax2_1 in tir.range(0, 2): + tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32) + for ax4_ax5_fused_inner in tir.range(0, 8, "vectorized"): + W_shared[kh, ax1, (ax2_1 + (ic_outer*2)), ((threadIdx_z + (threadIdx_y*2)) + (blockIdx_y*8)), tir.floordiv((ax4_ax5_fused_inner + (threadIdx_x*8)), 16), tir.floormod((ax4_ax5_fused_inner + (threadIdx_x*8)), 16)] = W_1[kh, ax1, (ax2_1 + (ic_outer*2)), ((threadIdx_z + (threadIdx_y*2)) + (blockIdx_y*8)), tir.floordiv((ax4_ax5_fused_inner + (threadIdx_x*8)), 16), tir.floormod((ax4_ax5_fused_inner + (threadIdx_x*8)), 16)] + for ic_inner in tir.range(0, 2): + for kw in tir.range(0, 3): + tir.attr(Apad_shared_wmma_matrix_a, "realize_scope", "wmma.matrix_a") + tir.realize(Apad_shared_wmma_matrix_a[((blockIdx_x*8) + (threadIdx_y*2)):(((blockIdx_x*8) + (threadIdx_y*2)) + 2), (tir.floordiv(blockIdx_z, 14) + kh):((tir.floordiv(blockIdx_z, 14) + kh) + 1), (kw + tir.floormod(blockIdx_z, 14)):((kw + tir.floormod(blockIdx_z, 14)) + 1), ((ic_outer*2) + ic_inner):(((ic_outer*2) + ic_inner) + 1), 0:16, 0:16]) + for ax0 in tir.range(0, 2): + tir.attr([buffer, Apad_shared], "buffer_bind_scope", tir.tvm_tuple((ax0 + ((blockIdx_x*8) + (threadIdx_y*2))), 1, (tir.floordiv(blockIdx_z, 14) + kh), 1, (kw + tir.floormod(blockIdx_z, 14)), 1, ((ic_outer*2) + ic_inner), 1, 0, 16, 0, 16, dtype="handle")) + tir.attr([buffer_1, Apad_shared_wmma_matrix_a], "buffer_bind_scope", tir.tvm_tuple((ax0 + ((blockIdx_x*8) + (threadIdx_y*2))), 1, (tir.floordiv(blockIdx_z, 14) + kh), 1, (kw + tir.floormod(blockIdx_z, 14)), 1, ((ic_outer*2) + ic_inner), 1, 0, 16, 0, 16, dtype="handle")) + tir.evaluate(tir.tvm_load_matrix_sync(buffer_1.data, 16, 16, 16, tir.floordiv(buffer_1.elem_offset, 256), tir.tvm_access_ptr(tir.type_annotation(dtype="float16"), buffer.data, buffer.elem_offset, 256, 1, dtype="handle"), 16, "row_major", dtype="handle")) + tir.attr(W_shared_wmma_matrix_b, "realize_scope", "wmma.matrix_b") + tir.realize(W_shared_wmma_matrix_b[kh:(kh + 1), kw:(kw + 1), ((ic_outer*2) + ic_inner):(((ic_outer*2) + ic_inner) + 1), ((blockIdx_y*8) + (threadIdx_z*4)):(((blockIdx_y*8) + (threadIdx_z*4)) + 4), 0:16, 0:16]) + for ax3_1 in tir.range(0, 4): + tir.attr([buffer_2, W_shared], "buffer_bind_scope", tir.tvm_tuple(kh, 1, kw, 1, ((ic_outer*2) + ic_inner), 1, (ax3_1 + ((blockIdx_y*8) + (threadIdx_z*4))), 1, 0, 16, 0, 16, dtype="handle")) + tir.attr([buffer_3, W_shared_wmma_matrix_b], "buffer_bind_scope", tir.tvm_tuple(kh, 1, kw, 1, ((ic_outer*2) + ic_inner), 1, (ax3_1 + ((blockIdx_y*8) + (threadIdx_z*4))), 1, 0, 16, 0, 16, dtype="handle")) + tir.evaluate(tir.tvm_load_matrix_sync(buffer_3.data, 16, 16, 16, tir.floordiv(buffer_3.elem_offset, 256), tir.tvm_access_ptr(tir.type_annotation(dtype="float16"), buffer_2.data, buffer_2.elem_offset, 256, 1, dtype="handle"), 16, "row_major", dtype="handle")) + for n_c in tir.range(0, 2): + for o_c in tir.range(0, 4): + tir.attr([BA, Apad_shared_wmma_matrix_a], "buffer_bind_scope", tir.tvm_tuple((n_c + ((blockIdx_x*8) + (threadIdx_y*2))), 1, (tir.floordiv(blockIdx_z, 14) + kh), 1, (tir.floormod(blockIdx_z, 14) + kw), 1, ((ic_outer*2) + ic_inner), 1, 0, 16, 0, 16, dtype="handle")) + tir.attr([BB, W_shared_wmma_matrix_b], "buffer_bind_scope", tir.tvm_tuple(kh, 1, kw, 1, ((ic_outer*2) + ic_inner), 1, (o_c + ((blockIdx_y*8) + (threadIdx_z*4))), 1, 0, 16, 0, 16, dtype="handle")) + tir.attr([BC, Conv_wmma_accumulator], "buffer_bind_scope", tir.tvm_tuple((n_c + ((blockIdx_x*8) + (threadIdx_y*2))), 1, tir.floordiv(blockIdx_z, 14), 1, tir.floormod(blockIdx_z, 14), 1, (o_c + ((blockIdx_y*8) + (threadIdx_z*4))), 1, 0, 16, 0, 16, dtype="handle")) + tir.evaluate(tir.tvm_mma_sync(BC.data, tir.floordiv(BC.elem_offset, 256), BA.data, tir.floordiv(BA.elem_offset, 256), BB.data, tir.floordiv(BB.elem_offset, 256), BC.data, tir.floordiv(BC.elem_offset, 256), dtype="handle")) + for n_inner in tir.range(0, 2): + for o_inner in tir.range(0, 4): + tir.attr([buffer_4, Conv_wmma_accumulator], "buffer_bind_scope", tir.tvm_tuple(((((blockIdx_x*4) + threadIdx_y)*2) + n_inner), 1, tir.floordiv(blockIdx_z, 14), 1, tir.floormod(blockIdx_z, 14), 1, ((((blockIdx_y*2) + threadIdx_z)*4) + o_inner), 1, 0, 16, 0, 16, dtype="handle")) + tir.attr([buffer_5, Conv_1], "buffer_bind_scope", tir.tvm_tuple(((((blockIdx_x*4) + threadIdx_y)*2) + n_inner), 1, tir.floordiv(blockIdx_z, 14), 1, tir.floormod(blockIdx_z, 14), 1, ((((blockIdx_y*2) + threadIdx_z)*4) + o_inner), 1, 0, 16, 0, 16, dtype="handle")) + tir.evaluate(tir.tvm_store_matrix_sync(buffer_4.data, 16, 16, 16, tir.floordiv(buffer_4.elem_offset, 256), tir.tvm_access_ptr(tir.type_annotation(dtype="float32"), buffer_5.data, buffer_5.elem_offset, 256, 2, dtype="handle"), 16, "row_major", dtype="handle")) + + +def test_opt_conv_tensorcore_normalize(): + mod = opt_conv_tensorcore_normalize + rt_mod = tvm.hybrid.from_source(tvm.hybrid.ashybrid(mod, True)) + tvm.ir.assert_structural_equal(mod, rt_mod, True) + + +@tvm.hybrid.script +def opt_conv_tensorcore_lower(A: ty.handle, W: ty.handle, Conv: ty.handle) -> None: + # function attr dict + tir.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + # var definition + Apad_shared = tir.var("handle") + Apad_shared_wmma_matrix_a = tir.var("handle") + Conv_wmma_accumulator = tir.var("handle") + W_shared = tir.var("handle") + W_shared_wmma_matrix_b = tir.var("handle") + blockIdx_x = tir.var("int32") + blockIdx_y = tir.var("int32") + blockIdx_z = tir.var("int32") + threadIdx_x = tir.var("int32") + threadIdx_y = tir.var("int32") + threadIdx_z = tir.var("int32") + A_1 = tir.buffer_bind(A, [16, 14, 14, 16, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1) + W_1 = tir.buffer_bind(W, [3, 3, 16, 32, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1) + Conv_1 = tir.buffer_bind(Conv, [16, 14, 14, 32, 16, 16], elem_offset=0, align=128, offset_factor=1) + # body + tir.attr(tir.iter_var(blockIdx_z, None, "ThreadIndex", "blockIdx.z"), "thread_extent", 196) + tir.attr(Conv_wmma_accumulator, "storage_scope", "wmma.accumulator") + tir.allocate(Conv_wmma_accumulator, "float32", [2048]) + tir.attr(Apad_shared, "storage_scope", "shared") + tir.allocate(Apad_shared, "float16", [12288]) + tir.attr(W_shared, "storage_scope", "shared") + tir.allocate(W_shared, "float16", [12288]) + tir.attr(Apad_shared_wmma_matrix_a, "storage_scope", "wmma.matrix_a") + tir.allocate(Apad_shared_wmma_matrix_a, "float16", [512]) + tir.attr(W_shared_wmma_matrix_b, "storage_scope", "wmma.matrix_b") + tir.allocate(W_shared_wmma_matrix_b, "float16", [1024]) + tir.attr(tir.iter_var(blockIdx_x, None, "ThreadIndex", "blockIdx.x"), "thread_extent", 2) + tir.attr(tir.iter_var(blockIdx_y, None, "ThreadIndex", "blockIdx.y"), "thread_extent", 4) + tir.attr(tir.iter_var(threadIdx_y, None, "ThreadIndex", "threadIdx.y"), "thread_extent", 4) + tir.attr(tir.iter_var(threadIdx_z, None, "ThreadIndex", "threadIdx.z"), "thread_extent", 2) + tir.evaluate(tir.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 0, tir.float32(0), dtype="handle")) + tir.evaluate(tir.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 1, tir.float32(0), dtype="handle")) + tir.evaluate(tir.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 2, tir.float32(0), dtype="handle")) + tir.evaluate(tir.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 3, tir.float32(0), dtype="handle")) + tir.evaluate(tir.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 4, tir.float32(0), dtype="handle")) + tir.evaluate(tir.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 5, tir.float32(0), dtype="handle")) + tir.evaluate(tir.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 6, tir.float32(0), dtype="handle")) + tir.evaluate(tir.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 7, tir.float32(0), dtype="handle")) + for ic_outer in tir.range(0, 8): + for kh in tir.range(0, 3): + for ax2 in tir.range(0, 3): + with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32): + Apad_shared[((((threadIdx_y*3072) + (threadIdx_z*1536)) + (ax2*512)) + threadIdx_x)] = tir.if_then_else(((((1 <= (tir.floordiv(blockIdx_z, 14) + kh)) and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)) and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))) and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)), tir.load("float16", A_1.data, (((((((((blockIdx_x*6422528) + (threadIdx_y*1605632)) + (threadIdx_z*802816)) + (kh*57344)) + (blockIdx_z*4096)) + (ax2*4096)) + (ic_outer*512)) + threadIdx_x) - 61440)), tir.float16(0), dtype="float16") + with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32): + Apad_shared[(((((threadIdx_y*3072) + (threadIdx_z*1536)) + (ax2*512)) + threadIdx_x) + 32)] = tir.if_then_else(((((1 <= (tir.floordiv(blockIdx_z, 14) + kh)) and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)) and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))) and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)), tir.load("float16", A_1.data, (((((((((blockIdx_x*6422528) + (threadIdx_y*1605632)) + (threadIdx_z*802816)) + (kh*57344)) + (blockIdx_z*4096)) + (ax2*4096)) + (ic_outer*512)) + threadIdx_x) - 61408)), tir.float16(0), dtype="float16") + with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32): + Apad_shared[(((((threadIdx_y*3072) + (threadIdx_z*1536)) + (ax2*512)) + threadIdx_x) + 64)] = tir.if_then_else(((((1 <= (tir.floordiv(blockIdx_z, 14) + kh)) and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)) and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))) and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)), tir.load("float16", A_1.data, (((((((((blockIdx_x*6422528) + (threadIdx_y*1605632)) + (threadIdx_z*802816)) + (kh*57344)) + (blockIdx_z*4096)) + (ax2*4096)) + (ic_outer*512)) + threadIdx_x) - 61376)), tir.float16(0), dtype="float16") + with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32): + Apad_shared[(((((threadIdx_y*3072) + (threadIdx_z*1536)) + (ax2*512)) + threadIdx_x) + 96)] = tir.if_then_else(((((1 <= (tir.floordiv(blockIdx_z, 14) + kh)) and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)) and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))) and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)), tir.load("float16", A_1.data, (((((((((blockIdx_x*6422528) + (threadIdx_y*1605632)) + (threadIdx_z*802816)) + (kh*57344)) + (blockIdx_z*4096)) + (ax2*4096)) + (ic_outer*512)) + threadIdx_x) - 61344)), tir.float16(0), dtype="float16") + with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32): + Apad_shared[(((((threadIdx_y*3072) + (threadIdx_z*1536)) + (ax2*512)) + threadIdx_x) + 128)] = tir.if_then_else(((((1 <= (tir.floordiv(blockIdx_z, 14) + kh)) and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)) and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))) and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)), tir.load("float16", A_1.data, (((((((((blockIdx_x*6422528) + (threadIdx_y*1605632)) + (threadIdx_z*802816)) + (kh*57344)) + (blockIdx_z*4096)) + (ax2*4096)) + (ic_outer*512)) + threadIdx_x) - 61312)), tir.float16(0), dtype="float16") + with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32): + Apad_shared[(((((threadIdx_y*3072) + (threadIdx_z*1536)) + (ax2*512)) + threadIdx_x) + 160)] = tir.if_then_else(((((1 <= (tir.floordiv(blockIdx_z, 14) + kh)) and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)) and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))) and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)), tir.load("float16", A_1.data, (((((((((blockIdx_x*6422528) + (threadIdx_y*1605632)) + (threadIdx_z*802816)) + (kh*57344)) + (blockIdx_z*4096)) + (ax2*4096)) + (ic_outer*512)) + threadIdx_x) - 61280)), tir.float16(0), dtype="float16") + with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32): + Apad_shared[(((((threadIdx_y*3072) + (threadIdx_z*1536)) + (ax2*512)) + threadIdx_x) + 192)] = tir.if_then_else(((((1 <= (tir.floordiv(blockIdx_z, 14) + kh)) and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)) and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))) and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)), tir.load("float16", A_1.data, (((((((((blockIdx_x*6422528) + (threadIdx_y*1605632)) + (threadIdx_z*802816)) + (kh*57344)) + (blockIdx_z*4096)) + (ax2*4096)) + (ic_outer*512)) + threadIdx_x) - 61248)), tir.float16(0), dtype="float16") + with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32): + Apad_shared[(((((threadIdx_y*3072) + (threadIdx_z*1536)) + (ax2*512)) + threadIdx_x) + 224)] = tir.if_then_else(((((1 <= (tir.floordiv(blockIdx_z, 14) + kh)) and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)) and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))) and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)), tir.load("float16", A_1.data, (((((((((blockIdx_x*6422528) + (threadIdx_y*1605632)) + (threadIdx_z*802816)) + (kh*57344)) + (blockIdx_z*4096)) + (ax2*4096)) + (ic_outer*512)) + threadIdx_x) - 61216)), tir.float16(0), dtype="float16") + with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32): + Apad_shared[(((((threadIdx_y*3072) + (threadIdx_z*1536)) + (ax2*512)) + threadIdx_x) + 256)] = tir.if_then_else(((((1 <= (tir.floordiv(blockIdx_z, 14) + kh)) and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)) and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))) and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)), tir.load("float16", A_1.data, (((((((((blockIdx_x*6422528) + (threadIdx_y*1605632)) + (threadIdx_z*802816)) + (kh*57344)) + (blockIdx_z*4096)) + (ax2*4096)) + (ic_outer*512)) + threadIdx_x) - 61184)), tir.float16(0), dtype="float16") + with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32): + Apad_shared[(((((threadIdx_y*3072) + (threadIdx_z*1536)) + (ax2*512)) + threadIdx_x) + 288)] = tir.if_then_else(((((1 <= (tir.floordiv(blockIdx_z, 14) + kh)) and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)) and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))) and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)), tir.load("float16", A_1.data, (((((((((blockIdx_x*6422528) + (threadIdx_y*1605632)) + (threadIdx_z*802816)) + (kh*57344)) + (blockIdx_z*4096)) + (ax2*4096)) + (ic_outer*512)) + threadIdx_x) - 61152)), tir.float16(0), dtype="float16") + with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32): + Apad_shared[(((((threadIdx_y*3072) + (threadIdx_z*1536)) + (ax2*512)) + threadIdx_x) + 320)] = tir.if_then_else(((((1 <= (tir.floordiv(blockIdx_z, 14) + kh)) and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)) and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))) and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)), tir.load("float16", A_1.data, (((((((((blockIdx_x*6422528) + (threadIdx_y*1605632)) + (threadIdx_z*802816)) + (kh*57344)) + (blockIdx_z*4096)) + (ax2*4096)) + (ic_outer*512)) + threadIdx_x) - 61120)), tir.float16(0), dtype="float16") + with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32): + Apad_shared[(((((threadIdx_y*3072) + (threadIdx_z*1536)) + (ax2*512)) + threadIdx_x) + 352)] = tir.if_then_else(((((1 <= (tir.floordiv(blockIdx_z, 14) + kh)) and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)) and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))) and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)), tir.load("float16", A_1.data, (((((((((blockIdx_x*6422528) + (threadIdx_y*1605632)) + (threadIdx_z*802816)) + (kh*57344)) + (blockIdx_z*4096)) + (ax2*4096)) + (ic_outer*512)) + threadIdx_x) - 61088)), tir.float16(0), dtype="float16") + with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32): + Apad_shared[(((((threadIdx_y*3072) + (threadIdx_z*1536)) + (ax2*512)) + threadIdx_x) + 384)] = tir.if_then_else(((((1 <= (tir.floordiv(blockIdx_z, 14) + kh)) and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)) and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))) and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)), tir.load("float16", A_1.data, (((((((((blockIdx_x*6422528) + (threadIdx_y*1605632)) + (threadIdx_z*802816)) + (kh*57344)) + (blockIdx_z*4096)) + (ax2*4096)) + (ic_outer*512)) + threadIdx_x) - 61056)), tir.float16(0), dtype="float16") + with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32): + Apad_shared[(((((threadIdx_y*3072) + (threadIdx_z*1536)) + (ax2*512)) + threadIdx_x) + 416)] = tir.if_then_else(((((1 <= (tir.floordiv(blockIdx_z, 14) + kh)) and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)) and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))) and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)), tir.load("float16", A_1.data, (((((((((blockIdx_x*6422528) + (threadIdx_y*1605632)) + (threadIdx_z*802816)) + (kh*57344)) + (blockIdx_z*4096)) + (ax2*4096)) + (ic_outer*512)) + threadIdx_x) - 61024)), tir.float16(0), dtype="float16") + with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32): + Apad_shared[(((((threadIdx_y*3072) + (threadIdx_z*1536)) + (ax2*512)) + threadIdx_x) + 448)] = tir.if_then_else(((((1 <= (tir.floordiv(blockIdx_z, 14) + kh)) and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)) and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))) and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)), tir.load("float16", A_1.data, (((((((((blockIdx_x*6422528) + (threadIdx_y*1605632)) + (threadIdx_z*802816)) + (kh*57344)) + (blockIdx_z*4096)) + (ax2*4096)) + (ic_outer*512)) + threadIdx_x) - 60992)), tir.float16(0), dtype="float16") + tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32) + Apad_shared[(((((threadIdx_y*3072) + (threadIdx_z*1536)) + (ax2*512)) + threadIdx_x) + 480)] = tir.if_then_else(((((1 <= (tir.floordiv(blockIdx_z, 14) + kh)) and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)) and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))) and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)), tir.load("float16", A_1.data, (((((((((blockIdx_x*6422528) + (threadIdx_y*1605632)) + (threadIdx_z*802816)) + (kh*57344)) + (blockIdx_z*4096)) + (ax2*4096)) + (ic_outer*512)) + threadIdx_x) - 60960)), tir.float16(0), dtype="float16") + with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32): + tir.store(W_shared, tir.ramp((((threadIdx_y*512) + (threadIdx_z*256)) + (threadIdx_x*8)), 1, 8), tir.load("float16x8", W_1.data, tir.ramp(((((((kh*393216) + (ic_outer*16384)) + (blockIdx_y*2048)) + (threadIdx_y*512)) + (threadIdx_z*256)) + (threadIdx_x*8)), 1, 8), tir.broadcast(True, 8)), tir.broadcast(True, 8)) + with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32): + tir.store(W_shared, tir.ramp(((((threadIdx_y*512) + (threadIdx_z*256)) + (threadIdx_x*8)) + 2048), 1, 8), tir.load("float16x8", W_1.data, tir.ramp((((((((kh*393216) + (ic_outer*16384)) + (blockIdx_y*2048)) + (threadIdx_y*512)) + (threadIdx_z*256)) + (threadIdx_x*8)) + 8192), 1, 8), tir.broadcast(True, 8)), tir.broadcast(True, 8)) + with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32): + tir.store(W_shared, tir.ramp(((((threadIdx_y*512) + (threadIdx_z*256)) + (threadIdx_x*8)) + 4096), 1, 8), tir.load("float16x8", W_1.data, tir.ramp((((((((kh*393216) + (ic_outer*16384)) + (blockIdx_y*2048)) + (threadIdx_y*512)) + (threadIdx_z*256)) + (threadIdx_x*8)) + 131072), 1, 8), tir.broadcast(True, 8)), tir.broadcast(True, 8)) + with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32): + tir.store(W_shared, tir.ramp(((((threadIdx_y*512) + (threadIdx_z*256)) + (threadIdx_x*8)) + 6144), 1, 8), tir.load("float16x8", W_1.data, tir.ramp((((((((kh*393216) + (ic_outer*16384)) + (blockIdx_y*2048)) + (threadIdx_y*512)) + (threadIdx_z*256)) + (threadIdx_x*8)) + 139264), 1, 8), tir.broadcast(True, 8)), tir.broadcast(True, 8)) + with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32): + tir.store(W_shared, tir.ramp(((((threadIdx_y*512) + (threadIdx_z*256)) + (threadIdx_x*8)) + 8192), 1, 8), tir.load("float16x8", W_1.data, tir.ramp((((((((kh*393216) + (ic_outer*16384)) + (blockIdx_y*2048)) + (threadIdx_y*512)) + (threadIdx_z*256)) + (threadIdx_x*8)) + 262144), 1, 8), tir.broadcast(True, 8)), tir.broadcast(True, 8)) + with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32): + tir.store(W_shared, tir.ramp(((((threadIdx_y*512) + (threadIdx_z*256)) + (threadIdx_x*8)) + 10240), 1, 8), tir.load("float16x8", W_1.data, tir.ramp((((((((kh*393216) + (ic_outer*16384)) + (blockIdx_y*2048)) + (threadIdx_y*512)) + (threadIdx_z*256)) + (threadIdx_x*8)) + 270336), 1, 8), tir.broadcast(True, 8)), tir.broadcast(True, 8)) + for ic_inner in tir.range(0, 2): + for kw in tir.range(0, 3): + tir.evaluate(tir.tvm_load_matrix_sync(Apad_shared_wmma_matrix_a, 16, 16, 16, 0, tir.tvm_access_ptr(tir.type_annotation(dtype="float16"), Apad_shared, (((threadIdx_y*3072) + (kw*512)) + (ic_inner*256)), 256, 1, dtype="handle"), 16, "row_major", dtype="handle")) + tir.evaluate(tir.tvm_load_matrix_sync(Apad_shared_wmma_matrix_a, 16, 16, 16, 1, tir.tvm_access_ptr(tir.type_annotation(dtype="float16"), Apad_shared, ((((threadIdx_y*3072) + (kw*512)) + (ic_inner*256)) + 1536), 256, 1, dtype="handle"), 16, "row_major", dtype="handle")) + tir.evaluate(tir.tvm_load_matrix_sync(W_shared_wmma_matrix_b, 16, 16, 16, 0, tir.tvm_access_ptr(tir.type_annotation(dtype="float16"), W_shared, (((kw*4096) + (ic_inner*2048)) + (threadIdx_z*1024)), 256, 1, dtype="handle"), 16, "row_major", dtype="handle")) + tir.evaluate(tir.tvm_load_matrix_sync(W_shared_wmma_matrix_b, 16, 16, 16, 1, tir.tvm_access_ptr(tir.type_annotation(dtype="float16"), W_shared, ((((kw*4096) + (ic_inner*2048)) + (threadIdx_z*1024)) + 256), 256, 1, dtype="handle"), 16, "row_major", dtype="handle")) + tir.evaluate(tir.tvm_load_matrix_sync(W_shared_wmma_matrix_b, 16, 16, 16, 2, tir.tvm_access_ptr(tir.type_annotation(dtype="float16"), W_shared, ((((kw*4096) + (ic_inner*2048)) + (threadIdx_z*1024)) + 512), 256, 1, dtype="handle"), 16, "row_major", dtype="handle")) + tir.evaluate(tir.tvm_load_matrix_sync(W_shared_wmma_matrix_b, 16, 16, 16, 3, tir.tvm_access_ptr(tir.type_annotation(dtype="float16"), W_shared, ((((kw*4096) + (ic_inner*2048)) + (threadIdx_z*1024)) + 768), 256, 1, dtype="handle"), 16, "row_major", dtype="handle")) + tir.evaluate(tir.tvm_mma_sync(Conv_wmma_accumulator, 0, Apad_shared_wmma_matrix_a, 0, W_shared_wmma_matrix_b, 0, Conv_wmma_accumulator, 0, dtype="handle")) + tir.evaluate(tir.tvm_mma_sync(Conv_wmma_accumulator, 1, Apad_shared_wmma_matrix_a, 0, W_shared_wmma_matrix_b, 1, Conv_wmma_accumulator, 1, dtype="handle")) + tir.evaluate(tir.tvm_mma_sync(Conv_wmma_accumulator, 2, Apad_shared_wmma_matrix_a, 0, W_shared_wmma_matrix_b, 2, Conv_wmma_accumulator, 2, dtype="handle")) + tir.evaluate(tir.tvm_mma_sync(Conv_wmma_accumulator, 3, Apad_shared_wmma_matrix_a, 0, W_shared_wmma_matrix_b, 3, Conv_wmma_accumulator, 3, dtype="handle")) + tir.evaluate(tir.tvm_mma_sync(Conv_wmma_accumulator, 4, Apad_shared_wmma_matrix_a, 1, W_shared_wmma_matrix_b, 0, Conv_wmma_accumulator, 4, dtype="handle")) + tir.evaluate(tir.tvm_mma_sync(Conv_wmma_accumulator, 5, Apad_shared_wmma_matrix_a, 1, W_shared_wmma_matrix_b, 1, Conv_wmma_accumulator, 5, dtype="handle")) + tir.evaluate(tir.tvm_mma_sync(Conv_wmma_accumulator, 6, Apad_shared_wmma_matrix_a, 1, W_shared_wmma_matrix_b, 2, Conv_wmma_accumulator, 6, dtype="handle")) + tir.evaluate(tir.tvm_mma_sync(Conv_wmma_accumulator, 7, Apad_shared_wmma_matrix_a, 1, W_shared_wmma_matrix_b, 3, Conv_wmma_accumulator, 7, dtype="handle")) + tir.evaluate(tir.tvm_store_matrix_sync(Conv_wmma_accumulator, 16, 16, 16, 0, tir.tvm_access_ptr(tir.type_annotation(dtype="float32"), Conv_1.data, (((((blockIdx_x*12845056) + (threadIdx_y*3211264)) + (blockIdx_z*8192)) + (blockIdx_y*2048)) + (threadIdx_z*1024)), 256, 2, dtype="handle"), 16, "row_major", dtype="handle")) + tir.evaluate(tir.tvm_store_matrix_sync(Conv_wmma_accumulator, 16, 16, 16, 1, tir.tvm_access_ptr(tir.type_annotation(dtype="float32"), Conv_1.data, ((((((blockIdx_x*12845056) + (threadIdx_y*3211264)) + (blockIdx_z*8192)) + (blockIdx_y*2048)) + (threadIdx_z*1024)) + 256), 256, 2, dtype="handle"), 16, "row_major", dtype="handle")) + tir.evaluate(tir.tvm_store_matrix_sync(Conv_wmma_accumulator, 16, 16, 16, 2, tir.tvm_access_ptr(tir.type_annotation(dtype="float32"), Conv_1.data, ((((((blockIdx_x*12845056) + (threadIdx_y*3211264)) + (blockIdx_z*8192)) + (blockIdx_y*2048)) + (threadIdx_z*1024)) + 512), 256, 2, dtype="handle"), 16, "row_major", dtype="handle")) + tir.evaluate(tir.tvm_store_matrix_sync(Conv_wmma_accumulator, 16, 16, 16, 3, tir.tvm_access_ptr(tir.type_annotation(dtype="float32"), Conv_1.data, ((((((blockIdx_x*12845056) + (threadIdx_y*3211264)) + (blockIdx_z*8192)) + (blockIdx_y*2048)) + (threadIdx_z*1024)) + 768), 256, 2, dtype="handle"), 16, "row_major", dtype="handle")) + tir.evaluate(tir.tvm_store_matrix_sync(Conv_wmma_accumulator, 16, 16, 16, 4, tir.tvm_access_ptr(tir.type_annotation(dtype="float32"), Conv_1.data, ((((((blockIdx_x*12845056) + (threadIdx_y*3211264)) + (blockIdx_z*8192)) + (blockIdx_y*2048)) + (threadIdx_z*1024)) + 1605632), 256, 2, dtype="handle"), 16, "row_major", dtype="handle")) + tir.evaluate(tir.tvm_store_matrix_sync(Conv_wmma_accumulator, 16, 16, 16, 5, tir.tvm_access_ptr(tir.type_annotation(dtype="float32"), Conv_1.data, ((((((blockIdx_x*12845056) + (threadIdx_y*3211264)) + (blockIdx_z*8192)) + (blockIdx_y*2048)) + (threadIdx_z*1024)) + 1605888), 256, 2, dtype="handle"), 16, "row_major", dtype="handle")) + tir.evaluate(tir.tvm_store_matrix_sync(Conv_wmma_accumulator, 16, 16, 16, 6, tir.tvm_access_ptr(tir.type_annotation(dtype="float32"), Conv_1.data, ((((((blockIdx_x*12845056) + (threadIdx_y*3211264)) + (blockIdx_z*8192)) + (blockIdx_y*2048)) + (threadIdx_z*1024)) + 1606144), 256, 2, dtype="handle"), 16, "row_major", dtype="handle")) + tir.evaluate(tir.tvm_store_matrix_sync(Conv_wmma_accumulator, 16, 16, 16, 7, tir.tvm_access_ptr(tir.type_annotation(dtype="float32"), Conv_1.data, ((((((blockIdx_x*12845056) + (threadIdx_y*3211264)) + (blockIdx_z*8192)) + (blockIdx_y*2048)) + (threadIdx_z*1024)) + 1606400), 256, 2, dtype="handle"), 16, "row_major", dtype="handle")) + + +def test_opt_conv_tensorcore_lower(): + mod = opt_conv_tensorcore_lower + rt_mod = tvm.hybrid.from_source(tvm.hybrid.ashybrid(mod, True)) + tvm.ir.assert_structural_equal(mod, rt_mod, True) + + +@tvm.hybrid.script +def opt_conv_tensorcore_mod_host(args: ty.handle, arg_type_ids: ty.handle, num_args: ty.int32, out_ret_value: ty.handle, out_ret_tcode: ty.handle, resource_handle: ty.handle) -> ty.int32: + # function attr dict + tir.func_attr({"tir.noalias": True, "global_symbol": "default_function", "tir.is_entry_func": True, "calling_conv": 1}) + # body + stack_tcode: ty.handle = tir.tvm_stack_alloca("arg_tcode", 10, dtype="handle") + stack_value: ty.handle = tir.tvm_stack_alloca("arg_value", 10, dtype="handle") + assert (num_args == 3), "default_function: num_args should be 3" + arg0: ty.handle = tir.tvm_struct_get(args, 0, 12, dtype="handle") + arg0_code: ty.int32 = tir.load("int32", arg_type_ids, 0) + arg1: ty.handle = tir.tvm_struct_get(args, 1, 12, dtype="handle") + arg1_code: ty.int32 = tir.load("int32", arg_type_ids, 1) + arg2: ty.handle = tir.tvm_struct_get(args, 2, 12, dtype="handle") + arg2_code: ty.int32 = tir.load("int32", arg_type_ids, 2) + A: ty.handle = tir.tvm_struct_get(arg0, 0, 1, dtype="handle") + tir.attr(A, "storage_alignment", 128) + arg0_shape: ty.handle = tir.tvm_struct_get(arg0, 0, 2, dtype="handle") + arg0_strides: ty.handle = tir.tvm_struct_get(arg0, 0, 3, dtype="handle") + dev_id: ty.int32 = tir.tvm_struct_get(arg0, 0, 9, dtype="int32") + W: ty.handle = tir.tvm_struct_get(arg1, 0, 1, dtype="handle") + tir.attr(W, "storage_alignment", 128) + arg1_shape: ty.handle = tir.tvm_struct_get(arg1, 0, 2, dtype="handle") + arg1_strides: ty.handle = tir.tvm_struct_get(arg1, 0, 3, dtype="handle") + Conv: ty.handle = tir.tvm_struct_get(arg2, 0, 1, dtype="handle") + tir.attr(Conv, "storage_alignment", 128) + arg2_shape: ty.handle = tir.tvm_struct_get(arg2, 0, 2, dtype="handle") + arg2_strides: ty.handle = tir.tvm_struct_get(arg2, 0, 3, dtype="handle") + assert ((((arg0_code == 3) or (arg0_code == 13)) or (arg0_code == 7)) or (arg0_code == 4)), "default_function: Expect arg[0] to be pointer" + assert ((((arg1_code == 3) or (arg1_code == 13)) or (arg1_code == 7)) or (arg1_code == 4)), "default_function: Expect arg[1] to be pointer" + assert ((((arg2_code == 3) or (arg2_code == 13)) or (arg2_code == 7)) or (arg2_code == 4)), "default_function: Expect arg[2] to be pointer" + assert (6 == tir.tvm_struct_get(arg0, 0, 4, dtype="int32")), "arg0.ndim is expected to equal 6" + assert (6 == tir.tvm_struct_get(arg0, 0, 4, dtype="int32")), "arg0.ndim is expected to equal 6" + assert (((tir.tvm_struct_get(arg0, 0, 5, dtype="uint8") == tir.uint8(2)) and (tir.tvm_struct_get(arg0, 0, 6, dtype="uint8") == tir.uint8(16))) and (tir.tvm_struct_get(arg0, 0, 7, dtype="uint16") == tir.uint16(1))), "arg0.dtype is expected to be float16" + assert (16 == tir.cast("int32", tir.load("int64", arg0_shape, 0))), "Argument arg0.shape[0] has an unsatisfied constraint" + assert (14 == tir.cast("int32", tir.load("int64", arg0_shape, 1))), "Argument arg0.shape[1] has an unsatisfied constraint" + assert (14 == tir.cast("int32", tir.load("int64", arg0_shape, 2))), "Argument arg0.shape[2] has an unsatisfied constraint" + assert (16 == tir.cast("int32", tir.load("int64", arg0_shape, 3))), "Argument arg0.shape[3] has an unsatisfied constraint" + assert (16 == tir.cast("int32", tir.load("int64", arg0_shape, 4))), "Argument arg0.shape[4] has an unsatisfied constraint" + assert (16 == tir.cast("int32", tir.load("int64", arg0_shape, 5))), "Argument arg0.shape[5] has an unsatisfied constraint" + if not (tir.isnullptr(arg0_strides, dtype="bool")): + assert ((((((1 == tir.cast("int32", tir.load("int64", arg0_strides, 5))) and (16 == tir.cast("int32", tir.load("int64", arg0_strides, 4)))) and (256 == tir.cast("int32", tir.load("int64", arg0_strides, 3)))) and (4096 == tir.cast("int32", tir.load("int64", arg0_strides, 2)))) and (57344 == tir.cast("int32", tir.load("int64", arg0_strides, 1)))) and (802816 == tir.cast("int32", tir.load("int64", arg0_strides, 0)))), "arg0.strides: expected to be compact array" + tir.evaluate(0) + assert (tir.uint64(0) == tir.tvm_struct_get(arg0, 0, 8, dtype="uint64")), "Argument arg0.byte_offset has an unsatisfied constraint" + assert (2 == tir.tvm_struct_get(arg0, 0, 10, dtype="int32")), "Argument arg0.device_type has an unsatisfied constraint" + assert (6 == tir.tvm_struct_get(arg1, 0, 4, dtype="int32")), "arg1.ndim is expected to equal 6" + assert (6 == tir.tvm_struct_get(arg1, 0, 4, dtype="int32")), "arg1.ndim is expected to equal 6" + assert (((tir.tvm_struct_get(arg1, 0, 5, dtype="uint8") == tir.uint8(2)) and (tir.tvm_struct_get(arg1, 0, 6, dtype="uint8") == tir.uint8(16))) and (tir.tvm_struct_get(arg1, 0, 7, dtype="uint16") == tir.uint16(1))), "arg1.dtype is expected to be float16" + assert (3 == tir.cast("int32", tir.load("int64", arg1_shape, 0))), "Argument arg1.shape[0] has an unsatisfied constraint" + assert (3 == tir.cast("int32", tir.load("int64", arg1_shape, 1))), "Argument arg1.shape[1] has an unsatisfied constraint" + assert (16 == tir.cast("int32", tir.load("int64", arg1_shape, 2))), "Argument arg1.shape[2] has an unsatisfied constraint" + assert (32 == tir.cast("int32", tir.load("int64", arg1_shape, 3))), "Argument arg1.shape[3] has an unsatisfied constraint" + assert (16 == tir.cast("int32", tir.load("int64", arg1_shape, 4))), "Argument arg1.shape[4] has an unsatisfied constraint" + assert (16 == tir.cast("int32", tir.load("int64", arg1_shape, 5))), "Argument arg1.shape[5] has an unsatisfied constraint" + if not (tir.isnullptr(arg1_strides, dtype="bool")): + assert ((((((1 == tir.cast("int32", tir.load("int64", arg1_strides, 5))) and (16 == tir.cast("int32", tir.load("int64", arg1_strides, 4)))) and (256 == tir.cast("int32", tir.load("int64", arg1_strides, 3)))) and (8192 == tir.cast("int32", tir.load("int64", arg1_strides, 2)))) and (131072 == tir.cast("int32", tir.load("int64", arg1_strides, 1)))) and (393216 == tir.cast("int32", tir.load("int64", arg1_strides, 0)))), "arg1.strides: expected to be compact array" + tir.evaluate(0) + assert (tir.uint64(0) == tir.tvm_struct_get(arg1, 0, 8, dtype="uint64")), "Argument arg1.byte_offset has an unsatisfied constraint" + assert (2 == tir.tvm_struct_get(arg1, 0, 10, dtype="int32")), "Argument arg1.device_type has an unsatisfied constraint" + assert (dev_id == tir.tvm_struct_get(arg1, 0, 9, dtype="int32")), "Argument arg1.device_id has an unsatisfied constraint" + assert (6 == tir.tvm_struct_get(arg2, 0, 4, dtype="int32")), "arg2.ndim is expected to equal 6" + assert (6 == tir.tvm_struct_get(arg2, 0, 4, dtype="int32")), "arg2.ndim is expected to equal 6" + assert (((tir.tvm_struct_get(arg2, 0, 5, dtype="uint8") == tir.uint8(2)) and (tir.tvm_struct_get(arg2, 0, 6, dtype="uint8") == tir.uint8(32))) and (tir.tvm_struct_get(arg2, 0, 7, dtype="uint16") == tir.uint16(1))), "arg2.dtype is expected to be float32" + assert (16 == tir.cast("int32", tir.load("int64", arg2_shape, 0))), "Argument arg2.shape[0] has an unsatisfied constraint" + assert (14 == tir.cast("int32", tir.load("int64", arg2_shape, 1))), "Argument arg2.shape[1] has an unsatisfied constraint" + assert (14 == tir.cast("int32", tir.load("int64", arg2_shape, 2))), "Argument arg2.shape[2] has an unsatisfied constraint" + assert (32 == tir.cast("int32", tir.load("int64", arg2_shape, 3))), "Argument arg2.shape[3] has an unsatisfied constraint" + assert (16 == tir.cast("int32", tir.load("int64", arg2_shape, 4))), "Argument arg2.shape[4] has an unsatisfied constraint" + assert (16 == tir.cast("int32", tir.load("int64", arg2_shape, 5))), "Argument arg2.shape[5] has an unsatisfied constraint" + if not (tir.isnullptr(arg2_strides, dtype="bool")): + assert ((((((1 == tir.cast("int32", tir.load("int64", arg2_strides, 5))) and (16 == tir.cast("int32", tir.load("int64", arg2_strides, 4)))) and (256 == tir.cast("int32", tir.load("int64", arg2_strides, 3)))) and (8192 == tir.cast("int32", tir.load("int64", arg2_strides, 2)))) and (114688 == tir.cast("int32", tir.load("int64", arg2_strides, 1)))) and (1605632 == tir.cast("int32", tir.load("int64", arg2_strides, 0)))), "arg2.strides: expected to be compact array" + tir.evaluate(0) + assert (tir.uint64(0) == tir.tvm_struct_get(arg2, 0, 8, dtype="uint64")), "Argument arg2.byte_offset has an unsatisfied constraint" + assert (2 == tir.tvm_struct_get(arg2, 0, 10, dtype="int32")), "Argument arg2.device_type has an unsatisfied constraint" + assert (dev_id == tir.tvm_struct_get(arg2, 0, 9, dtype="int32")), "Argument arg2.device_id has an unsatisfied constraint" + tir.evaluate(tir.tvm_struct_set(stack_value, 0, 12, tir.cast("int64", 2), dtype="int32")) + stack_tcode[0] = 0 + tir.evaluate(tir.tvm_struct_set(stack_value, 1, 12, tir.cast("int64", dev_id), dtype="int32")) + stack_tcode[1] = 0 + tir.evaluate(tir.tvm_call_packed_lowered("__tvm_set_device", stack_value, stack_tcode, 0, 2, dtype="int32")) + tir.attr(0, "compute_scope", "default_function_compute_") + tir.evaluate(tir.tvm_struct_set(stack_value, 0, 12, A, dtype="int32")) + stack_tcode[0] = 3 + tir.evaluate(tir.tvm_struct_set(stack_value, 1, 12, W, dtype="int32")) + stack_tcode[1] = 3 + tir.evaluate(tir.tvm_struct_set(stack_value, 2, 12, Conv, dtype="int32")) + stack_tcode[2] = 3 + tir.evaluate(tir.tvm_struct_set(stack_value, 3, 12, tir.cast("int64", 196), dtype="int32")) + stack_tcode[3] = 0 + tir.evaluate(tir.tvm_struct_set(stack_value, 4, 12, tir.cast("int64", 2), dtype="int32")) + stack_tcode[4] = 0 + tir.evaluate(tir.tvm_struct_set(stack_value, 5, 12, tir.cast("int64", 4), dtype="int32")) + stack_tcode[5] = 0 + tir.evaluate(tir.tvm_struct_set(stack_value, 6, 12, tir.cast("int64", 4), dtype="int32")) + stack_tcode[6] = 0 + tir.evaluate(tir.tvm_struct_set(stack_value, 7, 12, tir.cast("int64", 2), dtype="int32")) + stack_tcode[7] = 0 + tir.evaluate(tir.tvm_struct_set(stack_value, 8, 12, tir.cast("int64", 32), dtype="int32")) + stack_tcode[8] = 0 + tir.evaluate(tir.tvm_call_packed_lowered("default_function_kernel0", stack_value, stack_tcode, 0, 9, dtype="int32")) + + +def test_opt_conv_tensorcore_mod_host(): + mod = opt_conv_tensorcore_mod_host + rt_mod = tvm.hybrid.from_source(tvm.hybrid.ashybrid(mod, True)) + tvm.ir.assert_structural_equal(mod, rt_mod, True) + + +if __name__ == '__main__': + test_opt_gemm_normalize() + test_opt_gemm_mod_host() + test_opt_gemm_lower() + test_opt_conv_tensorcore_normalize() + test_opt_conv_tensorcore_lower() + test_opt_conv_tensorcore_mod_host()