Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unity] Implement FNormalize for relax.op.call_tir #16068

Merged
merged 4 commits into from
Nov 14, 2023

Conversation

Lunderberg
Copy link
Contributor

Prior to this commit, relax.op.call_tir could express the TIR arguments as either an in-line tuple, or as a variable bound to a tuple. Because several passes assume the arguments will always be an in-line tuple, this is being codified as the normal form of relax.op.call_tir. Any upstream transform that produces relax.op.call_tir with arguments provided as a by-variable tuple will either be normalized to an in-line tuple if possible, or will produce an error during the upstream transform otherwise.

This commit is specifically to allow the current usage of Downcast<Tuple>(call->args[1]) in passes such as CallTIRRewrite, FoldConstant, FuseTIR, and RewriteDataflowReshape.

@Lunderberg
Copy link
Contributor Author

Lunderberg commented Nov 3, 2023

This is currently marked as a draft PR, as it depends on functionality introduced in #16067. #16067 is included in the PR branch for #16068, so that CI testing can proceed, but it does make this PR accidentally look 3x as large. The changes that are specific to this PR are in the last commit of this branch, and can be viewed independently here.

@Lunderberg Lunderberg force-pushed the unity_normalize_call_tir_operator branch 3 times, most recently from c43b2cb to 6d4dbef Compare November 6, 2023 20:51
@Lunderberg
Copy link
Contributor Author

Rebased to include latest version of pre-requisite PR #16067, to resolve CI lint failures specific to changes made in that PR.

