Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TensorIR][Schedule] New primitive reorder_block_itervar #14448

Merged
merged 6 commits into from
Apr 3, 2023

Conversation

yzh119
Copy link
Member

@yzh119 yzh119 commented Apr 1, 2023

Motivation

Currently the reorder primitive only changes the loops, and block iterable variables order would not be changed.
transform_block_layout can change the block iterable variables, but it requires the loops outside the given block to have no branches, which limited its usage.

This schedule primitive changes the block iterable variable order directly, with API like:

def reorder_block_iter_var(self, block: BlockRV, new_order: List[int]) -> None:
    """Reorder the itervars inside a given block.
    Parameters
    ----------
    block : BlockRV
        The block to be transformed. 
    new_order : List[int]
        The new block itervar order.
    """

where the new_order is a permutation of [0, 1, ..., n-1] if n is the number of itervars in the block.

Example

Suppose we need to change the block itervar order in block "C":

@T.prim_func
def matmul(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]) -> None:
    for i, j, k in T.grid(128, 128, 128):
        with T.block("C"):
            vi, vj, vk = T.axis.remap("SSR", [i, j, k])
            with T.init():
                C[vi, vj] = 0.0
            C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]

after applying:

sch = tir.Schedule(matmul, debug_mask="all")
C = sch.get_block("C")
sch.reorder_block_iter_var(C, [2, 1, 0])

the block itervar order would be changed to vk, vj, vi.

@T.prim_func
def matmul_after_reorder_block_iter_var(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]):
    for i, j, k in T.grid(128, 128, 128):
        with T.block("C"):
            vk, vj, vi = T.axis.remap("RSS", [k, j, i])
            T.reads(A[vi, vk], B[vj, vk])
            T.writes(C[vi, vj])
            with T.init():
                C[vi, vj] = T.float32(0)
            C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]

cc @junrushao @vinx13 @Hzfengsy

@tvm-bot
Copy link
Collaborator

tvm-bot commented Apr 1, 2023

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

  • No users to tag found in teams: tensorir, schedule See #10317 for details

Generated by tvm-bot

@wrongtest-intellif
Copy link
Contributor

transform_block_layout can change the block iterable variables, but it requires the loops outside the given block to have no branches

Is it safe to relax the single subtree constraint of transform_block_layout? It seems that the remapping of block itervars do not affect the evaluation order and we do not need take concern of out of block loop structures? cc @vinx13 @Hzfengsy

@vinx13
Copy link
Member

vinx13 commented Apr 2, 2023

transform_block_layout not only change orders of iter vars but also rewrite the outer loops to make sure the evaluation order is consistent with the iter vars, it is okay to relax a bit the constraint but will make analysis for rewriting more difficult

@yzh119
Copy link
Member Author

yzh119 commented Apr 2, 2023

@wrongtest-intellif Actually, I have tried to change the behavior of transform_block_layout by adding a root_loop argument to specify the scope, however, as @vinx13 mentioned the analysis is non-trivial.

@wrongtest-intellif
Copy link
Contributor

Thanks! then it looks quite good to me.

@wrongtest-intellif Actually, I have tried to change the behavior of transform_block_layout by adding a root_loop argument to specify the scope, however, as @vinx13 mentioned the analysis is non-trivial.

Thanks, then it looks quite good to me!

Copy link
Member

@junrushao junrushao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the discussion here! I believe this PR is ready to be merged!

@junrushao junrushao merged commit 6006d25 into apache:main Apr 3, 2023
zxybazh pushed a commit to zxybazh/tvm that referenced this pull request Apr 4, 2023
)

# Motivation
Currently the `reorder` primitive only changes the loops, and block iterable variables order would not be changed.
`transform_block_layout` can change the block iterable variables, but it requires the loops outside the given block to have no branches, which limited its usage.

This schedule primitive changes the block iterable variable order directly, with API like:
```python
def reorder_block_iter_var(self, block: BlockRV, new_order: List[int]) -> None:
    """Reorder the itervars inside a given block.
    Parameters
    ----------
    block : BlockRV
        The block to be transformed. 
    new_order : List[int]
        The new block itervar order.
    """
```
where the `new_order` is a permutation of [0, 1, ..., n-1] if n is the number of itervars in the block.

# Example

Suppose we need to change the block itervar order in block "C":
```python
@T.prim_func
def matmul(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]) -> None:
    for i, j, k in T.grid(128, 128, 128):
        with T.block("C"):
            vi, vj, vk = T.axis.remap("SSR", [i, j, k])
            with T.init():
                C[vi, vj] = 0.0
            C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
```

after applying:
```python
sch = tir.Schedule(matmul, debug_mask="all")
C = sch.get_block("C")
sch.reorder_block_iter_var(C, [2, 1, 0])
```

the block itervar order would be changed to `vk, vj, vi`.
```python
@T.prim_func
def matmul_after_reorder_block_iter_var(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]):
    for i, j, k in T.grid(128, 128, 128):
        with T.block("C"):
            vk, vj, vi = T.axis.remap("RSS", [k, j, i])
            T.reads(A[vi, vk], B[vj, vk])
            T.writes(C[vi, vj])
            with T.init():
                C[vi, vj] = T.float32(0)
            C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
```
zxybazh pushed a commit to zxybazh/tvm that referenced this pull request Apr 6, 2023
)

# Motivation
Currently the `reorder` primitive only changes the loops, and block iterable variables order would not be changed.
`transform_block_layout` can change the block iterable variables, but it requires the loops outside the given block to have no branches, which limited its usage.

This schedule primitive changes the block iterable variable order directly, with API like:
```python
def reorder_block_iter_var(self, block: BlockRV, new_order: List[int]) -> None:
    """Reorder the itervars inside a given block.
    Parameters
    ----------
    block : BlockRV
        The block to be transformed. 
    new_order : List[int]
        The new block itervar order.
    """
```
where the `new_order` is a permutation of [0, 1, ..., n-1] if n is the number of itervars in the block.

# Example

Suppose we need to change the block itervar order in block "C":
```python
@T.prim_func
def matmul(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]) -> None:
    for i, j, k in T.grid(128, 128, 128):
        with T.block("C"):
            vi, vj, vk = T.axis.remap("SSR", [i, j, k])
            with T.init():
                C[vi, vj] = 0.0
            C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
```

after applying:
```python
sch = tir.Schedule(matmul, debug_mask="all")
C = sch.get_block("C")
sch.reorder_block_iter_var(C, [2, 1, 0])
```

the block itervar order would be changed to `vk, vj, vi`.
```python
@T.prim_func
def matmul_after_reorder_block_iter_var(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]):
    for i, j, k in T.grid(128, 128, 128):
        with T.block("C"):
            vk, vj, vi = T.axis.remap("RSS", [k, j, i])
            T.reads(A[vi, vk], B[vj, vk])
            T.writes(C[vi, vj])
            with T.init():
                C[vi, vj] = T.float32(0)
            C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants