From ab28afbab2d87b6b3173c0af65b5a63d10750997 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Mon, 19 Dec 2022 06:09:49 -0800 Subject: [PATCH] [TVMScript] Remove obsolete modules (#13638) Removing some minor code path that is not used any longer. --- apps/microtvm/cmsisnn/requirements.txt | 3 - apps/microtvm/ethosu/requirements.txt | 3 - .../install/ubuntu_install_python_package.sh | 1 - docs/README.md | 2 +- docs/contribute/pull_request.rst | 2 +- .../how_to/work_with_microtvm/micro_ethosu.py | 1 - python/gen_requirements.py | 2 - python/tvm/script/parser_v1/__init__.py | 21 - python/tvm/script/parser_v1/_ffi_api.py | 20 - .../script/parser_v1/context_maintainer.py | 248 --- python/tvm/script/parser_v1/diagnostics.py | 55 - python/tvm/script/parser_v1/meta_unparser.py | 45 - python/tvm/script/parser_v1/parser.py | 1391 ----------------- python/tvm/script/parser_v1/registry.py | 62 - python/tvm/script/parser_v1/tir/__init__.py | 33 - python/tvm/script/parser_v1/tir/__init__.pyi | 475 ------ python/tvm/script/parser_v1/tir/intrin.py | 307 ---- python/tvm/script/parser_v1/tir/node.py | 218 --- python/tvm/script/parser_v1/tir/prim_func.py | 45 - .../tvm/script/parser_v1/tir/scope_handler.py | 793 ---------- .../tvm/script/parser_v1/tir/special_stmt.py | 927 ----------- python/tvm/script/parser_v1/tir/ty.py | 226 --- python/tvm/script/parser_v1/utils.py | 105 -- src/tir/schedule/error.h | 2 +- tests/python/unittest/test_tvmscript_spans.py | 73 - tests/scripts/ci.py | 1 - 26 files changed, 3 insertions(+), 5058 deletions(-) delete mode 100644 python/tvm/script/parser_v1/__init__.py delete mode 100644 python/tvm/script/parser_v1/_ffi_api.py delete mode 100644 python/tvm/script/parser_v1/context_maintainer.py delete mode 100644 python/tvm/script/parser_v1/diagnostics.py delete mode 100644 python/tvm/script/parser_v1/meta_unparser.py delete mode 100644 python/tvm/script/parser_v1/parser.py delete mode 100644 python/tvm/script/parser_v1/registry.py delete mode 100644 python/tvm/script/parser_v1/tir/__init__.py delete mode 100644 python/tvm/script/parser_v1/tir/__init__.pyi delete mode 100644 python/tvm/script/parser_v1/tir/intrin.py delete mode 100644 python/tvm/script/parser_v1/tir/node.py delete mode 100644 python/tvm/script/parser_v1/tir/prim_func.py delete mode 100644 python/tvm/script/parser_v1/tir/scope_handler.py delete mode 100644 python/tvm/script/parser_v1/tir/special_stmt.py delete mode 100644 python/tvm/script/parser_v1/tir/ty.py delete mode 100644 python/tvm/script/parser_v1/utils.py delete mode 100644 tests/python/unittest/test_tvmscript_spans.py diff --git a/apps/microtvm/cmsisnn/requirements.txt b/apps/microtvm/cmsisnn/requirements.txt index 72ae166963ee..1c99bd49a92e 100644 --- a/apps/microtvm/cmsisnn/requirements.txt +++ b/apps/microtvm/cmsisnn/requirements.txt @@ -216,9 +216,6 @@ scipy==1.5.4 \ --hash=sha256:ed572470af2438b526ea574ff8f05e7f39b44ac37f712105e57fc4d53a6fb660 \ --hash=sha256:f87b39f4d69cf7d7529d7b1098cb712033b17ea7714aed831b95628f483fd012 \ --hash=sha256:fa789583fc94a7689b45834453fec095245c7e69c58561dc159b5d5277057e4c -synr==0.6.0 \ - --hash=sha256:0b4e16b10c3988e1981e3372153a31956f74d86752eaaa55e8c4e7b7fe591e4e \ - --hash=sha256:9399b27d9f21c5d439eae92e0159d6f521cc396d27149ac45473012a205a3c30 tflite==2.10.0 \ --hash=sha256:6818a5d7776958b803944ba0a1f4c4395559606d9e795d67ac467a8a3904757d \ --hash=sha256:89cb9f57df0f5345f8fad1381e0fae6180ded687113eb552cfbb60a05edc002c diff --git a/apps/microtvm/ethosu/requirements.txt b/apps/microtvm/ethosu/requirements.txt index d9593a8184e9..d8a7fa7bd901 100644 --- a/apps/microtvm/ethosu/requirements.txt +++ b/apps/microtvm/ethosu/requirements.txt @@ -216,9 +216,6 @@ scipy==1.5.4 \ --hash=sha256:ed572470af2438b526ea574ff8f05e7f39b44ac37f712105e57fc4d53a6fb660 \ --hash=sha256:f87b39f4d69cf7d7529d7b1098cb712033b17ea7714aed831b95628f483fd012 \ --hash=sha256:fa789583fc94a7689b45834453fec095245c7e69c58561dc159b5d5277057e4c -synr==0.6.0 \ - --hash=sha256:0b4e16b10c3988e1981e3372153a31956f74d86752eaaa55e8c4e7b7fe591e4e \ - --hash=sha256:9399b27d9f21c5d439eae92e0159d6f521cc396d27149ac45473012a205a3c30 tflite==2.4.0 \ --hash=sha256:0510db1b48a3eec86bf9bb8d2749cd9d6d26d6a4fb329fd141bde5b4404932d1 \ --hash=sha256:0796f6ce6eb2aef4a318f5509e5fb0ce808e29cd3094801b4abbb1d8575a28cd diff --git a/docker/install/ubuntu_install_python_package.sh b/docker/install/ubuntu_install_python_package.sh index 757ad0228c5d..93abac52beaa 100755 --- a/docker/install/ubuntu_install_python_package.sh +++ b/docker/install/ubuntu_install_python_package.sh @@ -41,7 +41,6 @@ pip3 install --upgrade \ requests \ scipy \ Jinja2 \ - synr==0.6.0 \ junitparser==2.4.2 \ six \ tornado \ diff --git a/docs/README.md b/docs/README.md index b6ca8e06f3f2..6c32d2d6bfed 100644 --- a/docs/README.md +++ b/docs/README.md @@ -48,7 +48,7 @@ This folder contains the source of TVM's documentation, hosted at https://tvm.ap ```bash # Pillow on Ubuntu may require libjpeg-dev from apt ./docker/bash.sh ci_gpu -c \ - 'python3 -m pip install --quiet tlcpack-sphinx-addon==0.2.1 synr==0.5.0 && python3 -m pip freeze' > frozen-requirements.txt + 'python3 -m pip install --quiet tlcpack-sphinx-addon==0.2.1 && python3 -m pip freeze' > frozen-requirements.txt pip install -r frozen-requirements.txt ``` diff --git a/docs/contribute/pull_request.rst b/docs/contribute/pull_request.rst index 7b5509be0aa9..60faff307457 100644 --- a/docs/contribute/pull_request.rst +++ b/docs/contribute/pull_request.rst @@ -254,7 +254,7 @@ Necessary dependencies: .. code:: bash - pip install --user pytest Cython synr + pip install --user pytest Cython If you want to run all tests: diff --git a/gallery/how_to/work_with_microtvm/micro_ethosu.py b/gallery/how_to/work_with_microtvm/micro_ethosu.py index 386c658ea818..e80860dc0ce6 100644 --- a/gallery/how_to/work_with_microtvm/micro_ethosu.py +++ b/gallery/how_to/work_with_microtvm/micro_ethosu.py @@ -95,7 +95,6 @@ # Pillow==8.3.2 # psutil==5.8.0 # scipy==1.5.4 -# synr==0.6 # tflite==2.4.0 # tornado==6.1 # diff --git a/python/gen_requirements.py b/python/gen_requirements.py index 9778937ae80b..b8c72a8f2744 100755 --- a/python/gen_requirements.py +++ b/python/gen_requirements.py @@ -70,7 +70,6 @@ "numpy", "psutil", "scipy", - "synr", "tornado", ], ), @@ -270,7 +269,6 @@ ("sphinx_autodoc_annotation", None), ("sphinx_gallery", None), ("sphinx_rtd_theme", None), - ("synr", "==0.6.0"), ("tensorflow", None), ("tensorflow-estimator", None), ("tflite", None), diff --git a/python/tvm/script/parser_v1/__init__.py b/python/tvm/script/parser_v1/__init__.py deleted file mode 100644 index 555659d0c55e..000000000000 --- a/python/tvm/script/parser_v1/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""TVM Script APIs of TVM Python Package, aimed to support TIR""" - -from . import tir - -from .parser import ir_module, from_source diff --git a/python/tvm/script/parser_v1/_ffi_api.py b/python/tvm/script/parser_v1/_ffi_api.py deleted file mode 100644 index 926d17b1667e..000000000000 --- a/python/tvm/script/parser_v1/_ffi_api.py +++ /dev/null @@ -1,20 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""FFI APIs for tvm.script""" -import tvm._ffi - -tvm._ffi._init_api("script", __name__) diff --git a/python/tvm/script/parser_v1/context_maintainer.py b/python/tvm/script/parser_v1/context_maintainer.py deleted file mode 100644 index b84b7d398084..000000000000 --- a/python/tvm/script/parser_v1/context_maintainer.py +++ /dev/null @@ -1,248 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""TVM Script Context Maintainer for TIR""" - -from typing import List, Mapping, Union, Optional, Dict, Callable -import synr - - -import tvm -from tvm.ir import Span -from tvm.ir.expr import Range -from tvm.tir import Var, Buffer, PrimExpr, Stmt, MatchBufferRegion -from tvm.runtime import Object -from tvm.tir.expr import IterVar -from .tir.node import BufferSlice - - -class BlockInfo: - """Information for block and block_realize signature - - Examples - ---------- - .. code-block:: python - - @T.prim_func - def example_func(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (16, 16), "float32") - B = T.match_buffer(b, (16, 16), "float32") - C = T.match_buffer(a, (16, 16), "float32") - - for i, j, k in T.grid(16, 16, 16): - with T.block("matmul"): - vi = T.axis.S(16, i) - vj = T.axis.S(16, j) - vk = T.axis.R(16, k) # iter_bindings = {vj: i, vj: j, vk: k} - - T.where(True) # predicate of the block_realize - - T.reads(A[0:16, 0:16], B[0: 16, 0: 16]) # reads region of the block - T.writes(C[0: 16, 0: 16]) # writes region of the block - T.block_attr({"attr_key": "attr_value"}) # block annotations - - # alloc_buffers inside the block - CC = T.alloc_buffer((1, 1), dtype="float32") - - # match_buffers of the block, - # which bind a sub-region of source buffer into a new buffer - D = T.match_buffer(C[vi, vj], ()) - - # init part of the block, executed when all reduce axes are the beginning value - with T.init(): - C[vi, vj] = T.float32(0) - - # block body - CC[0, 0] = A[vi, vk] * B[vj, vk] - D[()] += CC[0, 0] # The same as C[vi, vj] += CC[0, 0] - """ - - alloc_buffers: List[Buffer] = [] - """List[Buffer]: list of T.alloc_buffer statements in the block signature""" - match_buffers: List[MatchBufferRegion] = [] - """List[MatchBufferRegion]: list of T.match_buffer statements in the block signature""" - iter_values: List[PrimExpr] = [] - """List[PrimExpr]: list of binding values for iter vars""" - iter_vars: List[IterVar] = [] - """List[PrimExpr]: list of iter vars in the block""" - reads: Optional[List[BufferSlice]] = None - """Optional[List[BufferSlice]]: - list of T.reads statements in the block signature, None for not-visited""" - writes: Optional[List[BufferSlice]] = None - """Optional[List[BufferSlice]]: - list of T.writes statements in the block signature, None for not-visited""" - annotations: Optional[Mapping[str, Object]] = None - """Optional[Mapping[str, Object]]: - list of T.block_attr statements in the block signature, None for not-visited""" - predicate: Optional[PrimExpr] = None - """Optional[PrimExpr]: block realize predicate, None for not-visited""" - init: Optional[Stmt] = None - """Optional[Stmt]: init part of the block, None for not-visited""" - - def __init__(self): - self.alloc_buffers = [] - self.match_buffers = [] - self.iter_values = [] - self.iter_vars = [] - self.reads = None - self.writes = None - self.annotations = None - self.predicate = None - self.init = None - - -class ContextMaintainer: - """Maintain all the necessary context info - Parameters - ---------- - _report_error : Callable[[str, Union[Span, synr.ast.Span]], None] - The report error function handle - """ - - # scope context - node_stack: List[List[synr.ast.Node]] = [] - """List[List[synr.ast.Node]]: The ast nodes insides the current scope""" - block_info_stack: List[BlockInfo] = [] - """List[BlockInfo]: The block info for the current block scope""" - loop_stack: Dict[Var, Range] = {} - """Dict[Var, Range]: The dict from loop var to its domain outside the block""" - symbols: List[Dict[str, Union[Var, Buffer]]] = [] - """List[Dict[str, Union[Var, Buffer]]]: Symbol map from name to object for the current scope""" - closure_vars: Dict[str, Object] = {} - """ClosureVars: The closure vars defined in Python interpreter""" - - # function context - func_params: List[Var] = [] - """List[Var]: The function parameters""" - func_buffer_map: Mapping[Var, Buffer] = {} - """Mapping[Var, Buffer]: The function buffer map""" - func_dict_attr: Mapping[str, Object] = {} - """Mapping[str, Object]: The function attrs""" - func_var_env_dict: Mapping[Var, str] = {} - """Mapping[Var, str]: The map from var to env thread""" - - # parser and analyzer - analyzer: tvm.arith.Analyzer = tvm.arith.Analyzer() - """tvm.arith.Analyzer: The analyzer for simplifying""" - _report_error: Callable[[str, Union[Span, synr.ast.Span]], None] - """Callable[[str, Union[Span, synr.ast.Span]], None]: The report error function handle""" - - # root alloc_buffer - root_alloc_buffers: List[Buffer] = [] - """List[Buffer]: The buffers allocated under root block""" - - def __init__( - self, - _report_error: Callable[[str, Union[Span, synr.ast.Span]], None], - closure_vars: Dict[str, Object], - ): - # scope context - self.node_stack = [] - self.block_info_stack = [] - self.loop_stack = {} - self.symbols = [] - self.closure_vars = closure_vars - # function context - self.func_params = [] - self.func_buffer_map = {} - self.func_dict_attr = {} - self.func_var_env_dict = {} - # parser and analyzer - self._report_error = _report_error - self.analyzer = tvm.arith.Analyzer() - # root alloc_buffer - self.root_alloc_buffers = [] - - def enter_scope(self, nodes: Optional[List[synr.ast.Node]] = None): - """Creates a new scope - - Note - ---- - This function is used for normal scopes that do not involve - a `with block` scope. Use `enter_block_scope` - for block scope cases. - - Parameters - ---------- - nodes : Optional[List[synr.ast.Node]] - The synr AST nodes in new scope - """ - if nodes is None: - nodes = [] - self.node_stack.append(list(reversed(nodes))) - self.symbols.append(dict()) - - def enter_block_scope(self, nodes: Optional[List[synr.ast.Node]] = None): - """Creates a new block scope, the function will call `enter_scope` implicitly - Besides the behaviors of `enter_scope`, it will update loop_stack and block_info_stack - to maintain block info. - - Note - ---- - This function should be used to handle a block scope, - aka the blocks that involve a `with block` scope. - - Parameters - ---------- - nodes : Optional[List[synr.ast.Node]] - The synr AST nodes in new scope - """ - self.enter_scope(nodes) - # Create a new BlockInfo for the new block - self.block_info_stack.append(BlockInfo()) - - def exit_scope(self): - """Pop the inner most scope""" - self.symbols.pop() - self.node_stack.pop() - - def exit_block_scope(self): - """Pop the inner most block scope, the function will call `exit_scope` implicitly""" - self.exit_scope() - # Pop block_info - self.block_info_stack.pop() - - def update_symbol(self, name: str, symbol: Union[Buffer, Var], node: synr.ast.Node): - """Append a symbol into current scope""" - if isinstance(symbol, Buffer): - if name in self.symbols[0]: - self.report_error("Duplicate Buffer name: " + symbol.name, node.span) - self.symbols[0][name] = symbol - else: - self.symbols[-1][name] = symbol - - def remove_symbol(self, name: str): - """Remove a symbol""" - for symbols in reversed(self.symbols): - if name in symbols: - symbols.pop(name) - return - raise RuntimeError("Internal error of tvm script parser: no symbol named " + name) - - def lookup_symbol(self, name: str) -> Optional[Union[Buffer, Var]]: - """Look up symbol by name""" - for symbols in reversed(self.symbols): - if name in symbols: - return symbols[name] - return self.closure_vars.get(name) - - def report_error(self, message: str, span: Union[Span, synr.ast.Span]): - self._report_error(message, span) - - def current_block_scope(self) -> BlockInfo: - if self.block_info_stack: - return self.block_info_stack[-1] - return None diff --git a/python/tvm/script/parser_v1/diagnostics.py b/python/tvm/script/parser_v1/diagnostics.py deleted file mode 100644 index e676461ab39e..000000000000 --- a/python/tvm/script/parser_v1/diagnostics.py +++ /dev/null @@ -1,55 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Bridge from synr's (the library used for parsing the python AST) - DiagnosticContext to TVM's diagnostics -""" -from synr import DiagnosticContext, ast - -import tvm -from tvm.ir.diagnostics import DiagnosticContext as TVMCtx -from tvm.ir.diagnostics import get_renderer, DiagnosticLevel, Diagnostic - - -class TVMDiagnosticCtx(DiagnosticContext): - """TVM diagnostics for synr""" - - diag_ctx: TVMCtx - - def __init__(self) -> None: - self.diag_ctx = TVMCtx(tvm.IRModule(), get_renderer()) - self.source_name = None - - def to_tvm_span(self, src_name, ast_span: ast.Span) -> tvm.ir.Span: - return tvm.ir.Span( - src_name, - ast_span.start_line, - ast_span.end_line, - ast_span.start_column, - ast_span.end_column, - ) - - def add_source(self, name: str, source: str) -> None: - src_name = self.diag_ctx.module.source_map.add(name, source) - self.source_name = src_name - - def emit(self, _level, message, span): - span = self.to_tvm_span(self.source_name, span) - self.diag_ctx.emit(Diagnostic(DiagnosticLevel.ERROR, span, message)) - self.diag_ctx.render() # Raise exception on the first error we hit. TODO remove - - def render(self): - self.diag_ctx.render() diff --git a/python/tvm/script/parser_v1/meta_unparser.py b/python/tvm/script/parser_v1/meta_unparser.py deleted file mode 100644 index b1472ccdc758..000000000000 --- a/python/tvm/script/parser_v1/meta_unparser.py +++ /dev/null @@ -1,45 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Unparse meta AST node into a dict""" -# pylint: disable=invalid-name - -from synr import Transformer - - -class MetaUnparser(Transformer): - """Python AST Visitor to unparse meta AST node into a dict""" - - def transform(self, node): - method = "transform_" + node.__class__.__name__ - visitor = getattr(self, method, None) - if visitor is None: - self.error(f"Unexpected node type {type(node)} when parsing __tvm_meta__", node.span) - return visitor(node) - - def transform_DictLiteral(self, node): - keys = [self.visit(key) for key in node.keys] - values = [self.visit(value) for value in node.values] - return dict(zip(keys, values)) - - def transform_Tuple(self, node): - return tuple(self.visit(element) for element in node.elts) - - def transform_ArrayLiteral(self, node): - return [self.visit(element) for element in node.elts] - - def transform_Constant(self, node): - return node.value diff --git a/python/tvm/script/parser_v1/parser.py b/python/tvm/script/parser_v1/parser.py deleted file mode 100644 index ce8c1fe161a3..000000000000 --- a/python/tvm/script/parser_v1/parser.py +++ /dev/null @@ -1,1391 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""TVM Script Parser For TIR - -We use [synr](https://synr.readthedocs.io) to get an AST that is stable over -different python versions. Synr also provides an error handling context that we -use for error reporting. -""" -# pylint: disable=invalid-name, inconsistent-return-statements, no-else-return, broad-except -import types -import json -import operator -import inspect -from typing import Any, Callable, Dict, List, Optional, Union -from synr import ast, Transformer, to_ast - -import tvm -from tvm import IRModule -from tvm._ffi.base import TVMError -from tvm.ir import GlobalVar -from tvm.ir.function import BaseFunc -from tvm.tir import buffer -from tvm.tir.function import PrimFunc -from . import _ffi_api -from . import tir - -from .context_maintainer import ContextMaintainer -from .meta_unparser import MetaUnparser -from .registry import Registry -from .diagnostics import TVMDiagnosticCtx -from .utils import tvm_span_from_synr, synr_span_from_tvm, call_with_error_reporting - -from .tir.intrin import Intrin -from .tir.node import Slice, BufferSlice -from .tir.scope_handler import ScopeHandler, WithScopeHandler, ForScopeHandler -from .tir.special_stmt import SpecialStmt -from .tir import ty - - -class CallArgumentReader(object): - """Helper class to read required arguments from passed arguments. - - When parsing a function call, we need to match the arguments provided in - the AST to the required arguments of the function. This class makes sure - all the positional arguments are filled and also fill keyword arguments - with thier default value if a different value was not provided. - """ - - def __init__(self, func_name, args, kwargs, parser, node): - self.func_name = func_name - self.args = args - self.kwargs = kwargs - self.parser = parser - self.node = node - - def get_pos_only_arg(self, pos, name): - """Get corresponding position only function argument from argument list""" - if len(self.args) >= pos: - arg = self.args[pos - 1] - elif name not in self.kwargs: - # If no positional argument was found in the AST, we see if it was - # defined by name instead. - # TODO(tkonolige): this error message is not quite correct. The - # number of required arguments is >= pos - self.parser.report_error( - f"{self.func_name} requires {pos} arguments, but only {len(self.args)} were given.", - self.node.span, - ) - else: - arg = self.kwargs[name] - - return arg - - def get_kwarg(self, pos, name, default): - """Get corresponding keyword function argument from argument list. - - If the user hasn't provided the argument, set it to the default value. - """ - if len(self.args) >= pos: - arg = self.args[pos - 1] - elif name in self.kwargs: - arg = self.kwargs[name] - else: - return default - - return arg - - def get_varargs(self, pos): - """Get corresponding variable argument from argument list""" - if len(self.args) >= pos and len(self.kwargs) == 0: - return self.args[pos - 1 :] - return [] - - -class TVMScriptParser(Transformer): - """Synr AST visitor pass which finally lowers to TIR. - - Notes for Extension - ------------------- - 1. To support a new type of AST node, add a function transform_xxx(). - 2. To support new functions, add the function to the appropriate registry: - We divide allowed function calls in TVM script into 3 categories, - intrin, scope_handler and special_stmt. - 1. intrin functions are low level functions like mod, load, and - constants. They correspond to a tir `IRNode`. They must have a - return value. The user can register intrin functions for the parser to - use. - 2. scope_handler functions have no return value. They take two - arguments: the parser and the AST node. scope_handler functions are - used in with and for statements. - 3. special_stmt functions handle cases that do not have a corresponding - tir `IRNode`. These functions take the parser and the AST node as - arguments and may return a value. - When visiting a Call node, we check the special_stmt registry first. If - no registered function is found, we then check the intrin registry. - When visiting With node, we check the with_scope registry. - When visiting For node, we check the for_scope registry. - """ - - _binop_maker = { - ast.BuiltinOp.Add: tvm.tir.Add, - ast.BuiltinOp.Sub: tvm.tir.Sub, - ast.BuiltinOp.Mul: tvm.tir.Mul, - ast.BuiltinOp.Div: tvm.tir.Div, - ast.BuiltinOp.FloorDiv: tvm.tir.FloorDiv, - ast.BuiltinOp.Mod: tvm.tir.FloorMod, - ast.BuiltinOp.BitOr: lambda lhs, rhs, span: operator.or_(lhs, rhs), - ast.BuiltinOp.BitAnd: lambda lhs, rhs, span: operator.and_(lhs, rhs), - ast.BuiltinOp.BitXor: lambda lhs, rhs, span: operator.xor(lhs, rhs), - ast.BuiltinOp.GT: tvm.tir.GT, - ast.BuiltinOp.GE: tvm.tir.GE, - ast.BuiltinOp.LT: tvm.tir.LT, - ast.BuiltinOp.LE: tvm.tir.LE, - ast.BuiltinOp.Eq: tvm.tir.EQ, - ast.BuiltinOp.NotEq: tvm.tir.NE, - ast.BuiltinOp.And: tvm.tir.And, - ast.BuiltinOp.Or: tvm.tir.Or, - } - - _unaryop_maker = { - ast.BuiltinOp.USub: lambda rhs, span: operator.neg(rhs), - ast.BuiltinOp.Invert: lambda rhs, span: operator.invert(rhs), - ast.BuiltinOp.Not: tvm.tir.Not, - } - - # pylint gets confused here with synr.Transformer which doesn't have a - # custom init, so just disable it - def __init__( - self, base_lineno, tir_namespace, closure_vars - ): # pylint: disable=super-init-not-called - self.context = None - - self.base_lineno = base_lineno - self.current_lineno = 0 - self.current_col_offset = 0 - self.tir_namespace = tir_namespace - self.closure_vars = closure_vars - self.meta = None - self._inside_buffer_sugar = False - - def init_function_parsing_env(self): - """Initialize function parsing environment""" - self.context = ContextMaintainer(self.report_error, self.closure_vars) # scope emitter - - def init_meta(self, meta_dict): - if meta_dict is not None: - self.meta = tvm.ir.load_json(json.dumps(meta_dict)) - - def transform(self, node): - """Generic transformation for visiting the AST. Dispatches to - `transform_ClassName` for the appropriate ClassName.""" - old_lineno, old_col_offset = self.current_lineno, self.current_col_offset - - if hasattr(node, "lineno"): - self.current_lineno = self.base_lineno + node.lineno - 1 - if hasattr(node, "col_offset"): - self.current_col_offset = node.col_offset - - method = "transform_" + node.__class__.__name__ - visitor = getattr(self, method, self.generic_visit) - transform_res = visitor(node) - - self.current_lineno, self.current_col_offset = old_lineno, old_col_offset - - return transform_res - - def match_tir_namespace(self, identifier: str) -> bool: - """Check if the namespace is equal to tvm.script.tir""" - return identifier in self.tir_namespace - - def report_error(self, message: str, span: Union[ast.Span, tvm.ir.Span]): - """Report an error occuring at a location. - - This just dispatches to synr's DiagnosticContext. - - Parameters - ---------- - message : str - Error message - span : Union[synr.ast.Span, tvm.ir.Span] - Location of the error - """ - if isinstance(span, tvm.ir.Span): - span = synr_span_from_tvm(span) - self.error(message, span) - - def parse_body(self, parent): - """Parse remaining statements in this scope. - - Parameters - ---------- - parent : synr.ast.Node - Parent node of this scope. Errors will be reported here. - """ - body = [] - spans = [] - stmt = parent - while len(self.context.node_stack[-1]) > 0: - stmt = self.context.node_stack[-1].pop() - spans.append(stmt.span) - res = self.transform(stmt) - if res is not None: - body.append(res) - if len(body) == 0: - self.report_error( - "Expected another statement at the end of this block. Perhaps you " - "used a concise statement and forgot to include a body afterwards.", - stmt.span, - ) - else: - return ( - tvm.tir.SeqStmt(body, tvm_span_from_synr(ast.Span.union(spans))) - if len(body) > 1 - else body[0] - ) - - def parse_arg_list(self, func, node_call): - """Match the arguments of a function call in the AST to the required - arguments of the function. This handles positional arguments, - positional arguments specified by name, keyword arguments, and varargs. - - Parameters - ---------- - func : Function - The function that provides the signature - - node_call: Union[ast.Call, ast.TypeApply, ast.TypeCall] - The AST call node that calls into the function. - - Returns - ------- - arg_list : list - The parsed positional argument. - """ - assert isinstance(node_call, (ast.Call, ast.TypeApply, ast.TypeCall)) - # collect arguments - args = [self.transform(arg) for arg in node_call.params] - if isinstance(node_call, ast.TypeApply): - kw_args = {} # TypeApply (e.g. foo[bar]) doesn't have kwargs defined in synr - else: - kw_args = { - self.transform(k): self.transform(v) for k, v in node_call.keyword_params.items() - } - # get the name and parameter list of func - if isinstance(func, (Intrin, ScopeHandler, SpecialStmt)): - func_name, param_list = func.signature() - else: - self.report_error( - "Internal Error: function must be of type Intrin, ScopeHandler or SpecialStmt, " - f"but it is {type(func).__name__}", - node_call.span, - ) - # check arguments and parameter list and get a list of arguments - reader = CallArgumentReader(func_name, args, kw_args, self, node_call) - pos_only, kwargs, varargs = param_list - internal_args = list() - - for i, arg_name in enumerate(pos_only): - internal_args.append(reader.get_pos_only_arg(i + 1, arg_name)) - for i, arg_info in enumerate(kwargs): - arg_name, default = arg_info - internal_args.append(reader.get_kwarg(i + 1 + len(pos_only), arg_name, default=default)) - if varargs is not None: - internal_args.extend(reader.get_varargs(len(pos_only) + len(kwargs) + 1)) - elif len(args) + len(kw_args) > len(pos_only) + len(kwargs): - self.report_error( - "Arguments mismatched. " - + f"Expected {len(pos_only) + len(kwargs)} args but got " - + f"{len(args) + len(kw_args)}", - node_call.span, - ) - return internal_args - - def parse_type(self, type_node, parent): - """Parse a type annotation. - - We require the parent object to the type so that we have a place to - report the error message if the type does not exist. - """ - if type_node is None: - self.report_error("A type annotation is required", parent.span) - res_type = self.transform(type_node) - return tvm.ir.TupleType([]) if res_type is None else res_type.evaluate() - - def generic_visit(self, node): - """Fallback visitor if node type is not handled. Reports an error.""" - - self.report_error(type(node).__name__ + " AST node is not supported", node.span) - - def transform_Module(self, node): - """Module visitor - - Right now, we only support two formats for TVM Script. - - Example - ------- - 1. Generate a PrimFunc (If the code is printed, then it may also contain metadata) - .. code-block:: python - - import tvm - - @tvm.script - def A(...): - ... - - # returns a PrimFunc - func = A - - 2. Generate an IRModule - .. code-block:: python - - import tvm - - @tvm.script.ir_module - class MyMod(): - @T.prim_func - def A(...): - ... - @T.prim_func - def B(...): - ... - - __tvm_meta__ = ... - - # returns an IRModule - mod = MyMod - """ - if len(node.funcs) == 1: - return self.transform(next(iter(node.funcs.values()))) - elif len(node.funcs) == 0: - self.report_error( - "You must supply at least one class or function definition", node.span - ) - else: - self.report_error( - "Only one-function, one-class or function-with-meta source code is allowed", - ast.Span.union([x.span for x in list(node.funcs.values())[1:]]), - ) - - def transform_Class(self, node): - """Class definition visitor. - - A class can have multiple function definitions and a single - :code:`__tvm_meta__` statement. Each class corresponds to a single - :code:`IRModule`. - - Example - ------- - .. code-block:: python - - @tvm.script.ir_module - class MyClass: - __tvm_meta__ = {} - def A(): - T.evaluate(0) - """ - if len(node.assignments) == 1: - if not ( - len(node.assignments[0].lhs) == 1 - and isinstance(node.assignments[0].lhs[0], ast.Var) - and node.assignments[0].lhs[0].id.name == "__tvm_meta__" - ): - self.report_error( - "The only top level assignments allowed are `__tvm_meta__ = ...`", - node.assignments[0].span, - ) - self.init_meta( - MetaUnparser().do_transform(node.assignments[0].rhs, self._diagnostic_context) - ) - elif len(node.assignments) > 1: - self.report_error( - "Only a single top level `__tvm_meta__` is allowed", - ast.Span.union([x.span for x in node.assignments[1:]]), - ) - - return IRModule( - {GlobalVar(name): self.transform(func) for name, func in node.funcs.items()} - ) - - def transform_Function(self, node): - """Function definition visitor. - - Each function definition is translated to a single :code:`PrimFunc`. - - There are a couple restrictions on TVM Script functions: - 1. Function arguments must have their types specified. - 2. The body of the function can contain :code:`func_attr` to specify - attributes of the function (like it's name). - 3. The body of the function can also contain multiple :code:`buffer_bind`s, - which give shape and dtype information to arguments. - 4. Return statements are implicit. - - Example - ------- - .. code-block:: python - - @T.prim_func - def my_function(x: T.handle): # 1. Argument types - T.func_attr({"global_symbol": "mmult"}) # 2. Function attributes - X_1 = tir.buffer_bind(x, [1024, 1024]) # 3. Buffer binding - T.evaluate(0) # 4. This function returns 0 - """ - - def check_as_torch_decorator(decorator: Union[ast.Call, ast.Var]): - if isinstance(decorator, ast.Call): - if len(decorator.params) != 1: - return False - func_name = decorator.func_name - else: - func_name = decorator - if isinstance(func_name, ast.Var): - return func_name.id.name == "as_torch" - - def check_decorator(decorators: List[ast.Expr]) -> bool: - """Check the decorator is `T.prim_func""" - if len(decorators) > 2 or len(decorators) == 0: - return False - if len(decorators) == 2 and not check_as_torch_decorator(decorators[0]): - return False - d: ast.Expr = decorators[-1] - return ( - isinstance(d, ast.Attr) - and isinstance(d.object, ast.Var) - and self.match_tir_namespace(d.object.id.name) - and d.field.name == "prim_func" - ) - - self.init_function_parsing_env() - self.context.enter_scope(nodes=node.body.stmts) - - # add parameters of function - for arg in node.params: - # Note that this case is for T.match_buffer syntax sugar - if isinstance(arg.ty, (ast.TypeCall, ast.TypeApply)) and isinstance( - self.transform(arg.ty.func_name), ty.GenericBufferType - ): - result = self.handle_match_buffer_type(arg.ty, arg.name) - if not isinstance(result, buffer.Buffer): - self.report_error( - "The result type of evaluating TypeCall and TypeApply stmt" - f" is wrong: {type(result)}. It should be a Buffer", - node.span, - ) - arg_name_with_handle = arg.name + "_handle" - arg_var = tvm.te.var(arg_name_with_handle, tvm.ir.PrimType("handle")) - self.context.func_buffer_map[arg_var] = result - self.context.update_symbol(arg.name, result, node) - else: - arg_var = tvm.te.var(arg.name, self.parse_type(arg.ty, arg)) - self.context.update_symbol(arg.name, arg_var, node) - self.context.func_params.append(arg_var) - - if not check_decorator(node.decorators): - self.report_error( - "All functions should be decorated by `T.prim_func`", - node.span, - ) - - # fetch the body of root block - body = self.parse_body(node.body) - - # return a tir.PrimFunc - dict_attr = self.context.func_dict_attr - ret_type = self.parse_type(node.ret_type, node) if node.ret_type is not None else None - func = tvm.tir.PrimFunc( - self.context.func_params, - body, - ret_type, - buffer_map=self.context.func_buffer_map, - attrs=tvm.ir.make_node("DictAttrs", **dict_attr) if dict_attr else None, - span=tvm_span_from_synr(node.span), - ) - - # New Scope : Implicit root block - # Each function contains an implicit root block in TensorIR, - # so here we need a block scope for it. - # If the PrimFunc is not a TensorIR func (e.g. TE scheduled func or low-level func), - # the root block will not be added. The logic to add root block is in `_ffi_api.Complete` - - # Fix the PrimFunc - # 1. generate root block if necessary - # 2. generate surrounding loops for blocks if necessary - - func = call_with_error_reporting( - self.report_error, - node.span, - _ffi_api.Complete, - func, - self.context.root_alloc_buffers, - ) - - self.context.exit_scope() - return func - - def transform_Lambda(self, node): - """Lambda visitor - - Return an array of input parameters and the transformed lambda body. - """ - - self.context.enter_scope(nodes=[node.body]) - - # add parameters of the lambda - arg_vars = [] - for arg in node.params: - # Use "void" for dtype here. The actual type is not yet known and will be - # determined later. Using void type will allow IRSubstitute to do the - # replacement without flagging a type-mismatch error. - arg_var = tvm.te.var(arg.name, dtype="") - arg_vars.append(arg_var) - self.context.update_symbol(arg.name, arg_var, node) - - # the body of a lambda must be an expr - if not isinstance(node.body, ast.Expr): - self.report_error("The body of a lambda must be an expression", node.span) - - # transform the body of the lambda - body = self.transform(node.body) - - self.context.exit_scope() - return arg_vars, body - - def transform_Assign(self, node): - """Assign visitor - AST abstract grammar: - Assign(expr* targets, expr value, string? type_comment) - - By now 5 patterns of Assign is supported: - 1. special stmts with return value - 1.1 Buffer = T.match_buffer()/T.buffer_decl() - 1.2 Var = T.var() - 1.3 Var = T.env_thread() - 2. (BufferStore) Buffer[PrimExpr, PrimExpr, ..., PrimExpr] = PrimExpr - 3. (Store) Var[PrimExpr] = PrimExpr - 4. with scope handlers with concise scoping and var def - 4.1 var = T.allocate() - 5. A call to a pure python function, consuming and producing TVMScript values. - The outputs are inlined into the following body (no variable is created). - x, y = f(...) - """ - - if isinstance(node.rhs, ast.Call): - # Pattern 1 & Pattern 4 - if isinstance(node.rhs.func_name, ast.Op): - func = None - else: - func = self.transform(node.rhs.func_name) - - if isinstance(func, WithScopeHandler): - if not func.concise_scope or not func.def_symbol: - self.report_error( - "with scope handler " + func.signature()[0] + " is not suitable here", - node.rhs.span, - ) - # Pattern 4 - arg_list = self.parse_arg_list(func, node.rhs) - func.enter_scope(node, self.context, arg_list, node.rhs.func_name.span) - func.body = self.parse_body(node) - return func.exit_scope(node, self.context, arg_list, node.rhs.func_name.span) - elif isinstance(func, SpecialStmt): - # Pattern 1 - arg_list = self.parse_arg_list(func, node.rhs) - func.handle(node, self.context, arg_list, node.rhs.func_name.span) - return self.parse_body(node) - elif isinstance(func, types.FunctionType): - # Pattern 5 - args = [self.transform(arg) for arg in node.rhs.params] - try: - out = func(*args) - except Exception as e: - self.report_error( - "Error occurred when invoking the function " - + func.__name__ - + ": \n" - + str(e), - node.rhs.span, - ) - - if len(node.lhs) == 1 and not isinstance(out, list): - out = [out] - - assert len(out) == len(node.lhs) - - for var, value in zip(node.lhs, out): - self.context.update_symbol(var.id.name, value, node) - - body = self.parse_body(node) - - for var, value in zip(node.lhs, out): - self.context.remove_symbol(var.id.name) - - return body - - if isinstance(node.rhs, (ast.Call, ast.Constant)): - # Pattern 4 of let binding - value = self.transform(node.rhs) - if len(node.lhs) == 1 and not isinstance(node.lhs[0], ast.Var): - # This is a little confusing because it only is true when - # we have taken this branch. We might need to clarify what - # exectly is allowed in Assignments in tvmscript. - self.report_error( - "Left hand side of assignment must be an unqualified variable", - node.span, - ) - ast_var = node.lhs[0] - - if node.ty is None and hasattr(value, "dtype"): - var_ty = value.dtype - else: - var_ty = self.parse_type(node.ty, ast_var) - - var = tvm.te.var( - ast_var.id.name, - var_ty, - span=tvm_span_from_synr(ast_var.span), - ) - self.context.update_symbol(var.name, var, node) - body = self.parse_body(node) - self.context.remove_symbol(var.name) - return tvm.tir.LetStmt(var, value, body, span=tvm_span_from_synr(node.span)) - - self.report_error( - """Assignments should be one of: - 1. A "special statement" with return value - 1.1 Buffer = T.match_buffer()/T.buffer_decl() - 1.2 Var = T.var() - 1.3 Var = T.env_thread() - 2. A store into a buffer: Buffer[PrimExpr, PrimExpr, ..., PrimExpr] = PrimExpr - 3. A store into a variable: Var[PrimExpr] = PrimExpr - 4. A with scope handler with concise scoping and var def - 4.1 var = T.allocate() - 5. The right-hand side being a call to a pure python function, consuming and - producing TVMScript values. - x, y = f(...)""", - node.span, - ) - - def transform_SubscriptAssign(self, node): - """Visitor for statements of the form :code:`x[1] = 2`.""" - symbol = self.transform(node.params[0]) - indexes = self.transform(node.params[1]) - rhs = self.transform(node.params[2]) - rhs_span = tvm_span_from_synr(node.params[2].span) - if isinstance(symbol, tvm.tir.Buffer): - if len(indexes) != len(symbol.shape): - self.report_error( - f"Buffer {symbol.name} is {len(symbol.shape)}-dimensional, " - f"cannot be indexed by {len(indexes)}-dimensional indices.", - node.params[1].span, - ) - - def __convert_index(x): - if isinstance(x, Slice): - return x.as_index_expr(self.report_error) - return x - - # BufferStore - indexes = [__convert_index(x) for x in indexes] - return tvm.tir.BufferStore( - symbol, - tvm.runtime.convert(rhs, span=rhs_span), - indexes, - span=tvm_span_from_synr(node.span), - ) - else: - if symbol.dtype == "handle" and len(indexes) != 1: - self.report_error( - "Handles only support one-dimensional indexing. Use `T.match_buffer` to " - "construct a multidimensional buffer from a handle.", - node.params[0].span, - ) - if len(indexes) != 1: - self.report_error( - f"Store is only allowed with one index, but {len(indexes)} were provided.", - node.params[1].span, - ) - self.report_error( - "Use of tir.Store has been deprecated in favor of tir.BufferStore.", node.span - ) - - def transform_AttrAssign(self, node): - """Visitor for statements of the form :code:`x.y = 2`.""" - obj = self.transform(node.params[0]) - field = node.params[1] - value = self.transform(node.params[2]) - - if not hasattr(obj, field.name): - self.error(f"Field {field.name} does not exist", field.span) - - var = getattr(obj, field.name) - - if not isinstance(var, tvm.tir.Var): - self.error( - f"Can only assign to tir.Var attributes, not {type(var).__name__}", node.span - ) - - body = self.parse_body(node) - return tvm.tir.LetStmt(var, value, body, span=tvm_span_from_synr(node.span)) - - def transform_Assert(self, node): - """Assert visitor - - Pattern corresponds to concise mode of :code:`with T.Assert()`. - """ - - condition = self.transform(node.condition) - if node.msg is None: - self.report_error("Assert statements must have an error message.", node.span) - message = self.transform(node.msg) - body = self.parse_body(node) - return tvm.tir.AssertStmt( - condition, tvm.runtime.convert(message), body, span=tvm_span_from_synr(node.span) - ) - - def transform_For(self, node): - """For visitor - AST abstract grammar: - For(expr target, expr iter, stmt* body, stmt* orelse, string? type_comment) - By now 1 pattern of For is supported: - 1. for scope handler - for name in T.serial()/T.parallel()/T.vectorized()/T.unroll()/range()/ - T.grid()/T.thread_binding() - """ - - if not isinstance(node.rhs, ast.Call): - self.report_error("The loop iterator should be a function call.", node.rhs.span) - func = self.transform(node.rhs.func_name) - if not isinstance(func, ForScopeHandler): - self.report_error( - "Only For scope handlers can be used in a for statement.", node.rhs.func_name.span - ) - # prepare for new for scope - old_lineno, old_col_offset = self.current_lineno, self.current_col_offset - self.current_lineno = node.span.start_line - self.current_col_offset = node.span.start_column - self.context.enter_scope(nodes=node.body.stmts) - # for scope handler process the scope - arg_list = [ - tvm.runtime.convert(arg, span=tvm_span_from_synr(node.rhs.span)) - for arg in self.parse_arg_list(func, node.rhs) - ] - func.enter_scope(node, self.context, arg_list, node.rhs.func_name.span) - func.body = self.parse_body(node) - res = func.exit_scope(node, self.context, arg_list, node.rhs.func_name.span) - # exit the scope - self.context.exit_scope() - self.current_lineno, self.current_col_offset = old_lineno, old_col_offset - return res - - def transform_While(self, node): - """While visitor - AST abstract grammar: - While(expr condition, stmt* body) - """ - condition = self.transform(node.condition) - # body - self.context.enter_scope(nodes=node.body.stmts) - body = self.parse_body(node) - self.context.exit_scope() - - return tvm.tir.While(condition, body, span=tvm_span_from_synr(node.span)) - - def transform_With(self, node): - """With visitor - AST abstract grammar: - With(withitem* items, stmt* body, string? type_comment) - withitem = (expr context_expr, expr? optional_vars) - By now 2 patterns of With is supported: - 1. with scope handler with symbol def - with T.allocate() as targets: - 2. with scope handler without symbol def - with T.block(*axes)/T.let()/T.Assert()/T.attr()/T.realize() - """ - - if not isinstance(node.rhs, ast.Call): - self.report_error( - "The context expression of a `with` statement should be a function call.", - node.rhs.span, - ) - - func = self.transform(node.rhs.func_name) - - if not isinstance(func, WithScopeHandler): - self.report_error( - f"Function {func} cannot be used in a `with` statement.", node.rhs.func_name.span - ) - # prepare for new block scope - old_lineno, old_col_offset = self.current_lineno, self.current_col_offset - self.current_lineno = node.body.span.start_line - self.current_col_offset = node.body.span.start_column - self.context.enter_block_scope(nodes=node.body.stmts) - # with scope handler process the scope - arg_list = self.parse_arg_list(func, node.rhs) - func.enter_scope(node, self.context, arg_list, node.rhs.func_name.span) - func.body = self.parse_body(node) - res = func.exit_scope(node, self.context, arg_list, node.rhs.func_name.span) - # exit the scope - self.context.exit_block_scope() - self.current_lineno, self.current_col_offset = old_lineno, old_col_offset - return res - - def transform_If(self, node): - """If visitor - AST abstract grammar: - If(expr test, stmt* body, stmt* orelse) - """ - - condition = self.transform(node.condition) - # then body - self.context.enter_scope(nodes=node.true.stmts) - then_body = self.parse_body(node) - self.context.exit_scope() - - # else body - if len(node.false.stmts) > 0: - self.context.enter_scope(nodes=node.false.stmts) - else_body = self.parse_body(node) - self.context.exit_scope() - else: - else_body = None - - return tvm.tir.IfThenElse( - condition, then_body, else_body, span=tvm_span_from_synr(node.span) - ) - - def transform_Call(self, node): - """Call visitor - - 3 different Call patterns are allowed: - 1. Intrin representing a PrimExpr/IterVar - 1.1 tir.int/uint/float8/16/32/64/floormod/floordiv/load/cast/ramp/broadcast/max - 1.2 tir.range/reduce_axis/scan_axis/opaque_axis - 2. tir.Op(dtype, ...) - 3. other callable functions - """ - - if isinstance(node.func_name, ast.Op): - if node.func_name.name == ast.BuiltinOp.Subscript: - return self.transform_Subscript(node) - if node.func_name.name in self._binop_maker: - lhs = self.transform(node.params[0]) - # There is no supertype for everything that can appear in - # an expression, so we manually add what we might get here. - if not isinstance(lhs, (tvm.tir.PrimExpr, BufferSlice)): - # We would really like to report a more specific - # error here, but this parser contains no distinction - # between parsing statements and parsing expressions. All - # rules just call `transform`. - self.report_error( - f"Left hand side of binary op must be a PrimExpr, " - "but it is a {type(lhs).__name__}", - node.params[0].span, - ) - rhs = self.transform(node.params[1]) - if not isinstance(rhs, (tvm.tir.PrimExpr, BufferSlice)): - self.report_error( - f"Right hand side of binary op must be a PrimExpr, " - "but it is a {type(rhs).__name__}", - node.params[1].span, - ) - return call_with_error_reporting( - self.report_error, - node.span, - lambda node, lhs, rhs, span: self._binop_maker[node.func_name.name]( - lhs, rhs, span=span - ), - node, - lhs, - rhs, - tvm_span_from_synr(node.span), - ) - if node.func_name.name in self._unaryop_maker: - rhs = self.transform(node.params[0]) - if node.func_name.name == ast.BuiltinOp.USub and isinstance( - node.params[0], ast.Constant - ): - # '-literal' should be parsed together for proper literal type inference - if not isinstance(rhs, (tvm.tir.IntImm, tvm.tir.FloatImm)): - self.report_error("The literal is illegal after -", node.params[0].span) - return tvm.tir.const(-rhs.value) - return self._unaryop_maker[node.func_name.name]( - rhs, span=tvm_span_from_synr(node.span) - ) - self.report_error(f"Unsupported operator {node.func_name.name}.", node.func_name.span) - else: - func = self.transform(node.func_name) - if isinstance(func, Intrin) and not func.stmt: - # pattern 1 - arg_list = self.parse_arg_list(func, node) - return call_with_error_reporting( - self.report_error, - node.func_name.span, - func.handle, - arg_list, - node.func_name.span, - ) - else: - args = [self.transform(arg) for arg in node.params] - kw_args = { - self.transform(k): self.transform(v) for k, v in node.keyword_params.items() - } - if isinstance(func, tvm.tir.op.Op): - if not "dtype" in kw_args.keys(): - self.report_error(f"{func} requires a dtype keyword argument.", node.span) - # pattern 2 - return tvm.tir.Call( - kw_args["dtype"], func, args, span=tvm_span_from_synr(node.span) - ) - elif callable(func): - # pattern 3 - return func(*args, **kw_args) - else: - self.report_error( - f"Function is neither callable nor a tvm.tir.op.Op (it is a {type(func)}).", - node.func_name.span, - ) - - def transform_UnassignedCall(self, node): - """Visitor for statements that are function calls. - - This handles function calls that appear on thier own line like `tir.realize`. - - Examples - -------- - .. code-block:: python - - @T.prim_func - def f(): - A = T.buffer_decl([10, 10]) - T.realize(A[1:2, 1:2], "") # This is an UnassignedCall - A[1, 1] = 2 # This is also an UnassignedCall - """ - # Only allowed builtin operator that can be a statement is x[1] = 3 i.e. subscript assign. - if isinstance(node.call.func_name, ast.Op): - if node.call.func_name.name == ast.BuiltinOp.SubscriptAssign: - return self.transform_SubscriptAssign(node.call) - - if node.call.func_name.name == ast.BuiltinOp.AttrAssign: - return self.transform_AttrAssign(node.call) - - self.report_error( - "Binary and unary operators are not allowed as a statement", node.span - ) - - # handle a regular function call - func = self.transform(node.call.func_name) - arg_list = self.parse_arg_list(func, node.call) - - if isinstance(func, tir.scope_handler.AssertHandler): - self.report_error( - "A standalone `T.Assert` is not allowed. Use `assert condition, message` " - "instead.", - node.call.func_name.span, - ) - - if isinstance(func, Intrin): - if func.stmt: - return call_with_error_reporting( - self.report_error, - node.call.func_name.span, - func.handle, - arg_list, - node.call.func_name.span, - ) - else: - self.report_error(f"This intrinsic cannot be used as a statement.", node.call.span) - elif isinstance(func, WithScopeHandler) and func.concise_scope and not func.def_symbol: - func.enter_scope(node, self.context, arg_list, node.call.func_name.span) - func.body = self.parse_body(node) - return func.exit_scope(node, self.context, arg_list, node.call.func_name.span) - elif isinstance(func, SpecialStmt) and not func.def_symbol: - func.handle(node, self.context, arg_list, node.call.func_name.span) - return - - self.report_error( - "Unexpected statement. Expected an assert, an intrinsic, a with statement, or a " - f"special statement, but got {type(func).__name__}.", - node.call.func_name.span, - ) - - def transform_Slice(self, node): - """Index slice visitor.""" - start = self.transform(node.start) - end = self.transform(node.end) - if not ( - isinstance(node.step, ast.Constant) - and isinstance(node.step.value, int) - and node.step.value > 0 - ): - self.report_error( - "Only positive integer step size is supported for slices.", node.step.span - ) - return Slice(start, end, node.step.value, tvm_span_from_synr(node.span)) - - def transform_Subscript(self, node): - """Array access visitor. - - By now only 3 types of Subscript are supported: - 1. Buffer[index, index, ...], Buffer element access(BufferLoad & BufferStore) - Var[index] Buffer element access() - 2. Buffer[start: stop, start: stop, ...], BufferRealize(realize(buffer[...])) - 3. Array[index], Buffer element access - """ - - symbol = self.transform(node.params[0]) - if symbol is None: - self.report_error( - f"Variable {node.params[0].id.name} is not defined.", node.params[0].span - ) - - indexes = [self.transform(x) for x in node.params[1].values] - if isinstance(symbol, tvm.tir.expr.Var): - if symbol.dtype == "handle": - self.report_error( - "Cannot read directly from a handle, use `T.match_buffer` " - "to create a buffer to read from.", - node.params[0].span, - ) - if len(indexes) > 1: - self.report_error( - "Only a single index can be provided when indexing into a `var`.", - node.params[1].span, - ) - index = indexes[0] - if not isinstance(index, (tvm.tir.PrimExpr, int)): - self.report_error( - "Var load index should be an int or PrimExpr, but it is a" + type(index), - node.span, - ) - - self.report_error( - "Use of tir.Load has been deprecated in favor of tir.BufferLoad", node.span - ) - elif isinstance(symbol, tvm.tir.Buffer): - return BufferSlice( - symbol, indexes, self.report_error, span=tvm_span_from_synr(node.span) - ) - elif isinstance(symbol, tvm.container.Array): - if len(indexes) > 1: - self.report_error( - "Array access should be one-dimension access, but the indices are " - + str(indexes), - node.span, - ) - index = indexes[0] - if not isinstance(index, (int, tvm.tir.expr.IntImm)): - self.report_error( - "Array access index expected int or IntImm, but got " + type(index), - node.span, - ) - if int(index) >= len(symbol): - self.report_error( - f"Array access out of bound, size: {len(symbol)}, got index {index}.", - node.span, - ) - return symbol[int(index)] - else: - self.report_error( - f"Cannot subscript from a {type(symbol).__name__}. Only variables and " - "buffers are supported.", - node.params[0].span, - ) - - def transform_Attr(self, node): - """Visitor for field access of the form `x.y`. - - This visitor is used to lookup function and symbol names. We have two - cases to handle here: - 1. If we have a statement of the form `tir.something`, then we lookup - `tir.something` in the `Registry`. If the function is not in the - registry, then we try to find a `tvm.ir.op.Op` with the same name. - 2. All other names `tvm.something` are lookup up in this current python - namespace. - """ - - def get_full_attr_name(node: ast.Attr) -> str: - reverse_field_names = [node.field.name] - while isinstance(node.object, ast.Attr): - node = node.object - reverse_field_names.append(node.field.name) - if isinstance(node.object, ast.Var): - reverse_field_names.append(node.object.id.name) - return ".".join(reversed(reverse_field_names)) - - if isinstance(node.object, (ast.Var, ast.Attr)): - full_attr_name = get_full_attr_name(node) - attr_object, fields = full_attr_name.split(".", maxsplit=1) - if self.match_tir_namespace(attr_object): - func_name = "tir." + fields - res = Registry.lookup(func_name) - if res is not None: - return res - try: - return tvm.ir.op.Op.get(func_name) - except TVMError as e: - # Check if we got an attribute error - if e.args[0].find("AttributeError"): - self.report_error(f"Unregistered function `tir.{fields}`.", node.span) - else: - raise e - - symbol = self.transform(node.object) - if symbol is None: - self.report_error("Unsupported Attribute expression.", node.object.span) - if not hasattr(symbol, node.field.name): - self.report_error( - f"Type {type(symbol)} does not have a field called `{node.field.name}`.", node.span - ) - res = getattr(symbol, node.field.name) - return res - - def transform_TypeAttr(self, node): - """Visitor for field access of the form `x.y` for types. - - We have two cases here: - 1. If the type is of the form `T.something`, we look up the type in - the `tir` namespace in this module. - 2. If the type is of the form `tvm.x.something` then we look up - `tvm.x.something` in this modules namespace. - """ - if isinstance(node.object, ast.TypeVar): - if self.match_tir_namespace(node.object.id.name): - if not hasattr(tir, node.field.name): - self.report_error( - f"Invalid type annotation `tir.{node.field.name}`.", node.span - ) - return getattr(tir, node.field.name) - - symbol = self.transform(node.object) - if symbol is None: - self.report_error("Unsupported Attribute expression", node.object.span) - if not hasattr(symbol, node.field): - self.report_error( - f"Type {type(symbol)} does not have a field called `{node.field}`.", node.span - ) - res = getattr(symbol, node.field) - return res - - def transform_DictLiteral(self, node): - """Dictionary literal visitor. - - Handles dictionary literals of the form `{x:y, z:2}`. - """ - - keys = [self.transform(key) for key in node.keys] - values = [self.transform(value) for value in node.values] - - return dict(zip(keys, values)) - - def transform_Tuple(self, node): - """Tuple visitor. - - Handles tuples of the form `(x, y, 2)`. - """ - - return tuple(self.transform(element) for element in node.values) - - def transform_ArrayLiteral(self, node): - """List literal visitor. - - Handles lists of the form `[x, 2, 3]`. - """ - - return [self.transform(element) for element in node.values] - - def transform_Var(self, node): - """Variable visitor - - Handles variables like `x` in `x = 2`. - """ - - name = node.id.name - if name == "meta": - return self.meta - symbol = Registry.lookup(name) - if symbol is not None: - return symbol - symbol = self.context.lookup_symbol(name) - if symbol is not None: - return symbol - self.report_error(f"Unknown identifier {name}.", node.span) - - def transform_TypeVar(self, node): - """Type variable visitor. - - Equivalent to `transform_Var` but for types. - """ - name = node.id.name - symbol = Registry.lookup(name) or self.context.lookup_symbol(name) - if symbol is not None: - return symbol - self.report_error(f"Unknown identifier {name}.", node.span) - - def transform_Constant(self, node): - """Constant value visitor. - - Constant values include `None`, `"strings"`, `2` (integers), `4.2` - (floats), and `true` (booleans). - """ - return tvm.runtime.convert(node.value, span=tvm_span_from_synr(node.span)) - - def transform_TypeConstant(self, node): - """Constant value visitor for types. - - See `transform_Constant`. - """ - if self._inside_buffer_sugar: - return self.transform_Constant(node) - - return node.value - - def transform_TypeTuple(self, node): - """Tuple value visitor for types. - - Mostly used in `transform_TypeCall` and `transform_TypeApply`. - """ - return [self.transform(value) for value in node.values] - - def transform_TypeCall(self, node): - """TypeCall visitor - - This occurs when an expression is used inside a T.Buffer - parameter annotation. - """ - - # ast.Call has the BuiltinOp as node.func_name.name, where - # ast.TypeCall has the BuiltinOp as node.func_name. So we can - # delegate to self.transform_Call, but the error messages for - # unsupported operations will highlight the entire expression - # and not just the function itself. - op = ast.Op(node.span, node.func_name) - call = ast.Call(node.span, op, node.params, node.keyword_params) - return self.transform_Call(call) - - def transform_TypeApply(self, node): - """Visitor for Type[Type] expressions. - - Mostly used for ``T.Ptr`` expressions. - """ - func = self.transform(node.func_name) - - if not isinstance(func, ty.TypeGeneric) or not hasattr(func, "__getitem__"): - self.report_error( - f"Use of type arguments requires a type that accepts type arguments (e.g. T.Ptr), " - f"but found {type(func).__name__} instead.", - node.span, - ) - - param_types = [] - for idx, param in enumerate(node.params): - param_type = self.transform(param) - if not isinstance(param_type, ty.TypeGeneric) and func.require_type_generic_at(idx): - self.report_error( - f"Expected a type but found {type(param).__name__} " - f"at {idx}th type argument", - param.span, - ) - - param_types.append(param_type) - - if len(param_types) == 1: - return func[param_types[0]] - else: - return func[param_types] - - def handle_match_buffer_type(self, node, buffer_name): - """special function to handle syntax sugar for match buffer. - - This method is for buffer declarations in the function parameters. - """ - func = self.transform(node.func_name) - assert isinstance(func, SpecialStmt) - - # parse args and kwargs for TypeCall and TypeApply - self._inside_buffer_sugar = True - try: - arg_list = self.parse_arg_list(func, node) - finally: - self._inside_buffer_sugar = False - - # Note that the third element in arg_list would always be the 'name' - # TODO: This index is hardcoded as a workaround. Better to make it programmatic - if arg_list[2] is None: - arg_list[2] = buffer_name - buf = func.handle(node, self.context, arg_list, node.func_name.span) - return buf - - def transform_Return(self, node): - self.report_error( - "TVM script does not support return statements. Instead the last statement in any " - "block is implicitly returned.", - node.span, - ) - - -def get_tir_namespace(script: Union[Callable, type]) -> List[str]: - assert inspect.isfunction(script) or inspect.isclass(script) - env: Dict[str, Any] = script.__globals__ - return [key for key in env.keys() if env[key] == tir] - - -def from_source( - input_func: Union[str, Callable], tir_prefix: Optional[List[str]] = None -) -> Union[PrimFunc, IRModule]: - """Parse function or string into PrimFunc or IRModule. - - If possible, pass the TVM script in as a function so that line numbers and - filename will be accurate. - - Parameters - ---------- - input_module : Union[str, Callable] - The python function to be parsed. - - tir_prefix : Optional[List[str]] - The tir prefix list. Only works for str input, default by "tir" and "T". - - Returns - ------- - output : Union[Function, Module] - The Function or Module in IR. - """ - if isinstance(input_func, str): - tir_prefix = ["T", "tir"] if tir_prefix is None else tir_prefix - return to_ast(input_func, TVMDiagnosticCtx(), TVMScriptParser(0, tir_prefix, {})) - elif inspect.isfunction(input_func): - _, start_line = inspect.getsourcelines(input_func) - env: Dict[str, Any] = input_func.__globals__ - namespace = [key for key in env.keys() if env[key] is tir] - _closure_vars = inspect.getclosurevars(input_func) - closure_vars = {**_closure_vars.nonlocals, **_closure_vars.globals} - parser = TVMScriptParser(start_line, namespace, closure_vars) - result = to_ast(input_func, TVMDiagnosticCtx(), parser) - return result - else: - raise TypeError("Only function definitions are supported.") - - -def ir_module(input_module: type) -> IRModule: - """Decorate a python class as tvm IRModule. - - Parameters - ---------- - input_module : type - The python class to be parsed. - - Returns - ------- - output : IRModule - The result IRModule. - """ - if inspect.isclass(input_module): - func_dict = { - name: f for name, f in input_module.__dict__.items() if isinstance(f, BaseFunc) - } - return IRModule(func_dict) - raise TypeError("Only class definitions are supported.") diff --git a/python/tvm/script/parser_v1/registry.py b/python/tvm/script/parser_v1/registry.py deleted file mode 100644 index e7d90dd51517..000000000000 --- a/python/tvm/script/parser_v1/registry.py +++ /dev/null @@ -1,62 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""TVM Script Parser Function Registry """ -# pylint: disable=inconsistent-return-statements, relative-beyond-top-level, import-outside-toplevel -import types -from typing import Union, Callable, Dict, Optional, Any - - -class Registry(object): - """Registration map - All these maps are static - """ - - registrations: Dict[str, type] = dict() - - @staticmethod - def lookup(name: str) -> Optional[Any]: - if name in Registry.registrations: - # every time we create a new handler - # since we may want to keep some local info inside it - return Registry.registrations[name]() - return None - - -def register(inputs: Union[Callable, type]) -> type: - """Register Intrin/ScopeHandler/SpecialStmt""" - registration: type - if isinstance(inputs, types.FunctionType): - # is function - from .tir.intrin import Intrin - - def create_new_intrin(func) -> type: - class NewIntrin(Intrin): - def __init__(self): - super().__init__(func) - - return NewIntrin - - registration = create_new_intrin(inputs) - elif isinstance(inputs, type): - # is class - registration = inputs - else: - raise ValueError() - - key: str = registration().signature()[0] - Registry.registrations[key] = registration - return registration diff --git a/python/tvm/script/parser_v1/tir/__init__.py b/python/tvm/script/parser_v1/tir/__init__.py deleted file mode 100644 index 662dd10ec068..000000000000 --- a/python/tvm/script/parser_v1/tir/__init__.py +++ /dev/null @@ -1,33 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""TVMScript for TIR""" - -# Type system -from .ty import void, boolean, handle, Ptr, Tuple, Buffer -from .ty import bool # pylint: disable=redefined-builtin - -from .prim_func import prim_func - -# add all floating point and integer datatypes to the module -for _dtype in ["float", "uint", "int"]: - for _size in ["8", "16", "32", "64"]: - for _lanes in ["", "x4", "x8", "x16", "x32", "x64"]: - from . import ty - - _name = _dtype + _size + _lanes - if hasattr(ty, _name): - globals()[_name] = getattr(ty, _name) diff --git a/python/tvm/script/parser_v1/tir/__init__.pyi b/python/tvm/script/parser_v1/tir/__init__.pyi deleted file mode 100644 index beefaf4c75d7..000000000000 --- a/python/tvm/script/parser_v1/tir/__init__.pyi +++ /dev/null @@ -1,475 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=redefined-builtin -from typing import ( - Any, - Callable, - ContextManager, - Dict, - Iterable, - Optional, - Tuple, - Union, - Sequence, - List, - Mapping, - overload, -) -from numbers import Number -import builtins - -from tvm.tir.function import PrimFunc -from tvm.tir import Range -from tvm.runtime import Object -from tvm.target import Target -from .node import BufferSlice - -""" -redefine types -""" - -class PrimExpr: - def __init__(self: PrimExpr) -> None: ... - @overload - def __add__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ... - @overload - def __add__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ... - @overload - def __sub__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ... - @overload - def __sub__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ... - @overload - def __mul__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ... - @overload - def __mul__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ... - @overload - def __div__(self: PrimExpr, other: PrimExpr) -> PrimExpr: ... - @overload - def __div__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ... - def __mod__(self: PrimExpr, other: Union[int, float, PrimExpr]) -> PrimExpr: ... - def __radd__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ... - def __rsub__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ... - def __rmul__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ... - def __rdiv__(self: PrimExpr, other: Union[int, float]) -> PrimExpr: ... - def __floordiv__(self: PrimExpr, other: Union[int, float, PrimExpr]) -> PrimExpr: ... - def __index__(self: PrimExpr) -> int: ... # so range doesn't complain - -class Var(PrimExpr): ... -class IterVar(Var): ... - -class Buffer: - @overload - def __getitem__(self: Buffer, pos: Sequence[Union[PrimExpr, int, slice]]) -> PrimExpr: ... - @overload - def __getitem__(self: Buffer, pos: Union[PrimExpr, int, slice]) -> PrimExpr: ... - @overload - def __setitem__( - self: Buffer, pos: Sequence[Union[PrimExpr, int, slice]], value: PrimExpr - ) -> None: ... - @overload - def __setitem__(self: Buffer, pos: Union[PrimExpr, int, slice], value: PrimExpr) -> None: ... - @property - def data(self: Buffer) -> Ptr: ... - -""" -Intrinsic -""" - -def min_value(dtype: str) -> PrimExpr: ... -def max_value(dtype: str) -> PrimExpr: ... -def floordiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ... -def floormod(x: PrimExpr, y: PrimExpr) -> PrimExpr: ... -def ceildiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ... -def truncmod(x: PrimExpr, y: PrimExpr) -> PrimExpr: ... -def truncdiv(x: PrimExpr, y: PrimExpr) -> PrimExpr: ... -def abs(x: PrimExpr) -> PrimExpr: ... -def load( - dtype: str, var: Var, index: PrimExpr, predicate: Union[PrimExpr, builtins.bool] = None -) -> PrimExpr: ... -def cast(value: PrimExpr, dtype: str) -> PrimExpr: ... -def ramp(base: PrimExpr, stride: Any, lanes: int) -> PrimExpr: ... -def broadcast(value: PrimExpr, lanes: int) -> PrimExpr: ... -def iter_var(var: Union[Var, str], dom: Range, iter_type: int, thread_tag: str) -> IterVar: ... -def max(a: PrimExpr, b: PrimExpr) -> PrimExpr: ... -def min(a: PrimExpr, b: PrimExpr) -> PrimExpr: ... -def Select(cond: PrimExpr, if_body: PrimExpr, else_body: PrimExpr) -> PrimExpr: ... -def if_then_else(cond: PrimExpr, t: PrimExpr, f: PrimExpr, dtype: str) -> PrimExpr: ... -def evaluate(value: PrimExpr) -> None: ... -def reinterpret(value: PrimExpr, dtype: str) -> PrimExpr: ... -def vectorlow(value: PrimExpr, dtype: str) -> PrimExpr: ... -def vectorhigh(value: PrimExpr, dtype: str) -> PrimExpr: ... -def store( - var: Var, index: PrimExpr, value: PrimExpr, predicate: Union[PrimExpr, builtins.bool] = True -) -> None: ... -def comm_reducer(lambda_io: Callable[[Any, Any], Any], identities: List[PrimExpr]) -> PrimExpr: ... -def llvm_lookup_intrinsic_id(name: str) -> PrimExpr: ... - -""" -Intrinsics - tvm builtin -""" - -def tvm_thread_allreduce( - *freduceargs: Union[PrimExpr, builtins.bool, Ptr], dtype: str -) -> PrimExpr: ... - -""" -Unary operator -Note that any intrinsics not registered in script.tir.intrin -should add "dtype" as an argument. This is different from their -definition but intentional. -""" - -def exp(x: PrimExpr, dtype: str) -> PrimExpr: ... -def exp2(x: PrimExpr, dtype: str) -> PrimExpr: ... -def exp10(x: PrimExpr, dtype: str) -> PrimExpr: ... -def erf(x: PrimExpr, dtype: str) -> PrimExpr: ... -def tanh(x: PrimExpr, dtype: str) -> PrimExpr: ... -def sigmoid(x: PrimExpr, dtype: str) -> PrimExpr: ... -def log(x: PrimExpr, dtype: str) -> PrimExpr: ... -def log2(x: PrimExpr, dtype: str) -> PrimExpr: ... -def log10(x: PrimExpr, dtype: str) -> PrimExpr: ... -def log1p(x: PrimExpr, dtype: str) -> PrimExpr: ... -def tan(x: PrimExpr, dtype: str) -> PrimExpr: ... -def cos(x: PrimExpr, dtype: str) -> PrimExpr: ... -def cosh(x: PrimExpr, dtype: str) -> PrimExpr: ... -def acos(x: PrimExpr, dtype: str) -> PrimExpr: ... -def acosh(x: PrimExpr, dtype: str) -> PrimExpr: ... -def sin(x: PrimExpr, dtype: str) -> PrimExpr: ... -def sinh(x: PrimExpr, dtype: str) -> PrimExpr: ... -def asin(x: PrimExpr, dtype: str) -> PrimExpr: ... -def asinh(x: PrimExpr, dtype: str) -> PrimExpr: ... -def atan(x: PrimExpr, dtype: str) -> PrimExpr: ... -def atanh(x: PrimExpr, dtype: str) -> PrimExpr: ... -def atan2(x: PrimExpr, dtype: str) -> PrimExpr: ... -def sqrt(x: PrimExpr, dtype: str) -> PrimExpr: ... -def rsqrt(x: PrimExpr, dtype: str) -> PrimExpr: ... - -""" -special_stmt - Buffers -""" - -def match_buffer( - param: Union[Var, BufferSlice], - shape: Sequence[Union[PrimExpr, int]], - dtype: str = "float32", - data: Var = None, - strides: Optional[Sequence[int]] = None, - elem_offset: Optional[int] = None, - scope: str = "global", - align: int = -1, - offset_factor: int = 0, - buffer_type: str = "default", - axis_separators: Optional[List[int]] = None, -) -> Buffer: ... -def decl_buffer( - shape: Sequence[Union[PrimExpr, int]], - dtype: str = "float32", - data: Var = None, - strides: Optional[Sequence[int]] = None, - elem_offset: Optional[int] = None, - scope: str = "global", - align: int = -1, - offset_factor: int = 0, - buffer_type: str = "default", - axis_separators: Optional[List[int]] = None, -) -> Buffer: ... -def buffer_decl( - shape: Sequence[Union[PrimExpr, int]], - dtype: str = "float32", - data: Var = None, - strides: Optional[Sequence[int]] = None, - elem_offset: Optional[int] = None, - scope: str = "global", - align: int = -1, - offset_factor: int = 0, - buffer_type: str = "default", - axis_separators: Optional[List[int]] = None, -) -> Buffer: ... -def alloc_buffer( - shape: Sequence[Union[PrimExpr, int]], - dtype: str = "float32", - data: Var = None, - strides: Optional[Sequence[int]] = None, - elem_offset: Optional[int] = None, - scope: str = "global", - align: int = -1, - offset_factor: int = 0, - buffer_type: str = "default", - axis_separators: Optional[List[int]] = None, -) -> Buffer: ... - -""" -special_stmt - Reads/Writes -""" - -@overload -def reads(read_regions: List[BufferSlice]) -> None: ... -@overload -def reads(*read_regions: BufferSlice) -> None: ... -@overload -def writes(write_region: List[BufferSlice]) -> None: ... -@overload -def writes(*write_region: BufferSlice) -> None: ... -def block_attr(attrs: Mapping[str, Object]) -> None: ... - -""" -special_stmt - Axis -""" - -class axis: - @overload - @staticmethod - def spatial(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ... - @overload - @staticmethod - def spatial( - dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr - ) -> IterVar: ... - @overload - @staticmethod - def S(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ... - @overload - @staticmethod - def S(dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr) -> IterVar: ... - @overload - @staticmethod - def reduce(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ... - @overload - @staticmethod - def reduce( - dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr - ) -> IterVar: ... - @overload - @staticmethod - def R(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ... - @overload - @staticmethod - def R(dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr) -> IterVar: ... - @overload - @staticmethod - def scan(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ... - @overload - @staticmethod - def scan( - dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr - ) -> IterVar: ... - @overload - @staticmethod - def opaque(dom: Union[PrimExpr, int], value: PrimExpr) -> IterVar: ... - @overload - @staticmethod - def opaque( - dom: Tuple[Union[PrimExpr, int], Union[PrimExpr, int]], value: PrimExpr - ) -> IterVar: ... - @staticmethod - def remap(iter_types: str, loop_vars: List[Var]) -> List[IterVar]: ... - -def get_axis(begin: PrimExpr, end: PrimExpr, iter_type: int) -> IterVar: ... - -""" -special_stmt - Annotations -""" - -def buffer_var(dtype: str, storage_scope: str) -> Var: ... -def func_attr(attrs: Mapping[str, Union[Object, str, bool, int, float]]) -> None: ... -def prim_func(input_func: Callable) -> PrimFunc: ... - -""" -special_stmt - Threads and Bindings -""" - -def env_thread(env_name: str) -> IterVar: ... -def bind(iter_var: IterVar, expr: PrimExpr) -> None: ... - -""" -Scope handler -""" - -class block(ContextManager): - def __init__(self, name_hint: str = "") -> None: ... - def __enter__(self) -> Sequence[IterVar]: ... - -class init(ContextManager): - def __init__(self) -> None: ... - -class let(ContextManager): - def __init__(self, var: Var, value: PrimExpr) -> None: ... - -def where(cond: PrimExpr) -> None: ... -def allocate( - extents: List[PrimExpr], - dtype: str, - scope: str, - condition: Union[PrimExpr, builtins.bool] = True, - annotations: Optional[Mapping[str, Object]] = None, -) -> Buffer: ... -def launch_thread(env_var: Var, extent: Union[int, PrimExpr]) -> Var: ... -def realize( - buffer_slice: BufferSlice, scope: str, condition: Union[PrimExpr, builtins.bool] = True -) -> None: ... -def attr(node: PrimExpr, attr_key: str, value: PrimExpr) -> None: ... -def Assert(condition: Union[PrimExpr, builtins.bool], message: str) -> PrimExpr: ... - -""" -Scope handler - Loops -""" - -@overload -def serial( - begin: Union[PrimExpr, int], - end: Union[PrimExpr, int], - annotations: Optional[Mapping[str, Object]] = None, -) -> Iterable[IterVar]: ... -@overload -def serial( - end: Union[PrimExpr, int], - annotations: Optional[Mapping[str, Object]] = None, -) -> Iterable[IterVar]: ... -@overload -def parallel( - begin: Union[PrimExpr, int], - end: Union[PrimExpr, int], - annotations: Optional[Mapping[str, Object]] = None, -) -> Iterable[IterVar]: ... -@overload -def parallel( - end: Union[PrimExpr, int], - annotations: Optional[Mapping[str, Object]] = None, -) -> Iterable[IterVar]: ... -@overload -def vectorized( - begin: Union[PrimExpr, int], - end: Union[PrimExpr, int], - annotations: Optional[Mapping[str, Object]] = None, -) -> Iterable[IterVar]: ... -@overload -def vectorized( - end: Union[PrimExpr, int], - annotations: Optional[Mapping[str, Object]] = None, -) -> Iterable[IterVar]: ... -@overload -def unroll( - begin: Union[PrimExpr, int], - end: Union[PrimExpr, int], - annotations: Optional[Mapping[str, Object]] = None, -) -> Iterable[IterVar]: ... -@overload -def unroll( - end: Union[PrimExpr, int], - annotations: Optional[Mapping[str, Object]] = None, -) -> Iterable[IterVar]: ... -@overload -def thread_binding( - begin: Union[PrimExpr, int], - end: Union[PrimExpr, int], - thread: str, - annotations: Optional[Mapping[str, Object]] = None, -) -> Iterable[IterVar]: ... -@overload -def thread_binding( - end: Union[PrimExpr, int], - thread: str, - annotations: Optional[Mapping[str, Object]] = None, -) -> Iterable[IterVar]: ... -@overload -def for_range( - begin: Union[PrimExpr, int], - end: Union[PrimExpr, int], - annotations: Optional[Mapping[str, Object]] = None, -) -> Iterable[IterVar]: ... -@overload -def for_range( - end: Union[PrimExpr, int], - annotations: Optional[Mapping[str, Object]] = None, -) -> Iterable[IterVar]: ... -def grid(*extents: Union[PrimExpr, int]) -> Iterable[Sequence[IterVar]]: ... - -""" -ty - redefine types -""" - -class boolean: ... - -class handle(Var): - @overload - def __getitem__(self: handle, pos: Sequence[Union[int, PrimExpr, slice]]) -> Buffer: ... - @overload - def __getitem__(self: handle, pos: Union[int, PrimExpr, slice]) -> Buffer: ... - @overload - def __setitem__( - self: handle, pos: Sequence[Union[int, PrimExpr, slice]], value: Buffer - ) -> None: ... - @overload - def __setitem__(self: handle, pos: Union[int, PrimExpr, slice], value: Buffer) -> None: ... - @property - def data(self: handle) -> Ptr: ... - -class Ptr: ... - -def target(target_str: Union[str, Mapping[str, Object]]) -> Target: ... - -class var(Var): - def __init__(self: Var, dtype: str): ... - -class bool(PrimExpr): - def __init__(self: bool, imm: Union[PrimExpr, builtins.bool, builtins.int]): ... - -class int8(PrimExpr): - def __init__(self: int8, imm: Union[PrimExpr, int]): ... - -class int16(PrimExpr): - def __init__(self: int16, imm: Union[PrimExpr, int]): ... - -class int32(PrimExpr): - def __init__(self: int32, imm: Union[PrimExpr, int]): ... - -class int64(PrimExpr): - def __init__(self: int64, imm: Union[PrimExpr, int]): ... - -class uint8(PrimExpr): - def __init__(self: uint8, imm: Union[PrimExpr, int]): ... - -class uint16(PrimExpr): - def __init__(self: uint16, imm: Union[PrimExpr, int]): ... - -class uint32(PrimExpr): - def __init__(self: uint32, imm: Union[PrimExpr, int]): ... - -class uint64(PrimExpr): - def __init__(self: uint64, imm: Union[PrimExpr, int]): ... - -# use typing.Literal instead for python 3.8 or higher -import sys - -if sys.version_info >= (3, 8): - from typing import Literal - - SpecialFloatLiteral = Literal["inf", "-inf", "nan"] -else: - SpecialFloatLiteral = str - -class float8(PrimExpr): - def __init__(self: float8, imm: Union[PrimExpr, int, float, SpecialFloatLiteral]): ... - -class float16(PrimExpr): - def __init__(self: float16, imm: Union[PrimExpr, int, float, SpecialFloatLiteral]): ... - -class float32(PrimExpr): - def __init__(self: float32, imm: Union[PrimExpr, int, float, SpecialFloatLiteral]): ... - -class float64(PrimExpr): - def __init__(self: float64, imm: Union[PrimExpr, int, float, SpecialFloatLiteral]): ... diff --git a/python/tvm/script/parser_v1/tir/intrin.py b/python/tvm/script/parser_v1/tir/intrin.py deleted file mode 100644 index 9cde8e3f6d08..000000000000 --- a/python/tvm/script/parser_v1/tir/intrin.py +++ /dev/null @@ -1,307 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""TVM Script Parser Intrinsic Classes""" -# pylint: disable=redefined-builtin, relative-beyond-top-level -import builtins -from typing import Any, List - -import tvm.tir -from tvm.tir import FloatImm - -from ....target import codegen -from ..registry import register -from ..utils import get_param_list, tvm_span_from_synr - - -class Intrin: - def __init__(self, intrin, stmt=False): - self.intrin = intrin - self.stmt = stmt - - def signature(self): - return "tir." + self.intrin.__name__, get_param_list(self.intrin) - - def handle(self, arg_list: List[Any], span: tvm.ir.Span): - return self.intrin(*arg_list, span=tvm_span_from_synr(span)) - - -@register -def bool(imm, span): - return imm.astype("bool", span) - - -# register all datatypes -for _dtype in ["float", "uint", "int"]: - for _size in ["8", "16", "32", "64"]: - for _lanes in ["", "x4", "x8", "x16", "x32"]: - _name = _dtype + _size + _lanes - - # nest closures so we copy the name string - def wrap(name): - def f(imm, span): - if name.startswith("float"): - if imm in {"inf", "-inf", "nan"}: - return FloatImm(dtype=name, value=float(imm), span=span) - return imm.astype(name, span) - - f.__name__ = name - return f - - _intrin = wrap(_name) - register(_intrin) - - -@register -def min_value(dtype, span): - return tvm.tir.min_value(dtype, span) - - -@register -def max_value(dtype, span): - return tvm.tir.max_value(dtype, span) - - -@register -def floordiv(x, y, span): - return tvm.tir.floordiv(x, y, span) - - -@register -def floormod(x, y, span): - return tvm.tir.floormod(x, y, span) - - -@register -def truncmod(x, y, span): - return tvm.tir.truncmod(x, y, span) - - -@register -def truncdiv(x, y, span): - return tvm.tir.truncdiv(x, y, span) - - -@register -def ceildiv(x, y, span): - return tvm.tir.ceildiv(x, y, span) - - -@register -def abs(x, span): - return tvm.tir.abs(x, span) - - -@register -def load(dtype, var, index, predicate=None, span=None): - return tvm.tir.Load(dtype, var, index, predicate, span) - - -@register -def cast(value, dtype, span): - return tvm.tir.Cast(dtype, value, span) - - -@register -def ramp(base, stride, lanes, span): - return tvm.tir.Ramp(base, stride, lanes.value, span) - - -@register -def broadcast(value, lanes, span): - return tvm.tir.Broadcast(value, lanes.value, span) - - -@register -def iter_var(var, dom, iter_type, thread_tag, span): - iter_type = getattr(tvm.tir.IterVar, iter_type) - return tvm.tir.IterVar(dom, var, iter_type, thread_tag, span) - - -@register -def max(a, b, span): # pylint: disable=redefined-builtin - return tvm.tir.Max(a, b, span) - - -@register -def min(a, b, span): # pylint: disable=redefined-builtin - return tvm.tir.Min(a, b, span) - - -def get_axis(begin, end, iter_type, span): - ana = tvm.arith.Analyzer() - extent = ana.simplify(end - begin) - block_var_dom = tvm.ir.Range.from_min_extent(begin, extent) - - iter_type_dict = {"data_par": 0, "reduce": 2, "scan": 3, "opaque": 4} - return tvm.tir.IterVar(block_var_dom, "bv", iter_type_dict[iter_type], span=span) - - -@register -def range(begin, end, span): - return get_axis(begin, end, "data_par", span) - - -@register -def reduce_axis(begin, end, span): - return get_axis(begin, end, "reduce", span) - - -@register -def scan_axis(begin, end, span): - return get_axis(begin, end, "scan", span) - - -@register -def opaque_axis(begin, end, span): - return get_axis(begin, end, "opaque", span) - - -@register -def Select(cond, if_body, else_body, span): # pylint: disable=invalid-name - return tvm.tir.Select(cond, if_body, else_body, span) - - -@register -def Let(var, value, body, span): # pylint: disable=invalid-name - return tvm.tir.Let(var, value, body, span) - - -@register -class EvaluateIntrin(Intrin): - def __init__(self): - def evaluate(value, span): - return tvm.tir.Evaluate(value, span) - - super().__init__(evaluate, stmt=True) - - -@register -class StoreIntrin(Intrin): - def __init__(self): - def store(var, index, value, predicate=True, span=None): - return tvm.tir.Store(var, value, index, predicate, span) - - super().__init__(store, stmt=True) - - -@register -class AssumeIntrin(Intrin): - def __init__(self): - def assume(constraint, span): - return tvm.tir.Evaluate( - tvm.tir.call_intrin("bool", "tir.assume", constraint, span=span) - ) - - super().__init__(assume, stmt=True) - - -@register -def comm_reducer(lambda_io, identities, span): - """Create a CommReducer from lambda inputs/outputs and the identities""" - lambda_input = lambda_io[0] - lambda_output = lambda_io[1] - - num_args = len(lambda_input) - num_arg_per_group = num_args // 2 - x = [lambda_input[i] for i in builtins.range(0, num_arg_per_group)] - y = [lambda_input[i] for i in builtins.range(num_arg_per_group, num_args)] - - if not isinstance(lambda_output, tuple): - lambda_output = (lambda_output,) - - return tvm.tir.CommReducer(x, y, lambda_output, identities, span) - - -@register -def llvm_lookup_intrinsic_id(name, span): - # pylint: disable=unused-argument - return codegen.llvm_lookup_intrinsic_id(name) - - -@register -def FloorMod(x, y, span): # pylint: disable=invalid-name - return tvm.tir.FloorMod(x, y, span) - - -@register -def FloorDiv(x, y, span): # pylint: disable=invalid-name - return tvm.tir.FloorDiv(x, y, span) - - -@register -def Mul(x, y, span): # pylint: disable=invalid-name - return tvm.tir.Mul(x, y, span) - - -@register -def Div(x, y, span): # pylint: disable=invalid-name - return tvm.tir.Div(x, y, span) - - -@register -def Add(x, y, span): # pylint: disable=invalid-name - return tvm.tir.Add(x, y, span) - - -@register -def Sub(x, y, span): # pylint: disable=invalid-name - return tvm.tir.Sub(x, y, span) - - -@register -def LT(x, y, span): # pylint: disable=invalid-name - return tvm.tir.LT(x, y, span) - - -@register -def LE(x, y, span): # pylint: disable=invalid-name - return tvm.tir.LE(x, y, span) - - -@register -def GT(x, y, span): # pylint: disable=invalid-name - return tvm.tir.GT(x, y, span) - - -@register -def GE(x, y, span): # pylint: disable=invalid-name - return tvm.tir.GE(x, y, span) - - -@register -def EQ(x, y, span): # pylint: disable=invalid-name - return tvm.tir.EQ(x, y, span) - - -@register -def NE(x, y, span): # pylint: disable=invalid-name - return tvm.tir.NE(x, y, span) - - -@register -def And(x, y, span): # pylint: disable=invalid-name - return tvm.tir.And(x, y, span) - - -@register -def Or(x, y, span): # pylint: disable=invalid-name - return tvm.tir.Or(x, y, span) - - -@register -def Cast(dtype, value, span): # pylint: disable=invalid-name - return tvm.tir.Cast(dtype, value, span) diff --git a/python/tvm/script/parser_v1/tir/node.py b/python/tvm/script/parser_v1/tir/node.py deleted file mode 100644 index 29e79607fbc9..000000000000 --- a/python/tvm/script/parser_v1/tir/node.py +++ /dev/null @@ -1,218 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=redefined-builtin -"""TVM Script nodes.""" - -from typing import Optional, Union, List, Callable -import synr -from tvm.arith import Analyzer -from tvm.runtime import ObjectGeneric, convert -from tvm.tir import PrimExpr, Buffer, BufferLoad, IntImm, Ramp, BufferRegion -from tvm.ir import Span, Range - - -class Slice: - """A helper class to present slice information for BufferSlice - - Parameters - ---------- - start : Union[PrimExpr, int] - The start index. - - stop : Optional[Union[PrimExpr, int]] - The stop index, None means the Slice is an element-wise index - - step : int - The slice step - - span : Optional[Span] - The location of the slice in the source. - """ - - start: Union[PrimExpr, int] - stop: Optional[Union[PrimExpr, int]] - step: int - span: Optional[Span] - - def __init__( - self, - start: Union[PrimExpr, int], - stop: Optional[Union[PrimExpr, int]] = None, - step: int = 1, - span: Optional[Span] = None, - ): - self.start = start - self.stop = stop - self.step = step - self.span = span - - def as_index_expr(self, report_error: Callable[[str, Union[Span, synr.ast.Span]], None]): - """Helper to create index PrimExpr from slice object - Parameters - ---------- - report_error: Callable[[str, Union[Span, synr.ast.Span]], None] - The error report func - """ - if self.stop is None: - # scalar index - return self.start - if self.step < 1: - report_error("Slice's step should be positive integer", self.span) - lanes = Analyzer().simplify((self.stop - self.start + self.step - 1) // self.step) - if not isinstance(lanes, (int, IntImm)): - report_error("Slice's lanes should be constant for buffer indices", self.span) - if lanes == 1: - return self.start - return Ramp(self.start, self.step, int(lanes), self.span) - - -class BufferSlice(ObjectGeneric): - """A generic object for representing general buffer access. Following cases are supported: - - element wise access buffer[i, j], which can be converted to BufferLoad if necessary - - slice access buffer[i: i + 1, j : j + 2] - - union of element and slice buffer[i, j: j + 2] - - This node is used in TVMScript to parse BufferLoad, BufferRegion and Realize - - Parameters - ---------- - buffer : Buffer - The buffer. - - indices : List[Union[Slice, PrimExpr, int]] - The access indexes can be slice, PrimExpr or int. - - report_error: Callable[[str, Union[Span, synr.ast.Span]], None] - The error report func - - span : Optional[Span] - The location of the buffer access in the source. - """ - - buffer: Buffer - slices: List[Slice] - report_error: Callable[[str, Union[Span, synr.ast.Span]], None] - span: Optional[Span] - - def __init__( - self, - buffer: Buffer, - indices: List[Union[Slice, PrimExpr, int]], - report_error: Callable[[str, Union[Span, synr.ast.Span]], None], - span: Optional[Span] = None, - ): - def check_index(index: Union[int, PrimExpr]): - """Check input index is non-negative integer or PrimExpr""" - if isinstance(index, int): - if index < 0: - report_error("Negative index is not allowed during buffer access", span) - elif isinstance(index, PrimExpr): - element_dtype = index.dtype.split("x", maxsplit=1)[0] - if element_dtype[:3] != "int": - report_error( - "index expected an integer type PrimExpr but got " + str(index.dtype), - index.span, - ) - else: - report_error( - "Unsupported index type, expected int or tvm.tir.PrimExpr, but got " - + str(type(index)), - span, - ) - - slices: List[Union[Slice, BufferSlice]] = [] - for index in indices: - if isinstance(index, Slice): - index.start, index.stop = [convert(_) for _ in [index.start, index.stop]] - check_index(index.start) - check_index(index.stop) - slices.append(index) - elif isinstance(index, (PrimExpr, int)): - check_index(index) - slices.append(Slice(index)) - elif isinstance(index, BufferSlice): - buffer_load = index.asobject() - check_index(buffer_load) - slices.append(Slice(buffer_load)) - else: - report_error( - "Unsupported index type for BufferSlice, " - + "expected int, tvm.tir.PrimExpr, tvm.tir.Slice, but got " - + str(type(index)), - span, - ) - - self.buffer = buffer - self.slices = slices - self.report_error = report_error - self.span = span - - def __str__(self): - regions: List[str] = [] - for s in self.slices: - if s.stop is None: - regions.append(str(s.start)) - else: - regions.append(str(s.start) + ": " + str(s.stop)) - - return self.buffer.name + "[" + ", ".join(regions) + "]" - - def asobject(self) -> BufferLoad: - """Convert object.""" - indices = [s.as_index_expr(self.report_error) for s in self.slices] - return BufferLoad(self.buffer, indices, span=self.span) - - def as_buffer_region(self, analyzer: Optional[Analyzer] = None) -> BufferRegion: - """Construct BufferRegion from BufferSlice - - Parameters - ---------- - analyzer : Optional[tvm.arith.Analyzer] - The analyzer for simplifying. If not provided, the method will construct a new one - - Returns - ------- - buffer_region : BufferRegion - The constructed BufferRegion. - """ - region: List[Range] = [] - for s in self.slices: - start = s.start if isinstance(s.start, PrimExpr) else IntImm("int32", s.start) - extent = IntImm(start.dtype, 1) if s.stop is None else s.stop - s.start - if not analyzer: - analyzer = Analyzer() - if isinstance(extent, PrimExpr): - extent = analyzer.simplify(extent) - if s.step != 1: - self.report_error("BufferRegion do not support non-trivial stride", s.span) - region.append(Range.from_min_extent(start, extent, span=s.span)) - return BufferRegion(self.buffer, region) - - def astype(self, dtype: str, span: Optional[Span] = None) -> PrimExpr: - return self.asobject().astype(dtype, span) - - @property - def dtype(self) -> str: - """Return the dtype referenced by the slice. - - Implemented as a property so that ``slice.dtype`` has the same - calling convention as ``primexpr.dtype``. This allows a - BufferSlice object can be assigned to a variable without - requiring a type annotation on the variable, similar to other - expressions. - """ - return self.asobject().dtype diff --git a/python/tvm/script/parser_v1/tir/prim_func.py b/python/tvm/script/parser_v1/tir/prim_func.py deleted file mode 100644 index 923eb97d2758..000000000000 --- a/python/tvm/script/parser_v1/tir/prim_func.py +++ /dev/null @@ -1,45 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""TVM Script Interface for PrimFunc""" - -import inspect -from typing import Callable - -from tvm.tir.function import PrimFunc -from ..parser import from_source - - -def prim_func(input_func: Callable) -> PrimFunc: - """Decorate a python function as tvm script. - - Parameters - ---------- - func : input_func - The function to be parsed. - - Returns - ------- - output : PrimFunc - The result functions. - """ - if inspect.isfunction(input_func): - result = from_source(input_func) - result.__name__ = input_func.__name__ - result.__qualname__ = input_func.__qualname__ - return result - - raise TypeError("Only function definitions are supported.") diff --git a/python/tvm/script/parser_v1/tir/scope_handler.py b/python/tvm/script/parser_v1/tir/scope_handler.py deleted file mode 100644 index 69a414890655..000000000000 --- a/python/tvm/script/parser_v1/tir/scope_handler.py +++ /dev/null @@ -1,793 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""TVM Script Parser Scope Handler Classes""" -# pylint: disable=redefined-builtin, unused-argument, invalid-name, relative-beyond-top-level -from typing import Tuple, Any, Callable, Optional, List, Union, Mapping - -import synr -import numpy as np -import tvm.tir -from tvm.runtime import Object, String, convert -from tvm.ir import Span, Range -from tvm.tir import Stmt, PrimExpr, IterVar, Var, Buffer, BufferRegion, ForKind - -from .node import BufferSlice - -from ..context_maintainer import ContextMaintainer -from ..registry import register -from ..utils import ( - get_param_list, - tvm_span_from_synr, - call_with_error_reporting, -) - - -class ScopeHandler: - """Base class for all scope handlers""" - - def __init__(self, func: Callable): - self.func: Callable = func - self.body: Optional[Stmt] = None - self.node: Optional[synr.ast.Node] = None - self.context: Optional[ContextMaintainer] = None - - def signature(self) -> Tuple[str, Tuple[list, list, Any]]: - return "tir." + self.func.__name__, get_param_list(self.func) - - def enter_scope( - self, - node: synr.ast.Node, - context: ContextMaintainer, - arg_list: List[Any], - span: synr.ast.Span, - ): - pass - - def exit_scope( - self, - node: synr.ast.Node, - context: ContextMaintainer, - arg_list: List[Any], - span: synr.ast.Span, - ): - self.node = node - self.context = context - return call_with_error_reporting( - context.report_error, span, self.func, *arg_list, span=tvm_span_from_synr(span) - ) - - -class WithScopeHandler(ScopeHandler): - """Base class for all with scope handlers""" - - def __init__(self, func, concise_scope, def_symbol): - super().__init__(func) - self.concise_scope = concise_scope - self.def_symbol = def_symbol - - @staticmethod - def get_optional_vars(node, context): - """Get a list synr.ast.With's optional_vars""" - assert isinstance( - node, synr.ast.With - ), f"WithScopeHandler expected synr.ast.With but got {type(node)}" - - if isinstance(node.lhs, list): - for var in node.lhs: - if not isinstance(var, synr.ast.Var): - context.report_error( - f"Invalid optional var definition, expected Var but got {type(var)}", - node.span, - ) - vars = node.lhs - else: - context.report_error( - f"Invalid optional var definition, expected list of Var but got {type(node.lhs)}", - node.span, - ) - return vars - - -@register -class Allocate(WithScopeHandler): - """With scope handler T.allocate(extents, dtype, scope, condition, annotations)""" - - def __init__(self): - def allocate(extents, dtype, scope, condition=True, annotations=None, span=None): - condition = tvm.runtime.convert(condition) - scope = tvm.runtime.convert(scope) - - return tvm.tir.Allocate( - self.buffer_var, - dtype, - extents, - condition, - self.body, - annotations=annotations, - span=span, - ) - - super().__init__(allocate, concise_scope=True, def_symbol=True) - self.buffer_var = None - - def enter_scope( - self, - node: synr.ast.Node, - context: ContextMaintainer, - arg_list: List[Any], - span: synr.ast.Span, - ): - # define buffer vars in symbol table - if isinstance(node, synr.ast.With): - vars = WithScopeHandler.get_optional_vars(node, context) - if len(vars) != 1: - context.report_error(f"Unexpected number of vars: 1 vs. {len(vars)}", node.span) - name = vars[0].id.name - var_span = vars[0].id.span - elif isinstance(node, synr.ast.Assign): - if len(node.lhs) != 1: - context.report_error(f"Unexpected number of vars: 1 vs. {len(node.lhs)}", node.span) - name = node.lhs[0].id.name - var_span = node.lhs[0].id.span - else: - raise Exception("Internal Bug") - - def setup_buffer_var( - extents, dtype, scope, condition=True, annotations=None, span: Span = None - ): - """Setup buffer var for a given type.""" - buffer_ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype), scope) - self.buffer_var = tvm.tir.Var(name, buffer_ptr_type, span) - - setup_buffer_var(*arg_list, span=tvm_span_from_synr(var_span)) - context.update_symbol(name, self.buffer_var, node) - - -@register -class AllocateConst(WithScopeHandler): - """With scope handler T.allocate_const(data, extents, dtype, condition) - - TIR constant node to represent non-scalar constant - """ - - def __init__(self): - def allocate_const(raw_data, dtype, shape, annotations=None, span=None): - list_data = [] - for i in raw_data: - list_data.append(i.value) - nd_data = tvm.nd.array(np.asarray(list_data, dtype=dtype)) - n = tvm.tir.AllocateConst( - self.buffer_var, - dtype, - shape, - nd_data, - self.body, - annotations=annotations, - span=span, - ) - return n - - super().__init__(allocate_const, concise_scope=True, def_symbol=True) - self.buffer_var = None - - def enter_scope( - self, - node: synr.ast.Node, - context: ContextMaintainer, - arg_list: List[Any], - span: synr.ast.Span, - ): - # define buffer vars in symbol table - if isinstance(node, synr.ast.With): - vars = WithScopeHandler.get_optional_vars(node, context) - if len(vars) != 1: - context.report_error(f"Unexpected number of vars: 1 vs. {len(vars)}", node.span) - name = vars[0].id.name - var_span = vars[0].id.span - elif isinstance(node, synr.ast.Assign): - if len(node.lhs) != 1: - context.report_error(f"Unexpected number of vars: 1 vs. {len(node.lhs)}", node.span) - name = node.lhs[0].id.name - var_span = node.lhs[0].id.span - else: - raise Exception("Internal Bug") - - def setup_buffer_var(data, dtype, shape, annotations: dict = None, span: Span = None): - """Setup buffer var for a given type.""" - buffer_ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype)) - self.buffer_var = tvm.tir.Var(name, buffer_ptr_type, span) - - setup_buffer_var(*arg_list, span=tvm_span_from_synr(var_span)) - context.update_symbol(name, self.buffer_var, node) - - -@register -class DeclBuffer(WithScopeHandler): - """Special Stmt decl_buffer(shape, dtype, data, strides, elem_offset, scope, align, - offset_factor, buffer_type, axis_separators) - Example - ------- - .. code-block:: python - A = T.decl_buffer((128, 128), dtype="float32") - """ - - def __init__(self): - def decl_buffer( - shape, - dtype="float32", - data=None, - strides=None, - elem_offset=None, - scope="global", - align=-1, - offset_factor=0, - buffer_type="default", - axis_separators=None, - span=None, - ): - decl_buffer = tvm.tir.DeclBuffer(self.buffer, self.body, span=span) - if data is None: - # when data is not specified, the buffer is implicitly allocated - return tvm.tir.Allocate( - self.buffer.data, - dtype, - shape, - tvm.runtime.convert(True), - decl_buffer, - span=span, - ) - return decl_buffer - - super().__init__(decl_buffer, concise_scope=True, def_symbol=True) - - def enter_scope( - self, - node: synr.ast.Node, - context: ContextMaintainer, - arg_list: List[Any], - span: synr.ast.Span, - ): - # define buffer vars in symbol table - if isinstance(node, synr.ast.With): - vars = WithScopeHandler.get_optional_vars(node, context) - if len(vars) != 1: - context.report_error(f"Unexpected number of vars: 1 vs. {len(vars)}", node.span) - name = vars[0].id.name - var_span = vars[0].id.span - elif isinstance(node, synr.ast.Assign): - if len(node.lhs) != 1: - context.report_error(f"Unexpected number of vars: 1 vs. {len(node.lhs)}", node.span) - name = node.lhs[0].id.name - var_span = node.lhs[0].id.span - else: - raise Exception("Internal Bug") - - def setup_buffer( - shape, - dtype, - data, - strides, - elem_offset, - scope, - align, - offset_factor, - buffer_type, - axis_separators, - span: Span = None, - ): - self.buffer = tvm.tir.decl_buffer( - shape=shape, - dtype=dtype, - data=data, - strides=strides, - elem_offset=elem_offset, - scope=scope, - data_alignment=align, - offset_factor=offset_factor, - buffer_type=buffer_type, - axis_separators=axis_separators, - name=name, - span=span, - ) - - setup_buffer(*arg_list, span=tvm_span_from_synr(var_span)) - context.update_symbol(name, self.buffer, node) - - -@register -class LaunchThread(WithScopeHandler): - """With scope handler T.launch_thread(env_var, extent)""" - - def __init__(self): - def launch_thread(env_var, extent, span): - extent = tvm.runtime.convert(extent, span=span) - thread_id = self.context.func_var_env_dict[env_var] - attr_key = "virtual_thread" if thread_id == "vthread" else "thread_extent" - return tvm.tir.AttrStmt( - IterVar( - (0, extent), - env_var, - getattr(IterVar, "ThreadIndex"), - thread_id, - span=span, - ), - attr_key, - extent, - self.body, - span=span, - ) - - super().__init__(launch_thread, concise_scope=True, def_symbol=False) - - -@register -class Realize(WithScopeHandler): - """With scope handler T.realize(buffer_bounds, scope, condition)""" - - def __init__(self): - def realize( - buffer_slice: BufferSlice, scope: str, condition: bool = True, span: bool = None - ): - assert self.context, "call 'exit_scope' before 'enter_scope'" - buffer: Buffer = buffer_slice.buffer - bounds: List[Range] = [] - for s in buffer_slice.slices: - min: Union[PrimExpr, int] = s.start - extent: Union[PrimExpr, int] = 1 if s.stop is None else s.stop - s.start - if isinstance(extent, PrimExpr): - extent = self.context.analyzer.simplify(extent) - bounds.append(Range.from_min_extent(min, extent, span=s.span)) - - scope = tvm.runtime.convert(scope, span=span) - return tvm.tir.AttrStmt( - buffer, - "realize_scope", - scope, - tvm.tir.BufferRealize(buffer, bounds, condition, self.body, span=span), - span=span, - ) - - super().__init__(realize, concise_scope=True, def_symbol=False) - - -@register -class Attr(WithScopeHandler): - """With scope handler T.attr(attr_node, attr_key, value)""" - - def __init__(self): - def attr(attr_node, attr_key, value, span): - attr_node = tvm.runtime.convert(attr_node, span=span) - value = tvm.runtime.convert(value, span=span) - return tvm.tir.AttrStmt(attr_node, attr_key, value, self.body, span=span) - - super().__init__(attr, concise_scope=True, def_symbol=False) - - -@register -class AssertHandler(WithScopeHandler): - """With scope handler T.Assert(condition, message)""" - - def __init__(self): - def Assert(condition, message, span): - return tvm.tir.AssertStmt(condition, tvm.runtime.convert(message), self.body, span=span) - - super().__init__(Assert, concise_scope=True, def_symbol=False) - - -@register -class Let(WithScopeHandler): - """With scope handler T.let(var, value)""" - - def __init__(self): - def let(var, value, span): - return tvm.tir.LetStmt(var, value, self.body, span=span) - - super().__init__(let, concise_scope=False, def_symbol=False) - - def __call__(self, var: tvm.tir.Var, value: tvm.tir.PrimExpr, body: tvm.tir.PrimExpr): - return tvm.tir.Let(var, value, body) - - -@register -class Block(WithScopeHandler): - """With scope handler T.block(name)""" - - def __init__(self): - def block(name_hint: str = "", span: Optional[Span] = None): - assert ( - self.node and self.context and self.body - ), "call 'exit_scope' before 'enter_scope'" - block_info = self.context.block_info_stack[-1] - - # create block read/write regions - reads: List[BufferRegion] = ( - [read.as_buffer_region() for read in block_info.reads] if block_info.reads else [] - ) - writes: List[BufferRegion] = ( - [write.as_buffer_region() for write in block_info.writes] - if block_info.writes - else [] - ) - - region_detect_mask: int = (block_info.reads is None) | ( - (block_info.writes is None) << 1 - ) - annotations = {} if block_info.annotations is None else block_info.annotations - if region_detect_mask != 0: - annotations["tir.script_parsing_detect_access"] = region_detect_mask - inner = tvm.tir.Block( - block_info.iter_vars, - reads, - writes, - name_hint, - self.body, - block_info.init, - block_info.alloc_buffers, - block_info.match_buffers, - annotations, - span, - ) - assert len(block_info.iter_vars) == len(block_info.iter_values) - predicate = ( - tvm.tir.const(True, "bool") - if block_info.predicate is None - else block_info.predicate - ) - body = tvm.tir.BlockRealize(block_info.iter_values, predicate, inner, span) - return body - - super().__init__(func=block, concise_scope=False, def_symbol=True) - self.block_vars = None - - def enter_scope( - self, - node: synr.ast.Node, - context: ContextMaintainer, - arg_list: List[Any], - span: synr.ast.Span, - ): - # define block vars - assert isinstance( - node, synr.ast.With - ), f"BlockScopeHandler expected to work on synr.ast.With but got {type(node)}" - - optional_vars = [var.id.name for var in WithScopeHandler.get_optional_vars(node, context)] - if optional_vars: - context.report_error( - f"Block expected no optional_vars (e.g., `x` in `with block() as x`), " - f"but got {optional_vars}", - node.span, - ) - - -@register -class InitBlock(WithScopeHandler): - """With scope handler T.init()""" - - def __init__(self): - def init(span: Span = None): - assert self.context, "call 'exit_scope' before 'enter_scope'" - if self.context.block_info_stack[-2].init is not None: - self.context.report_error("Duplicate init block declaration", span) - self.context.block_info_stack[-2].init = self.body - - super().__init__(func=init, concise_scope=False, def_symbol=True) - - -class LoopInfo: - """Helper class for loop information""" - - loop_var: Var - begin: PrimExpr - extent: PrimExpr - kind: ForKind - thread_binding: Optional[str] - annotations: Optional[Mapping[str, Object]] - - def __init__( - self, - begin: PrimExpr, - extent: PrimExpr, - kind: ForKind, - thread_binding: Optional[str] = None, - annotations: Optional[Mapping[str, Object]] = None, - ) -> None: - self.begin = begin - self.extent = extent - self.kind = kind - self.thread_binding = thread_binding - self.annotations = annotations - - -class ForScopeHandler(ScopeHandler): - """Base class for all for scope handlers""" - - def __init__(self, func): - super().__init__(func) - self.loop_vars: List[Var] = [] - self.loop_info: List[LoopInfo] = [] - - def enter_scope( - self, - node: synr.ast.Node, - context: ContextMaintainer, - arg_list: List[Any], - span: synr.ast.Span, - ): - assert isinstance( - node, synr.ast.For - ), f"ForScopeHandler expected synr.ast.For but got {type(node)}" - - loop_var_names = list() - spans = list() - if isinstance(node.lhs, synr.ast.Var): - loop_var_names.append(node.lhs.id.name) - spans.append(tvm_span_from_synr(node.lhs.id.span)) - elif isinstance(node.lhs, list): - for elt in node.lhs: - if not isinstance(elt, synr.ast.Var): - context.report_error( - f"Invalid loop var. Expected a var, but got {type(elt)}", elt.span - ) - loop_var_names.append(elt.id.name) - spans.append(tvm_span_from_synr(elt.id.span)) - else: - context.report_error( - f"Invalid loop var. Expected var or list of vars as lhs, but got {type(node.lhs)}", - span, - ) - - self.node = node - self.context = context - # collect loop infos by calling self.func - call_with_error_reporting(context.report_error, span, self.func, *arg_list) - if len(loop_var_names) != len(self.loop_info): - self.context.report_error( - f"Inconsistent number of vars and loops, got {len(loop_var_names)} " - + f"vs {len(self.loop_info)}", - self.node.span, - ) - # generate loop vars - self.loop_vars = [] - for name, lv_span, li in zip(loop_var_names, spans, self.loop_info): - if not li.begin.dtype.startswith("int"): - raise NotImplementedError(f"Unsupported dtype in loop begin: {li.begin.dtype}") - if not li.extent.dtype.startswith("int"): - raise NotImplementedError(f"Unsupported dtype in loop extent: {li.extent.dtype}") - dtype = "int64" if "int64" in [li.begin.dtype, li.extent.dtype] else "int32" - self.loop_vars.append(tvm.te.var(name, dtype=dtype, span=lv_span)) - - for loop_var, loop_info in zip(self.loop_vars, self.loop_info): - context.update_symbol(loop_var.name, loop_var, node) - context.loop_stack[loop_var] = Range.from_min_extent(loop_info.begin, loop_info.extent) - - def exit_scope( - self, - node: synr.ast.Node, - context: ContextMaintainer, - arg_list: List[Any], - span: synr.ast.Span, - ): - assert self.loop_vars, "call 'exit_scope' before 'enter_scope'" - for loop_var in self.loop_vars: - context.loop_stack.pop(loop_var) - # Use assert here since we have check it in `enter_scope` - assert len(self.loop_vars) == len(self.loop_info) - - body = self.body - for var, info in zip(reversed(self.loop_vars), reversed(self.loop_info)): - body = tvm.tir.For( - var, - info.begin, - info.extent, - info.kind, - body, - info.thread_binding, - info.annotations, - span=tvm_span_from_synr(span), - ) - - return body - - def create_loop_info( - self, - begin: Optional[PrimExpr], - end: PrimExpr, - kind: ForKind, - thread_binding: Optional[str] = None, - annotations: Optional[Mapping[str, Object]] = None, - ) -> None: - """ - Helper function for creating For in TVM Script parser. - - Parameters - ---------- - begin : Optional[PrimExpr] - The beginning value. If None, it will be set to 0. - - end : PrimExpr - The endding value. - - kind : ForKind - The type of the for. - - thread_binding: Optional[str] - The thread this loop binds to. - - annotations : Optional[Mapping[str, Object]] - Additional annotation hints. - - span : Optional[Span] - The location of this for in the source code. - - Returns - ------- - for : For - The constructed For. - """ - end = convert(end) - if begin is None: - begin = tvm.tir.const(0, end.dtype) - else: - begin = convert(begin) - assert self.context and self.node, "call 'exit_scope' before 'enter_scope'" - extent = ( - end - if self.context.analyzer.can_prove_equal(begin, 0) - else self.context.analyzer.simplify(end - begin) - ) - self.annotations: Mapping[str, Object] = {} - if annotations is not None: - self.annotations = { - key: String(val) if isinstance(val, str) else val - for key, val in annotations.items() - } - - self.loop_info.append(LoopInfo(begin, extent, kind, thread_binding, annotations)) - - -@register -class Serial(ForScopeHandler): - """For scope handler T.serial(begin, end, annotations)""" - - def __init__(self): - def serial( - begin: PrimExpr, - end: PrimExpr = None, - annotations: Optional[Mapping[str, Object]] = None, - ): - if end is None: - end, begin = begin, end - self.create_loop_info(begin, end, ForKind.SERIAL, annotations=annotations) - - super().__init__(serial) - - -@register -class Parallel(ForScopeHandler): - """For scope handler T.parallel(begin, end, annotations)""" - - def __init__(self): - def parallel( - begin: PrimExpr, - end: PrimExpr = None, - annotations: Optional[Mapping[str, Object]] = None, - ): - if end is None: - end, begin = begin, end - self.create_loop_info(begin, end, ForKind.PARALLEL, annotations=annotations) - - super().__init__(parallel) - - -@register -class Vectorized(ForScopeHandler): - """For scope handler T.vectorized(begin, end, annotations)""" - - def __init__(self): - def vectorized( - begin: PrimExpr, - end: PrimExpr = None, - annotations: Optional[Mapping[str, Object]] = None, - ): - if end is None: - end, begin = begin, end - self.create_loop_info(begin, end, ForKind.VECTORIZED, annotations=annotations) - - super().__init__(vectorized) - - -@register -class Unroll(ForScopeHandler): - """For scope handler T.unroll(begin, end, annotations)""" - - def __init__(self): - def unroll( - begin: PrimExpr, - end: PrimExpr = None, - annotations: Optional[Mapping[str, Object]] = None, - ): - if end is None: - end, begin = begin, end - self.create_loop_info(begin, end, ForKind.UNROLLED, annotations=annotations) - - super().__init__(unroll) - - -@register -class ThreadBinding(ForScopeHandler): - """For scope handler T.thread_binding(begin, end, thread, annotations)""" - - def __init__(self): - def thread_binding( - begin: PrimExpr, - end: PrimExpr = None, - thread: str = None, - annotations: Optional[Mapping[str, Object]] = None, - ): - if thread is None: - if isinstance(end, str): # handle case like thread_binding(128, "threadIdx.x") - thread = end - end = None - else: - raise ValueError("Thread cannot be None for thread_binding") - if end is None: - end, begin = begin, end - thread_iter_var = IterVar(None, None, IterVar.ThreadIndex, thread) - self.create_loop_info( - begin, - end, - ForKind.THREAD_BINDING, - thread_binding=thread_iter_var, - annotations=annotations, - ) - - super().__init__(thread_binding) - - -@register -class RangeHandler(ForScopeHandler): - """For scope handler range(begin, end, annotations) - Note that tir.range is totally the same as T.serial - """ - - def __init__(self): - def for_range( - begin: PrimExpr, - end: PrimExpr = None, - annotations: Optional[Mapping[str, Object]] = None, - ): - if end is None: - end, begin = begin, end - self.create_loop_info(begin, end, ForKind.SERIAL, annotations=annotations) - - super().__init__(for_range) - - def signature(self): - return "range", get_param_list(self.func) - - -@register -class Grid(ForScopeHandler): - """For scope handler T.grid(extents)""" - - def __init__(self): - def grid(*extents: List[PrimExpr]): - for extent in extents: - self.create_loop_info(None, extent, ForKind.SERIAL) - - super().__init__(grid) diff --git a/python/tvm/script/parser_v1/tir/special_stmt.py b/python/tvm/script/parser_v1/tir/special_stmt.py deleted file mode 100644 index f558eb6b7f73..000000000000 --- a/python/tvm/script/parser_v1/tir/special_stmt.py +++ /dev/null @@ -1,927 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""TVM Script Parser Special Stmt Classes""" -# pylint: disable=unused-argument, no-self-argument, inconsistent-return-statements -# pylint: disable=relative-beyond-top-level -from typing import Callable, List, Optional, Tuple, Any, Mapping, Union - -import synr -from synr import ast -from tvm.ir.expr import PrimExpr, Range - -import tvm.tir -from tvm.runtime import Object, String -from tvm.target import Target -from tvm.ir import Span -from tvm.tir import IntImm, IterVar, Var - -from .node import BufferSlice - -from ..context_maintainer import BlockInfo, ContextMaintainer -from ..registry import register -from ..utils import ( - get_param_list, - tvm_span_from_synr, - call_with_error_reporting, -) - - -def convert_to_int( - value: Union[IntImm, int], - arg_name: str, - report_error: Callable, - span: Union[Span, synr.ast.Span], -) -> int: - """convert a const int or TVM IntImm to Python int. - Reports an error when input cannot be converted to int. - - Parameters - ---------- - value : Union[tvm.tir.IntImm, int] - The input value to be converted. - arg_name : str - Function argument name for error reporting. - report_error: Callable - The report error function handle - span : Union[synr.ast.Span, tvm.ir.Span] - Location of the error - """ - if isinstance(value, IntImm): - return value.value - if isinstance(value, int): - return value - report_error( - f"Expected int or IntImm for {arg_name}, but got {str(type(value))}", - span, - ) - - -class SpecialStmt: - """Base class for all Special Stmts""" - - def __init__(self, func: Callable, def_symbol: bool): - self.func: Callable = func - self.def_symbol: bool = def_symbol - self.node: Optional[synr.ast.Node] = None - self.context: Optional[ContextMaintainer] = None - - def signature(self) -> Tuple[str, Tuple[list, list, Any]]: - return "tir." + self.func.__name__, get_param_list(self.func) - - def handle( - self, - node: ast.Node, - context: ContextMaintainer, - arg_list: List[Any], - span: synr.ast.Span, - ): - self.node = node - self.context = context - return call_with_error_reporting( - context.report_error, span, self.func, *arg_list, span=tvm_span_from_synr(span) - ) - - -@register -class MatchBuffer(SpecialStmt): - """Special Stmt match_buffer(param, shape, dtype, data, strides, elem_offset, scope, align, - offset_factor, buffer_type, axis_separators) - - Note - ---- - This Special Stmt will perform different behavior depends on the type of param. - If the param is a var in function parameter, it will create a buffer from DLTensor. - Else if the param is a subregion of other buffers, then create a subregion match inside a block. - - Example - ------- - Match buffer from function parameter - .. code-block:: python - A = T.match_buffer(a, (128, 128), dtype="float32") - - Match buffer from Buffer subregion - .. code-block:: python - A = T.match_buffer(B[0:128, i * 128 : i * 128 + 128], (128, 128), dtype="float32") - """ - - def __init__(self): - def match_buffer( - param, - shape=None, - dtype=None, - data=None, - strides=None, - elem_offset=None, - scope="global", - align=-1, - offset_factor=0, - buffer_type="default", - axis_separators=None, - span=None, - ): - if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1: - self.context.report_error( - "`match_buffer` must be assigned to a single buffer, " - "e.g. A = match_buffer(...)", - self.node.span, - ) - if strides is None: - strides = [] - align = convert_to_int(align, "align", self.context.report_error, self.node.span) - offset_factor = convert_to_int( - offset_factor, "offset_factor", self.context.report_error, self.node.span - ) - buffer_name: str = self.node.lhs[0].id.name - - if isinstance(param, tvm.tir.Var): - if shape is None: - self.context.report_error( - "Shape must be specified when binding input param", - self.node.rhs.span, - ) - - if dtype is None: - dtype = "float32" - - buffer = tvm.tir.decl_buffer( - shape, - dtype, - buffer_name, - data, - strides, - elem_offset, - scope, - align, - offset_factor, - buffer_type, - axis_separators, - span=span, - ) - if param not in self.context.func_params: - self.context.report_error( - "Can not bind non-input param to buffer", self.node.rhs.params[0].span - ) - self.context.func_buffer_map[param] = buffer - - elif isinstance(param, BufferSlice): - buffer_region = param.as_buffer_region() - - if shape is None: - shape = [dim.extent for dim in buffer_region.region] - - if dtype is None: - dtype = buffer_region.buffer.dtype - - if elem_offset is None and offset_factor == 0: - offset_factor = 1 - - buffer = tvm.tir.decl_buffer( - shape, - dtype, - buffer_name, - data, - strides, - elem_offset, - scope, - align, - offset_factor, - buffer_type, - axis_separators, - span=span, - ) - - self.context.current_block_scope().match_buffers.append( - tvm.tir.MatchBufferRegion(buffer, buffer_region) - ) - else: - self.context.report_error( - "The source of match_buffer expected Var or BufferSlice, but got " - + str(type(param)), - self.node.rhs.params[0].span, - ) - self.context.update_symbol(buffer_name, buffer, self.node) - - super().__init__(match_buffer, def_symbol=True) - - -@register -class BufferDeclare(SpecialStmt): - """Special Stmt buffer_decl(shape, dtype, data, strides, elem_offset, scope, align, - offset_factor, buffer_type, axis_separators) - Example - ------- - .. code-block:: python - A = T.buffer_decl((128, 128), dtype="float32") - """ - - def __init__(self): - def buffer_decl( - shape, - dtype="float32", - data=None, - strides=None, - elem_offset=None, - scope="global", - align=-1, - offset_factor=0, - buffer_type="default", - axis_separators=None, - span=None, - ): - if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1: - self.context.report_error( - "`buffer_decl` must be assigned to a single buffer, e.g. A = buffer_decl(...)", - self.node.span, - ) - - if strides is None: - strides = [] - align = convert_to_int(align, "align", self.context.report_error, self.node.span) - offset_factor = convert_to_int( - offset_factor, "offset_factor", self.context.report_error, self.node.span - ) - buffer_name: str = self.node.lhs[0].id.name - buffer = tvm.tir.decl_buffer( - shape, - dtype, - buffer_name, - data, - strides, - elem_offset, - scope, - align, - offset_factor, - buffer_type, - axis_separators, - span=span, - ) - self.context.update_symbol(buffer_name, buffer, self.node) - return buffer - - super().__init__(buffer_decl, def_symbol=True) - - -@register -class AllocBuffer(SpecialStmt): - """Special function alloc_buffer(shape, dtype, data, strides, elem_offset, scope, align, - offset_factor, buffer_type, axis_separators) - - Example - ------- - .. code-block:: python - - A = T.alloc_buffer((128, 128), dtype="float32") - """ - - def __init__(self): - def alloc_buffer( - shape, - dtype="float32", - data=None, - strides=None, - elem_offset=None, - scope="global", - align=-1, - offset_factor=0, - buffer_type="default", - axis_separators=None, - span=None, - ): - if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1: - self.context.report_error( - "`alloc_buffer` must be assigned to a single buffer, " - "e.g. A = alloc_buffer(...)", - self.node.span, - ) - - if strides is None: - strides = [] - align = convert_to_int(align, "align", self.context.report_error, self.node.span) - offset_factor = convert_to_int( - offset_factor, "offset_factor", self.context.report_error, self.node.span - ) - buffer_name: str = self.node.lhs[0].id.name - buffer = tvm.tir.decl_buffer( - shape, - dtype, - buffer_name, - data, - strides, - elem_offset, - scope, - align, - offset_factor, - buffer_type, - axis_separators, - span=span, - ) - if self.context.current_block_scope(): - self.context.current_block_scope().alloc_buffers.append(buffer) - else: - # If it is allocated outside all blocks, allocate it under root block. - self.context.root_alloc_buffers.append(buffer) - self.context.update_symbol(buffer_name, buffer, self.node) - - super().__init__(alloc_buffer, def_symbol=True) - - -@register -class BlockReads(SpecialStmt): - """Special function reads([read_regions], *other_regions) - - Note - ---- - *other_region is an unpackable list of BufferSlice to support - reads syntax sugar like reads(BufferRegion1, BufferRegion2, ...) - - Example - ------- - .. code-block:: python - - T.reads([A[vi: vi + 4, vk: vk + 4], B[vk: vk + 4, vj]]) - """ - - def __init__(self): - def reads( - *read_regions: Union[BufferSlice, List[BufferSlice]], - span: Span = None, - ): - assert self.context, "call 'exit_scope' before 'enter_scope'" - block_scope = self.context.current_block_scope() - if block_scope is None: - self.context.report_error( - "Expected to declare read regions inside a block.", - span, - ) - if block_scope.reads is not None: - self.context.report_error( - "Duplicate write region declaration, " - + "previous one is " - + str(", ".join(str(x) for x in block_scope.reads)), - span, - ) - if len(read_regions) > 1: - for read_region in read_regions: - if not isinstance(read_region, BufferSlice): - self.context.report_error( - "Incorrect input type. Expected *BufferSlice or List[BufferSlice]," - + f" but got {type(read_regions)}", - span, - ) - elif len(read_regions) == 1: - if isinstance(read_regions[0], list): - read_regions = read_regions[0] - - block_scope.reads = read_regions - - super().__init__(reads, def_symbol=False) - - -@register -class BlockWrites(SpecialStmt): - """Special function writes([write_regions], *other_regions) - - Note - ---- - *other_region is an unpackable list of BufferSlice to support - writes syntax sugar like writes(BufferRegion1, BufferRegion2, ...) - - Example - ------- - .. code-block:: python - - T.writes([C[vi: vi + 4, vj]) - """ - - def __init__(self): - def writes( - *write_regions: Union[BufferSlice, List[BufferSlice]], - span: Span = None, - ): - assert self.context, "call 'exit_scope' before 'enter_scope'" - block_scope = self.context.current_block_scope() - if block_scope is None: - self.context.report_error( - "Expected to declare write regions inside a block.", - span, - ) - if block_scope.writes is not None: - self.context.report_error( - "Duplicate write region declaration, " - + "previous one is " - + str(", ".join(str(x) for x in block_scope.writes)), - span, - ) - if len(write_regions) > 1: - for write_region in write_regions: - if not isinstance(write_region, BufferSlice): - self.context.report_error( - "Incorrect input type. Expected *BufferSlice or List[BufferSlice]," - + f" but got {type(write_regions)}", - span, - ) - elif len(write_regions) == 1: - if isinstance(write_regions[0], list): - write_regions = write_regions[0] - block_scope.writes = write_regions - - super().__init__(writes, def_symbol=False) - - -@register -class BlockAttr(SpecialStmt): - """Special function block_attr({attr_key: attr_value}) - - Example - ------- - .. code-block:: python - - T.block_attr({"double_buffer_scope": 1}) - """ - - def __init__(self): - def block_attr(attrs: Mapping[str, Object], span: Span = None): - assert self.context, "call 'exit_scope' before 'enter_scope'" - block_scope = self.context.current_block_scope() - if block_scope is None: - self.context.report_error( - "Expected to declare block annotations inside a block.", - span, - ) - if block_scope.annotations is not None: - self.context.report_error( - "Duplicate block annotations declaration, " - + "previous one is " - + str(block_scope.annotations), - span, - ) - attrs = { - key: String(val) if isinstance(val, str) else val for key, val in attrs.items() - } - block_scope.annotations = attrs - - super().__init__(block_attr, def_symbol=False) - - -class BlockAxis(SpecialStmt): - """Special stmt for defining a spatial block axis - axis.S(dom, iter_value) - - Example - ------- - .. code-block:: python - - vi = T.axis.S(128, i * 4 + j) - """ - - def axis( - self, - var_name: str, - dom: Union[PrimExpr, Range], - value: PrimExpr, - iter_type: int, - span: Optional[Span] = None, - ) -> None: - """ - Helper function for creating block axis - - Parameters - ---------- - var_name : str - The name_hint of var - - dom : Union[PrimExpr, Range] - The iter domain. - - value : PrimExpr - The binding value - - iter_type : int - The iteration type. - - span : Optional[Span] - The location of this for in the source code. - """ - assert self.context, "call 'exit_scope' before 'enter_scope'" - block_scope: BlockInfo = self.context.current_block_scope() - if block_scope is None: - self.context.report_error( - "Expected to declare block axes inside a block.", - self.node.span, - ) - if var_name in [iter_var.var.name for iter_var in block_scope.iter_vars]: - self.context.report_error("Duplicate block axis " + var_name, self.node.span) - - dom = tvm.runtime.convert(dom) - if isinstance(dom, PrimExpr): - dom = tvm.ir.Range(dom) - elif isinstance(dom, tvm.ir.container.Array) and len(dom) == 2: - dom = tvm.ir.Range(dom[0], dom[1]) - elif not isinstance(dom, tvm.ir.Range): - self.context.report_error( - f"Block axis domain expected PrimExpr or Range, but got {type(dom)}", - self.node.span, - ) - block_var = tvm.tir.Var(var_name, dtype=dom.extent.dtype) - value = tvm.runtime.convert(value) - if not isinstance(value, PrimExpr): - self.context.report_error( - f"Block axis value expected PrimExpr, but got {type(value)}", - self.node.span, - ) - iter_var = tvm.tir.IterVar(dom, block_var, iter_type) - block_scope.iter_vars.append(iter_var) - block_scope.iter_values.append(value) - self.context.update_symbol(var_name, block_var, self.node) - - -@register -class BlockAxisSpatial(BlockAxis): - """Special stmt for defining a spatial block axis - axis.spatial(dom, iter_value) - - Example - ------- - .. code-block:: python - - vi = T.axis.spatial(128, k) - """ - - def __init__(self): - def axis_spatial( - dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr, span: Span = None - ): - if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1: - self.context.report_error( - "`axis.spatial` must be assigned to a var, e.g. vi = axis.spatial(...)", - self.node.span, - ) - self.axis(self.node.lhs[0].id.name, dom, value, IterVar.DataPar) - - super().__init__(axis_spatial, def_symbol=True) - - def signature(self) -> Tuple[str, Tuple[list, list, Any]]: - return "tir.axis.spatial", get_param_list(self.func) - - -@register -class BlockAxisS(BlockAxis): - """The sugar special stmt for defining a spatial block axis - axis.S(dom, iter_value) - - Example - ------- - .. code-block:: python - - vi = T.axis.S(128, k) - """ - - def __init__(self): - def axis_spatial( - dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr, span: Span = None - ): - if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1: - self.context.report_error( - "`axis.S` must be assigned to a var, e.g. vi = axis.S(...)", - self.node.span, - ) - self.axis(self.node.lhs[0].id.name, dom, value, IterVar.DataPar) - - super().__init__(axis_spatial, def_symbol=True) - - def signature(self) -> Tuple[str, Tuple[list, list, Any]]: - return "tir.axis.S", get_param_list(self.func) - - -@register -class BlockAxisReduce(BlockAxis): - """Special stmt for defining a reduce block axis - axis.reduce(dom, iter_value) - - Example - ------- - .. code-block:: python - - vi = T.axis.reduce(128, k) - """ - - def __init__(self): - def axis_reduce( - dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr, span: Span = None - ): - if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1: - self.context.report_error( - "`axis.reduce` must be assigned` to a var, e.g. vi = axis.reduce(...)", - self.node.span, - ) - self.axis(self.node.lhs[0].id.name, dom, value, IterVar.CommReduce) - - super().__init__(axis_reduce, def_symbol=True) - - def signature(self) -> Tuple[str, Tuple[list, list, Any]]: - return "tir.axis.reduce", get_param_list(self.func) - - -@register -class BlockAxisR(BlockAxis): - """The sugar special stmt for defining a reduce block axis - axis.R(dom, iter_value) - - Example - ------- - .. code-block:: python - - vi = T.axis.R(128, k) - """ - - def __init__(self): - def axis_reduce( - dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr, span: Span = None - ): - if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1: - self.context.report_error( - "`axis.R` must be assigned to a var, e.g. vi = axis.R(...)", - self.node.span, - ) - self.axis(self.node.lhs[0].id.name, dom, value, IterVar.CommReduce) - - super().__init__(axis_reduce, def_symbol=True) - - def signature(self) -> Tuple[str, Tuple[list, list, Any]]: - return "tir.axis.R", get_param_list(self.func) - - -@register -class BlockAxisScan(BlockAxis): - """Special stmt for defining a ordered block axis - axis.scan(dom, iter_value) - - Example - ------- - .. code-block:: python - - vi = T.axis.scan(128, k) - """ - - def __init__(self): - def axis_scan( - dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr, span: Span = None - ): - if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1: - self.context.report_error( - "`axis.scan` must be assigned to a var, e.g. vi = axis.scan(...)", - self.node.span, - ) - self.axis(self.node.lhs[0].id.name, dom, value, IterVar.Ordered) - - super().__init__(axis_scan, def_symbol=True) - - def signature(self) -> Tuple[str, Tuple[list, list, Any]]: - return "tir.axis.scan", get_param_list(self.func) - - -@register -class BlockAxisOpaque(BlockAxis): - """Special stmt for defining a opaque block axis - axis.opaque(dom, iter_value) - - Example - ------- - .. code-block:: python - - vi = T.axis.opaque(128, k) - """ - - def __init__(self): - def axis_opaque( - dom: Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], value: PrimExpr, span: Span = None - ): - if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1: - self.context.report_error( - "`axis.opaque` must be assigned to a var, e.g. vi = axis.opaque(...)", - self.node.span, - ) - self.axis(self.node.lhs[0].id.name, dom, value, IterVar.DimInfo) - - super().__init__(axis_opaque, def_symbol=True) - - def signature(self) -> Tuple[str, Tuple[list, list, Any]]: - return "tir.axis.opaque", get_param_list(self.func) - - -@register -class BlockAxisRemap(BlockAxis): - """Special stmt for remapping loops vars to block axes. - axis.remap(iter_type, iter_value) - - Note - ---- - Iter_type is a string consisting of 'S' and 'R', where 'S' means - for spatial and 'R' means for reduce. - - Example - ------- - .. code-block:: python - - vi, vj = T.axis.remap("SS", [i, j]) - """ - - def __init__(self): - def axis_remap(iter_types: str, loop_vars: List[tvm.tir.expr.Var], span: Span = None): - if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) >= 1: - self.context.report_error( - "`axis.remap` must be assigned to one or more vars, " - "e.g. vi, vj = axis.remap(...)", - self.node.span, - ) - var_num: int = len(self.node.lhs) - if var_num != len(iter_types): - self.context.report_error( - f"`iter_type` expected {var_num} charactor(s), " - f"but got {len(iter_types)}: {iter_types}", - span, - ) - if var_num != len(loop_vars): - self.context.report_error( - f"`iter_type` expected {var_num} loop var(s), " - f"but got {len(loop_vars)}: {loop_vars}", - span, - ) - for var, iter_ty, loop_var in zip(self.node.lhs, iter_types, loop_vars): - iter_type: int - if iter_ty == "S": - iter_type = IterVar.DataPar - elif iter_ty == "R": - iter_type = IterVar.CommReduce - else: - self.context.report_error( - f'`iter_type` only expected "S" (for spatial) or "R" (for reduce), ' - f'but got "{iter_ty}"', - span, - ) - - if not isinstance(loop_var, tvm.tir.expr.Var): - self.context.report_error( - f"Values of `axis.remap` expected single loop var, but got {loop_var}", - loop_var.span, - ) - loops = self.context.loop_stack - if loop_var not in loops: - self.context.report_error( - f"Cannot find loop var {loop_var} in loop nesting.", - span, - ) - self.axis(var.id.name, loops[loop_var], loop_var, iter_type) - - super().__init__(axis_remap, def_symbol=True) - - def signature(self) -> Tuple[str, Tuple[list, list, Any]]: - return "tir.axis.remap", get_param_list(self.func) - - -@register -class BlockPredicate(SpecialStmt): - """Special function where(predicate) - - Example - ------- - .. code-block:: python - - T.where(i < 4) - """ - - def __init__(self): - def where(predicate, span=None): - assert self.context, "call 'exit_scope' before 'enter_scope'" - block_scope = self.context.current_block_scope() - if block_scope is None: - self.context.report_error( - "Expected to declare the predicate inside a block.", - span, - ) - if block_scope.predicate is not None: - self.context.report_error( - "Duplicate block predicate declaration, " - + "previous one is " - + str(block_scope.predicate), - span, - ) - - block_scope.predicate = predicate - - super().__init__(where, def_symbol=False) - - -@register -class VarDef(SpecialStmt): - """Special function for defining a Var""" - - def __init__(self): - def var(dtype, span): - assert isinstance( - self.node, ast.Assign - ), f"VarDef expected ast.Assign but got {type(self.node)}" - names = [x.id.name for x in self.node.lhs] - if len(names) != 1: - self.context.report_error( - f"VarDef expected assign to only one var, but got {names}", span - ) - v = Var(names[0], dtype, span=span) - self.context.update_symbol(v.name, v, self.node) - - super().__init__(var, def_symbol=True) - - -@register -class BufferVarDef(SpecialStmt): - """Special function for defining a variable of pointer type""" - - def __init__(self): - def buffer_var(dtype, storage_scope, span): - assert isinstance( - self.node, ast.Assign - ), f"BufferVarDef expected ast.Assign but got {type(self.node)}" - names = [x.id.name for x in self.node.lhs] - if len(names) != 1: - self.context.report_error( - f"VarDef expected assign to only one var, but got {names}", span - ) - ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype), storage_scope) - v = Var(names[0], ptr_type, span=span) - self.context.update_symbol(v.name, v, self.node) - - super().__init__(buffer_var, def_symbol=True) - - -@register -class EnvThread(SpecialStmt): - """Bind a var to thread env""" - - def __init__(self): - def env_thread(env_name, span): - assert isinstance( - self.node, ast.Assign - ), f"EnvThread expected ast.Assign but got {type(self.node)}" - names = [x.id.name for x in self.node.lhs] - if len(names) != 1: - self.context.report_error( - f"VarDef expected assign to only one var, but got {names}", span - ) - v = Var(names[0], dtype="int32", span=span) - self.context.func_var_env_dict[v] = env_name - self.context.update_symbol(v.name, v, self.node) - - super().__init__(env_thread, def_symbol=True) - - -@register -class FuncAttr(SpecialStmt): - """Special Stmt for declaring the DictAttr of PrimFunc - Example - ------- - .. code-block:: python - T.func_attr({"tir.noalias": True, "global_symbol"}) - """ - - def __init__(self): - def func_attr(dict_attr, span): - self.context.func_dict_attr = dict_attr - - super().__init__(func_attr, def_symbol=False) - - -@register -class TargetAttrValue(SpecialStmt): - """Special Stmt for target attr value. - Example - ------- - .. code-block:: python - T.target("llvm") - """ - - def __init__(self): - def target(*args, span): - self.context.report_error(f"T.target should not appear as a stmt", span) - - super().__init__(target, def_symbol=False) - - def __call__(self, target_config): - if not isinstance(target_config, (str, dict)): - raise ValueError( - f"T.target expected a config dict or string, but got {type(target_config)}" - ) - return Target(target_config) diff --git a/python/tvm/script/parser_v1/tir/ty.py b/python/tvm/script/parser_v1/tir/ty.py deleted file mode 100644 index b17b571e88e7..000000000000 --- a/python/tvm/script/parser_v1/tir/ty.py +++ /dev/null @@ -1,226 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""TVM Script Parser Typing Class for TIR - -This module provides typing class for TVM script type annotation usage, it can be viewed as -a wrapper for uniform Type system in IR -""" -# pylint: disable=invalid-name -from numbers import Integral - -import tvm -from .special_stmt import SpecialStmt, convert_to_int - - -class TypeGeneric: # pylint: disable=too-few-public-methods - """Base class for all the TVM script typing class""" - - def evaluate(self): - """Return an actual ir.Type Object that this Generic class wraps""" - raise TypeError("Cannot get tvm.Type from a generic type") - - def require_type_generic_at(self, idx): # pylint: disable=unused-argument - """If True, the `idx`th type argument must be TypeGeneric""" - return True - - # This function is added here to avoid a pylint error - # for T.int/float below not being callable - def __call__(self): - raise NotImplementedError() - - -class ConcreteType(TypeGeneric): # pylint: disable=too-few-public-methods, abstract-method - """TVM script typing class for uniform Type objects - - Params - ------ - vtype: Union[str, tvm.ir.Type] - - The IR type represented by the type annotation. If a string - (e.g. "float32"), this represents a `ir.PrimType` generated - from that string. If a `ir.Type` is provided, this represents - the type provided. - """ - - def __init__(self, vtype): - if isinstance(vtype, tvm.ir.Type): - self.type = vtype - else: - self.type = tvm.ir.PrimType(vtype) - - def __call__(self, *args): # pylint: disable=arguments-differ - pass - - def evaluate(self): - return self.type - - -class VoidType(ConcreteType): # pylint: disable=too-few-public-methods, abstract-method - """TVM script typing class for void type""" - - def __init__(self): - super().__init__("") - - -class GenericPtrType(TypeGeneric): # pylint: disable=abstract-method - """TVM script typing class generator for PtrType - - [] operator is overloaded, accepts a ConcreteType and an optional storage scope string, - returns a ConcreteType wrapping PtrType - """ - - def __getitem__(self, args): - if isinstance(args, TypeGeneric): - args = [args] - if len(args) == 1: - vtype, scope = args[0], "global" - elif len(args) == 2: - vtype, scope = args[0], args[1] - else: - raise TypeError(f"Illegal type argument num for Ptr") - if not isinstance(vtype, TypeGeneric): - raise TypeError(f"Ptr expects a type argument, but received {type(vtype).__name__}") - if not isinstance(scope, str): - raise TypeError(f"Ptr expects storage scope argument be a string") - return ConcreteType(tvm.ir.PointerType(vtype.evaluate(), scope)) - - def require_type_generic_at(self, idx): - return idx != 1 # the second argument is storage scope for Ptr - - -class GenericTupleType(TypeGeneric): # pylint: disable=abstract-method - """TVM script typing class generator for TupleType - - [] operator is overloaded, accepts a list of ConcreteType and returns a ConcreteType - wrapping TupleType - """ - - def __getitem__(self, vtypes): - if isinstance(vtypes, TypeGeneric): - vtypes = [vtypes] - return ConcreteType(tvm.ir.TupleType([vtype.evaluate() for vtype in vtypes])) - - -class GenericBufferType(SpecialStmt): # pylint: disable=too-few-public-methods, abstract-method - """TVM script typing class for uniform Type objects""" - - def __init__(self, vtype): - def match_buffer_syntax_sugar( - shape, - dtype: str = "float32", - name: str = None, - data=None, - strides=None, - elem_offset=None, - scope="global", - align=-1, - offset_factor=0, - buffer_type="default", - axis_separators=None, - span=None, - ): - if strides is None: - strides = [] - align = convert_to_int(align, "align", self.context.report_error, self.node.span) - offset_factor = convert_to_int( - offset_factor, "offset_factor", self.context.report_error, self.node.span - ) - buffer = tvm.tir.decl_buffer( - shape, - dtype, - name, - data, - strides, - elem_offset, - scope, - align, - offset_factor, - buffer_type, - axis_separators, - span=span, - ) - return buffer - - self.type = vtype - super().__init__(match_buffer_syntax_sugar, def_symbol=True) - - def __call__( - self, - shape, - dtype="float32", - *, - name: str = None, - data=None, - strides=None, - elem_offset=None, - scope="global", - align=-1, - offset_factor=0, - buffer_type="default", - axis_separators=None, - span=None, - ): - """ - This function is for Buffer(...) syntax sugar. - """ - pass # pylint: disable=unnecessary-pass - - def __getitem__(self, args): - """ - This function is for Buffer[...] syntax sugar - Note that args is the list of all arguments - """ - if len(args) < 2: - raise ValueError("T.Buffer[...] needs at least two arguments: shape and dtype.") - - shape = args[0] - dtype = args[1] - - valid_shape = isinstance(shape, (tvm.ir.PrimExpr, Integral, tuple, list)) - valid_dtype = isinstance(dtype, str) - if not (valid_shape and valid_dtype): - raise ValueError( - "The first argument of T.Buffer[...] needs to be a tuple, " - "followed by the second argument dtype as a string" - ) - - -# add all floating point and integer datatypes to the module -for _dtype in ["float", "uint", "int"]: - for _size in ["8", "16", "32", "64"]: - for _lanes in ["", "x4", "x8", "x16", "x32", "x64"]: - _name = _dtype + _size + _lanes - globals()[_name] = ConcreteType(_name) - - -# All other DataType annotations are represented with the same string -# as is used by `tvm.runtime.DataType`. This does redefine the Python -# built-in bool, but only within the context of `tvm.script.tir.ty` -# and `tvm.script.tir` modules. The `T.boolean` alias is maintained -# for backwards compatibility. - -bool = ConcreteType("bool") # pylint: disable=redefined-builtin -boolean = bool - - -handle = ConcreteType("handle") -void = VoidType() -Ptr = GenericPtrType() -Tuple = GenericTupleType() -# we don't have 'buffer' type on the cpp side -# thus 'handle' is used here for convenience's sake -Buffer = GenericBufferType("handle") diff --git a/python/tvm/script/parser_v1/utils.py b/python/tvm/script/parser_v1/utils.py deleted file mode 100644 index c655a6223740..000000000000 --- a/python/tvm/script/parser_v1/utils.py +++ /dev/null @@ -1,105 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Helper functions in TVM Script Parser""" - -from typing import Callable, List, Any, Optional, Tuple - -import inspect -import synr - -from tvm.ir import Span, SourceName -from tvm.error import DiagnosticError - - -def get_param_list( - func: Callable, -) -> Tuple[List[str], List[Tuple[str, Tuple[Any, ...]]], Optional[str]]: - """Get the parameter list from definition of function""" - full_arg_spec: inspect.FullArgSpec = inspect.getfullargspec(func) - - args: List[str] - defaults: Optional[Tuple[Any, ...]] - kwonlyargs: List[str] - args, defaults, kwonlyargs = ( - full_arg_spec.args, - full_arg_spec.defaults, - full_arg_spec.kwonlyargs, - ) - - if defaults is None: - defaults = tuple() - - if full_arg_spec.varkw is not None: - raise RuntimeError( - "TVM Script register error : variable keyword argument is not supported now" - ) - - if len(kwonlyargs) == 1 and kwonlyargs[0] == "span": - pass - elif not len(kwonlyargs) == 0: - raise RuntimeError("TVM Script register error : keyword only argument is not supported now") - - pos_only: List[str] = list() - for arg in args[: len(args) - len(defaults)]: - if arg != "span": - pos_only.append(arg) - kwargs: List[Tuple[str, Tuple[Any, ...]]] = list() - for default, arg in zip(defaults, args[len(args) - len(defaults) :]): - if arg != "span": - kwargs.append((arg, default)) - - return pos_only, kwargs, full_arg_spec.varargs - - -def tvm_span_from_synr(span: synr.ast.Span) -> Span: - """Convert a synr span to a TVM span""" - return Span( - SourceName(span.filename), - span.start_line, - span.end_line, - span.start_column, - span.end_column, - ) - - -def synr_span_from_tvm(span: Span) -> synr.ast.Span: - """Convert a TVM span to a synr span""" - return synr.ast.Span( - span.source_name.name, - span.line, - span.column, - span.end_line, - span.end_column, - ) - - -def call_with_error_reporting( - report_error, - node_span, - func, - *args, - **kwargs, -): - """Call function with exception handling and report error using node_span""" - try: - return func(*args, **kwargs) - except DiagnosticError: - raise - except Exception as err: # pylint: disable=broad-except - # printing last non-empty row of error message. - error_msg = list(filter(None, str(err).split("\n")))[-1] - report_error(error_msg, node_span) diff --git a/src/tir/schedule/error.h b/src/tir/schedule/error.h index e28164c6c39b..d344f4687305 100644 --- a/src/tir/schedule/error.h +++ b/src/tir/schedule/error.h @@ -42,7 +42,7 @@ class ScheduleError : public tvm::runtime::Error { * "Some error occurred on block {0} and loop {1} blah blah" * And renderer will replace {0} and {1} according to the list provided LocationsOfInterest. Right * now it only printed out all the locations in plain text, but in the future, we may want to mark - * the IR with underscores and attach names to each location of interest, like what synr does. + * the IR with underscores and attach names to each location of interest. */ virtual String DetailRenderTemplate() const = 0; /*! diff --git a/tests/python/unittest/test_tvmscript_spans.py b/tests/python/unittest/test_tvmscript_spans.py deleted file mode 100644 index 2c0522e3e3c9..000000000000 --- a/tests/python/unittest/test_tvmscript_spans.py +++ /dev/null @@ -1,73 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - - -from tvm.script.parser_v1 import tir as T - - -@T.prim_func -def loops() -> None: - for i in T.parallel(0, 2): - for j in T.serial(0, 1): - for z in T.vectorized(3, 4): - T.evaluate(0) - - -def test_loops(): - start_line = 23 - parsed = loops - - assert parsed.span.line == start_line - - assert parsed.body.span.line == start_line + 1 - assert parsed.body.min.span.column == 25 - assert parsed.body.extent.span.column == 28 - assert parsed.body.extent.span.line == start_line + 1 - - assert parsed.body.body.span.line == start_line + 2 - assert parsed.body.body.loop_var.span.line == start_line + 2 - assert parsed.body.body.loop_var.span.column == 13 - - assert parsed.body.body.body.span.line == start_line + 3 - assert parsed.body.body.body.span.column == 22 - - assert parsed.body.body.body.body.span.line == start_line + 4 - assert parsed.body.body.body.body.span.column == 17 - - -@T.prim_func -def statements() -> None: - T.evaluate(1) - T.evaluate("test") - - -def test_statements(): - start_line = 53 - parsed = statements - - assert parsed.body.span.line == start_line + 1 - - assert parsed.body[0].span.line == start_line + 1 - assert parsed.body[0].span.column == 5 - - assert parsed.body[0].span.line == start_line + 1 - assert parsed.body[0].span.column == 5 - - -if __name__ == "__main__": - test_loops() - test_statements() diff --git a/tests/scripts/ci.py b/tests/scripts/ci.py index b11ee538dc68..16389d29354c 100755 --- a/tests/scripts/ci.py +++ b/tests/scripts/ci.py @@ -274,7 +274,6 @@ def docs( requirements = [ "Sphinx==4.2.0", "tlcpack-sphinx-addon==0.2.1", - "synr==0.5.0", "image==1.5.33", # Temporary git link until a release is published "git+https://github.com/sphinx-gallery/sphinx-gallery.git@6142f1791151849b5bec4bf3959f75697ba226cd",