Skip to content

Commit

Permalink
[Fix] Buffer slicing using index dtype as extent
Browse files Browse the repository at this point in the history
  • Loading branch information
MasterJH5574 committed Jan 15, 2023
1 parent fe01c5a commit 84a9f8c
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 8 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/cutlass
Submodule cutlass updated 165 files
8 changes: 6 additions & 2 deletions python/tvm/tir/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def offset_of(self, indices):

def __getitem__(self, indices):
from ..arith import Analyzer # pylint: disable=import-outside-toplevel
from .expr import BufferLoad, Ramp # pylint: disable=import-outside-toplevel
from .expr import BufferLoad, Ramp, const # pylint: disable=import-outside-toplevel
from .stmt import BufferRegion # pylint: disable=import-outside-toplevel

if not isinstance(indices, (tuple, list)):
Expand All @@ -195,7 +195,11 @@ def __getitem__(self, indices):
stop = self.shape[i] if index.stop is None else index.stop
region.append(Range.from_min_extent(start, analyzer.simplify(stop - start)))
else:
region.append(Range.from_min_extent(index, 1))
region.append(
Range.from_min_extent(
index, const(1, index.dtype) if isinstance(index, PrimExpr) else 1
)
)
return BufferRegion(self, region)
else:
expr_indices = []
Expand Down
19 changes: 14 additions & 5 deletions tests/python/unittest/test_tvmscript_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import numpy

import tvm
import tvm.testing
from tvm.script import tir as T


Expand Down Expand Up @@ -73,9 +74,17 @@ def func_ref():
tvm.ir.assert_structural_equal(test_case, func_ref)


def test_tir_buffer_region_extent_correct_dtype():
@T.prim_func
def func(A: T.Buffer[(T.int64(16), T.int64(1)), "float32"]):
for i in T.grid(T.int64(16)):
with T.block("block"):
vi = T.axis.remap("S", [i])
T.reads(A[vi, T.int64(0) : T.int64(1)])
T.evaluate(0)

assert func.body.block.body.body.block.reads[0].region[0].extent.dtype == "int64"


if __name__ == "__main__":
a = numpy.zeros((10, 10), dtype="int8")
test_multi_element_array_in_outmost_namespace()
test_different_dtype_assignment_to_var()
b = 1
test_var_capturing_order()
tvm.testing.main()

0 comments on commit 84a9f8c

Please sign in to comment.