Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] not estimating the flops when there is a default estimated flops as attr #14379

Merged
merged 4 commits into from
Mar 24, 2023

Conversation

farshidsp
Copy link
Contributor

@farshidsp farshidsp commented Mar 23, 2023

FlopEstimator is run as part of the Meta Schedule as an estimate for which primfunc to spend time on; e.g. which primfunc to send more trials on. In a case that the for loop extent is an expression, it’s getting stuck on the calculation of the FLOPS.

This PR add a function annotation to the main primfunc and then have FlopEstimator check for this annotation before visiting the function. And if the annotation exists, just use the handcoded flops instead of trying to estimate.

@Lunderberg @csullivan

@tvm-bot
Copy link
Collaborator

tvm-bot commented Mar 23, 2023

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

Generated by tvm-bot

@farshidsp
Copy link
Contributor Author

@tvm-bot rerun

@farshidsp farshidsp marked this pull request as ready for review March 23, 2023 17:18
Copy link
Contributor

@Lunderberg Lunderberg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, just one change requested.



@T.prim_func
def flops_override(a: T.Buffer(16, "float32"), b: T.Buffer(16, "float32")):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we reduce this to a minimal PrimFunc as well? That way, it makes it clear to the reader that the test is validating the behavior of the attribute, and that the if/else/floormod aren't required for the test.

@T.prim_func
def func(A: T.Buffer(1)):
    T.func_attr("estimated_flops": 32)
    for i in T.serial(0, 16):
        A[0] = A[0] + 1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good Eric. Thanks for the suggestion. Fixed in the next commit.

@Lunderberg
Copy link
Contributor

@tvm-bot rerun

@Lunderberg Lunderberg merged commit c5075dc into apache:main Mar 24, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants