-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Unity] Implement FNormalize for relax.op.call_tir #16068
[Unity] Implement FNormalize for relax.op.call_tir #16068
Conversation
This is currently marked as a draft PR, as it depends on functionality introduced in #16067. #16067 is included in the PR branch for #16068, so that CI testing can proceed, but it does make this PR accidentally look 3x as large. The changes that are specific to this PR are in the last commit of this branch, and can be viewed independently here. |
c43b2cb
to
6d4dbef
Compare
Rebased to include latest version of pre-requisite PR #16067, to resolve CI lint failures specific to changes made in that PR. |
include/tvm/relax/block_builder.h
Outdated
* The name is deliberately verbose to draw attention during a code | ||
* review. The explicit default constructor prevents aggregate |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great tactic :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you. Since any unnormalized operator is now considered ill-formed, I wanted to make it as obvious as possible when the per-operator normalization is being disabled. It was rather amusing how many lint checks needed to be disabled in order to ensure that the parameter name occurs at every call site. (Lint rules generally aren't aimed at making intentionally-verbose code.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I generally approve, though I'm not fond of the principle of having per-operator normalization rules (if we insist on having them, a central registry is definitely the way to do it, though), but should this rule also be attached to the call_tir
variants? Do we anticipate having rules for anything other than the small coterie of call_tir
variants (and maybe call_pure_packed
or call_inplace_packed
)? If we don't anticipate that this group will grow, maybe we don't need a registry.
It should, yes. I applied it just to
I think I do, yes. The
From a language-purity standard, I agree, because per-operator normalization increases the overall language complexity. From a development standpoint, I think it is very useful for evolving the IR without breaking backwards compatibility. Since the IR can be inspected/manipulated both internally to TVM and externally (e.g. MLC-LLM), being able to iterate on the IR design while maintaining backwards compatibility at the source-code level becomes essential. Having the per-operator normalization lets us make the following changes:
|
Prior to this commit, `relax.op.call_tir` could express the TIR arguments as either an in-line tuple, or as a variable bound to a tuple. Because several passes assume the arguments will always be an in-line tuple, this is being codified as the normal form of `relax.op.call_tir`. Any upstream transform that produces `relax.op.call_tir` with arguments provided as a by-variable tuple will either be normalized to an in-line tuple if possible, or will produce an error during the upstream transform otherwise. This commit is specifically to allow the current usage of `Downcast<Tuple>(call->args[1])` in passes such as `CallTIRRewrite`, `FoldConstant`, `FuseTIR`, and `RewriteDataflowReshape`.
6d4dbef
to
30612b0
Compare
(Rebased on top of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems to be reasonable. I would request a similar normalizer be added for call_tir_inplace
and call_tir_with_grad
, as they all expect a tuple of arguments.
Question: Consider a function like the following:
@R.function
def func(x: R.Tuple([...])):
...
R.call_tir(some_primfunc, x, ...)
...
Is this a case we would unequivocally prohibit? The normalizer here would. It is not possible to inline the tuple here. However, it would still work at run time.
Sounds good, and updated. This also included moving the validity checks from
My personal preference would be to allow it, though I can understand the reasons @tqchen laid out in the #15916 discussion for wanting it to be prohibited. That said, I think we still can provide a normalized form of @R.function
def func(x: R.Tuple([...])):
...
R.call_tir(some_primfunc, R.tuple(x[0], x[1], ..., x[n]), ...)
... It's not the cleanest of normalized forms, but it's semantically equivalent and preserves the in-line tuple of arguments. If a later pass (e.g. |
That's a good point that we could normalize that case. I don't like normalizations that turn something simple into something complicated, but it's not likely to be a case that will come up very much. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for addressing my feedback!
Thank you, and agreed that the breaking apart of a tuple for the purpose of making an in-line tuple is probably not a common case. |
Prior to this commit,
relax.op.call_tir
could express the TIR arguments as either an in-line tuple, or as a variable bound to a tuple. Because several passes assume the arguments will always be an in-line tuple, this is being codified as the normal form ofrelax.op.call_tir
. Any upstream transform that producesrelax.op.call_tir
with arguments provided as a by-variable tuple will either be normalized to an in-line tuple if possible, or will produce an error during the upstream transform otherwise.This commit is specifically to allow the current usage of
Downcast<Tuple>(call->args[1])
in passes such asCallTIRRewrite
,FoldConstant
,FuseTIR
, andRewriteDataflowReshape
.