Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relax] Implement Rewriter class for pattern-rewrite #17149

Merged
merged 17 commits into from
Jul 24, 2024

Conversation

Lunderberg
Copy link
Contributor

Prior to this commit, the pattern to be matched and the rewrite to be performed were provided as separate arguments to rewrite_call. This commit introduces a new class ExprRewriter, which contains both parts. This abstraction makes it easier to combine multiple independent rewrite rules, then apply them all in a IRModule pass.

This implementation also allows pattern rewrite rule to be written using TVMScript, by representing both the pattern to be matched, and the replacement to be applied as functions. For example:

@R.rewriter
class Rewriter:
    @R.function
    def pattern(
        A: R.Tensor([16], "float32"),
        B: R.Tensor([16], "float32"),
    ):
        C = R.add(A, B)
        return C

    @R.function
    def replacement(
        A: R.Tensor([16], "float32"),
        B: R.Tensor([16], "float32"),
    ):
        C = R.call_pure_packed(
            "my_optimized_add_impl", A, B,
            sinfo_args=R.Tensor([16], "float32"),
        )
        return C

Prior to this commit, the branches of `relax::If` were normalized
using `EraseToWellDefinedInScope`, using a fresh variable scope.
While this had the intended behavior of preventing variables defined
in a single branch from being usable outside of the conditional, it
also caused the conditional's branches to treat function-scope
symbolic variables as if they were undefined.

This commit updates the `tvm::relax::Normalizer` so that `relax::If`
is normalized within an inherited scope.  This preserves the previous
behavior for symbolic variables defined within a branch, but allows
shapes within a branch to use symbolic variables defined outside of
the branch.
Prior to this commit, known constants in Relax functions would be
inlined by the `CanonicalizeBindings` pass, but only if they appeared as Relax
expressions (e.g. `R.const` or `R.prim_value`).  Known constants that
appeared as TIR variables (e.g. symbolic shapes) would be kept as
dynamic parameters, even if they were known at compile time.

This commit updates the `CanonicalizeBindings` pass to identify known
values of symbolic shapes, and to use these known values in shape
expressions.
A follow-up to apache#16730.  Now that the
implementations for `rewrite_call` and `rewrite_bindings` are in
separate classes, they can be further split out into separate files.
Prior to this commit, the pattern to be matched and the rewrite to be
performed were provided as separate arguments.  This commit introduces
a new class `ExprRewriter`, which contains both parts.

This abstraction will make it easier to combine multiple different
rewrite rules, applying them in a single pass.
@Lunderberg Lunderberg requested a review from sunggg July 10, 2024 19:23
@Lunderberg
Copy link
Contributor Author

@sunggg @csullivan

@Lunderberg Lunderberg force-pushed the pattern_matching_improvements branch from 6d77ec9 to 28a72e9 Compare July 10, 2024 20:41
Copy link
Contributor

@sunggg sunggg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @Lunderberg! Thank you for the cool addition!
I haven't taken a deeper look at the detailed mechanics yet, but I love how this PR allows us to manage the pattern and rewrite conveniently with better readability :)

My main question for you is if this PR is aiming the full support for the existing pattern language.
It is okay if there are some rough edges that we don't currently support, but I would like to clearly document them somewhere to avoid the potential misusage.

Here are some of the cases I wonder how these will be supported with your PR:

include/tvm/relax/block_builder.h Outdated Show resolved Hide resolved
python/tvm/relax/dpl/rewrite.py Outdated Show resolved Hide resolved
python/tvm/relax/dpl/rewrite.py Show resolved Hide resolved
tests/python/relax/test_dataflow_rewriter.py Show resolved Hide resolved
tests/python/relax/test_dataflow_rewriter.py Show resolved Hide resolved
@Lunderberg
Copy link
Contributor Author

My main question for you is if this PR is aiming the full support for the existing pattern language. It is okay if there are some rough edges that we don't currently support, but I would like to clearly document them somewhere to avoid the potential misusage.

I've been aiming to support as much of the existing pattern language as possible, but not the entire syntax. There are some portions of the pattern language that don't have clear Relax equivalents, and those wouldn't be supported.

Here are some of the cases I wonder how these will be supported with your PR:

* `~` not pattern: https://github.com/apache/tvm/blob/main/tests/python/relax/test_dataflow_pattern.py#L222

This is probably the biggest one that doesn't have a relax-level equivalent. The main usage would be to avoid aliasing in a pattern-match. (e.g. An implementation that is valid so long as two wildcard patterns do not refer to the same tensor.)

* If we want to match with any placeholder (`wildcard()`), which could be tensor, constant, etc., how should we express this?

If the pattern has an argument var_name: R.Object, then it would match against any Relax value. This is converted to StructInfoPattern(WildcardPattern(), ObjectStructInfo()) when producing the pattern rule. When the StructInfoPattern validates a match, it does so using StructInfoBaseChecker. Since ObjectStructInfo() is the base of all other relax types, this validation always passes.

There are some edge cases where this doesn't work, mostly arising from struct inference applied within the pattern or replacement. For example, suppose the pattern has wildcard A: R.Object and returns A + A. This would raise an error, because Relax requires binary operations to have either R.Prim or R.Tensor struct info. Resolving this edge case would require re-visiting this discussion on whether the struct inference should throw an error for provably-incorrect usage, or should throw an error for not-provably-correct usage.

* Can we match with any number of function arguments?

We can have arbitrarily many function arguments for the pattern/replacement. However, each pattern must have a fixed number of arguments.

@Lunderberg Lunderberg force-pushed the pattern_matching_improvements branch from 0f8b4fe to 234ddde Compare July 15, 2024 19:59
One unit test that had been relying on invalid shape propagation.
Another unit test that required constructed an ill-formed output to
test against.
@sunggg
Copy link
Contributor

sunggg commented Jul 16, 2024

Thanks, @Lunderberg for the clarification! I think it is absolutely okay to have some rough edges that we cannot support yet, but I hope we clearly document in the code so that people won't get confused.

@Lunderberg
Copy link
Contributor Author

I think it is absolutely okay to have some rough edges that we cannot support yet, but I hope we clearly document in the code so that people won't get confused.

Agreed. Any unsupported functionality should raise an error during the conversion, which (hopefully) will make it as clear as possible. I'm a big fan of error-checking as early as possible, especially when the input is relatively unstructured.

Copy link
Contributor

@sunggg sunggg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @Lunderberg for addressing my comments and pushing this direction! :)

@sunggg sunggg merged commit 7bd738a into apache:main Jul 24, 2024
19 checks passed
@Lunderberg Lunderberg deleted the pattern_matching_improvements branch July 24, 2024 16:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants