Skip to content

Commit

Permalink
[Relax] Implement Rewriter class for pattern-rewrite (#17149)
Browse files Browse the repository at this point in the history
* [TVMScript][Bugfix] Normalize relax::If with function's TIR var

Prior to this commit, the branches of `relax::If` were normalized
using `EraseToWellDefinedInScope`, using a fresh variable scope.
While this had the intended behavior of preventing variables defined
in a single branch from being usable outside of the conditional, it
also caused the conditional's branches to treat function-scope
symbolic variables as if they were undefined.

This commit updates the `tvm::relax::Normalizer` so that `relax::If`
is normalized within an inherited scope.  This preserves the previous
behavior for symbolic variables defined within a branch, but allows
shapes within a branch to use symbolic variables defined outside of
the branch.

* [Relax] Canonicalize known symbolic shapes in Relax expressions

Prior to this commit, known constants in Relax functions would be
inlined by the `CanonicalizeBindings` pass, but only if they appeared as Relax
expressions (e.g. `R.const` or `R.prim_value`).  Known constants that
appeared as TIR variables (e.g. symbolic shapes) would be kept as
dynamic parameters, even if they were known at compile time.

This commit updates the `CanonicalizeBindings` pass to identify known
values of symbolic shapes, and to use these known values in shape
expressions.

* [Relax][Refactor] Reorganize pattern-matching

A follow-up to #16730.  Now that the
implementations for `rewrite_call` and `rewrite_bindings` are in
separate classes, they can be further split out into separate files.

* [Relax][Refactor] Implement Rewriter class for pattern-rewrite

Prior to this commit, the pattern to be matched and the rewrite to be
performed were provided as separate arguments.  This commit introduces
a new class `ExprRewriter`, which contains both parts.

This abstraction will make it easier to combine multiple different
rewrite rules, applying them in a single pass.

* lint fixes

* Remove unnecessary change which broke a unit test

* lint fix for import order

* Add docstrings

* lint fix

* Lint fix

* lint fixes

* lint fix

* Update based on review comments

* Add test case for matching against arbitrary dtype

* Fix breakage in unit tests

One unit test that had been relying on invalid shape propagation.
Another unit test that required constructed an ill-formed output to
test against.

* Updated base class name from ExprRewriter to PatternMatchingRewriter

* lint fix
  • Loading branch information
Lunderberg authored Jul 24, 2024
1 parent cc8afdb commit 7bd738a
Show file tree
Hide file tree
Showing 24 changed files with 4,142 additions and 751 deletions.
35 changes: 33 additions & 2 deletions include/tvm/relax/block_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,16 +133,47 @@ class BlockBuilderNode : public Object {
* \brief Begin a new scope, with optional parameters that
* are visible within the scope.
*
* Symbolic variables from the parent scope are not available.
*
* \param params Parameters that are visible within the scope.
*
* \note This function should be called when new scope is introduced
* (function, seq) to properly track the variable availability
* and help the best effort deduction.
* (e.g. function bodies) to properly track the variable
* availability and help the best effort deduction.
*
* \sa EndScope
*/
virtual void BeginScope(Optional<Array<Var>> params) = 0;

/*!
* \brief Begin a new scope, which inherits visible parameters from
* its parent scope.
*
* Symbolic variables from the parent scope are available.
*
* \note This function should be called when an inner scope is
* introduced (e.g. conditional branches) to properly track
* the variable availability and help the best effort
* deduction.
*
* \sa EndScope
*/
virtual void BeginInnerScope() = 0;

/*!
* \brief Append a definition to the current scope.
*
* \param var A variable within the current scope.
*
* \note This function should be called when a new variable is
* defined that may impact struct inference (e.g. MatchCast)
* to properly track the variable availability and help the
* best effort deduction.
*
* \sa EndScope
*/
virtual void AddDefinitionToScope(Var var) = 0;

/*! \brief End the previously defined scope. */
virtual void EndScope() = 0;

Expand Down
21 changes: 20 additions & 1 deletion include/tvm/relax/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,10 @@ class ExprMutator : public ExprMutatorBase {
void ReEmitBinding(const VarBindingNode* binding, Expr new_value);

/*!
* \brief Rewrite the expr with a new scope, used in a Function's body and the branches of If.
* \brief Rewrite the expr with a new scope, used in a Function's body.
*
* Visit an expression that may neither access variables from the
* current scope, nor may export definitions into the current scope.
*
* \param body_expr The body to be visited.
* \param params Optional parameters that are visible within the scope.
Expand All @@ -504,6 +507,22 @@ class ExprMutator : public ExprMutatorBase {
*/
Expr VisitWithNewScope(const Expr& body_expr, Optional<Array<Var>> params = NullOpt);

/*!
* \brief Rewrite the expr with a new scope, used in the branches of If.
*
* Visit an expression that may access variables from the current
* scope, but may not export definitions into the current scope.
*
* \param body_expr The body to be visited.
*
* \return The expr after visiting.
*
* \sa VisitWithNewScope
*
* \note The body_expr must be an SeqExpr in the normal form.
*/
Expr VisitWithInnerScope(const Expr& body_expr);

/*!
* \brief Look up the value bound to a variable.
* \param var The var to be looked up.
Expand Down
1 change: 1 addition & 0 deletions include/tvm/script/ir_builder/relax/frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ class FunctionFrameNode : public SeqExprFrameNode {
TVM_DECLARE_FINAL_OBJECT_INFO(FunctionFrameNode, SeqExprFrameNode);

public:
void EnterWithScope() final;
void ExitWithScope() final;
};

Expand Down
8 changes: 7 additions & 1 deletion python/tvm/relax/dpl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,10 @@

from .pattern import *
from .context import *
from .rewrite import rewrite_call, rewrite_bindings
from .rewrite import (
rewrite_call,
rewrite_bindings,
PatternMatchingRewriter,
ExprPatternRewriter,
OrRewriter,
)
186 changes: 183 additions & 3 deletions python/tvm/relax/dpl/rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,196 @@
# specific language governing permissions and limitations
# under the License.
"""APIs for pattern-based rewriting."""
from typing import Dict, Callable

from typing import Dict, Callable, Union

from tvm.ir import IRModule
from tvm.runtime import Object
from tvm._ffi import register_object

from .pattern import DFPattern
from .context import PatternContext

from ..expr import Expr, Function, Var
from . import _ffi as ffi


@register_object("relax.dpl.PatternMatchingRewriter")
class PatternMatchingRewriter(Object):
"""A pattern-matching rewriter for Relax"""

@staticmethod
def from_pattern(
pattern: DFPattern,
func: Callable[[Expr, Dict[DFPattern, Expr]], Expr],
) -> "PatternMatchingRewriter":
"""Construct from a pattern and rewriter-function
The replacements performed by the rewriter will be equivalent
to using the `pattern` and `func` as arguments to
`rewrite_call`.
Parameters
----------
pattern: DFPattern
The pattern to be matched against.
func: Callable[[Expr, Dict[DFPattern, Expr]], Expr]
A function that returns the rewritten expression. See
`rewrite_call` for details and examples.
Returns
-------
rewriter_obj: PatternMatchingRewriter
The rewriter object
"""
return ffi.PatternMatchingRewriterFromPattern(
pattern,
func,
) # type: ignore

@staticmethod
def from_module(mod: IRModule) -> "PatternMatchingRewriter":
"""Construct a rewriter from an IRModule
The IRModule must have two publicly-exposed functions,
`pattern` and `replacement`, where `pattern` and `replacement`
have the same function signature, as shown in the example
below.
.. code-block:: python
@I.ir_module
class RewriteAddIntoMultiply:
@R.function
def pattern(A: R.Tensor):
B = A + A
return B
@R.function
def replacement(A: R.Tensor):
B = A * 2
return B
rewriter = PatternMatchingRewriter.from_module(RewriteAddIntoMultiply)
rewritten_ir_module = rewriter(ir_module)
To support the common case of defining an IRModule with
TVMScript, then immediately turning it into a rewriter, the
`@R.rewriter` annotation can be used.
.. code-block:: python
@R.rewriter
class RewriteAddIntoMultiply:
@R.function
def pattern(A: R.Tensor):
B = A + A
return B
@R.function
def replacement(A: R.Tensor):
B = A * 2
return B
rewritten_ir_module = RewriteAddIntoMultiply(ir_module)
Parameters
----------
mod: IRModule
A module with `pattern` and `replacement` functions,
defining a rewrite rule.
Returns
-------
rewriter_obj: PatternMatchingRewriter
The rewriter object
"""
return ffi.PatternMatchingRewriterFromModule(mod) # type: ignore

def __call__(self, obj: Union[Expr, IRModule]) -> Union[Expr, IRModule]:
"""Apply the rewriter
Parameters
----------
obj: Union[Expr, IRModule])
The object to be rewritten. May be applied to either a
relax expression, or an IRModule.
Returns
-------
updated: Union[Expr, IRModule]
The rewritten object
"""
return ffi.PatternMatchingRewriterApply(self, obj)

def __or__(self, other: "PatternMatchingRewriter") -> "PatternMatchingRewriter":
"""Compose two rewriters
Composing two rewrite rules together allows them to be applied
in a single Relax-level transformation.
Parameters
----------
other: PatternMatchingRewriter
Another rewrite rule
Returns
-------
PatternMatchingRewriter
A rewriter that will apply either rewrite pattern
"""
return OrRewriter(self, other)


@register_object("relax.dpl.ExprPatternRewriter")
class ExprPatternRewriter(PatternMatchingRewriter):
def __init__(self, pattern, func):
self.__init_handle_by_constructor__(
ffi.PatternRewriter,
pattern,
func,
) # type: ignore


@register_object("relax.dpl.OrRewriter")
class OrRewriter(PatternMatchingRewriter):
def __init__(self, lhs, rhs):
self.__init_handle_by_constructor__(
ffi.OrRewriter,
lhs,
rhs,
) # type: ignore


@register_object("relax.dpl.TupleRewriter")
class TupleRewriter(PatternMatchingRewriter):
def __init__(self, patterns, func):
self.__init_handle_by_constructor__(
ffi.TupleRewriter,
patterns,
func,
) # type: ignore


def rewrite_call(
pattern: DFPattern, rewriter: Callable[[Expr, Dict[DFPattern, Expr]], Expr], func: Function
pattern: DFPattern,
rewriter: Callable[[Expr, Dict[DFPattern, Expr]], Expr],
func: Function,
) -> Function:
"""
Rewrite a function with the given pattern and the rewriter function.
Expand Down
48 changes: 46 additions & 2 deletions python/tvm/script/ir_builder/relax/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
import builtins
import functools
import inspect
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Type

import tvm
from tvm import DataType, relax
from tvm.ir import PrimExpr, VDevice
from tvm.ir import PrimExpr, VDevice, IRModule
from tvm.relax import (
Call,
Expr,
Expand All @@ -35,6 +35,7 @@
VarBinding,
const,
)
from tvm.relax.dpl import PatternMatchingRewriter

############################### Operators ###############################
from tvm.relax.op import (
Expand Down Expand Up @@ -306,6 +307,48 @@ def func_ret_value(value: Expr) -> None:
return _ffi_api.FuncRetValue(value) # type: ignore[attr-defined] # pylint: disable=no-member


def rewriter(rewriter_mod: Union[IRModule, Type]) -> PatternMatchingRewriter:
"""Define a pattern-rewrite rule
The IRModule must have two publicly-exposed functions, `pattern`
and `replacement`, where `pattern` and `replacement` have the same
function signature.
.. code-block:: python
@R.rewriter
class RewriteAddIntoMultiply:
@R.function
def pattern(A: R.Tensor):
B = A + A
return B
@R.function
def replacement(A: R.Tensor):
B = A * 2
return B
Parameters
----------
rewriter_mod: Union[IRModule, Type]
Either an IRModule that defines a rewrite pattern, or a
TVMScript class that can be parsed into an IRModule.
Returns
-------
rewriter: PatternMatchingRewriter
A rewriter object, which can be applied either to a Relax
function or to an entire IRModule.
"""
if not isinstance(rewriter_mod, IRModule):
rewriter_mod = tvm.script.ir_module(rewriter_mod)

return PatternMatchingRewriter.from_module(rewriter_mod)


############################# BindingBlock ##############################


Expand Down Expand Up @@ -765,6 +808,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
"dequantize",
"repeat",
"reshape",
"rewriter",
"tensor_to_shape",
"shape_to_tensor",
"rocm",
Expand Down
Loading

0 comments on commit 7bd738a

Please sign in to comment.