Skip to content

Commit

Permalink
[TIR] Fix an index out of bound issue of compact buffer region (apach…
Browse files Browse the repository at this point in the history
…e#11201)

After apache#10557, the region extent after compaction is ensured to not exceed original shape. Now when the inferred region min is negative, the index remap rule `idx -> (idx - region_min)` would introduce out of bound accesses, which would cause crashes at runtime.

The two updated cases in UT:
- padding block inlined to pooling
Current version results to out of bound accesses in `cache` block, since the H/W extents are compacted to no more than 224 but accesses are shifted by `- (-1)`.
```python
@T.prim_func
def func(X: T.Buffer[(224, 224), "float32"], Y: T.Buffer[(224, 224), "float32"]) -> None:
    cache = T.alloc_buffer([224, 224], dtype="float32")
    for h, w in T.grid(224, 224):
        with T.block("cache"):
            cache[h + 1, w + 1] = X[h, w]
    for h, w, kh, kw in T.grid(224, 224, 3, 3):
        with T.block("compute"):
            Y[h, w] = T.max(Y[h, w], T.if_then_else(T.likely(1 <= h + kh, dtype="bool") and T.likely(h + kh < 225, dtype="bool") and T.likely(1 <= w + kw, dtype="bool") and T.likely(w + kw < 225, dtype="bool"), cache[h + kh, w + kw], T.float32(0), dtype="float32"))
```

-  sparse access
`A_data_local[A_indptr[i] + k]` is rewritten to `A_data_local[T.min(A_indptr[i] + k, 0)]` instead of `A_data_local[0]`. Compared to current version, interestingly, it keeps the semantic that negative sparse index result to oob behavior.
  • Loading branch information
wrongtest-intellif authored and Sergey Shtin committed May 17, 2022
1 parent ca137cb commit 1929bd4
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 5 deletions.
5 changes: 3 additions & 2 deletions src/tir/transforms/compact_buffer_region.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,9 @@ Region SimplifyAndNarrowBufferRegionFromNDIntSet(const NDIntSet& nd_int_set,
for (size_t i = 0; i < nd_int_set.size(); ++i) {
const arith::IntSet& int_set = nd_int_set[i];
Range range = int_set.CoverRange(Range(/*begin=*/0, /*end=*/original_shape[i]));
result.push_back(Range::FromMinExtent(
range->min, analyzer->Simplify(min(original_shape[i], range->extent))));
result.push_back(
Range::FromMinExtent(analyzer->Simplify(max(0, range->min)),
analyzer->Simplify(min(original_shape[i], range->extent))));
}
return result;
}
Expand Down
58 changes: 55 additions & 3 deletions tests/python/unittest/test_tir_transform_compact_buffer_region.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,54 @@ def compacted_padding_pattern_func(a: T.handle, c: T.handle) -> None:
)


@T.prim_func
def padding_pattern_inlined(a: T.handle, b: T.handle) -> None:
X = T.match_buffer(a, [224, 224], dtype="float32")
Y = T.match_buffer(b, [224, 224], dtype="float32")
cache = T.alloc_buffer([224, 224], dtype="float32")
for h, w in T.grid(224, 224):
with T.block("cache"):
cache[h, w] = X[h, w]
for h, w, kh, kw in T.grid(224, 224, 3, 3):
with T.block("compute"):
Y[h, w] = T.max(
Y[h, w],
T.if_then_else(
T.likely(1 <= h + kh, dtype="bool")
and T.likely(h + kh < 225, dtype="bool")
and T.likely(1 <= w + kw, dtype="bool")
and T.likely(w + kw < 225, dtype="bool"),
cache[h + kh - 1, w + kw - 1],
0.0,
dtype="float32",
),
)


@T.prim_func
def compacted_padding_pattern_inlined(
X: T.Buffer[(224, 224), "float32"], Y: T.Buffer[(224, 224), "float32"]
) -> None:
cache = T.alloc_buffer([224, 224], dtype="float32")
for h, w in T.grid(224, 224):
with T.block("cache"):
cache[h, w] = X[h, w]
for h, w, kh, kw in T.grid(224, 224, 3, 3):
with T.block("compute"):
Y[h, w] = T.max(
Y[h, w],
T.if_then_else(
T.likely(1 <= h + kh, dtype="bool")
and T.likely(h + kh < 225, dtype="bool")
and T.likely(1 <= w + kw, dtype="bool")
and T.likely(w + kw < 225, dtype="bool"),
cache[h + kh - 1, w + kw - 1],
0.0,
dtype="float32",
),
)


@T.prim_func
def mem_access_in_branch_func(a: T.handle) -> None:
A = T.match_buffer(a, (224, 224), "float32")
Expand Down Expand Up @@ -570,12 +618,12 @@ def compacted_sparse_read_cache(
A_data_local = T.alloc_buffer([1], dtype="float32", scope="local")
with T.block("A_data_cache_read"):
T.reads(A_indptr[i], A_data[A_indptr[i] + k])
T.writes(A_data_local[A_indptr[i] + k - (A_indptr[i] + k)])
A_data_local[A_indptr[i] + k - (A_indptr[i] + k)] = A_data[A_indptr[i] + k]
T.writes(A_data_local[T.min(A_indptr[i] + k, 0)])
A_data_local[T.min(A_indptr[i] + k, 0)] = A_data[A_indptr[i] + k]
with T.block("rowsum_inner"):
T.reads(B[i], A_indptr[i], A_data[A_indptr[i] + k])
T.writes(B[i])
B[i] = B[i] + A_data_local[A_indptr[i] + k - (A_indptr[i] + k)]
B[i] = B[i] + A_data_local[T.min(A_indptr[i] + k, 0)]


@T.prim_func
Expand Down Expand Up @@ -654,6 +702,10 @@ def test_padding_pattern():
_check(padding_pattern_func, compacted_padding_pattern_func)


def test_padding_pattern_inlined():
_check(padding_pattern_inlined, compacted_padding_pattern_inlined)


def test_mem_access_in_branch_func():
_check(mem_access_in_branch_func, compacted_mem_access_in_branch_func)

Expand Down

0 comments on commit 1929bd4

Please sign in to comment.