Comment on lines 242 to 243
* The name is deliberately verbose to draw attention during a code
* review. The explicit default constructor prevents aggregate
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great tactic :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you. Since any unnormalized operator is now considered ill-formed, I wanted to make it as obvious as possible when the per-operator normalization is being disabled. It was rather amusing how many lint checks needed to be disabled in order to ensure that the parameter name occurs at every call site. (Lint rules generally aren't aimed at making intentionally-verbose code.)

Copy link
Contributor

@slyubomirsky slyubomirsky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I generally approve, though I'm not fond of the principle of having per-operator normalization rules (if we insist on having them, a central registry is definitely the way to do it, though), but should this rule also be attached to the call_tir variants? Do we anticipate having rules for anything other than the small coterie of call_tir variants (and maybe call_pure_packed or call_inplace_packed)? If we don't anticipate that this group will grow, maybe we don't need a registry.

@Lunderberg
Copy link
Contributor Author

on rules (if we insist on having them, a central registry is definitely the way to do it, though), but should this rule also be attached to the call_tir variants

It should, yes. I applied it just to call_tir initially as it was the case we ran into, and to keep the first example usage as simple as possible. Do you have preferences between adding the other variants in this PR, or in a follow-up PR?

Do we anticipate having rules for anything other than the small coterie of call_tir variants

I think I do, yes. The FNormalize functionality also makes dynamic tuple indices (PR#16002) much simpler to implement.

though I'm not fond of the principle of having per-operator normalization rules

From a language-purity standard, I agree, because per-operator normalization increases the overall language complexity. From a development standpoint, I think it is very useful for evolving the IR without breaking backwards compatibility. Since the IR can be inspected/manipulated both internally to TVM and externally (e.g. MLC-LLM), being able to iterate on the IR design while maintaining backwards compatibility at the source-code level becomes essential.

Having the per-operator normalization lets us make the following changes:

  • Provide guarantees to consumers without imposing a burden on producers. (e.g. call_tir normalizing to an in-line tuple)

  • Generalize an operator without introducing equivalent representations.

    For example, in [Draft][Unity] Allow dynamic indices to TupleGetItem #16002, I introduce dynamic tuple indices. The first implementation changed TupleGetItemNode::index from int to relax::Expr, and required updating a very large number of consumer passes. The second implementation introduced a tuple_get_item_dyn operator with Array<Expr> args = {tuple, index}. However, this would allow a static index to be represented either as TupleGetItem(tuple, 1) or as Call(tuple_get_item_dyn, {tuple, PrimValue(1)}), and so the checks for TupleGetItem would no longer cover all cases of a static tuple index.

    Using FNormalize to normalize Call(tuple_get_item_dyn, {tuple, PrimValue(1)}) into TupleGetItem(tuple, 1) keeps a single representation for static indices.

  • Restrict the use of existing operators. If an operator's semantics are found to be unsound, we can normalize any sound usage of it to a new operator, while throwing an error for unsound usage. This would be an explicit backwards-compatibility breakage, but being able to apply deprecation warnings and errors in a single location would use the same FNormalize functionality.

Prior to this commit, `relax.op.call_tir` could express the TIR
arguments as either an in-line tuple, or as a variable bound to a
tuple.  Because several passes assume the arguments will always be an
in-line tuple, this is being codified as the normal form of
`relax.op.call_tir`.  Any upstream transform that produces
`relax.op.call_tir` with arguments provided as a by-variable tuple
will either be normalized to an in-line tuple if possible, or will
produce an error during the upstream transform otherwise.

This commit is specifically to allow the current usage of
`Downcast<Tuple>(call->args[1])` in passes such as `CallTIRRewrite`,
`FoldConstant`, `FuseTIR`, and `RewriteDataflowReshape`.
@Lunderberg Lunderberg force-pushed the unity_normalize_call_tir_operator branch from 6d4dbef to 30612b0 Compare November 7, 2023 14:41
@Lunderberg Lunderberg marked this pull request as ready for review November 7, 2023 14:42
@Lunderberg
Copy link
Contributor Author

(Rebased on top of unity as #16067 has landed, and marked this PR as ready for review.)

Copy link
Contributor

@slyubomirsky slyubomirsky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be reasonable. I would request a similar normalizer be added for call_tir_inplace and call_tir_with_grad, as they all expect a tuple of arguments.

Question: Consider a function like the following:

@R.function
def func(x: R.Tuple([...])):
    ...
    R.call_tir(some_primfunc, x, ...)
    ...

Is this a case we would unequivocally prohibit? The normalizer here would. It is not possible to inline the tuple here. However, it would still work at run time.

@Lunderberg
Copy link
Contributor Author

This seems to be reasonable. I would request a similar normalizer be added for call_tir_inplace and call_tir_with_grad, as they all expect a tuple of arguments.

Sounds good, and updated. This also included moving the validity checks from InferStructInfoCallTIRInplace to NormalizeCallTIRInplace, since the shape inference is done before normalization.

Is this a case we would unequivocally prohibit?

My personal preference would be to allow it, though I can understand the reasons @tqchen laid out in the #15916 discussion for wanting it to be prohibited.

That said, I think we still can provide a normalized form of call_tir, so we don't need to throw an error back at the user. Because all relax tuples have a fixed size, we could still normalize it to an in-line tuple, but one whose contents were populated by TupleGetItem instances.

@R.function
def func(x: R.Tuple([...])):
    ...
    R.call_tir(some_primfunc, R.tuple(x[0], x[1], ..., x[n]), ...)
    ...

It's not the cleanest of normalized forms, but it's semantically equivalent and preserves the in-line tuple of arguments. If a later pass (e.g. BindParams or a ExpandTupleParams pass I'm putting together) updates the function arguments, then this form can be simplified with the existing CanonicalizeBindings.

@slyubomirsky
Copy link
Contributor

That's a good point that we could normalize that case. I don't like normalizations that turn something simple into something complicated, but it's not likely to be a case that will come up very much.

Copy link
Contributor

@slyubomirsky slyubomirsky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for addressing my feedback!

@Lunderberg Lunderberg merged commit 0ddfc65 into apache:unity Nov 14, 2023
2 checks passed
@Lunderberg Lunderberg deleted the unity_normalize_call_tir_operator branch November 14, 2023 21:03
@Lunderberg
Copy link
Contributor Author

Thank you, and agreed that the breaking apart of a tuple for the purpose of making an in-line tuple is probably not a common case.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants