Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Allow compute_at create block predicate for non-trivial bounds and support floordiv pattern #9527

Conversation

wrongtest-intellif
Copy link
Contributor

Hi there~ This PR is an enforcement for compute_at and reverse_compute_at primitives. Binding block into loops may create some non-trivial iter bounds. Complex iter bound is neither human-kind friendly nor compatible with backend passes targeting at bounds and conditions (eg, loop partition). So the PR try to distinguish some of complex bounds and use block predicates to make the ir structure simpler.

A working example is as below, we want to create spatial tiles and read each tiled data from cache, thus the schedule operation is compute_at cache_read block into tiled loops.

@T.prim_func
def tiled_pooling_read_cache(a: T.handle, b: T.handle) -> None:
    X = T.match_buffer(a, [224, 224], dtype="float32")
    Y = T.match_buffer(b, [224, 224], dtype="float32")
    cache = T.alloc_buffer([224, 224], dtype="float32")
    for hh, ww in T.grid(224, 224):
        with T.block("cache"):
            h, w = T.axis.remap("SS", [hh, ww])
            T.reads([X[h, w]])
            T.writes([cache[h, w]])
            cache[h, w] = X[h, w]
    for hh_0, ww_0, hh_1, ww_1, khh, kww in T.grid(28, 28, 8, 8, 3, 3):
        with T.block("compute"):
            h = T.axis.spatial(224, hh_0 * 8 + hh_1)
            w = T.axis.spatial(224, ww_0 * 8 + ww_1)
            kh, kw = T.axis.remap("RR", [khh, kww])
            T.reads([Y[h, w], cache[h + kh - 1, w + kw - 1]])
            T.writes([Y[h, w]])
            with T.init():
                Y[h, w] = 0.0
            Y[h, w] = T.max(Y[h, w], T.if_then_else(
                T.likely(1 <= h + kh, dtype="bool") and \
                T.likely(h + kh < 225, dtype="bool") and \
                T.likely(1 <= w + kw, dtype="bool") and \
                T.likely(w + kw < 225, dtype="bool"),
                cache[h + kh - 1, w + kw - 1], 0.0, dtype="float32"))

Main stream code will produce

@T.prim_func
def func(a: T.handle, b: T.handle) -> None:
    X = T.match_buffer(a, [224, 224], dtype="float32")
    Y = T.match_buffer(b, [224, 224], dtype="float32")
    # body
    # with T.block("root")
    cache = T.alloc_buffer([224, 224], dtype="float32")
    for hh_0, ww_0 in T.grid(28, 28):
        for ax0 in T.serial(0, T.min(hh_0 * 8 + 8, 223) + 1 - T.max(hh_0 * 8 - 1, 0)):
            for ax1 in T.serial(0, T.min(ww_0 * 8 + 8, 223) + 1 - T.max(ww_0 * 8 - 1, 0)):
                with T.block("cache"):
                    h = T.axis.spatial(224, T.max(hh_0 * 8 - 1, 0) + ax0)
                    w = T.axis.spatial(224, T.max(ww_0 * 8 - 1, 0) + ax1)
                    T.reads([X[h, w]])
                    T.writes([cache[h, w]])
                    cache[h, w] = X[h, w]
        for hh_1, ww_1, khh, kww in T.grid(8, 8, 3, 3):
            with T.block("compute"):
                ...

The PR will produce

def tiled_pooling_read_cache_after_compute_at(a: T.handle, b: T.handle) -> None:
    X = T.match_buffer(a, [224, 224], dtype="float32")
    Y = T.match_buffer(b, [224, 224], dtype="float32")
    cache = T.alloc_buffer([224, 224], dtype="float32")
    for hh_0, ww_0 in T.grid(28, 28):
        for ax0, ax1 in T.grid(10, 10):
            with T.block("cache"):
                h = T.axis.spatial(224, hh_0 * 8 - 1 + ax0)
                w = T.axis.spatial(224, ww_0 * 8 - 1 + ax1)
                T.where(1 <= hh_0 * 8 + ax0 and hh_0 * 8 + ax0 < 225 and 1 <= ww_0 * 8 + ax1 and ww_0 * 8 + ax1 < 225)
                T.reads([X[h, w]])
                T.writes([cache[h, w]])
                cache[h, w] = X[h, w]
        for hh_1, ww_1, khh, kww in T.grid(8, 8, 3, 3):
            with T.block("compute"):
                ...

