Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TensorIR] New schedule primitive set_dtype #14316

Merged
merged 7 commits into from
Mar 22, 2023
Merged

Conversation

yzh119
Copy link
Member

@yzh119 yzh119 commented Mar 16, 2023

Motivation

Currently, we miss a schedule primitive to change the data type of allocated buffer (e.g. via cache_read/cache_write), and thus we cannot perform type conversion while loading data from global to shared memory.

This PR adds a new schedule primitive set_dtype that follows the interface of set_scope and allows users to customize the allocated buffers' data type.

Example

Before running set_dtype:

@T.prim_func
def before_set_dtype(
    A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")
) -> None:
    B = T.alloc_buffer((128, 128), dtype="float32")

    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj] * 2.0
    for i, j in T.grid(128, 128):
        with T.block("C"):
            vi, vj = T.axis.remap("SS", [i, j]
            C[vi, vj] = B[vi, vj] + 1.0

then we perform the set_dtype schedule:

sch = tir.Schedule(before_set_dtype)
sch.set_dtype("B", buffer_index=0, dtype="float16")
print(sch.mod["main"].script())

we get transformed code:

@T.prim_func
def after_set_dtype(
    A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")
) -> None:
    B = T.alloc_buffer((128, 128), dtype="float16")

    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = T.cast(A[vi, vj] * 2.0, "float16")
    for i, j in T.grid(128, 128):
        with T.block("C"):
            vi, vj = T.axis.remap("SS", [i, j]
            C[vi, vj] = T.cast(B[vi, vj], "float32") + 1.0

where data type conversions are inserted automatically.

Other Usage

Using the combination of cache_read + set_dtype can help us load data from the memory hierarchy while converting data to the desired type.

cc @Hzfengsy @vinx13 @junrushao

@tvm-bot
Copy link
Collaborator

tvm-bot commented Mar 16, 2023

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

  • No users to tag found in teams: tensorir, status: need update See #10317 for details

Generated by tvm-bot

@Hzfengsy
Copy link
Member

It's a bit tricky since it would change the behavior of the PrimFunc (aka get different result before and after the schedule)

@quic-sanirudh
Copy link
Contributor

It's a bit tricky since it would change the behavior of the PrimFunc (aka get different result before and after the schedule)

There seems to be some interest in introducing schedule primitives like these that might cause different output. This certainly seems like a useful primitive to have, but since we don't normally want schedule primitives to cause the output to change, primitives like these might be hard to introduce.

Would it make sense to have some discussion on perhaps introducing a different scheduling class that allows these kinds of primitives, or explicitly setting a "danger" flag in the Schedule class before using these primitives. That way, we could still allow primitives like these and make sure that a user knows the risk of using a primitive that could cause the output to change.

@Hzfengsy
Copy link
Member

Thanks @quic-sanirudh for the suggestion. I agree that it's a really useful primitive/tool for optimizations. It makes sense if we introduce a danger mode or flag :-)

@yzh119
Copy link
Member Author

yzh119 commented Mar 16, 2023

Adding a danger flag sounds good to me, any ideas on how to make proper abstractions for danger?

@tqchen
Copy link
Member

tqchen commented Mar 17, 2023

how about unsafe_set_dtype, with document to users that this operation can change the model behavior.

@tqchen tqchen added the status: need update need update based on feedbacks label Mar 19, 2023
@junrushao
Copy link
Member

@yzh119 any updates

@yzh119
Copy link
Member Author

yzh119 commented Mar 20, 2023

I have updated my code with new name unsafe_set_dtype.

@tqchen
Copy link
Member

tqchen commented Mar 20, 2023

oops, @yzh119 please also check the c++ apis @junrushao mentioned :)

Copy link
Member

@junrushao junrushao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@junrushao
Copy link
Member

I don't know why but now it requires committer approval to run github actions for existing contributors. I'm merging this in after the actions got green

@junrushao junrushao merged commit c7970dd into apache:main Mar 22, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
status: need update need update based on feedbacks
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants