From 037e94d865f2041e1ecce02d32e32e001609e22b Mon Sep 17 00:00:00 2001 From: Yuanjing Shi Date: Tue, 7 Dec 2021 16:19:47 -0800 Subject: [PATCH] [TVMScript] Add for loop syntax sugar (#9620) * add for loop syntax sugar * remove prints * better doc * finish thread binding * fix CI * fix CI * address comments * update sstub * fix CI * remove failed test * update stub * address comments * add decorator --- python/tvm/script/tir/__init__.pyi | 40 +++++++++++++++- python/tvm/script/tir/scope_handler.py | 33 ++++++++++--- python/tvm/testing/__init__.py | 2 + python/tvm/testing/tir.py | 48 +++++++++++++++++++ .../unittest/test_tvmscript_error_report.py | 38 +-------------- .../unittest/test_tvmscript_syntax_sugar.py | 39 +++++++++++++++ 6 files changed, 155 insertions(+), 45 deletions(-) create mode 100644 python/tvm/testing/tir.py diff --git a/python/tvm/script/tir/__init__.pyi b/python/tvm/script/tir/__init__.pyi index 663cd20cdfb3..ac4ee3018f7c 100644 --- a/python/tvm/script/tir/__init__.pyi +++ b/python/tvm/script/tir/__init__.pyi @@ -322,36 +322,72 @@ def Assert(condition: Union[PrimExpr, builtins.bool], message: str) -> PrimExpr: """ Scope handler - Loops """ - +@overload def serial( begin: Union[PrimExpr, int], end: Union[PrimExpr, int], annotations: Optional[Mapping[str, Object]] = None, ) -> Iterable[IterVar]: ... +@overload +def serial( + end: Union[PrimExpr, int], + annotations: Optional[Mapping[str, Object]] = None, +) -> Iterable[IterVar]: ... +@overload def parallel( begin: Union[PrimExpr, int], end: Union[PrimExpr, int], annotations: Optional[Mapping[str, Object]] = None, ) -> Iterable[IterVar]: ... +@overload +def parallel( + end: Union[PrimExpr, int], + annotations: Optional[Mapping[str, Object]] = None, +) -> Iterable[IterVar]: ... +@overload def vectorized( begin: Union[PrimExpr, int], end: Union[PrimExpr, int], annotations: Optional[Mapping[str, Object]] = None, ) -> Iterable[IterVar]: ... +@overload +def vectorized( + end: Union[PrimExpr, int], + annotations: Optional[Mapping[str, Object]] = None, +) -> Iterable[IterVar]: ... +@overload def unroll( begin: Union[PrimExpr, int], end: Union[PrimExpr, int], annotations: Optional[Mapping[str, Object]] = None, ) -> Iterable[IterVar]: ... +@overload +def unroll( + end: Union[PrimExpr, int], + annotations: Optional[Mapping[str, Object]] = None, +) -> Iterable[IterVar]: ... +@overload def thread_binding( begin: Union[PrimExpr, int], end: Union[PrimExpr, int], thread: str, annotations: Optional[Mapping[str, Object]] = None, ) -> Iterable[IterVar]: ... +@overload +def thread_binding( + end: Union[PrimExpr, int], + thread: str, + annotations: Optional[Mapping[str, Object]] = None, +) -> Iterable[IterVar]: ... +@overload def for_range( begin: Union[PrimExpr, int], - end: Union[PrimExpr, int] = None, + end: Union[PrimExpr, int], + annotations: Optional[Mapping[str, Object]] = None, +) -> Iterable[IterVar]: ... +@overload +def for_range( + end: Union[PrimExpr, int], annotations: Optional[Mapping[str, Object]] = None, ) -> Iterable[IterVar]: ... def grid(*extents: Union[PrimExpr, int]) -> Iterable[Sequence[IterVar]]: ... diff --git a/python/tvm/script/tir/scope_handler.py b/python/tvm/script/tir/scope_handler.py index 0ce02d4cc244..42f84bc40f60 100644 --- a/python/tvm/script/tir/scope_handler.py +++ b/python/tvm/script/tir/scope_handler.py @@ -500,9 +500,12 @@ class Serial(ForScopeHandler): def __init__(self): def serial( begin: PrimExpr, - end: PrimExpr, + end: PrimExpr = None, annotations: Optional[Mapping[str, Object]] = None, ): + if end is None: + end = begin + begin = 0 self.create_loop_info(begin, end, ForKind.SERIAL, annotations=annotations) super().__init__(serial) @@ -515,9 +518,12 @@ class Parallel(ForScopeHandler): def __init__(self): def parallel( begin: PrimExpr, - end: PrimExpr, + end: PrimExpr = None, annotations: Optional[Mapping[str, Object]] = None, ): + if end is None: + end = begin + begin = 0 self.create_loop_info(begin, end, ForKind.PARALLEL, annotations=annotations) super().__init__(parallel) @@ -530,9 +536,12 @@ class Vectorized(ForScopeHandler): def __init__(self): def vectorized( begin: PrimExpr, - end: PrimExpr, + end: PrimExpr = None, annotations: Optional[Mapping[str, Object]] = None, ): + if end is None: + end = begin + begin = 0 self.create_loop_info(begin, end, ForKind.VECTORIZED, annotations=annotations) super().__init__(vectorized) @@ -545,9 +554,12 @@ class Unroll(ForScopeHandler): def __init__(self): def unroll( begin: PrimExpr, - end: PrimExpr, + end: PrimExpr = None, annotations: Optional[Mapping[str, Object]] = None, ): + if end is None: + end = begin + begin = 0 self.create_loop_info(begin, end, ForKind.UNROLLED, annotations=annotations) super().__init__(unroll) @@ -560,10 +572,19 @@ class ThreadBinding(ForScopeHandler): def __init__(self): def thread_binding( begin: PrimExpr, - end: PrimExpr, - thread: str, + end: PrimExpr = None, + thread: str = None, annotations: Optional[Mapping[str, Object]] = None, ): + if thread is None: + if isinstance(end, str): # handle case like thread_binding(128, "threadIdx.x") + thread = end + end = None + else: + raise ValueError("Thread cannot be None for thread_binding") + if end is None: + end = begin + begin = 0 thread_iter_var = IterVar(None, None, IterVar.ThreadIndex, thread) self.create_loop_info( begin, diff --git a/python/tvm/testing/__init__.py b/python/tvm/testing/__init__.py index d84846725ec4..9a18f1689100 100644 --- a/python/tvm/testing/__init__.py +++ b/python/tvm/testing/__init__.py @@ -28,5 +28,7 @@ from .popen_pool import call_py_ffi, call_cpp_py_ffi, fast_summation, slow_summation from .popen_pool import timeout_job +from .tir import check_error + from . import auto_scheduler from . import autotvm diff --git a/python/tvm/testing/tir.py b/python/tvm/testing/tir.py new file mode 100644 index 000000000000..f9115fc61bfa --- /dev/null +++ b/python/tvm/testing/tir.py @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, import-outside-toplevel, unused-variable +"""Common utility functions in TVM tir""" +import inspect +import tvm +from tvm.ir.diagnostics import override_renderer + + +def check_error(func, rel_lineno): + """check if TIR script throws error""" + # Override the default renderer to accumulate errors + errors = [] + + def render(e): + for d in e.diagnostics: + errors.append(d) + + override_renderer(render) + # The diagnostic context throws an exception when it gets an error + try: + source_code = inspect.getsource(func) + source_code = "@T.prim_func\n" + source_code + from tvm.script import from_source + + # to avoid cyclic import + from_source(source_code) + except tvm.error.DiagnosticError as e: + pass + assert len(errors) == 1, errors + for d in errors: + assert ( + d.span.line - 1 == rel_lineno + ), f"Expected error to be on line {rel_lineno}, but it was on {d.span.line - 1}" diff --git a/tests/python/unittest/test_tvmscript_error_report.py b/tests/python/unittest/test_tvmscript_error_report.py index 995cf2afa30b..102e6e3c4955 100644 --- a/tests/python/unittest/test_tvmscript_error_report.py +++ b/tests/python/unittest/test_tvmscript_error_report.py @@ -19,6 +19,7 @@ import sys import tvm from tvm import tir +from tvm.testing import check_error from tvm.script import tir as T from tvm.ir.diagnostics import override_renderer import inspect @@ -32,20 +33,6 @@ def test_buffer_bind(): check_error(buffer_bind_missing_args, 2) -def range_missing_args(a: T.handle) -> None: - A = T.match_buffer(a, (16, 16), "float32") - - T.attr(A, "realize_scope", "") - T.realize(A[0:16, 0:16], "") - for i in T.serial(16): # error - for j in T.serial(0, 16): - A[i, j] = 0.0 - - -def test_range_missing_args(): - check_error(range_missing_args, 6) - - def undefined_buffer(a: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") @@ -509,29 +496,6 @@ def test_implicit_root_has_attrs(): check_error(implicit_root_has_axes, 2) -def check_error(func, rel_lineno): - # Override the default renderer to accumulate errors - errors = [] - - def render(e): - for d in e.diagnostics: - errors.append(d) - - override_renderer(render) - # The diagnostic context throws an exception when it gets an error - try: - source_code = inspect.getsource(func) - source_code = "@T.prim_func\n" + source_code - tvm.script.from_source(source_code) - except tvm.error.DiagnosticError as e: - pass - assert len(errors) == 1, errors - for d in errors: - assert ( - d.span.line - 1 == rel_lineno - ), f"Expected error to be on line {rel_lineno}, but it was on {d.span.line - 1}" - - @T.prim_func def elementwise_not_affine(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128, 128, 128)) diff --git a/tests/python/unittest/test_tvmscript_syntax_sugar.py b/tests/python/unittest/test_tvmscript_syntax_sugar.py index 1d4b916e9d4a..b8d123236982 100644 --- a/tests/python/unittest/test_tvmscript_syntax_sugar.py +++ b/tests/python/unittest/test_tvmscript_syntax_sugar.py @@ -20,6 +20,7 @@ import pytest from tvm.ir import assert_structural_equal from tvm.script import tir as T +from tvm.testing import check_error @T.prim_func @@ -62,5 +63,43 @@ def test_reads_writes_syntax_sugar(): assert_structural_equal(transformed_matmul_no_syntax_sugar, transformed_matmul_syntax_sugar) +@T.prim_func +def loop_no_syntax_sugar(a: T.handle) -> None: + A = T.match_buffer(a, (128, 128, 128, 128)) + for i in T.serial(0, 128): + for j in T.parallel(0, 128): + for k in T.vectorized(0, 128): + for x in T.unroll(0, 128): + for y in T.thread_binding(0, 128, thread="threadIdx.x"): + for z in T.thread_binding(0, 128, thread="threadIdx.x"): + A[i, j, k, x] = A[i, j, k, x] * 2.0 + + +@T.prim_func +def loop_syntax_sugar(a: T.handle) -> None: + A = T.match_buffer(a, (128, 128, 128, 128)) + for i in T.serial(128): + for j in T.parallel(128): + for k in T.vectorized(128): + for x in T.unroll(128): + for y in T.thread_binding(128, "threadIdx.x"): + for z in T.thread_binding(128, thread="threadIdx.x"): + A[i, j, k, x] = A[i, j, k, x] * 2.0 + + +def loop_syntax_sugar_fail(a: T.handle) -> None: + A = T.match_buffer(a, (128,)) + for i in T.thread_binding(128, 128): + A[i] = A[i] * 2.0 + + +def test_loop_syntax_sugar(): + assert_structural_equal(loop_no_syntax_sugar, loop_syntax_sugar) + + +def test_syntax_sugar_fail(): + check_error(loop_syntax_sugar_fail, 3) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:]))