The modification is to delay the intersection of intset deduced from required uses and intset enforced by buffer shape / original iter bound. Instead of direct intset intersection (can create much complex expr of min/max), A BlockVarDomainInfo class is added to maintain above two intsets named as dom and bound. Finally the implementation can choose with some heuristic:

  1. use (dom ^ bound) as iter domain if it is simple enough
  2. use dom as iter domain and add block predicate for bound

The PR also add minimal support to analyze floordiv/floormod in provide-required region mapping.

Copy link
Member

@Hzfengsy Hzfengsy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR. It is super helpful for imperfect tiling case.

For the region cover problem, I will look at it. It's better to fix it before this PR merged.
cc @junrushao1994

@@ -514,6 +605,14 @@ void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_s
/*realize=*/reconstructor.new_block_realize_,
/*loop_var_ranges=*/LoopDomainOfSRefTreePath(GetRef<StmtSRef>(block_sref->parent)),
/*analyzer=*/&analyzer);
// The verifier can not prove region cover state if some complex predicte is introduced
// so here it explicitly reset these flags below.
if (is_compute_at && !is_const_int(reconstructor.new_block_realize_->predicate)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bug of RegionCoverCheck. We should fix it instead of working around it.

@junrushao
Copy link
Member

This is very helpful! Would love to let @Hzfengsy shepherd this PR. Thanks a lot!

@junrushao
Copy link
Member

CC @Hzfengsy

@wrongtest-intellif
Copy link
Contributor Author

Add an option allow_block_predicate, users could set it False if old behavious (dynamic loop extent) are prefered.

@wrongtest-intellif wrongtest-intellif force-pushed the tir_compute_at_support_block_predicate branch from f91c1fc to 15e3963 Compare December 16, 2021 11:16
@spectrometerHBH
Copy link
Contributor

Great job! Here are some comments.
Looks like you add allow_block_predicate, why do you think it is necessary to keep the dynamic loop extent behavior. It looks to me that we can abandon this.

@wrongtest-intellif
Copy link
Contributor Author

wrongtest-intellif commented Dec 17, 2021

why do you think it is necessary to keep the dynamic loop extent behavior

After discusion with @Hzfengsy, I decide to revert the allow_block_predicate option to make a unified behavior. Since there is not a sound demand for that yet.

The original concern is that if the desired pattern is just the dynamic loop extents. Take "cache" block as an example, user may want to lower it into some DMA operations. If the DMA intrinsic happen to be dynamic shape enabled, but without conditional accesses, it would be non-trivial to pattern matching during lowering.

@gumingsiyi
Copy link

In your example, why the extend of ax0 and ax1 is 10 ?

@wrongtest-intellif
Copy link
Contributor Author

In your example, why the extend of ax0 and ax1 is 10 ?

This is the extent to cover the region required by compute block's reads.

@wrongtest-intellif wrongtest-intellif force-pushed the tir_compute_at_support_block_predicate branch from 0280297 to dca31be Compare February 8, 2022 09:46
Copy link
Member

@Hzfengsy Hzfengsy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks @wrongtest for the hard and long-term work!

@Hzfengsy Hzfengsy merged commit 8c53f62 into apache:main Feb 9, 2022
ylc pushed a commit to ylc/tvm that referenced this pull request Feb 16, 2022
…and support floordiv pattern (apache#9527)

* allow generate block predicate in compute_at schedule

* revert apache#9880 and add more testcases
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants