From e70d8eb491fdfbe23988ce66deecae036f4e4370 Mon Sep 17 00:00:00 2001 From: Zihao Date: Sat, 11 Mar 2023 10:50:26 -0800 Subject: [PATCH 1/9] init --- python/tvm/script/ir_builder/tir/ir.py | 27 +++++++++++++++++++ .../unittest/test_tvmscript_roundtrip.py | 18 +++++++++++++ 2 files changed, 45 insertions(+) diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index d65f9adea86f..cb3cdbbf5054 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1665,6 +1665,32 @@ def target(target_config: Union[Dict, str]) -> Target: ) return Target(target_config) +def Range(begin: PrimExpr, end: Optional[PrimExpr]) -> Range: + """Create a Range node. + + Parameters + ---------- + begin : PrimExpr + The begin value of the range. + + end : Optional[PrimExpr] + The end value of the range. + + Returns + ------- + res : Range + The Range node. + """ + if not isinstance(begin, PrimExpr): + raise ValueError( + f"T.Range expected a PrimExpr as begin value, but got {type(begin)} instead." + ) + if not isinstance(end, PrimExpr) and end is not None: + raise ValueError( + f"T.Range expected a Optional[PrimExpr] as end value, but got {type(end)} instead." + ) + return Range(begin, end) + class meta_var: # pylint: disable=invalid-name """A meta variable used in TVMScript metaprogramming. It means that the value of the variable @@ -2109,4 +2135,5 @@ def wrapped(*args, **kwargs): "Let", "IterVar", "CommReducer", + "Range" ] diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index c956f3bb02b9..350d06395b6c 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -3623,6 +3623,24 @@ def main(A: T.handle, B: T.handle): return main +def iter_var_range(): + T.prim_func + + def func(): + blockIdx_x = T.int32() + threadIdx_x = T.int32() + T.func_attr( + { + "tir.device_thread_axis": [ + T.iter_var(blockIdx_x, T.Range(0, 1), "ThreadIndex", "blockIdx.x"), + T.iter_var(threadIdx_x, T.Range(0, 32), "ThreadIndex", "threadIdx.x"), + ] + } + ) + + return func + + ir_generator = tvm.testing.parameter( launch_env_thread, opt_gemm_normalize, From b3b98fa461f75c8f7de21c2e1986671ee57daab4 Mon Sep 17 00:00:00 2001 From: Zihao Date: Sun, 12 Mar 2023 05:22:39 -0700 Subject: [PATCH 2/9] upd --- python/tvm/script/ir_builder/tir/ir.py | 8 ++ python/tvm/tir/op.py | 104 ++++++++++++++++++++++++- 2 files changed, 111 insertions(+), 1 deletion(-) diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index cb3cdbbf5054..ce9c9dce17a2 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1808,6 +1808,11 @@ def wrapped(*args, **kwargs): tvm_bmma_sync = _op_wrapper(_tir_op.tvm_bmma_sync) tvm_fill_fragment = _op_wrapper(_tir_op.tvm_fill_fragment) tvm_store_matrix_sync = _op_wrapper(_tir_op.tvm_store_matrix_sync) +tvm_storage_sync = _op_wrapper(_tir_op.tvm_storage_sync) +tvm_warp_shuffle = _op_wrapper(_tir_op.tvm_warp_shuffle) +tvm_warp_shuffle_up = _op_wrapper(_tir_op.tvm_warp_shuffle_up) +tvm_warp_shuffle_down = _op_wrapper(_tir_op.tvm_warp_shuffle_down) +tvm_warp_activemask = _op_wrapper(_tir_op.tvm_warp_activemask) ptx_wait_group = _op_wrapper(_tir_op.ptx_wait_group) ptx_commit_group = _op_wrapper(_tir_op.ptx_commit_group) assume = _op_wrapper(_tir_op.assume) @@ -2068,6 +2073,9 @@ def wrapped(*args, **kwargs): "tvm_bmma_sync", "tvm_fill_fragment", "tvm_store_matrix_sync", + "tvm_storage_sync", + "tvm_warp_shuffle", + "tvm_warp_activemask", "ptx_mma", "ptx_mma_sp", "ptx_ldmatrix", diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 0a9c4fdfaa52..498926045667 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -569,7 +569,8 @@ def lookup_param(param_name, span=None): def tvm_thread_allreduce(*freduce_args): - """ + """Perform allreduce inside threadblock. + Parameters ---------- freduce_args : Expr @@ -583,6 +584,107 @@ def tvm_thread_allreduce(*freduce_args): return call_intrin("handle", "tir.tvm_thread_allreduce", *freduce_args) +def tvm_storage_sync(storage_scope): + """Perform synchronization in specified scope. + + Parameters + ---------- + storage_scope : str + The storage scope to perform synchronization. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("handle", "tir.tvm_storage_sync", storage_scope) + + +def tvm_warp_shuffle(mask, value, warp_id, width, warp_size): + """Exchange value between threads inside a warp. + + Parameters + ---------- + mask : PrimExpr + The warp mask indicates active threads inside warp. + value : PrimExpr + The value to exchange. + warp_id : PrimExpr + The source lane index to fetch value. + width : PrimExpr + The width of sub-sections to perform warp shuffle. + warp_size : PrimExpr + The warp size. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("handle", "tir.tvm_warp_shuffle", mask, value, warp_id, width, warp_size) + + +def tvm_warp_shuffle_up(mask, value, offset, width, warp_size): + """Copy value from a lane with lower (by offset) index relative to caller. + + Parameters + ---------- + mask : PrimExpr + The warp mask indicates active threads inside warp. + value : PrimExpr + The value to exchange. + offset : PrimExpr + The difference between source lane index and destination lane index: + `offset = dst_lane_idx - src_lane_idx` + width : PrimExpr + The width of sub-sections to perform warp shuffle. + warp_size : PrimExpr + The warp size. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("handle", "tir.tvm_warp_shuffle_up", mask, value, offset, width, warp_size) + + +def tvm_warp_shuffle_down(mask, value, offset, width, warp_size): + """Copy value from a lane with higher (by offset) index relative to caller. + + Parameters + ---------- + mask : PrimExpr + The warp mask indicates active threads inside warp. + value : PrimExpr + The value to exchange. + offset : PrimExpr + The difference between source lane index and destination lane index: + `offset = src_lane_idx - dst_lane_idx` + width : PrimExpr + The width of sub-sections to perform warp shuffle. + warp_size : PrimExpr + The warp size. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("handle", "tir.tvm_warp_shuffle_down", mask, value, offset, width, warp_size) + + +def tvm_warp_activemask(): + """Return a 32-bit mask indicates currently active threads in a calling warp. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("handle", "tir.tvm_warp_activemask") + + def type_annotation(dtype): """Create a type annotation expression From 409f7c5c3403fefae78dc9e49c3245f1cfa82d71 Mon Sep 17 00:00:00 2001 From: Zihao Date: Sun, 12 Mar 2023 05:59:42 -0700 Subject: [PATCH 3/9] upd --- python/tvm/script/ir_builder/tir/ir.py | 55 +++++++---------- python/tvm/tir/op.py | 8 +-- .../unittest/test_tvmscript_roundtrip.py | 61 ++++++++++++++----- 3 files changed, 74 insertions(+), 50 deletions(-) diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index ce9c9dce17a2..3abaadd8cf99 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -29,7 +29,8 @@ import numpy as np # type: ignore from tvm import tir -from tvm.ir import Range, Type +from tvm import ir +from tvm.ir import Type from tvm.ir.base import deprecated from tvm.runtime import String, convert, ndarray from tvm.target import Target @@ -496,7 +497,7 @@ def alloc_buffer( ) -def _as_range(dom: Union[Range, List[PrimExpr]]) -> Range: +def _as_range(dom: Union[ir.Range, List[PrimExpr]]) -> ir.Range: """The range constructor. Parameters @@ -509,13 +510,13 @@ def _as_range(dom: Union[Range, List[PrimExpr]]) -> Range: res : Range The Range. """ - if isinstance(dom, Range): + if isinstance(dom, ir.Range): return dom if isinstance(dom, (list, tuple)): - return Range(dom[0], dom[1]) + return ir.Range(dom[0], dom[1]) if hasattr(dom, "dtype"): - return Range(IntImm(dom.dtype, 0), dom) - return Range(0, dom) + return ir.Range(IntImm(dom.dtype, 0), dom) + return ir.Range(0, dom) class axis: # pylint: disable=invalid-name @@ -523,7 +524,7 @@ class axis: # pylint: disable=invalid-name @staticmethod def spatial( - dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]], + dom: Union[ir.Range, List[PrimExpr], Tuple[PrimExpr]], binding: PrimExpr, dtype: str = "int32", ) -> Var: @@ -551,7 +552,7 @@ def spatial( @staticmethod def reduce( - dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]], + dom: Union[ir.Range, List[PrimExpr], Tuple[PrimExpr]], binding: PrimExpr, dtype: str = "int32", ) -> Var: @@ -579,7 +580,7 @@ def reduce( @staticmethod def scan( - dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]], + dom: Union[ir.Range, List[PrimExpr], Tuple[PrimExpr]], binding: PrimExpr, dtype: str = "int32", ) -> Var: @@ -607,7 +608,7 @@ def scan( @staticmethod def opaque( - dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]], + dom: Union[ir.Range, List[PrimExpr], Tuple[PrimExpr]], binding: PrimExpr, dtype: str = "int32", ) -> Var: @@ -1288,7 +1289,7 @@ def buffer_store( def prefetch( buffer: Buffer, # pylint: disable=redefined-outer-name - bounds: List[Range], + bounds: List[ir.Range], ) -> None: """The prefetch hint for a buffer. @@ -1579,7 +1580,7 @@ def max(a: PrimExpr, b: PrimExpr) -> PrimExpr: # pylint: disable=redefined-buil return _ffi_api.max(a, b) # type: ignore[attr-defined] # pylint: disable=no-member -def iter_var(v: Union[Var, str], dom: Range, iter_type: str, thread_tag: str) -> IterVar: +def iter_var(v: Union[Var, str], dom: ir.Range, iter_type: str, thread_tag: str) -> IterVar: """The iteration variable. Parameters @@ -1665,32 +1666,20 @@ def target(target_config: Union[Dict, str]) -> Target: ) return Target(target_config) -def Range(begin: PrimExpr, end: Optional[PrimExpr]) -> Range: - """Create a Range node. - + +def Range(begin: PrimExpr, end: PrimExpr) -> ir.Range: + """ + Create a Range object. + Parameters ---------- - begin : PrimExpr + begin : PrimExpr The begin value of the range. end : Optional[PrimExpr] The end value of the range. - - Returns - ------- - res : Range - The Range node. """ - if not isinstance(begin, PrimExpr): - raise ValueError( - f"T.Range expected a PrimExpr as begin value, but got {type(begin)} instead." - ) - if not isinstance(end, PrimExpr) and end is not None: - raise ValueError( - f"T.Range expected a Optional[PrimExpr] as end value, but got {type(end)} instead." - ) - return Range(begin, end) - + return ir.Range(begin, end) class meta_var: # pylint: disable=invalid-name """A meta variable used in TVMScript metaprogramming. It means that the value of the variable @@ -2075,6 +2064,8 @@ def wrapped(*args, **kwargs): "tvm_store_matrix_sync", "tvm_storage_sync", "tvm_warp_shuffle", + "tvm_warp_shuffle_up", + "tvm_warp_shuffle_down", "tvm_warp_activemask", "ptx_mma", "ptx_mma_sp", @@ -2143,5 +2134,5 @@ def wrapped(*args, **kwargs): "Let", "IterVar", "CommReducer", - "Range" + "Range", ] diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 498926045667..dd14d728128d 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -621,7 +621,7 @@ def tvm_warp_shuffle(mask, value, warp_id, width, warp_size): call : PrimExpr The call expression. """ - return call_intrin("handle", "tir.tvm_warp_shuffle", mask, value, warp_id, width, warp_size) + return call_intrin(value.dtype, "tir.tvm_warp_shuffle", mask, value, warp_id, width, warp_size) def tvm_warp_shuffle_up(mask, value, offset, width, warp_size): @@ -646,7 +646,7 @@ def tvm_warp_shuffle_up(mask, value, offset, width, warp_size): call : PrimExpr The call expression. """ - return call_intrin("handle", "tir.tvm_warp_shuffle_up", mask, value, offset, width, warp_size) + return call_intrin(value.dtype, "tir.tvm_warp_shuffle_up", mask, value, offset, width, warp_size) def tvm_warp_shuffle_down(mask, value, offset, width, warp_size): @@ -671,7 +671,7 @@ def tvm_warp_shuffle_down(mask, value, offset, width, warp_size): call : PrimExpr The call expression. """ - return call_intrin("handle", "tir.tvm_warp_shuffle_down", mask, value, offset, width, warp_size) + return call_intrin(value.dtype, "tir.tvm_warp_shuffle_down", mask, value, offset, width, warp_size) def tvm_warp_activemask(): @@ -682,7 +682,7 @@ def tvm_warp_activemask(): call : PrimExpr The call expression. """ - return call_intrin("handle", "tir.tvm_warp_activemask") + return call_intrin("int32", "tir.tvm_warp_activemask") def type_annotation(dtype): diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index 350d06395b6c..9f3c18b7ca6b 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -3623,20 +3623,52 @@ def main(A: T.handle, B: T.handle): return main -def iter_var_range(): - T.prim_func - - def func(): - blockIdx_x = T.int32() - threadIdx_x = T.int32() - T.func_attr( - { - "tir.device_thread_axis": [ - T.iter_var(blockIdx_x, T.Range(0, 1), "ThreadIndex", "blockIdx.x"), - T.iter_var(threadIdx_x, T.Range(0, 32), "ThreadIndex", "threadIdx.x"), - ] - } - ) +def tvm_shfl_builtins(): + @T.prim_func + def func( + A: T.handle("float32", "global"), + C: T.handle("float32", "global"), + B: T.handle("float32", "global"), + ): + blockIdx_x = T.launch_thread("blockIdx.x", 1) + threadIdx_x = T.launch_thread("threadIdx.x", 32) + A_warp = T.allocate([32], "float32", "warp") + B_warp = T.allocate([32], "float32", "warp") + red_buf0 = T.allocate([1], "float32", "local") + A_warp_1 = T.Buffer((32,), data=A_warp, scope="warp") + A_1 = T.Buffer((32,), data=A) + A_warp_1[threadIdx_x] = A_1[threadIdx_x] + B_warp_1 = T.Buffer((32,), data=B_warp, scope="warp") + T.tvm_storage_sync("warp") + B_warp_1[threadIdx_x] = A_warp_1[threadIdx_x % 4 * 8 + threadIdx_x // 4] + T.float32(1) + red_buf0_1 = T.Buffer((1,), data=red_buf0, scope="local") + with T.attr( + T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ): + mask = T.allocate([1], "uint32", "local") + t0 = T.allocate([1], "float32", "local") + red_buf0_1[0] = A_warp_1[threadIdx_x] + mask_1 = T.Buffer((1,), "uint32", data=mask, scope="local") + mask_1[0] = T.tvm_warp_activemask() + t0_1 = T.Buffer((1,), data=t0, scope="local") + t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 16, 32, 32) + red_buf0_1[0] = red_buf0_1[0] + t0_1[0] + t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 8, 32, 32) + red_buf0_1[0] = red_buf0_1[0] + t0_1[0] + t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 4, 32, 32) + red_buf0_1[0] = red_buf0_1[0] + t0_1[0] + t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 2, 32, 32) + red_buf0_1[0] = red_buf0_1[0] + t0_1[0] + t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 1, 32, 32) + red_buf0_1[0] = red_buf0_1[0] + t0_1[0] + red_buf0_1[0] = T.tvm_warp_shuffle(mask_1[0], red_buf0_1[0], 0, 32, 32) + if threadIdx_x == 0: + C_1 = T.Buffer((1,), data=C) + C_1[0] = red_buf0_1[0] + B_1 = T.Buffer((32,), data=B) + B_1[threadIdx_x] = B_warp_1[threadIdx_x] return func @@ -3704,6 +3736,7 @@ def func(): let_stmt_value, string_stride, merge_shape_var_def, + tvm_shfl_builtins, ) From 9b9225269c3ac13e5dd572c85194fd1fb5c71873 Mon Sep 17 00:00:00 2001 From: Zihao Date: Sun, 12 Mar 2023 06:45:52 -0700 Subject: [PATCH 4/9] add tests --- python/tvm/script/ir_builder/tir/ir.py | 3 ++- .../unittest/test_tvmscript_roundtrip.py | 18 ++++++++++-------- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 3abaadd8cf99..fce1fa5c6010 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1675,12 +1675,13 @@ def Range(begin: PrimExpr, end: PrimExpr) -> ir.Range: ---------- begin : PrimExpr The begin value of the range. - + end : Optional[PrimExpr] The end value of the range. """ return ir.Range(begin, end) + class meta_var: # pylint: disable=invalid-name """A meta variable used in TVMScript metaprogramming. It means that the value of the variable does not appear in the final TIR, but only stays in the parser. diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index 9f3c18b7ca6b..dc2a48b43e7c 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -3632,15 +3632,17 @@ def func( ): blockIdx_x = T.launch_thread("blockIdx.x", 1) threadIdx_x = T.launch_thread("threadIdx.x", 32) - A_warp = T.allocate([32], "float32", "warp") - B_warp = T.allocate([32], "float32", "warp") + A_warp = T.allocate([1], "float32", "local") + B_warp = T.allocate([1], "float32", "local") red_buf0 = T.allocate([1], "float32", "local") - A_warp_1 = T.Buffer((32,), data=A_warp, scope="warp") + A_warp_1 = T.Buffer((32,), data=A_warp, scope="local") A_1 = T.Buffer((32,), data=A) - A_warp_1[threadIdx_x] = A_1[threadIdx_x] - B_warp_1 = T.Buffer((32,), data=B_warp, scope="warp") + A_warp_1[0] = A_1[threadIdx_x] + B_warp_1 = T.Buffer((32,), data=B_warp, scope="local") T.tvm_storage_sync("warp") - B_warp_1[threadIdx_x] = A_warp_1[threadIdx_x % 4 * 8 + threadIdx_x // 4] + T.float32(1) + B_warp_1[0] = T.tvm_warp_shuffle( + T.tvm_warp_activemask(), A_warp_1[0], threadIdx_x % 4 * 8 + threadIdx_x // 4, 32, 32 + ) + T.float32(1) red_buf0_1 = T.Buffer((1,), data=red_buf0, scope="local") with T.attr( T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), @@ -3649,7 +3651,7 @@ def func( ): mask = T.allocate([1], "uint32", "local") t0 = T.allocate([1], "float32", "local") - red_buf0_1[0] = A_warp_1[threadIdx_x] + red_buf0_1[0] = A_warp_1[0] mask_1 = T.Buffer((1,), "uint32", data=mask, scope="local") mask_1[0] = T.tvm_warp_activemask() t0_1 = T.Buffer((1,), data=t0, scope="local") @@ -3668,7 +3670,7 @@ def func( C_1 = T.Buffer((1,), data=C) C_1[0] = red_buf0_1[0] B_1 = T.Buffer((32,), data=B) - B_1[threadIdx_x] = B_warp_1[threadIdx_x] + B_1[threadIdx_x] = B_warp_1[0] return func From ec5346015e4abe1960ce7035f8ea4e2979cd3027 Mon Sep 17 00:00:00 2001 From: Zihao Date: Sun, 12 Mar 2023 06:51:57 -0700 Subject: [PATCH 5/9] upd --- tests/python/unittest/test_tvmscript_roundtrip.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index dc2a48b43e7c..337e23b0209b 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -3666,6 +3666,8 @@ def func( t0_1[0] = T.tvm_warp_shuffle_down(mask_1[0], red_buf0_1[0], 1, 32, 32) red_buf0_1[0] = red_buf0_1[0] + t0_1[0] red_buf0_1[0] = T.tvm_warp_shuffle(mask_1[0], red_buf0_1[0], 0, 32, 32) + # NOTE(Zihao): test tvm_warp_shuffle_up + red_buf0_1[0] = T.tvm_warp_shuffle_up(mask_1[0], red_buf0_1[0], 0, 32, 32) if threadIdx_x == 0: C_1 = T.Buffer((1,), data=C) C_1[0] = red_buf0_1[0] From 4cc6681568d9f15170e0f37b5823813972b0b4d6 Mon Sep 17 00:00:00 2001 From: Zihao Date: Sun, 12 Mar 2023 07:25:55 -0700 Subject: [PATCH 6/9] remove _op_wrapper --- python/tvm/script/ir_builder/tir/ir.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index fce1fa5c6010..57090c6f3f39 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1798,11 +1798,11 @@ def wrapped(*args, **kwargs): tvm_bmma_sync = _op_wrapper(_tir_op.tvm_bmma_sync) tvm_fill_fragment = _op_wrapper(_tir_op.tvm_fill_fragment) tvm_store_matrix_sync = _op_wrapper(_tir_op.tvm_store_matrix_sync) -tvm_storage_sync = _op_wrapper(_tir_op.tvm_storage_sync) -tvm_warp_shuffle = _op_wrapper(_tir_op.tvm_warp_shuffle) -tvm_warp_shuffle_up = _op_wrapper(_tir_op.tvm_warp_shuffle_up) -tvm_warp_shuffle_down = _op_wrapper(_tir_op.tvm_warp_shuffle_down) -tvm_warp_activemask = _op_wrapper(_tir_op.tvm_warp_activemask) +tvm_storage_sync = _tir_op.tvm_storage_sync +tvm_warp_shuffle = _tir_op.tvm_warp_shuffle +tvm_warp_shuffle_up = _tir_op.tvm_warp_shuffle_up +tvm_warp_shuffle_down = _tir_op.tvm_warp_shuffle_down +tvm_warp_activemask = _tir_op.tvm_warp_activemask ptx_wait_group = _op_wrapper(_tir_op.ptx_wait_group) ptx_commit_group = _op_wrapper(_tir_op.ptx_commit_group) assume = _op_wrapper(_tir_op.assume) From 4a5396ab7ed9b0667e7ba577ba5ebe9aee2264e6 Mon Sep 17 00:00:00 2001 From: Zihao Date: Sun, 12 Mar 2023 07:49:14 -0700 Subject: [PATCH 7/9] fix --- python/tvm/tir/op.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index dd14d728128d..0fe460c085d7 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -646,7 +646,9 @@ def tvm_warp_shuffle_up(mask, value, offset, width, warp_size): call : PrimExpr The call expression. """ - return call_intrin(value.dtype, "tir.tvm_warp_shuffle_up", mask, value, offset, width, warp_size) + return call_intrin( + value.dtype, "tir.tvm_warp_shuffle_up", mask, value, offset, width, warp_size + ) def tvm_warp_shuffle_down(mask, value, offset, width, warp_size): @@ -671,7 +673,9 @@ def tvm_warp_shuffle_down(mask, value, offset, width, warp_size): call : PrimExpr The call expression. """ - return call_intrin(value.dtype, "tir.tvm_warp_shuffle_down", mask, value, offset, width, warp_size) + return call_intrin( + value.dtype, "tir.tvm_warp_shuffle_down", mask, value, offset, width, warp_size + ) def tvm_warp_activemask(): @@ -682,7 +686,7 @@ def tvm_warp_activemask(): call : PrimExpr The call expression. """ - return call_intrin("int32", "tir.tvm_warp_activemask") + return call_intrin("uint32", "tir.tvm_warp_activemask") def type_annotation(dtype): From fdabd56ee4f369c2136ac681cfbea2f1252d4580 Mon Sep 17 00:00:00 2001 From: Zihao Date: Sun, 12 Mar 2023 08:08:53 -0700 Subject: [PATCH 8/9] flake --- tests/python/unittest/test_tvmscript_roundtrip.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index 337e23b0209b..6f07b6a75aeb 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -3626,9 +3626,9 @@ def main(A: T.handle, B: T.handle): def tvm_shfl_builtins(): @T.prim_func def func( - A: T.handle("float32", "global"), - C: T.handle("float32", "global"), - B: T.handle("float32", "global"), + A: T.handle("float32"), + B: T.handle("float32"), + C: T.handle("float32"), ): blockIdx_x = T.launch_thread("blockIdx.x", 1) threadIdx_x = T.launch_thread("threadIdx.x", 32) From 5c4228a76eaa4793e39cd4b6b862f6d4ad79363f Mon Sep 17 00:00:00 2001 From: Zihao Date: Sun, 12 Mar 2023 08:34:00 -0700 Subject: [PATCH 9/9] pylint --- python/tvm/script/ir_builder/tir/ir.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 57090c6f3f39..45350c5a65c7 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1667,7 +1667,7 @@ def target(target_config: Union[Dict, str]) -> Target: return Target(target_config) -def Range(begin: PrimExpr, end: PrimExpr) -> ir.Range: +def Range(begin: PrimExpr, end: PrimExpr) -> ir.Range: # pylint: disable=invalid-name """ Create a Range object.