-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TIR] not estimating the flops when there is a default estimated flops as attr #14379
Conversation
Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.
Generated by tvm-bot |
@tvm-bot rerun |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me, just one change requested.
|
||
|
||
@T.prim_func | ||
def flops_override(a: T.Buffer(16, "float32"), b: T.Buffer(16, "float32")): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we reduce this to a minimal PrimFunc as well? That way, it makes it clear to the reader that the test is validating the behavior of the attribute, and that the if/else/floormod aren't required for the test.
@T.prim_func
def func(A: T.Buffer(1)):
T.func_attr("estimated_flops": 32)
for i in T.serial(0, 16):
A[0] = A[0] + 1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good Eric. Thanks for the suggestion. Fixed in the next commit.
@tvm-bot rerun |
FlopEstimator is run as part of the Meta Schedule as an estimate for which primfunc to spend time on; e.g. which primfunc to send more trials on. In a case that the for loop extent is an expression, it’s getting stuck on the calculation of the FLOPS.
This PR add a function annotation to the main primfunc and then have FlopEstimator check for this annotation before visiting the function. And if the annotation exists, just use the handcoded flops instead of trying to estimate.
@Lunderberg @csullivan