diff --git a/include/tvm/relax/block_builder.h b/include/tvm/relax/block_builder.h index 7ca9aab6d5aa..ad2b9820707a 100644 --- a/include/tvm/relax/block_builder.h +++ b/include/tvm/relax/block_builder.h @@ -133,16 +133,47 @@ class BlockBuilderNode : public Object { * \brief Begin a new scope, with optional parameters that * are visible within the scope. * + * Symbolic variables from the parent scope are not available. + * * \param params Parameters that are visible within the scope. * * \note This function should be called when new scope is introduced - * (function, seq) to properly track the variable availability - * and help the best effort deduction. + * (e.g. function bodies) to properly track the variable + * availability and help the best effort deduction. * * \sa EndScope */ virtual void BeginScope(Optional> params) = 0; + /*! + * \brief Begin a new scope, which inherits visible parameters from + * its parent scope. + * + * Symbolic variables from the parent scope are available. + * + * \note This function should be called when an inner scope is + * introduced (e.g. conditional branches) to properly track + * the variable availability and help the best effort + * deduction. + * + * \sa EndScope + */ + virtual void BeginInnerScope() = 0; + + /*! + * \brief Append a definition to the current scope. + * + * \param var A variable within the current scope. + * + * \note This function should be called when a new variable is + * defined that may impact struct inference (e.g. MatchCast) + * to properly track the variable availability and help the + * best effort deduction. + * + * \sa EndScope + */ + virtual void AddDefinitionToScope(Var var) = 0; + /*! \brief End the previously defined scope. */ virtual void EndScope() = 0; diff --git a/include/tvm/relax/expr_functor.h b/include/tvm/relax/expr_functor.h index ce209ccd460f..c3aea24dcb50 100644 --- a/include/tvm/relax/expr_functor.h +++ b/include/tvm/relax/expr_functor.h @@ -494,7 +494,10 @@ class ExprMutator : public ExprMutatorBase { void ReEmitBinding(const VarBindingNode* binding, Expr new_value); /*! - * \brief Rewrite the expr with a new scope, used in a Function's body and the branches of If. + * \brief Rewrite the expr with a new scope, used in a Function's body. + * + * Visit an expression that may neither access variables from the + * current scope, nor may export definitions into the current scope. * * \param body_expr The body to be visited. * \param params Optional parameters that are visible within the scope. @@ -504,6 +507,22 @@ class ExprMutator : public ExprMutatorBase { */ Expr VisitWithNewScope(const Expr& body_expr, Optional> params = NullOpt); + /*! + * \brief Rewrite the expr with a new scope, used in the branches of If. + * + * Visit an expression that may access variables from the current + * scope, but may not export definitions into the current scope. + * + * \param body_expr The body to be visited. + * + * \return The expr after visiting. + * + * \sa VisitWithNewScope + * + * \note The body_expr must be an SeqExpr in the normal form. + */ + Expr VisitWithInnerScope(const Expr& body_expr); + /*! * \brief Look up the value bound to a variable. * \param var The var to be looked up. diff --git a/include/tvm/script/ir_builder/relax/frame.h b/include/tvm/script/ir_builder/relax/frame.h index 1ad681388912..0ee144f03e77 100644 --- a/include/tvm/script/ir_builder/relax/frame.h +++ b/include/tvm/script/ir_builder/relax/frame.h @@ -122,6 +122,7 @@ class FunctionFrameNode : public SeqExprFrameNode { TVM_DECLARE_FINAL_OBJECT_INFO(FunctionFrameNode, SeqExprFrameNode); public: + void EnterWithScope() final; void ExitWithScope() final; }; diff --git a/python/tvm/relax/dpl/__init__.py b/python/tvm/relax/dpl/__init__.py index 6451238428c2..a4f3f4063e90 100644 --- a/python/tvm/relax/dpl/__init__.py +++ b/python/tvm/relax/dpl/__init__.py @@ -19,4 +19,10 @@ from .pattern import * from .context import * -from .rewrite import rewrite_call, rewrite_bindings +from .rewrite import ( + rewrite_call, + rewrite_bindings, + PatternMatchingRewriter, + ExprPatternRewriter, + OrRewriter, +) diff --git a/python/tvm/relax/dpl/rewrite.py b/python/tvm/relax/dpl/rewrite.py index 291061090fc2..96c69e9266a2 100644 --- a/python/tvm/relax/dpl/rewrite.py +++ b/python/tvm/relax/dpl/rewrite.py @@ -15,16 +15,196 @@ # specific language governing permissions and limitations # under the License. """APIs for pattern-based rewriting.""" -from typing import Dict, Callable + +from typing import Dict, Callable, Union + +from tvm.ir import IRModule +from tvm.runtime import Object +from tvm._ffi import register_object + from .pattern import DFPattern from .context import PatternContext - from ..expr import Expr, Function, Var from . import _ffi as ffi +@register_object("relax.dpl.PatternMatchingRewriter") +class PatternMatchingRewriter(Object): + """A pattern-matching rewriter for Relax""" + + @staticmethod + def from_pattern( + pattern: DFPattern, + func: Callable[[Expr, Dict[DFPattern, Expr]], Expr], + ) -> "PatternMatchingRewriter": + """Construct from a pattern and rewriter-function + + The replacements performed by the rewriter will be equivalent + to using the `pattern` and `func` as arguments to + `rewrite_call`. + + Parameters + ---------- + pattern: DFPattern + + The pattern to be matched against. + + func: Callable[[Expr, Dict[DFPattern, Expr]], Expr] + + A function that returns the rewritten expression. See + `rewrite_call` for details and examples. + + + Returns + ------- + rewriter_obj: PatternMatchingRewriter + + The rewriter object + + """ + return ffi.PatternMatchingRewriterFromPattern( + pattern, + func, + ) # type: ignore + + @staticmethod + def from_module(mod: IRModule) -> "PatternMatchingRewriter": + """Construct a rewriter from an IRModule + + The IRModule must have two publicly-exposed functions, + `pattern` and `replacement`, where `pattern` and `replacement` + have the same function signature, as shown in the example + below. + + .. code-block:: python + + @I.ir_module + class RewriteAddIntoMultiply: + @R.function + def pattern(A: R.Tensor): + B = A + A + return B + + @R.function + def replacement(A: R.Tensor): + B = A * 2 + return B + + rewriter = PatternMatchingRewriter.from_module(RewriteAddIntoMultiply) + rewritten_ir_module = rewriter(ir_module) + + To support the common case of defining an IRModule with + TVMScript, then immediately turning it into a rewriter, the + `@R.rewriter` annotation can be used. + + .. code-block:: python + + @R.rewriter + class RewriteAddIntoMultiply: + @R.function + def pattern(A: R.Tensor): + B = A + A + return B + + @R.function + def replacement(A: R.Tensor): + B = A * 2 + return B + + rewritten_ir_module = RewriteAddIntoMultiply(ir_module) + + Parameters + ---------- + mod: IRModule + + A module with `pattern` and `replacement` functions, + defining a rewrite rule. + + + Returns + ------- + rewriter_obj: PatternMatchingRewriter + + The rewriter object + + """ + return ffi.PatternMatchingRewriterFromModule(mod) # type: ignore + + def __call__(self, obj: Union[Expr, IRModule]) -> Union[Expr, IRModule]: + """Apply the rewriter + + Parameters + ---------- + obj: Union[Expr, IRModule]) + + The object to be rewritten. May be applied to either a + relax expression, or an IRModule. + + Returns + ------- + updated: Union[Expr, IRModule] + + The rewritten object + + """ + return ffi.PatternMatchingRewriterApply(self, obj) + + def __or__(self, other: "PatternMatchingRewriter") -> "PatternMatchingRewriter": + """Compose two rewriters + + Composing two rewrite rules together allows them to be applied + in a single Relax-level transformation. + + Parameters + ---------- + other: PatternMatchingRewriter + + Another rewrite rule + + Returns + ------- + PatternMatchingRewriter + + A rewriter that will apply either rewrite pattern + + """ + return OrRewriter(self, other) + + +@register_object("relax.dpl.ExprPatternRewriter") +class ExprPatternRewriter(PatternMatchingRewriter): + def __init__(self, pattern, func): + self.__init_handle_by_constructor__( + ffi.PatternRewriter, + pattern, + func, + ) # type: ignore + + +@register_object("relax.dpl.OrRewriter") +class OrRewriter(PatternMatchingRewriter): + def __init__(self, lhs, rhs): + self.__init_handle_by_constructor__( + ffi.OrRewriter, + lhs, + rhs, + ) # type: ignore + + +@register_object("relax.dpl.TupleRewriter") +class TupleRewriter(PatternMatchingRewriter): + def __init__(self, patterns, func): + self.__init_handle_by_constructor__( + ffi.TupleRewriter, + patterns, + func, + ) # type: ignore + + def rewrite_call( - pattern: DFPattern, rewriter: Callable[[Expr, Dict[DFPattern, Expr]], Expr], func: Function + pattern: DFPattern, + rewriter: Callable[[Expr, Dict[DFPattern, Expr]], Expr], + func: Function, ) -> Function: """ Rewrite a function with the given pattern and the rewriter function. diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index ef9ae775450b..c4be8afac4d2 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -20,11 +20,11 @@ import builtins import functools import inspect -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Type import tvm from tvm import DataType, relax -from tvm.ir import PrimExpr, VDevice +from tvm.ir import PrimExpr, VDevice, IRModule from tvm.relax import ( Call, Expr, @@ -35,6 +35,7 @@ VarBinding, const, ) +from tvm.relax.dpl import PatternMatchingRewriter ############################### Operators ############################### from tvm.relax.op import ( @@ -306,6 +307,48 @@ def func_ret_value(value: Expr) -> None: return _ffi_api.FuncRetValue(value) # type: ignore[attr-defined] # pylint: disable=no-member +def rewriter(rewriter_mod: Union[IRModule, Type]) -> PatternMatchingRewriter: + """Define a pattern-rewrite rule + + The IRModule must have two publicly-exposed functions, `pattern` + and `replacement`, where `pattern` and `replacement` have the same + function signature. + + .. code-block:: python + + @R.rewriter + class RewriteAddIntoMultiply: + @R.function + def pattern(A: R.Tensor): + B = A + A + return B + + @R.function + def replacement(A: R.Tensor): + B = A * 2 + return B + + Parameters + ---------- + rewriter_mod: Union[IRModule, Type] + + Either an IRModule that defines a rewrite pattern, or a + TVMScript class that can be parsed into an IRModule. + + Returns + ------- + rewriter: PatternMatchingRewriter + + A rewriter object, which can be applied either to a Relax + function or to an entire IRModule. + + """ + if not isinstance(rewriter_mod, IRModule): + rewriter_mod = tvm.script.ir_module(rewriter_mod) + + return PatternMatchingRewriter.from_module(rewriter_mod) + + ############################# BindingBlock ############################## @@ -765,6 +808,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "dequantize", "repeat", "reshape", + "rewriter", "tensor_to_shape", "shape_to_tensor", "rocm", diff --git a/python/tvm/script/parser/core/utils.py b/python/tvm/script/parser/core/utils.py index 3edae3f25a33..8ad64f5dbc68 100644 --- a/python/tvm/script/parser/core/utils.py +++ b/python/tvm/script/parser/core/utils.py @@ -100,19 +100,29 @@ def is_defined_in_class(frames: List[FrameType], obj: Any) -> bool: res : bool The result if the object is defined in a class scope. """ + + def _is_tvmscript_class_annotator(line: str) -> bool: + """Checks if the line contains a TVMScript annotator for a class + + These match either `@I.ir_module` or `@R.rewriter`, or their + imported names `@ir_module` or `@rewriter`. + """ + + return line.startswith("@") and ("ir_module" in line or "rewriter" in line) + if len(frames) > 2: frame_info = frames[2] code_context = frame_info.code_context if code_context is None: return False line = code_context[0].strip() - if line.startswith("@") and "ir_module" in line: + if _is_tvmscript_class_annotator(line): return True if line.startswith("class"): lineno = frame_info.lineno if lineno >= 2: source, _ = findsource(obj) line = source[lineno - 2].strip() - if line.startswith("@") and "ir_module" in line: + if _is_tvmscript_class_annotator(line): return True return False diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index f6aec79a4ac4..b8092bbf3a4d 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -178,29 +178,54 @@ class BlockBuilderImpl : public BlockBuilderNode { // but can be further improved. // // TODO(relax-team): Add support for relax Var in struct info annotations. - Map shape_var_map; - for (const Var& var : params.value_or(Array())) { - const Map& var_map = StructInfoVarCollector::Collect(GetStructInfo(var)); - for (const auto& kv : var_map) { - const tir::Var& shape_var = kv.first; - const PrimExpr& shape_expr = kv.second; - auto it = shape_var_map.find(shape_var); - if (it == shape_var_map.end()) { - shape_var_map.Set(shape_var, shape_expr); - // Expose the shape variable as non-negative, for purposes - // of shape inference. In many cases, knowning that the - // shape variable is non-negative allows for simpler - // expressions for dynamic shapes. - analyzer_.MarkGlobalNonNegValue(shape_var); - } else { - const PrimExpr& old_shape_expr = (*it).second; - CHECK(analyzer_.CanProveEqual(old_shape_expr, shape_expr)) - << "Inconsistent shape var " << shape_var << " in scope: " << old_shape_expr << " vs " - << shape_expr; - } + + scope_stack_.emplace_back(ScopeFrame()); + if (params.defined()) { + for (const auto& param : params.value()) { + AddDefinitionToScope(param); + } + } + } + + void BeginInnerScope() final { + if (scope_stack_.size()) { + scope_stack_.emplace_back(scope_stack_.back()); + } else { + scope_stack_.emplace_back(ScopeFrame()); + } + } + + void AddDefinitionToScope(Var var) final { + if (scope_stack_.empty()) { + return; + } + + auto& shape_var_map = CurrentScopeFrame()->shape_var_map; + + // The current implementation handles the collection of shape var + // defined in parameter struct info annotations. The implementation + // is correct (since we will simply erase all relax Vars in EraseToWellDefined), + // but can be further improved. + Map var_map = StructInfoVarCollector::Collect(GetStructInfo(var)); + for (const auto& kv : var_map) { + const tir::Var& shape_var = kv.first; + const PrimExpr& shape_expr = kv.second; + auto it = shape_var_map.find(shape_var); + if (it == shape_var_map.end()) { + shape_var_map.Set(shape_var, shape_expr); + // Expose the shape variable as non-negative, for purposes + // of shape inference. In many cases, knowning that the + // shape variable is non-negative allows for simpler + // expressions for dynamic shapes. + analyzer_.MarkGlobalNonNegValue(shape_var); + } else { + const PrimExpr& old_shape_expr = (*it).second; + CHECK(old_shape_expr.same_as(shape_expr) || + analyzer_.CanProveEqual(old_shape_expr, shape_expr)) + << "Inconsistent shape var " << shape_var << " in scope: " << old_shape_expr << " vs " + << shape_expr; } } - scope_stack_.emplace_back(ScopeFrame({std::move(shape_var_map)})); } void EndScope() final { scope_stack_.pop_back(); } @@ -236,6 +261,8 @@ class BlockBuilderImpl : public BlockBuilderNode { cur_frame->bindings.push_back(match_cast); // NOTE match shape do not follow simple binding rule // as a result should not appear in binding table. + + AddDefinitionToScope(var); return var; } @@ -271,6 +298,7 @@ class BlockBuilderImpl : public BlockBuilderNode { // NOTE match shape do not follow simple binding rule // as a result should not appear in binding table. cur_frame->bindings.push_back(binding); + AddDefinitionToScope(match_cast->var); } else { LOG(FATAL) << "Unsupported binding type: " << binding->GetTypeKey(); } @@ -831,7 +859,9 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor Optional { @@ -843,15 +873,18 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor> params = NullOpt) { + if (params.defined()) { + this->BeginScope(params.value()); + } else { + this->BeginInnerScope(); + } + + Expr ret; + // SeqExpr do not need to prepare for normalization. if (expr.as()) { - this->BeginScope(params); - Expr ret = this->VisitExpr(expr); - this->EndScope(); - return ret; + ret = this->VisitExpr(expr); } else { - this->BeginScope(params); - this->BeginBindingBlock(); Expr post = this->NormalizeArgument(expr); BindingBlock prologue = this->EndBlock(); @@ -868,9 +901,11 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctorbody))); - this->EndScope(); - return seq; + ret = seq; } + + this->EndScope(); + return ret; } Array FlattenBlocks(const Array& blocks) { diff --git a/src/relax/ir/dataflow_block_rewriter.cc b/src/relax/ir/dataflow_block_rewriter.cc new file mode 100644 index 000000000000..fb08dfe96a17 --- /dev/null +++ b/src/relax/ir/dataflow_block_rewriter.cc @@ -0,0 +1,452 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/ir/dataflow_block_rewriter.cc + * \brief A transform to match a Relax DataflowBlock and rewrite + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "dataflow_matcher.h" +#include "dataflow_rewriter.h" + +namespace tvm { +namespace relax { + +class MatcherUseDefAnalysis : public relax::ExprVisitor { + public: + std::vector vars; + std::map> def2use; + // caller -> callee table. + std::map> caller2callees; + + const VarNode* cur_user_; + + void VisitBinding_(const VarBindingNode* binding) override { + // init + cur_user_ = binding->var.get(); + this->VisitVarDef(binding->var); + this->VisitExpr(binding->value); + cur_user_ = nullptr; + } + + void VisitExpr_(const VarNode* op) override { + if (nullptr == cur_user_) return; + + auto check_and_push = [](std::vector& vec, const VarNode* var) { + if (std::find(vec.begin(), vec.end(), var) == vec.end()) { + vec.push_back(var); + } + }; + + check_and_push(def2use[op], cur_user_); + check_and_push(vars, op); + + caller2callees[cur_user_].push_back(op); + } +}; + +struct PNode { + const DFPatternNode* ptr; + std::vector&>> children; + std::vector&>> parents; +}; + +struct RNode { + const VarNode* ptr; + std::vector children; + std::vector parents; +}; + +struct MatchState { + void add(const PNode* p, const RNode* r) { + match_p_r[p] = r; + match_r_p[r] = p; + } + + void add(const DFConstraintNode* constraint) { validated_constraints_.insert(constraint); } + + void add(MatchState&& other) { + match_p_r.merge(std::move(other.match_p_r)); + match_r_p.merge(std::move(other.match_r_p)); + validated_constraints_.merge(other.validated_constraints_); + } + + const VarNode* matched(const PNode* p) const { + if (auto it = match_p_r.find(p); it != match_p_r.end()) { + return it->second->ptr; + } + return nullptr; + } + + const DFPatternNode* matched(const RNode* r) const { + if (auto it = match_r_p.find(r); it != match_r_p.end()) { + return it->second->ptr; + } + return nullptr; + } + + const VarNode* matched(const PNode& p) const { return matched(&p); } + const DFPatternNode* matched(const RNode& r) const { return matched(&r); } + + bool is_validated(const DFConstraintNode* constraint) const { + return validated_constraints_.count(constraint); + } + + private: + std::unordered_map match_p_r; + std::unordered_map match_r_p; + std::unordered_set validated_constraints_; +}; + +/** + * \brief This method try to match a real node and a pattern node along with its neighbors. + */ +static std::optional TryMatch(const PNode& p, const RNode& r, + const MatchState& current_match, DFPatternMatcher* m, + const MatcherUseDefAnalysis& ud_analysis) { + if (!m->Match(GetRef(p.ptr), GetRef(r.ptr))) return std::nullopt; + + MatchState new_match; + + new_match.add(&p, &r); + + // forward matching; + for (const auto& [pchild, constraints] : p.children) { + bool any_cons_sat = false; + for (const auto& rchild : r.children) { + if (new_match.matched(rchild)) { + // The child variable is already matched to other child pattern in a previous iteration. + continue; + } + if (auto v = current_match.matched(pchild); v && v != rchild->ptr) { + // The child pattern is already matched to other variable in a earlier call to TryMatch. + continue; + } + + const auto& uses = ud_analysis.def2use.at(r.ptr); + + // check edge constraints. + bool all_cons_pass = true; + for (const auto& cons : constraints) { + if (cons.type == PairCons::kOnlyUsedBy && uses.size() != 1) { + all_cons_pass = false; + break; + } + + if (cons.index != -1) { + const auto& callees = ud_analysis.caller2callees.at(rchild->ptr); + if (callees.size() <= static_cast(cons.index) || callees[cons.index] != r.ptr) { + all_cons_pass = false; + break; + } + } + } + if (!all_cons_pass || new_match.matched(pchild)) continue; + any_cons_sat = true; + + if (auto match_rec = TryMatch(*pchild, *rchild, current_match, m, ud_analysis)) { + new_match.add(pchild, rchild); + new_match.add(std::move(*match_rec)); + } + } + if (!new_match.matched(pchild) || !any_cons_sat) return std::nullopt; + } + + return new_match; +} + +static std::optional TryValidate( + const MatchState& current_match, + const std::unordered_map& pattern2node, + const std::vector& validation_constraints, arith::Analyzer* analyzer) { + MatchState new_match; + + std::function(const DFPatternNode*)> query_match_state = + [&pattern2node, ¤t_match](const DFPatternNode* pattern) -> Optional { + auto it = pattern2node.find(pattern); + ICHECK(it != pattern2node.end()) + << "DFConstraint attempted to access DFPattern " << GetRef(pattern) + << ", which does not appear in the PatternContext"; + const auto& p_node = it->second; + if (auto ptr = current_match.matched(p_node)) { + return GetRef(ptr); + } else { + return NullOpt; + } + }; + + for (const auto& constraint : validation_constraints) { + if (!current_match.is_validated(constraint.get())) { + auto [necessary_condition, is_sufficient] = constraint->AsPrimExpr(query_match_state); + + necessary_condition = analyzer->Simplify(necessary_condition); + const auto* known = tir::as_const_int(necessary_condition); + + if (known && *known && is_sufficient) { + // The condition passes, and the expression provided is both + // necessary and sufficient for the constraint to pass. Mark + // the constraint as passing, to avoid re-checking it unless + // we backtrack. + new_match.add(constraint.get()); + } else if (known && !*known) { + // The condition fails. Even if additional information would + // be required to pass a constraint, it may bail out early as + // a failure (e.g. shape mismatch in the first two items out + // of N shapes that must all match). + return std::nullopt; + } else if (is_sufficient) { + // The condition depends on dynamic parameters. In the + // future, this may be exposed to the user as a condition for + // optimization, or can be combined with the conditions + // provided from other constraints. + return std::nullopt; + } + } + } + + return new_match; +} + +static std::optional MatchTree( + const MatchState& current_match, size_t current_root_idx, + const std::unordered_map& pattern2node, + const std::unordered_map& var2node, DFPatternMatcher* matcher, + const std::vector& roots, const std::vector& validation_constraints, + const MatcherUseDefAnalysis& ud_analysis, arith::Analyzer* analyzer) { + auto get_next_root = [&](size_t root_idx) -> const PNode* { + // Look for the next unmatched root node. + for (; root_idx < roots.size(); ++root_idx) { + const auto& root = pattern2node.at(roots[root_idx].get()); + if (!current_match.matched(root)) { + return &root; + } + } + return nullptr; + }; + + const auto root = get_next_root(current_root_idx); + + if (!root) { + // All root nodes have been matched + return current_match; + } + + MatchState new_match = current_match; + + for (const auto& var : ud_analysis.vars) { + const RNode& r_node = var2node.at(var); + if (new_match.matched(r_node)) continue; + if (auto match = TryMatch(*root, r_node, new_match, matcher, ud_analysis)) { + // Recursively try to match the next subtree. + new_match.add(std::move(*match)); + if (auto validation = + TryValidate(new_match, pattern2node, validation_constraints, analyzer)) { + new_match.add(std::move(*validation)); + if (auto match_rec = + MatchTree(new_match, current_root_idx + 1, pattern2node, var2node, matcher, roots, + validation_constraints, ud_analysis, analyzer)) { + new_match.add(std::move(*match_rec)); + return new_match; + } + } + // Recursive matching has failed, backtrack. + new_match = current_match; + continue; + } + } + + return std::nullopt; +} + +Optional> MatchGraph(const PatternContext& ctx, + const Array& binding_arr, + const Map& bindings) { + // TODO(@ganler): Handle non-may external use. + ICHECK(ctx->allow_extern_use == PatternContextNode::kMay) << "Only kMay is supported yet."; + DFPatternMatcher matcher(bindings); + + MatcherUseDefAnalysis ud_analysis; + for (const auto& binding : binding_arr) { + ud_analysis.VisitBinding(binding); + } + + // First construct a graph of PNode and RNode. + std::unordered_map var2node; + var2node.reserve(bindings.size()); + + for (const VarNode* cur_var : ud_analysis.vars) { + const auto& uses = ud_analysis.def2use.at(cur_var); + RNode& cur_node = var2node[cur_var]; + cur_node.ptr = cur_var; + for (const VarNode* use : uses) { + auto& use_node = var2node[use]; + use_node.ptr = use; + cur_node.children.push_back(&use_node); + use_node.parents.push_back(&cur_node); + } + } + + std::unordered_map pattern2node; + pattern2node.reserve(ctx->edge_constraints.size()); + + for (const auto& def_pattern : ctx->src_ordered) { + PNode& def_node = pattern2node[def_pattern.get()]; + const auto& uses = ctx->edge_constraints.at(def_pattern); + def_node.ptr = def_pattern.get(); + def_node.children.reserve(uses.size()); + for (const auto& [use_pattern, cons] : uses) { + PNode& use_node = pattern2node[use_pattern.get()]; + use_node.ptr = use_pattern.get(); + use_node.parents.emplace_back(&def_node, std::ref(cons)); + def_node.children.emplace_back(&use_node, std::ref(cons)); + } + } + + std::vector roots; + for (const auto& pat : ctx->src_ordered) { + if (pattern2node[pat.get()].parents.empty()) { + roots.push_back(pat); + } + } + + if (roots.empty()) { + return NullOpt; + } + + arith::Analyzer analyzer; + auto match = MatchTree({}, 0, pattern2node, var2node, &matcher, roots, + ctx->validation_constraints, ud_analysis, &analyzer); + if (!match) { + return NullOpt; + } + + Map ret; + for (const auto& [pat, p_node] : pattern2node) { + ICHECK(match->matched(p_node)); + ret.Set(GetRef(pat), GetRef(match->matched(p_node))); + } + return ret; +} + +Optional> MatchGraph(const PatternContext& ctx, const DataflowBlock& dfb) { + return MatchGraph(ctx, dfb->bindings, AnalyzeVar2Value(dfb)); +} + +TVM_REGISTER_GLOBAL("relax.dpl.match_dfb") + .set_body_typed([](const PatternContext& ctx, const DataflowBlock& dfb) { + return MatchGraph(ctx, dfb); + }); + +class PatternContextRewriterNode : public PatternMatchingRewriterNode { + public: + PatternContext pattern; + TypedPackedFunc(Map, Map)> rewriter_func; + + RewriteSpec RewriteBindings(const Array& bindings) const override; + + void VisitAttrs(AttrVisitor* visitor) { + visitor->Visit("pattern", &pattern); + PackedFunc untyped_func = rewriter_func; + visitor->Visit("rewriter_func", &untyped_func); + } + + static constexpr const char* _type_key = "relax.dpl.PatternContextRewriter"; + TVM_DECLARE_FINAL_OBJECT_INFO(PatternContextRewriterNode, PatternMatchingRewriterNode); + + private: + Optional> MatchBindings(const Array& bindings) const { + Map var_lookup; + for (const auto& binding : bindings) { + var_lookup.Set(binding->var, GetBoundValue(binding)); + } + + if (auto matches = MatchGraph(pattern, bindings, var_lookup)) { + Map replacements = rewriter_func(matches.value(), var_lookup); + if (replacements.size()) { + return replacements; + } + } + + return NullOpt; + } +}; + +class PatternContextRewriter : public PatternMatchingRewriter { + public: + PatternContextRewriter( + PatternContext pattern, + TypedPackedFunc(Map, Map)> rewriter_func); + + TVM_DEFINE_OBJECT_REF_METHODS(PatternContextRewriter, PatternMatchingRewriter, + PatternContextRewriterNode); +}; + +RewriteSpec PatternContextRewriterNode::RewriteBindings(const Array& bindings) const { + std::vector remaining_bindings{bindings.begin(), bindings.end()}; + + Map variable_rewrites; + while (auto opt = MatchBindings(remaining_bindings)) { + auto new_rewrites = opt.value(); + remaining_bindings.erase(std::remove_if(remaining_bindings.begin(), remaining_bindings.end(), + [&new_rewrites](const Binding& binding) { + return new_rewrites.count(binding->var); + }), + remaining_bindings.end()); + for (const auto& [var, expr] : new_rewrites) { + variable_rewrites.Set(var, expr); + } + } + + return RewriteSpec{variable_rewrites, {}}; +} + +PatternContextRewriter::PatternContextRewriter( + PatternContext pattern, + TypedPackedFunc(Map, Map)> rewriter_func) { + auto node = make_object(); + node->pattern = std::move(pattern); + node->rewriter_func = std::move(rewriter_func); + data_ = std::move(node); +} + +Function RewriteBindings( + const PatternContext& ctx, + TypedPackedFunc(Map, Map)> rewriter, Function func) { + // return BlockPatternRewriter::Run(ctx, rewriter, func); + return Downcast(PatternContextRewriter(ctx, rewriter)(func)); +} + +TVM_REGISTER_GLOBAL("relax.dpl.rewrite_bindings").set_body_typed(RewriteBindings); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/ir/dataflow_expr_rewriter.cc b/src/relax/ir/dataflow_expr_rewriter.cc new file mode 100644 index 000000000000..514116c5cadf --- /dev/null +++ b/src/relax/ir/dataflow_expr_rewriter.cc @@ -0,0 +1,1079 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/ir/dataflow_expr_rewriter.cc + * \brief A transform to match a Relax Expr and rewrite + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "../transform/utils.h" +#include "dataflow_matcher.h" +#include "dataflow_rewriter.h" + +namespace tvm { +namespace relax { + +namespace { +class GlobalVarReplacer : public ExprMutator { + public: + explicit GlobalVarReplacer(Map gvar_map) : gvar_map_(gvar_map) {} + + using ExprMutator::VisitExpr_; + Expr VisitExpr_(const GlobalVarNode* op) override { + auto gvar = GetRef(op); + if (auto opt = gvar_map_.Get(gvar)) { + gvar = opt.value(); + } + return gvar; + } + + private: + Map gvar_map_; +}; + +Array TopologicalSort(const Array& bindings) { + std::unordered_set remaining_bindings; + for (const auto& binding : bindings) { + remaining_bindings.insert(binding->var); + } + + // Utility structure used to track bindings that are moved later in + // the list. + struct DelayedBinding { + Binding binding; + std::unordered_set unmet_requirements; + bool emitted; + }; + std::vector delayed_bindings; + Array sorted_bindings; + + // Utility function to append the + auto push_sorted_binding = [&](Binding binding) { + sorted_bindings.push_back(binding); + remaining_bindings.erase(binding->var); + for (auto& delayed_binding : delayed_bindings) { + delayed_binding.unmet_requirements.erase(binding->var); + } + }; + + bool required_sorting = false; + for (const auto& binding : bindings) { + // Collect any variables used by this binding, but are emitted by + // a later binding. + std::unordered_set unmet_requirements; + for (auto free_var : FreeVars(GetBoundValue(binding))) { + if (remaining_bindings.count(free_var)) { + unmet_requirements.insert(free_var); + } + } + + if (unmet_requirements.empty()) { + push_sorted_binding(binding); + } else { + required_sorting = true; + delayed_bindings.push_back(DelayedBinding{binding, unmet_requirements, false}); + } + + bool requires_delayed_binding_check = true; + while (requires_delayed_binding_check) { + requires_delayed_binding_check = false; + for (auto& delayed_binding : delayed_bindings) { + if (!delayed_binding.emitted && delayed_binding.unmet_requirements.empty()) { + // If we find a delayed binding that can be emitted, mark it + // as emitted and push to the sorted list. This may + delayed_binding.emitted = true; + requires_delayed_binding_check = true; + push_sorted_binding(delayed_binding.binding); + + // The break is not necessary for a topological sort, but is + // necessary to minimize the amount of re-ordering that is + // performed. With this break, the next binding is always + // the earliest binding that is legal to emit at this point. + break; + } + } + } + + // Remove any delayed bindings that have been emitted, now that we + // are done iterating over the delayed bindings. + delayed_bindings.erase( + std::remove_if(delayed_bindings.begin(), delayed_bindings.end(), + [](const auto& delayed_binding) { return delayed_binding.emitted; }), + delayed_bindings.end()); + } + + // All bindings should be emitted by this point. If any remain, + // then there exists a circular dependency somewhere in the + // remaining bindings. + CHECK(delayed_bindings.empty()) << "ValueError: " + << "Bindings contain circular dependency"; + + if (required_sorting) { + return sorted_bindings; + } else { + return bindings; + } +} +} // namespace + +void RewriteSpec::Append(RewriteSpec other) { + if (variable_rewrites.empty()) { + *this = std::move(other); + return; + } + if (other.variable_rewrites.empty()) { + return; + } + + NameSupply gvar_name_supply(""); + for (const auto& [gvar, func] : new_subroutines) { + gvar_name_supply->ReserveName(gvar->name_hint); + } + + Map gvar_rewrites; + for (auto [gvar, func] : other.new_subroutines) { + if (auto it = new_subroutines.find(gvar); it != new_subroutines.end()) { + // The two rewrites provide the same GlobalVar. + // (e.g. Multiple rewrites of the same pattern.) Ensure that + // they are referring to the same underlying BaseFunc. + CHECK(func.same_as((*it).second)); + } else if (auto new_name = gvar_name_supply->FreshName(gvar->name_hint); + new_name != gvar->name_hint) { + // The two rewrites provide distinct GlobalVar subroutines, + // but with conflicting names. Because an IRModule must have + // enough names for each GlobalVar, even if they are not + // publicly exposed, one of the GlobalVars must be replaced. + // Replacing the GlobalVar here, when the conflict is first + // identified, minimizes the size of the `relax::Expr` that + // must be updated with `GlobalVarReplacer`. + GlobalVar new_gvar = gvar; + new_gvar.CopyOnWrite()->name_hint = new_name; + gvar_rewrites.Set(gvar, new_gvar); + new_subroutines.Set(new_gvar, func); + } else { + new_subroutines.Set(gvar, func); + } + } + + for (auto [var, expr] : other.variable_rewrites) { + if (gvar_rewrites.size()) { + expr = GlobalVarReplacer(gvar_rewrites)(expr); + } + variable_rewrites.Set(var, expr); + } +} + +TVM_REGISTER_NODE_TYPE(PatternMatchingRewriterNode); + +TVM_REGISTER_GLOBAL("relax.dpl.PatternMatchingRewriterFromPattern") + .set_body_typed([](DFPattern pattern, + TypedPackedFunc(Expr, Map)> func) { + return PatternMatchingRewriter::FromPattern(pattern, func); + }); + +TVM_REGISTER_GLOBAL("relax.dpl.PatternMatchingRewriterFromModule").set_body_typed([](IRModule mod) { + return PatternMatchingRewriter::FromModule(mod); +}); + +TVM_REGISTER_GLOBAL("relax.dpl.PatternMatchingRewriterApply") + .set_body_typed([](PatternMatchingRewriter rewriter, + Variant obj) -> Variant { + if (auto expr = obj.as()) { + return rewriter(expr.value()); + } else if (auto mod = obj.as()) { + return rewriter(mod.value()); + } else { + LOG(FATAL) << "Unreachable: object does not contain either variant type"; + } + }); + +TVM_REGISTER_NODE_TYPE(ExprPatternRewriterNode); + +RewriteSpec ExprPatternRewriterNode::RewriteBindings(const Array& bindings) const { + Map variable_rewrites; + Map binding_lookup; + for (const auto& binding : bindings) { + auto bound_value = GetBoundValue(binding); + if (auto new_expr = RewriteExpr(bound_value, binding_lookup)) { + variable_rewrites.Set(binding->var, new_expr.value()); + } else { + binding_lookup.Set(binding->var, bound_value); + } + } + if (variable_rewrites.size()) { + return RewriteSpec{variable_rewrites, new_subroutines}; + } else { + return RewriteSpec(); + } +} + +Optional ExprPatternRewriterNode::RewriteExpr(const Expr& expr, + const Map& bindings) const { + if (auto opt_matches = ExtractMatchedExpr(pattern, expr, bindings)) { + auto matches = opt_matches.value(); + if (additional_bindings) { + // Append any additional matches that from the unwrapped + // `OrPattern`. When matching against `pat = pat_lhs | + // pat_rhs`, we call `ExtractMatchedExpr` on `pat_lhs` and + // `pat_rhs` separately. The top-level `pat` is never seen by + // `ExtractMatchedExpr`, and must be re-added afterward. + auto matched_expr = DFPatternMatcher::UnwrapBindings(expr, bindings); + for (const auto& pat : additional_bindings.value()) { + matches.Set(pat, matched_expr); + } + } + + Optional rewritten_expr = func(expr, matches); + if (rewritten_expr.defined() && !rewritten_expr.same_as(expr)) { + return rewritten_expr.value(); + } + } + return NullOpt; +} + +TVM_REGISTER_GLOBAL("relax.dpl.PatternRewriter") + .set_body_typed([](DFPattern pattern, + TypedPackedFunc(Expr, Map)> func) { + return ExprPatternRewriter(pattern, func); + }); + +ExprPatternRewriter::ExprPatternRewriter( + DFPattern pattern, TypedPackedFunc(Expr, Map)> func, + Optional> additional_bindings, Map new_subroutines) { + auto node = make_object(); + node->pattern = std::move(pattern); + node->func = std::move(func); + node->additional_bindings = std::move(additional_bindings); + node->new_subroutines = std::move(new_subroutines); + data_ = std::move(node); +} + +TVM_REGISTER_NODE_TYPE(OrRewriterNode); + +RewriteSpec OrRewriterNode::RewriteBindings(const Array& bindings) const { + auto lhs_match = lhs->RewriteBindings(bindings); + if (!lhs_match) { + // If no rewrites found on LHS, RHS is allowed to modify any + // variable binding. + return rhs->RewriteBindings(bindings); + } + + // The LHS matched some subset of the bindings. These + // replacements may not be normalized expressions, so the RHS may + // only replace variable bindings that haven't been modified by + // the LHS. Variable replacements from the RHS may still occur, + // but will need to wait for the next round of + // iterate-until-converged. + Array remaining_bindings; + for (const auto& binding : bindings) { + if (!lhs_match.variable_rewrites.count(binding->var)) { + remaining_bindings.push_back(binding); + } + } + + if (remaining_bindings.empty()) { + // Early bail-out, the RHS has no bindings available to rewrite. + return lhs_match; + } + + lhs_match.Append(rhs->RewriteBindings(remaining_bindings)); + return lhs_match; +} + +TVM_REGISTER_GLOBAL("relax.dpl.OrRewriter") + .set_body_typed([](PatternMatchingRewriter lhs, PatternMatchingRewriter rhs) { + return OrRewriter(lhs, rhs); + }); + +OrRewriter::OrRewriter(PatternMatchingRewriter lhs, PatternMatchingRewriter rhs) { + auto node = make_object(); + node->lhs = std::move(lhs); + node->rhs = std::move(rhs); + data_ = std::move(node); +} + +TVM_REGISTER_NODE_TYPE(TupleRewriterNode); + +RewriteSpec TupleRewriterNode::RewriteBindings(const Array& bindings) const { + CHECK_LE(patterns.size(), 3) << "For performance reasons, " + << "matching of implicit tuple patterns is currently limited" + << " to tuples with 3 elements or fewer."; + Map variable_rewrites = GenerateVariableRewrites(bindings); + + if (variable_rewrites.size()) { + return RewriteSpec{variable_rewrites, new_subroutines}; + } else { + return RewriteSpec(); + } +} + +Map TupleRewriterNode::GenerateVariableRewrites(const Array& bindings) const { + Map rewrites; + + Map binding_lookup; + + std::vector info_vec; + + std::unordered_map binding_index_lookup; + + // Initialize a vector of indices, each of which corresponds to a + // potential match for a tuple element. + // + // \param tuple_index_of_current_expr The index for the most recent + // binding. + // + // \param indices An output vector, into which indices will be + // generated. + // + // \returns bool True if the indices could be initialized to a + // potential match. False, otherwise. + auto initialize_indices = [&](size_t tuple_index_of_current_expr, + std::vector& indices) -> bool { + if (!info_vec.back().matches[tuple_index_of_current_expr]) { + return false; + } + + indices = std::vector(patterns.size(), info_vec.size()); + + indices[tuple_index_of_current_expr] = info_vec.size() - 1; + + for (size_t i_rev = 0; i_rev < indices.size(); i_rev++) { + size_t i = indices.size() - i_rev - 1; + if (indices[i] == info_vec.size() - 1) { + continue; + } + + auto binding_index = [&]() -> std::optional { + if (indices[i] == info_vec.size() - 1) { + return info_vec.size() - 1; + } + + for (size_t j_rev = 1; j_rev < info_vec.size(); j_rev++) { + size_t j = info_vec.size() - j_rev - 1; + if (info_vec[j].matches[i] && !info_vec[j].used && + std::all_of(indices.begin() + (j + 1), indices.end(), + [j](size_t prev_binding_index) { return j != prev_binding_index; })) { + return j; + } + } + + return std::nullopt; + }(); + + if (binding_index.has_value()) { + indices[i] = binding_index.value(); + } else { + return false; + } + } + + return true; + }; + + auto decrement_indices = [&](std::vector& indices) -> bool { + ICHECK_EQ(indices.size(), patterns.size()); + + // Step 1, find the first index that can be decremented, while + // still generating a valid set of indices. + size_t i_forward; + for (i_forward = 0; i_forward < indices.size(); i_forward++) { + if (indices[i_forward] == info_vec.size() - 1) { + continue; + } + + bool found_valid = false; + size_t& index = indices[i_forward]; + while (index) { + index--; + if (info_vec[index].matches[i_forward] && !info_vec[index].used && + std::all_of( + indices.begin() + (i_forward + 1), indices.end(), + [index](size_t later_binding_index) { return index != later_binding_index; })) { + found_valid = true; + break; + } + } + if (found_valid) { + break; + } + } + + // Step 2, if we reached the end, then all indices were + // decremented to zero without finding anything. Return false to + // indicate that we've reached the end. + if (i_forward == indices.size()) { + return false; + } + + // Step 3, refill all indices that were decremented to zero before from 0 to + for (size_t i = 0; i < i_forward; i++) { + size_t i_backward = i_forward - (i + 1); + if (indices[i_backward] == info_vec.size() - 1) { + continue; + } + + auto binding_index = [&]() -> std::optional { + for (size_t j_rev = 1; j_rev < info_vec.size(); j_rev++) { + size_t j = info_vec.size() - j_rev - 1; + if (info_vec[j].matches[i_backward] && !info_vec[j].used && + std::all_of(indices.begin() + (j + 1), indices.end(), + [j](size_t prev_binding_index) { return j != prev_binding_index; })) { + return j; + } + } + + return std::nullopt; + }(); + + if (binding_index.has_value()) { + indices[i_backward] = binding_index.value(); + } else { + return false; + } + } + + return true; + }; + + for (size_t i_binding = 0; i_binding < bindings.size(); i_binding++) { + const auto& binding = bindings[i_binding]; + + auto expr = GetBoundValue(binding); + + binding_index_lookup[binding->var] = i_binding; + + info_vec.push_back(VarInfo{ + binding->var, + expr, + patterns.Map( + [&](const DFPattern& pat) { return ExtractMatchedExpr(pat, expr, binding_lookup); }), + std::unordered_set(), + false, + }); + + auto new_match = [&]() -> std::optional, std::vector>> { + std::vector indices; + for (size_t i = 0; i < patterns.size(); i++) { + if (initialize_indices(patterns.size() - i - 1, indices)) { + do { + if (auto match = TryMatchByBindingIndex(info_vec, indices)) { + return std::pair{indices, match.value()}; + } + } while (decrement_indices(indices)); + } + } + return std::nullopt; + }(); + + if (new_match) { + const auto& [indices, exprs] = new_match.value(); + ICHECK_EQ(indices.size(), exprs.size()); + for (size_t i = 0; i < indices.size(); i++) { + ICHECK_LT(indices[i], info_vec.size()); + auto& info = info_vec[indices[i]]; + + ICHECK(!info.used) << "InternalError: " + << "Produced multiple replacements for variable " << info.var; + + rewrites.Set(info.var, exprs[i]); + binding_lookup.erase(info.var); + info.used = true; + } + } else { + binding_lookup.Set(binding->var, expr); + } + + for (const auto& prev_var : FreeVars(expr)) { + if (auto it = binding_index_lookup.find(prev_var); it != binding_index_lookup.end()) { + info_vec[it->second].downstream_usage.insert(binding->var); + } + } + } + + return rewrites; +} + +std::optional> TupleRewriterNode::TryMatchByBindingIndex( + const std::vector& info_vec, const std::vector& indices) const { + ICHECK_GE(indices.size(), 1); + + ICHECK_EQ(indices.size(), patterns.size()); + for (size_t i = 0; i < indices.size(); i++) { + const auto& info = info_vec[indices[i]]; + if (info.used || !info.matches[i]) { + return std::nullopt; + } + } + + Map merged_matches = info_vec[indices[0]].matches[0].value(); + for (size_t i = 1; i < indices.size(); i++) { + for (const auto& [pat, expr] : info_vec[indices[i]].matches[i].value()) { + if (auto it = merged_matches.find(pat); it != merged_matches.end()) { + if (!StructuralEqual()(expr, (*it).second)) { + return std::nullopt; + } + } else { + merged_matches.Set(pat, expr); + } + } + } + + bool tuple_element_is_already_used_outside_of_matched_tuple = [&]() -> bool { + std::unordered_set matched_vars; + for (const auto& [pat, expr] : merged_matches) { + if (auto opt = expr.as()) { + matched_vars.insert(opt.value()); + } + } + + for (size_t index : indices) { + const auto& downstream_of_rewritten_var = info_vec[index].downstream_usage; + + for (const auto& uses_matched_var : downstream_of_rewritten_var) { + if (!matched_vars.count(uses_matched_var)) { + return true; + } + } + } + + return false; + }(); + if (tuple_element_is_already_used_outside_of_matched_tuple) { + return std::nullopt; + } + + auto full_tuple = [&]() -> relax::Expr { + Array fields; + for (size_t index : indices) { + fields.push_back(info_vec[index].expr); + } + return relax::Tuple(fields); + }(); + + auto opt_rewritten = func(full_tuple, merged_matches); + if (!opt_rewritten) { + return std::nullopt; + } + auto rewritten = opt_rewritten.value(); + + if (rewritten.same_as(full_tuple)) { + return std::nullopt; + } + + std::vector rewrites; + if (auto inline_tuple = rewritten.as()) { + const auto& fields = inline_tuple->fields; + CHECK_EQ(fields.size(), indices.size()) + << "Expected to receive " << indices.size() << " values to replace TuplePattern with " + << indices.size() << " fields, but received " << fields.size() << " values"; + rewrites = {fields.begin(), fields.end()}; + } else { + for (size_t i = 0; i < indices.size(); i++) { + rewrites.push_back(TupleGetItem(rewritten, i)); + } + } + return rewrites; +} + +TVM_REGISTER_GLOBAL("relax.dpl.TupleRewriter") + .set_body_typed([](Array patterns, + TypedPackedFunc(Expr, Map)> func) { + return TupleRewriter(patterns, func); + }); + +TupleRewriter::TupleRewriter(Array patterns, + TypedPackedFunc(Expr, Map)> func, + Optional> additional_bindings, + Map new_subroutines) { + auto node = make_object(); + node->patterns = std::move(patterns); + node->func = std::move(func); + node->additional_bindings = std::move(additional_bindings); + node->new_subroutines = std::move(new_subroutines); + data_ = std::move(node); +} + +PatternMatchingRewriter PatternMatchingRewriter::FromPattern( + DFPattern pattern, TypedPackedFunc(Expr, Map)> func, + Optional> additional_bindings, Map new_subroutines) { + if (auto or_pattern = pattern.as()) { + auto new_additional_bindings = additional_bindings.value_or({}); + new_additional_bindings.push_back(pattern); + return OrRewriter(PatternMatchingRewriter::FromPattern( + or_pattern->left, func, new_additional_bindings, new_subroutines), + PatternMatchingRewriter::FromPattern( + or_pattern->right, func, new_additional_bindings, new_subroutines)); + } else if (auto tuple_pattern = pattern.as()) { + auto new_additional_bindings = additional_bindings.value_or({}); + new_additional_bindings.push_back(pattern); + // If the Tuple appears as a Relax binding, apply it first. As a + // fallback, also check for implicit tuples. + return OrRewriter( + ExprPatternRewriter(pattern, func, additional_bindings, new_subroutines), + TupleRewriter(tuple_pattern->fields, func, new_additional_bindings, new_subroutines)); + } else { + return ExprPatternRewriter(pattern, func, additional_bindings, new_subroutines); + } +} + +PatternMatchingRewriter PatternMatchingRewriter::FromModule(IRModule mod) { + Function func_pattern = [&]() { + CHECK(mod->ContainGlobalVar("pattern")) + << "KeyError: " + << "Expected module to contain 'pattern', " + << "a Relax function defining the pattern to be matched, " + << "but the module did not contain a 'pattern' function."; + auto base_func = mod->Lookup("pattern"); + CHECK(base_func->IsInstance()) + << "TypeError: " + << "Expected module to contain 'pattern', " + << "a Relax function defining the pattern to be matched, " + << "but the 'pattern' function was of type " << base_func->GetTypeKey() << "."; + return Downcast(base_func); + }(); + Function func_replacement = [&]() { + CHECK(mod->ContainGlobalVar("replacement")) + << "KeyError: " + + << "Expected module to contain 'replacement', " + << "a Relax function defining the replacement to be matched, " + << "but the module did not contain a 'replacement' function."; + auto base_func = mod->Lookup("replacement"); + CHECK(base_func->IsInstance()) + << "TypeError: " + << "Expected module to contain 'replacement', " + << "a Relax function defining the replacement to be made on a successful match, " + << "but the 'replacement' function was of type " << base_func->GetTypeKey() << "."; + return Downcast(base_func); + }(); + + Map new_subroutines; + for (const auto& [gvar, func] : mod->functions) { + if (gvar->name_hint != "pattern" && gvar->name_hint != "replacement") { + bool is_public = func->GetAttr(tvm::attr::kGlobalSymbol).defined(); + CHECK(!is_public) << "ValueError: " + << "Expected module to have no publicly-exposed functions " + << "other than 'pattern' and 'replacement'. " + << "However, function '" << gvar->name_hint << "' of type " + << func->GetTypeKey() << " is publicly exposed."; + new_subroutines.Set(gvar, func); + } + } + + auto sinfo_pattern = GetStructInfo(func_pattern); + auto sinfo_replacement = GetStructInfo(func_replacement); + CHECK(StructuralEqual()(sinfo_pattern, sinfo_replacement)) + << "ValueError: " + << "The pattern and replacement must have the same signature, " + << "but the pattern has struct info " << sinfo_pattern + << ", while the replacement has struct info " << sinfo_replacement; + + Array param_wildcards; + Map pattern_lookup; + for (const auto& param : func_pattern->params) { + WildcardPattern wildcard; + param_wildcards.push_back(wildcard); + pattern_lookup.Set(param, StructInfoPattern(wildcard, GetStructInfo(param))); + } + + std::function make_pattern = [&](Expr expr) -> DFPattern { + if (auto var = expr.as()) { + return pattern_lookup[var.value()]; + + } else if (auto call = expr.as()) { + auto op = make_pattern(call->op); + auto args = call->args.Map(make_pattern); + return CallPattern(op, args); + + } else if (auto tuple = expr.as()) { + auto fields = tuple->fields.Map(make_pattern); + return TuplePattern(fields); + + } else if (auto tuple_get_item = expr.as()) { + auto tuple = make_pattern(tuple_get_item->tuple); + return TupleGetItemPattern(tuple, tuple_get_item->index); + + } else if (auto op = expr.as()) { + return ExprPattern(op.value()); + + } else if (auto func = expr.as()) { + return ExternFuncPattern(func->global_symbol); + + } else if (auto prim = expr.as()) { + return StructInfoPattern(WildcardPattern(), PrimStructInfo(prim->value)); + + } else { + LOG(FATAL) << "TypeError: " + << "Cannot convert Relax expression of type " << expr->GetTypeKey() + << " into pattern-matching rule."; + } + }; + + for (const auto& block : func_pattern->body->blocks) { + for (const auto& binding : block->bindings) { + auto value_pattern = make_pattern(GetBoundValue(binding)); + if (auto match_cast = binding.as()) { + value_pattern = StructInfoPattern(value_pattern, match_cast->struct_info); + } + pattern_lookup.Set(binding->var, value_pattern); + } + } + + DFPattern top_pattern = make_pattern(func_pattern->body->body); + + TypedPackedFunc(Expr, Map)> rewriter_func = + [param_wildcards = std::move(param_wildcards), + orig_func_replacement = std::move(func_replacement)]( + Expr expr, Map matches) -> Optional { + auto func_replacement = CopyWithNewVars(orig_func_replacement); + + Array new_blocks; + + Array wildcard_bindings; + ICHECK_EQ(param_wildcards.size(), func_replacement->params.size()); + for (size_t i = 0; i < param_wildcards.size(); i++) { + Expr matched_expr = matches[param_wildcards[i]]; + + // Introduce an intermediate variable, to ensure that the + // MatchCast's target will be a Var, even for expressions that + // wouldn't normally be normalized into a variable. + Var intermediate_var("intermediate_var", GetStructInfo(matched_expr)); + wildcard_bindings.push_back(VarBinding(intermediate_var, matched_expr)); + wildcard_bindings.push_back( + MatchCast(func_replacement->params[i], intermediate_var, GetStructInfo(matched_expr))); + } + + new_blocks.push_back(DataflowBlock(wildcard_bindings)); + + for (const auto& block : func_replacement->body->blocks) { + new_blocks.push_back(block); + } + + return SeqExpr(new_blocks, func_replacement->body->body); + }; + + return PatternMatchingRewriter::FromPattern(top_pattern, rewriter_func, NullOpt, new_subroutines); +} + +Optional> ExtractMatchedExpr(DFPattern pattern, Expr expr, + Optional> bindings_opt) { + auto bindings = bindings_opt.value_or({}); + DFPatternMatcher matcher(bindings); + + if (!matcher.Match(pattern, expr)) { + return NullOpt; + } + + return matcher.GetMemo(); +} + +TVM_REGISTER_GLOBAL("relax.dpl.extract_matched_expr").set_body_typed(ExtractMatchedExpr); + +bool MatchExpr(DFPattern pattern, Expr expr, Optional> bindings_opt) { + return static_cast(ExtractMatchedExpr(pattern, expr, bindings_opt)); +} + +TVM_REGISTER_GLOBAL("relax.dpl.match_expr").set_body_typed(MatchExpr); + +/*! + * \brief Apply pattern matching to each expression, replacing + * matches with the output of a user-provided rewriter function. + */ +class PatternMatchingMutator : public ExprMutator { + public: + using ExprMutator::VisitExpr_; + + PatternMatchingMutator(const PatternMatchingRewriterNode* rewriter) : rewriter_(rewriter) {} + + Map GetNewSubroutines() const { return new_subroutines_; } + + Expr VisitExpr_(const SeqExprNode* seq) override { + SeqExpr prev = Downcast(ExprMutator::VisitExpr_(seq)); + + StructuralEqual struct_equal; + + while (auto opt = TryRewriteSeqExpr(prev)) { + SeqExpr next = Downcast(builder_->Normalize(opt.value())); + if (struct_equal(prev, next)) { + break; + } + + // Canonicalization may result in two previously-different + // expressions being recognized as identical. Elimination of + // common subexpressions may result in trival var-to-var + // bindings that can be canonicalized. Therefore, iterate the + // simplification steps until converged. + while (true) { + auto start_of_loop = next; + next = Downcast(CanonicalizeBindings(next)); + next = Downcast(EliminateCommonSubexpr(next)); + next = Downcast(RemoveAllUnused(next)); + if (struct_equal(start_of_loop, next)) { + break; + } + } + + if (struct_equal(prev, next)) { + break; + } + + prev = next; + } + + return prev; + } + + Optional TryRewriteSeqExpr(const SeqExpr& seq) { + Array old_blocks = seq->blocks; + + // If the SeqExpr's output is not a variable, treat it as if it + // were the last variable binding of the last block. This + // simplifies the special handling of the SeqExpr's body. + Optional dummy_output_var = NullOpt; + if (!seq->body->IsInstance()) { + dummy_output_var = Var("dummy_output_var", GetStructInfo(seq->body)); + VarBinding dummy_binding(dummy_output_var.value(), seq->body); + + auto last_block = [&]() { + if (seq->blocks.size()) { + auto last_block = old_blocks.back(); + old_blocks.pop_back(); + return last_block; + } else { + return BindingBlock(Array{}); + } + }(); + + last_block.CopyOnWrite()->bindings.push_back(dummy_binding); + old_blocks.push_back(last_block); + } + + auto rewrite_block = [&](Array orig_bindings) -> Array { + auto rewrites = rewriter_->RewriteBindings(orig_bindings); + if (!rewrites) return orig_bindings; + + for (auto [gvar, func] : rewrites.new_subroutines) { + new_subroutines_.Set(gvar, func); + } + + auto bindings = orig_bindings.Map([&](Binding binding) -> Binding { + if (auto new_expr = rewrites.variable_rewrites.Get(binding->var)) { + if (auto match_cast = binding.as()) { + return MatchCast(binding->var, new_expr.value(), match_cast->struct_info); + } else { + return VarBinding(binding->var, new_expr.value()); + } + } else { + return binding; + } + }); + + if (bindings.same_as(orig_bindings)) { + return orig_bindings; + } + + // The rewriter may have introduced additional dependencies + // between computations. Since pattern-matching only occurs + // within blocks that may be re-ordered, these can be resolved + // by performing a topological sort. + bindings = TopologicalSort(bindings); + + return bindings; + }; + + // Utility function to return the rewrites that should be applied + // to a given block. + auto get_rewrites = [&](BindingBlock block) -> Array { + if (block.as()) { + // Early return for DataflowBlock. Since neither control flow + // nor impure functions are allowed within the dataflow block, + // all bindings may be considered at the same time. + return rewrite_block(block->bindings); + } + + RewriteSpec rewrites; + + Array collected_bindings; + Array finalized_bindings; + + auto handle_collected_rewrites = [&]() { + if (collected_bindings.size()) { + auto bindings = rewrite_block(collected_bindings); + if (finalized_bindings.empty()) { + finalized_bindings = bindings; + } else { + for (const auto& binding : bindings) { + finalized_bindings.push_back(binding); + } + } + collected_bindings.clear(); + } + }; + + for (const auto& binding : block->bindings) { + auto value = GetBoundValue(binding); + bool is_dataflow = (!value.as()) && + (!(value.as() && IsImpureCall(Downcast(value)))); + if (is_dataflow) { + // This binding satisfies the dataflow constraints. + collected_bindings.push_back(binding); + } else { + // This binding does not satisfy the dataflow constraints. + // Any operations prior to this binding should be checked + // for pattern-match replacements. + handle_collected_rewrites(); + finalized_bindings.push_back(binding); + } + } + + // Check for rewrites in dataflow operations after the last + // non-dataflow segment. + handle_collected_rewrites(); + + return finalized_bindings; + }; + + // Utility function, check for and apply rewrites to a single + // block. + auto visit_block = [&](BindingBlock old_block) -> BindingBlock { + auto new_bindings = get_rewrites(old_block); + if (new_bindings.same_as(old_block->bindings)) { + return old_block; + } + + if (old_block.as()) { + builder_->BeginDataflowBlock(); + } else { + builder_->BeginBindingBlock(); + } + + for (const auto& binding : new_bindings) { + auto value = builder_->Normalize(GetBoundValue(binding)); + + if (binding.as()) { + builder_->EmitNormalized(VarBinding(binding->var, value)); + } else if (auto match_cast = binding.as()) { + builder_->EmitNormalized(MatchCast(binding->var, value, match_cast->struct_info)); + } else { + LOG(FATAL) << "Binding must be either VarBinding or MatchCast"; + } + } + return builder_->EndBlock(); + }; + + auto new_blocks = old_blocks.Map(visit_block); + if (old_blocks.same_as(new_blocks)) { + return NullOpt; + } + + // Restore the body of the SeqExpr, if needed. + auto new_body = [&]() -> Expr { + if (dummy_output_var) { + auto last_block = new_blocks.back(); + new_blocks.pop_back(); + + auto last_binding = last_block->bindings.back(); + last_block.CopyOnWrite()->bindings.pop_back(); + ICHECK(last_binding->var.same_as(dummy_output_var)); + + if (last_block->bindings.size()) { + new_blocks.push_back(last_block); + } + + return GetBoundValue(last_binding); + } else { + return seq->body; + } + }(); + + return SeqExpr(new_blocks, new_body); + } + + private: + const PatternMatchingRewriterNode* rewriter_; + Map new_subroutines_; +}; + +Expr PatternMatchingRewriter::operator()(Expr expr) { + PatternMatchingMutator mutator(get()); + auto new_expr = mutator(expr); + auto new_subroutines = mutator.GetNewSubroutines(); + CHECK_EQ(new_subroutines.size(), 0) + << "If PatternMatchingRewriter provides subroutines, " + << "then it must be applied to an entire IRModule. " + << "However, PatternMatchingRewriter produced subroutines " << [&]() -> Array { + std::vector vec; + for (const auto& [gvar, func] : new_subroutines) { + vec.push_back(gvar); + } + std::sort(vec.begin(), vec.end(), + [](const GlobalVar& a, const GlobalVar& b) { return a->name_hint < b->name_hint; }); + return vec; + }() << "when applied to " + << "Relax expression of type " << expr->GetTypeKey(); + return new_expr; +} + +IRModule PatternMatchingRewriterNode::operator()( + IRModule mod, const tvm::transform::PassContext& pass_ctx) const { + PatternMatchingMutator mutator(this); + + IRModule updates; + for (const auto& [gvar, base_func] : mod->functions) { + if (auto func = base_func.as()) { + auto rewritten = Downcast(mutator(func.value())); + if (!rewritten.same_as(base_func)) { + updates->Add(gvar, rewritten); + } + } + } + + if (updates->functions.size()) { + auto write_ptr = mod.CopyOnWrite(); + write_ptr->Update(updates); + write_ptr->Update(IRModule(mutator.GetNewSubroutines())); + } + + return mod; +} +tvm::transform::PassInfo PatternMatchingRewriterNode::Info() const { + return tvm::transform::PassInfo(0, "PatternMatchingRewriter", {}, false); +} + +Function RewriteCall(const DFPattern& pat, + TypedPackedFunc)> rewriter, Function func) { + return Downcast(PatternMatchingRewriter::FromPattern(pat, rewriter)(func)); +} + +TVM_REGISTER_GLOBAL("relax.dpl.rewrite_call").set_body_typed(RewriteCall); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index c0b8d1e1df08..417a78f0d04b 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -22,6 +22,8 @@ * \brief The dataflow pattern matcher for Relax. */ +#include "dataflow_matcher.h" + #include #include #include @@ -37,6 +39,7 @@ #include #include #include +#include #include #include #include @@ -45,7 +48,6 @@ #include "../../arith/constraint_extract.h" #include "../transform/utils.h" -#include "dataflow_matcher_impl.h" namespace tvm { namespace relax { @@ -59,7 +61,7 @@ bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) { return VisitDFPattern(pattern, expr); } -static Expr TryGetValOfVar(Expr expr, const Map& var2val) { +Expr DFPatternMatcher::UnwrapBindings(Expr expr, const Map& var2val) { auto unwrap = [&](Expr expr) -> Optional { // Unwrap variables into the value to which they are bound. if (var2val.size()) { @@ -98,16 +100,15 @@ void DFPatternMatcher::ClearMap(size_t watermark) { bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr0) { CHECK(pattern.defined()) << "Null pattern found when matching against " << expr0; - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); if (memoize_ && memo_.count(pattern)) { - ICHECK_EQ(memo_[pattern].size(), 1); - return expr.same_as(memo_[pattern][0]); + return expr.same_as(memo_[pattern]); } else { PrimExpr cached_condition = symbolic_expr_condition_; size_t watermark = matched_nodes_.size(); bool out = DFPatternFunctor::VisitDFPattern(pattern, expr); if (out) { - memo_[pattern].push_back(expr); + memo_[pattern] = expr; matched_nodes_.push_back(pattern); } else { ClearMap(watermark); @@ -118,17 +119,17 @@ bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr } bool DFPatternMatcher::VisitDFPattern_(const OrPatternNode* op, const Expr& expr0) { - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); return VisitDFPattern(op->left, expr) || VisitDFPattern(op->right, expr); } bool DFPatternMatcher::VisitDFPattern_(const AndPatternNode* op, const Expr& expr0) { - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); return VisitDFPattern(op->left, expr) && VisitDFPattern(op->right, expr); } bool DFPatternMatcher::VisitDFPattern_(const NotPatternNode* op, const Expr& expr0) { - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); return !VisitDFPattern(op->reject, expr); } @@ -183,7 +184,7 @@ bool MatchRetValue(const ObjectRef& lhs, const TVMRetValue& rhs) { } bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr0) { - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); bool matches = VisitDFPattern(attr_pattern->pattern, expr); if (!matches) return matches; VLOG(1) << "considering AttrPatternNode at:\n" << expr; @@ -241,7 +242,7 @@ bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, cons } bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& expr0) { - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); // utilities auto get_op_node = [](const CallPatternNode* op) -> const tvm::OpNode* { if (op) { @@ -351,12 +352,12 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex } bool DFPatternMatcher::VisitDFPattern_(const ExprPatternNode* op, const Expr& expr0) { - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); return StructuralEqual()(op->expr, expr); } bool DFPatternMatcher::VisitDFPattern_(const FunctionPatternNode* op, const Expr& expr0) { - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); bool matches = false; if (const auto* func = expr.as()) { matches = true; @@ -379,7 +380,7 @@ bool DFPatternMatcher::VisitDFPattern_(const FunctionPatternNode* op, const Expr } bool DFPatternMatcher::VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr0) { - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); if (const auto* tuple_get_item_node = expr.as()) { return (op->index == -1 || op->index == tuple_get_item_node->index) && VisitDFPattern(op->tuple, tuple_get_item_node->tuple); @@ -388,7 +389,7 @@ bool DFPatternMatcher::VisitDFPattern_(const TupleGetItemPatternNode* op, const } bool DFPatternMatcher::VisitDFPattern_(const TuplePatternNode* op, const Expr& expr0) { - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); bool matches = false; if (const auto* tuple_node = expr.as()) { matches = true; @@ -429,7 +430,7 @@ bool DFPatternMatcher::TryUnorderedMatch(size_t idx, const tvm::Array } bool DFPatternMatcher::VisitDFPattern_(const UnorderedTuplePatternNode* op, const Expr& expr0) { - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); if (const auto* tuple_node = expr.as()) { if (op->fields.size() == tuple_node->fields.size()) { @@ -449,7 +450,7 @@ bool DFPatternMatcher::VisitDFPattern_(const StructInfoPatternNode* op, const Ex return false; } - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); auto expr_struct_info = GetStructInfo(expr); PrimExpr new_constraint = StructInfoBaseCheckPrecondition(op->struct_info, expr_struct_info); @@ -497,7 +498,7 @@ PrimExpr DFPatternMatcher::SimplifyCondition(PrimExpr condition) { } bool DFPatternMatcher::VisitDFPattern_(const TypePatternNode* op, const Expr& expr0) { - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); auto expr_type = expr.as()->checked_type(); return (StructuralEqual()(op->type, expr_type)) && VisitDFPattern(op->pattern, expr); } @@ -584,7 +585,7 @@ std::tuple SameShapeConstraintNode::AsPrimExpr( } bool DFPatternMatcher::VisitDFPattern_(const PrimArrPatternNode* op, const Expr& expr0) { - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); if (const ShapeExprNode* shape_expr = expr.as()) return ShapeEqual(&analyzer_, op->fields, shape_expr->values); return false; @@ -609,7 +610,7 @@ bool DFPatternMatcher::VisitDFPattern_(const VarPatternNode* op, const Expr& exp } bool DFPatternMatcher::VisitDFPattern_(const ExternFuncPatternNode* op, const Expr& expr0) { - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); if (const auto* extern_fn = expr.as()) { return "" == op->global_symbol() || op->global_symbol() == extern_fn->global_symbol; } @@ -618,7 +619,7 @@ bool DFPatternMatcher::VisitDFPattern_(const ExternFuncPatternNode* op, const Ex bool DFPatternMatcher::VisitDFPattern_(const ConstantPatternNode* op, const Expr& expr0) { // constants can be binded to relax.Var as well. - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); return expr.as() != nullptr; } @@ -642,631 +643,5 @@ bool DFPatternMatcher::VisitDFPattern_(const WildcardPatternNode* op, const Expr return true; } -Optional> ExtractMatchedExpr(DFPattern pattern, Expr expr, - Optional> bindings_opt) { - auto bindings = bindings_opt.value_or({}); - DFPatternMatcher matcher(bindings); - - if (!matcher.Match(pattern, expr)) { - return NullOpt; - } - - Map matching; - for (const auto& [pat, matches] : matcher.GetMemo()) { - ICHECK_EQ(matches.size(), 1) << "More than one match for the pattern " << pat; - matching.Set(pat, matches[0]); - } - return matching; -} - -TVM_REGISTER_GLOBAL("relax.dpl.extract_matched_expr").set_body_typed(ExtractMatchedExpr); - -bool MatchExpr(DFPattern pattern, Expr expr, Optional> bindings_opt) { - return static_cast(ExtractMatchedExpr(pattern, expr, bindings_opt)); -} - -TVM_REGISTER_GLOBAL("relax.dpl.match_expr").set_body_typed(MatchExpr); - -class MatcherUseDefAnalysis : public relax::ExprVisitor { - public: - std::vector vars; - std::map> def2use; - // caller -> callee table. - std::map> caller2callees; - - const VarNode* cur_user_; - - void VisitBinding_(const VarBindingNode* binding) override { - // init - cur_user_ = binding->var.get(); - this->VisitVarDef(binding->var); - this->VisitExpr(binding->value); - cur_user_ = nullptr; - } - - void VisitExpr_(const VarNode* op) override { - if (nullptr == cur_user_) return; - - auto check_and_push = [](std::vector& vec, const VarNode* var) { - if (std::find(vec.begin(), vec.end(), var) == vec.end()) { - vec.push_back(var); - } - }; - - check_and_push(def2use[op], cur_user_); - check_and_push(vars, op); - - caller2callees[cur_user_].push_back(op); - } -}; - -struct PNode { - const DFPatternNode* ptr; - std::vector&>> children; - std::vector&>> parents; -}; - -struct RNode { - const VarNode* ptr; - std::vector children; - std::vector parents; -}; - -struct MatchState { - void add(const PNode* p, const RNode* r) { - match_p_r[p] = r; - match_r_p[r] = p; - } - - void add(const DFConstraintNode* constraint) { validated_constraints_.insert(constraint); } - - void add(MatchState&& other) { - match_p_r.merge(std::move(other.match_p_r)); - match_r_p.merge(std::move(other.match_r_p)); - validated_constraints_.merge(other.validated_constraints_); - } - - const VarNode* matched(const PNode* p) const { - if (auto it = match_p_r.find(p); it != match_p_r.end()) { - return it->second->ptr; - } - return nullptr; - } - - const DFPatternNode* matched(const RNode* r) const { - if (auto it = match_r_p.find(r); it != match_r_p.end()) { - return it->second->ptr; - } - return nullptr; - } - - const VarNode* matched(const PNode& p) const { return matched(&p); } - const DFPatternNode* matched(const RNode& r) const { return matched(&r); } - - bool is_validated(const DFConstraintNode* constraint) const { - return validated_constraints_.count(constraint); - } - - private: - std::unordered_map match_p_r; - std::unordered_map match_r_p; - std::unordered_set validated_constraints_; -}; - -/** - * \brief This method try to match a real node and a pattern node along with its neighbors. - */ -static std::optional TryMatch(const PNode& p, const RNode& r, - const MatchState& current_match, DFPatternMatcher* m, - const MatcherUseDefAnalysis& ud_analysis) { - if (!m->Match(GetRef(p.ptr), GetRef(r.ptr))) return std::nullopt; - - MatchState new_match; - - new_match.add(&p, &r); - - // forward matching; - for (const auto& [pchild, constraints] : p.children) { - bool any_cons_sat = false; - for (const auto& rchild : r.children) { - if (new_match.matched(rchild)) { - // The child variable is already matched to other child pattern in a previous iteration. - continue; - } - if (auto v = current_match.matched(pchild); v && v != rchild->ptr) { - // The child pattern is already matched to other variable in a earlier call to TryMatch. - continue; - } - - const auto& uses = ud_analysis.def2use.at(r.ptr); - - // check edge constraints. - bool all_cons_pass = true; - for (const auto& cons : constraints) { - if (cons.type == PairCons::kOnlyUsedBy && uses.size() != 1) { - all_cons_pass = false; - break; - } - - if (cons.index != -1) { - const auto& callees = ud_analysis.caller2callees.at(rchild->ptr); - if (callees.size() <= static_cast(cons.index) || callees[cons.index] != r.ptr) { - all_cons_pass = false; - break; - } - } - } - if (!all_cons_pass || new_match.matched(pchild)) continue; - any_cons_sat = true; - - if (auto match_rec = TryMatch(*pchild, *rchild, current_match, m, ud_analysis)) { - new_match.add(pchild, rchild); - new_match.add(std::move(*match_rec)); - } - } - if (!new_match.matched(pchild) || !any_cons_sat) return std::nullopt; - } - - return new_match; -} - -static std::optional TryValidate( - const MatchState& current_match, - const std::unordered_map& pattern2node, - const std::vector& validation_constraints, arith::Analyzer* analyzer) { - MatchState new_match; - - std::function(const DFPatternNode*)> query_match_state = - [&pattern2node, ¤t_match](const DFPatternNode* pattern) -> Optional { - auto it = pattern2node.find(pattern); - ICHECK(it != pattern2node.end()) - << "DFConstraint attempted to access DFPattern " << GetRef(pattern) - << ", which does not appear in the PatternContext"; - const auto& p_node = it->second; - if (auto ptr = current_match.matched(p_node)) { - return GetRef(ptr); - } else { - return NullOpt; - } - }; - - for (const auto& constraint : validation_constraints) { - if (!current_match.is_validated(constraint.get())) { - auto [necessary_condition, is_sufficient] = constraint->AsPrimExpr(query_match_state); - - necessary_condition = analyzer->Simplify(necessary_condition); - const auto* known = tir::as_const_int(necessary_condition); - - if (known && *known && is_sufficient) { - // The condition passes, and the expression provided is both - // necessary and sufficient for the constraint to pass. Mark - // the constraint as passing, to avoid re-checking it unless - // we backtrack. - new_match.add(constraint.get()); - } else if (known && !*known) { - // The condition fails. Even if additional information would - // be required to pass a constraint, it may bail out early as - // a failure (e.g. shape mismatch in the first two items out - // of N shapes that must all match). - return std::nullopt; - } else if (is_sufficient) { - // The condition depends on dynamic parameters. In the - // future, this may be exposed to the user as a condition for - // optimization, or can be combined with the conditions - // provided from other constraints. - return std::nullopt; - } - } - } - - return new_match; -} - -static std::optional MatchTree( - const MatchState& current_match, size_t current_root_idx, - const std::unordered_map& pattern2node, - const std::unordered_map& var2node, DFPatternMatcher* matcher, - const std::vector& roots, const std::vector& validation_constraints, - const MatcherUseDefAnalysis& ud_analysis, arith::Analyzer* analyzer) { - auto get_next_root = [&](size_t root_idx) -> const PNode* { - // Look for the next unmatched root node. - for (; root_idx < roots.size(); ++root_idx) { - const auto& root = pattern2node.at(roots[root_idx].get()); - if (!current_match.matched(root)) { - return &root; - } - } - return nullptr; - }; - - const auto root = get_next_root(current_root_idx); - - if (!root) { - // All root nodes have been matched - return current_match; - } - - MatchState new_match = current_match; - - for (const auto& var : ud_analysis.vars) { - const RNode& r_node = var2node.at(var); - if (new_match.matched(r_node)) continue; - if (auto match = TryMatch(*root, r_node, new_match, matcher, ud_analysis)) { - // Recursively try to match the next subtree. - new_match.add(std::move(*match)); - if (auto validation = - TryValidate(new_match, pattern2node, validation_constraints, analyzer)) { - new_match.add(std::move(*validation)); - if (auto match_rec = - MatchTree(new_match, current_root_idx + 1, pattern2node, var2node, matcher, roots, - validation_constraints, ud_analysis, analyzer)) { - new_match.add(std::move(*match_rec)); - return new_match; - } - } - // Recursive matching has failed, backtrack. - new_match = current_match; - continue; - } - } - - return std::nullopt; -} - -Optional> MatchGraph(const PatternContext& ctx, const DataflowBlock& dfb, - const Map& bindings) { - // TODO(@ganler): Handle non-may external use. - ICHECK(ctx->allow_extern_use == PatternContextNode::kMay) << "Only kMay is supported yet."; - DFPatternMatcher matcher(bindings); - - MatcherUseDefAnalysis ud_analysis; - ud_analysis.VisitBindingBlock_(dfb.get()); - - // First construct a graph of PNode and RNode. - std::unordered_map var2node; - var2node.reserve(dfb->bindings.size()); - - for (const VarNode* cur_var : ud_analysis.vars) { - const auto& uses = ud_analysis.def2use.at(cur_var); - RNode& cur_node = var2node[cur_var]; - cur_node.ptr = cur_var; - for (const VarNode* use : uses) { - auto& use_node = var2node[use]; - use_node.ptr = use; - cur_node.children.push_back(&use_node); - use_node.parents.push_back(&cur_node); - } - } - - std::unordered_map pattern2node; - pattern2node.reserve(ctx->edge_constraints.size()); - - for (const auto& def_pattern : ctx->src_ordered) { - PNode& def_node = pattern2node[def_pattern.get()]; - const auto& uses = ctx->edge_constraints.at(def_pattern); - def_node.ptr = def_pattern.get(); - def_node.children.reserve(uses.size()); - for (const auto& [use_pattern, cons] : uses) { - PNode& use_node = pattern2node[use_pattern.get()]; - use_node.ptr = use_pattern.get(); - use_node.parents.emplace_back(&def_node, std::ref(cons)); - def_node.children.emplace_back(&use_node, std::ref(cons)); - } - } - - std::vector roots; - for (const auto& pat : ctx->src_ordered) { - if (pattern2node[pat.get()].parents.empty()) { - roots.push_back(pat); - } - } - - if (roots.empty()) { - return NullOpt; - } - - arith::Analyzer analyzer; - auto match = MatchTree({}, 0, pattern2node, var2node, &matcher, roots, - ctx->validation_constraints, ud_analysis, &analyzer); - if (!match) { - return NullOpt; - } - - Map ret; - for (const auto& [pat, p_node] : pattern2node) { - ICHECK(match->matched(p_node)); - ret.Set(GetRef(pat), GetRef(match->matched(p_node))); - } - return ret; -} - -Optional> MatchGraph(const PatternContext& ctx, const DataflowBlock& dfb) { - return MatchGraph(ctx, dfb, AnalyzeVar2Value(dfb)); -} - -TVM_REGISTER_GLOBAL("relax.dpl.match_dfb") - .set_body_typed([](const PatternContext& ctx, const DataflowBlock& dfb) { - return MatchGraph(ctx, dfb); - }); - -/*! - * \brief Apply pattern matching to each dataflow block, replacing matches - * with the output of a user-provided rewriter function. - */ -class BlockPatternRewriter : ExprMutator { - public: - using ExprMutator::VisitBindingBlock_; - using ExprMutator::VisitExpr_; - - BlockPatternRewriter( - const PatternContext& ctx, - TypedPackedFunc(Map, Map)> rewriter_func) - : ctx_(ctx), rewriter_func_(rewriter_func) {} - - template - static Function Run( - PatternType pat, - TypedPackedFunc(Map, Map)> rewriter_func, - Function func) { - BlockPatternRewriter rewriter(pat, rewriter_func); - - func = Downcast(rewriter(func)); - func = Downcast(RemoveAllUnused(func)); - return func; - } - - BindingBlock VisitBindingBlock_(const DataflowBlockNode* block_node) override { - return RewriteDataflowBlockFixedPoint(GetRef(block_node)); - } - - private: - void EmitUsedVars(Expr val, const Array& pending_bindings, - std::unordered_set* emitted_vars) { - std::unordered_set unemitted_vars; - PostOrderVisit(val, [=, &unemitted_vars](Expr e) { - if (auto v = e.as(); v && !emitted_vars->count(v)) { - unemitted_vars.insert(v); - } - }); - - if (unemitted_vars.empty()) { - return; - } - - size_t num_unemitted = unemitted_vars.size(); - for (size_t i = 0; i < pending_bindings.size(); ++i) { - const auto& binding = pending_bindings[i]; - if (auto var_bind = binding.as(); - var_bind && unemitted_vars.count(var_bind->var.get())) { - // var_bind->value may also depend on other unemitted vars in this range - Array prev_bindings(pending_bindings.begin(), pending_bindings.begin() + i); - EmitUsedVars(var_bind->value, prev_bindings, emitted_vars); - this->VisitBinding(binding); - emitted_vars->insert(var_bind->var.get()); - if (--num_unemitted == 0) { - return; - } - } - } - } - - // Repeat until all matchable subsets of bindings are rewritten. - BindingBlock RewriteDataflowBlockFixedPoint(BindingBlock block) { - auto df_block = Downcast(block); - Map bindings = AnalyzeVar2Value(df_block); - if (auto matches = MatchGraph(ctx_, df_block, bindings)) { - builder_->BeginDataflowBlock(); - Map replacements = rewriter_func_(matches.value(), bindings); - - std::unordered_set emitted_vars; - - bool changed = false; - for (size_t i = 0; i < block->bindings.size(); ++i) { - const auto& binding = block->bindings[i]; - if (auto var_bind = binding.as()) { - if (auto new_val = replacements.Get(var_bind->var).value_or(var_bind->value); - !StructuralEqual()(var_bind->value, new_val)) { - Array pending_bindings(block->bindings.begin() + i + 1, block->bindings.end()); - // Make sure there is no unbound variable used in the new value before it is emitted - EmitUsedVars(new_val, pending_bindings, &emitted_vars); - this->ReEmitBinding(var_bind, builder_->Normalize(new_val)); - changed = true; - } else if (!emitted_vars.count(var_bind->var.get())) { - this->VisitBinding(binding); - emitted_vars.insert(var_bind->var.get()); - } - } else { - this->VisitBinding(binding); - } - } - - auto new_block = builder_->EndBlock(); - - if (!changed) return new_block; - return RewriteDataflowBlockFixedPoint(new_block); - } - return block; - } - - /*! \brief The pattern constraint contexts for rewriting dataflow blocks */ - PatternContext ctx_; - /*! - * \brief The user-provided rewriter function. Its signature and semantics are: - * - * - (Map, Map) -> Map - * - * Given the map of patterns and corresponding variables (bound - * variables or parameters), it should return a map that - * specifies new values for matched bound variables. It can refer - * to the passed bindings to create the replacement expressions. - */ - TypedPackedFunc(Map, Map)> rewriter_func_; -}; - -/*! - * \brief Apply pattern matching to each expression, replacing - * matches with the output of a user-provided rewriter function. - */ -class ExprPatternRewriter : ExprMutator { - public: - using ExprMutator::VisitBindingBlock_; - using ExprMutator::VisitExpr_; - - ExprPatternRewriter(DFPattern pat, - TypedPackedFunc)> rewriter_func) - : pattern_(pat), rewriter_func_(rewriter_func) {} - - template - static Function Run(PatternType pat, - TypedPackedFunc)> rewriter_func, - Function func) { - ExprPatternRewriter rewriter(pat, rewriter_func); - func = Downcast(rewriter(func)); - func = Downcast(RemoveAllUnused(func)); - return func; - } - - Expr VisitExpr_(const SeqExprNode* seq) override { - auto cache = bindings_; - SeqExpr prev = GetRef(seq); - - StructuralEqual struct_equal; - - while (true) { - SeqExpr next = Downcast(builder_->Normalize(ExprMutator::VisitExpr_(prev.get()))); - if (struct_equal(prev, next)) { - return std::move(next); - } - - // Canonicalization may result in two previously-different - // expressions being recognized as identical. Elimination of - // common subexpressions may result in trival var-to-var - // bindings that can be canonicalized. Therefore, iterate the - // simplification steps until converged. - while (true) { - auto start_of_loop = next; - next = Downcast(CanonicalizeBindings(next)); - next = Downcast(EliminateCommonSubexpr(next)); - next = Downcast(RemoveAllUnused(next)); - if (struct_equal(start_of_loop, next)) { - break; - } - } - - if (struct_equal(prev, next)) { - return std::move(next); - } - - // Reset all knowledge of bindings that were collected from - // this SeqExpr. The collected bindings are only after - // the point where they were collected, and we are repeating - // the mutation of this SeqExpr. - bindings_ = cache; - prev = next; - } - } - - void VisitBinding_(const VarBindingNode* binding) override { - auto expr = VisitExpr(binding->value); - bindings_.Set(binding->var, expr); - ReEmitBinding(binding, expr); - } - - Expr VisitExpr(const Expr& expr) override { - auto node = ExprMutator::VisitExpr(expr); - - std::vector matches_top_level; - if (auto rewritten = TryRewrite(node, pattern_, &matches_top_level)) { - return builder_->Normalize(rewritten.value()); - } - - return node; - } - - private: - Optional TryRewrite(const Expr& expr, const DFPattern& pattern, - std::vector* matches_top_level) { - ICHECK(matches_top_level); - - // Special handling if the user-supplied pattern is a `OrPattern`. - // While the `ExtractMatchedExpr` can handle matching the - // `OrPattern`, it will return on the first match, even if the - // `rewriter_func_` doesn't apply a replacement. Unpacking the - // `OrPattern` here allows the match to be resumed if - // `rewriter_func_` returns the original function unmodified. - // This is only valid for a top-level match. - if (auto or_pattern = pattern.as()) { - matches_top_level->push_back(pattern); - Optional output = TryRewrite(expr, or_pattern->left, matches_top_level); - if (!output.defined()) { - output = TryRewrite(expr, or_pattern->right, matches_top_level); - } - matches_top_level->pop_back(); - return output; - } - - if (auto opt_matches = ExtractMatchedExpr(pattern, expr, bindings_)) { - auto matches = opt_matches.value(); - - // Append any additional matches that from the unwrapped - // `OrPattern`. When matching against `pat = pat_lhs | - // pat_rhs`, we call `ExtractMatchedExpr` on `pat_lhs` and - // `pat_rhs` separately. The top-level `pat` is never seen by - // `ExtractMatchedExpr`, and must be re-added afterward. - if (matches_top_level->size()) { - auto matched_expr = TryGetValOfVar(expr, bindings_); - for (const auto& pat : *matches_top_level) { - matches.Set(pat, matched_expr); - } - } - - Expr rewritten_expr = rewriter_func_(expr, matches); - if (!rewritten_expr.same_as(expr)) { - return builder_->Normalize(rewritten_expr); - } - } - - return NullOpt; - } - - /*! \brief The pattern for rewriting call nodes */ - DFPattern pattern_; - /*! - * \brief The user-provided rewriter function. Its signature and semantics are: - * - * - (Call, Map) -> Call - * - * Given the matched call node and the map of patterns and - * matched expressions, it should return a new call node to - * replace the original one or the original matched call node as - * is. - */ - TypedPackedFunc)> rewriter_func_; - - /*! \brief The known variable bindings - * - * The variable bindings whose value is known. This must be tracked - * separately from the block builder, so that it can be reset after - * each iteration of the mutate-until-converged loop applied to - * `SeqExpr`. - */ - Map bindings_; -}; - -Function RewriteBindings( - const PatternContext& ctx, - TypedPackedFunc(Map, Map)> rewriter, Function func) { - return BlockPatternRewriter::Run(ctx, rewriter, func); -} - -TVM_REGISTER_GLOBAL("relax.dpl.rewrite_bindings").set_body_typed(RewriteBindings); - -Function RewriteCall(const DFPattern& pat, - TypedPackedFunc)> rewriter, Function func) { - return ExprPatternRewriter::Run(pat, rewriter, func); -} - -TVM_REGISTER_GLOBAL("relax.dpl.rewrite_call").set_body_typed(RewriteCall); - } // namespace relax } // namespace tvm diff --git a/src/relax/ir/dataflow_matcher_impl.h b/src/relax/ir/dataflow_matcher.h similarity index 91% rename from src/relax/ir/dataflow_matcher_impl.h rename to src/relax/ir/dataflow_matcher.h index a0c35ac0dead..c5d58db5b9d0 100644 --- a/src/relax/ir/dataflow_matcher_impl.h +++ b/src/relax/ir/dataflow_matcher.h @@ -18,11 +18,11 @@ */ /*! - * \file src/tvm/relax/dataflow_matcher_impl.h + * \file src/tvm/relax/dataflow_matcher.h * \brief The auxiliary data structure for dataflow matcher. */ -#ifndef TVM_RELAX_IR_DATAFLOW_MATCHER_IMPL_H_ -#define TVM_RELAX_IR_DATAFLOW_MATCHER_IMPL_H_ +#ifndef TVM_RELAX_IR_DATAFLOW_MATCHER_H_ +#define TVM_RELAX_IR_DATAFLOW_MATCHER_H_ #include #include @@ -43,7 +43,10 @@ class DFPatternMatcher : public DFPatternFunctor> GetMemo() { return Map>(memo_); } + Map GetMemo() { return memo_; } + + /* \brief Unwrap trivial expressions/bindings */ + static Expr UnwrapBindings(Expr expr, const Map& bindings); protected: bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override; @@ -88,7 +91,7 @@ class DFPatternMatcher : public DFPatternFunctor, ObjectPtrHash, ObjectPtrEqual> memo_; + std::unordered_map memo_; var2val_t var2val_; std::vector matched_nodes_; PrimExpr symbolic_expr_condition_{Bool(true)}; @@ -99,4 +102,4 @@ class DFPatternMatcher : public DFPatternFunctor +#include +#include +#include + +#include +#include +#include +#include + +#include "dataflow_matcher.h" + +namespace tvm { +namespace relax { + +struct RewriteSpec { + Map variable_rewrites; + Map new_subroutines; + + explicit operator bool() const { return variable_rewrites.size(); } + + void Append(RewriteSpec other); +}; + +class PatternMatchingRewriterNode : public tvm::transform::PassNode { + public: + virtual RewriteSpec RewriteBindings(const Array& bindings) const { + return RewriteSpec(); + } + + void VisitAttrs(AttrVisitor* visitor) {} + + IRModule operator()(IRModule mod, const tvm::transform::PassContext& pass_ctx) const override; + tvm::transform::PassInfo Info() const override; + + static constexpr const char* _type_key = "relax.dpl.PatternMatchingRewriter"; + TVM_DECLARE_BASE_OBJECT_INFO(PatternMatchingRewriterNode, PassNode); +}; + +class PatternMatchingRewriter : public tvm::transform::Pass { + public: + static PatternMatchingRewriter FromPattern( + DFPattern pattern, TypedPackedFunc(Expr, Map)> func, + Optional> additional_bindings = NullOpt, + Map new_subroutines = {}); + + static PatternMatchingRewriter FromModule(IRModule mod); + + Expr operator()(Expr expr); + using Pass::operator(); + + TVM_DEFINE_OBJECT_REF_METHODS(PatternMatchingRewriter, Pass, PatternMatchingRewriterNode); +}; + +class ExprPatternRewriterNode : public PatternMatchingRewriterNode { + public: + DFPattern pattern; + TypedPackedFunc(Expr, Map)> func; + Optional> additional_bindings; + Map new_subroutines; + + RewriteSpec RewriteBindings(const Array& bindings) const final; + + Optional RewriteExpr(const Expr& expr, const Map& bindings) const; + + void VisitAttrs(AttrVisitor* visitor) { + visitor->Visit("pattern", &pattern); + PackedFunc untyped_func = func; + visitor->Visit("func", &untyped_func); + } + + static constexpr const char* _type_key = "relax.dpl.ExprPatternRewriter"; + TVM_DECLARE_BASE_OBJECT_INFO(ExprPatternRewriterNode, PatternMatchingRewriterNode); +}; + +class ExprPatternRewriter : public PatternMatchingRewriter { + public: + ExprPatternRewriter(DFPattern pattern, + TypedPackedFunc(Expr, Map)> func, + Optional> additional_bindings = NullOpt, + Map new_subroutines = {}); + + TVM_DEFINE_OBJECT_REF_METHODS(ExprPatternRewriter, PatternMatchingRewriter, + ExprPatternRewriterNode); +}; + +class OrRewriterNode : public PatternMatchingRewriterNode { + public: + PatternMatchingRewriter lhs; + PatternMatchingRewriter rhs; + + RewriteSpec RewriteBindings(const Array& bindings) const override; + + void VisitAttrs(AttrVisitor* visitor) { + visitor->Visit("lhs", &lhs); + visitor->Visit("rhs", &rhs); + } + + static constexpr const char* _type_key = "relax.dpl.OrRewriter"; + TVM_DECLARE_BASE_OBJECT_INFO(OrRewriterNode, PatternMatchingRewriterNode); +}; + +class OrRewriter : public PatternMatchingRewriter { + public: + OrRewriter(PatternMatchingRewriter lhs, PatternMatchingRewriter rhs); + + TVM_DEFINE_OBJECT_REF_METHODS(OrRewriter, PatternMatchingRewriter, OrRewriterNode); +}; + +class TupleRewriterNode : public PatternMatchingRewriterNode { + public: + Array patterns; + TypedPackedFunc(Expr, Map)> func; + Optional> additional_bindings; + Map new_subroutines; + + RewriteSpec RewriteBindings(const Array& bindings) const override; + + void VisitAttrs(AttrVisitor* visitor) { + visitor->Visit("patterns", &patterns); + PackedFunc untyped_func = func; + visitor->Visit("func", &untyped_func); + } + + static constexpr const char* _type_key = "relax.dpl.TupleRewriter"; + TVM_DECLARE_BASE_OBJECT_INFO(TupleRewriterNode, PatternMatchingRewriterNode); + + private: + struct VarInfo { + Var var; + Expr expr; + Array>> matches; + std::unordered_set downstream_usage; + bool used = false; + }; + + Map GenerateVariableRewrites(const Array& bindings) const; + + std::optional> TryMatchByBindingIndex(const std::vector& info_vec, + const std::vector& indices) const; +}; + +class TupleRewriter : public PatternMatchingRewriter { + public: + TupleRewriter(Array patterns, + TypedPackedFunc(Expr, Map)> func, + Optional> additional_bindings = NullOpt, + Map new_subroutines = {}); + + TVM_DEFINE_OBJECT_REF_METHODS(TupleRewriter, PatternMatchingRewriter, TupleRewriterNode); +}; + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_IR_DATAFLOW_REWRITER_H_ diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc index a14ba1d9aaa1..6ace974985a5 100644 --- a/src/relax/ir/expr.cc +++ b/src/relax/ir/expr.cc @@ -21,6 +21,8 @@ #include #include +#include + namespace tvm { namespace relax { @@ -576,17 +578,35 @@ Function::Function(Array params, Expr body, Optional ret_struct body_sinfo = GetStructInfo(body); } - if (ret_struct_info.defined()) { - // allow body to override ret if body is more fine-grained. - if (body_sinfo.defined()) { - if (IsBaseOf(ret_struct_info.value(), body_sinfo.value())) { - ret_struct_info = body_sinfo; - } - } - } else { - CHECK(body_sinfo.defined()) - << "Function do not have a return signature and body is not normalized"; - ret_struct_info = body_sinfo; + CHECK(body_sinfo.defined() || ret_struct_info.defined()) + << "Function must be constructed with either " + << "an explicit struct info for the return type, " + << "or a normalized body with struct info."; + + // Use the body's struct info if there is no explicit return type, + // or if the body may provide a more granular return type. + bool use_body_struct_info = + !ret_struct_info.defined() || + (body_sinfo && ret_struct_info && IsBaseOf(ret_struct_info.value(), body_sinfo.value())); + + if (use_body_struct_info) { + // MatchCast nodes within the body may introduce new symbolic + // variables. These are in-scope for the function body, but not + // for the function's return type. When hoisting the body's type + // to the function return type, symbolic variables may only be + // used if they were defined by the function's parameters. + auto f_shape_var_map = [&] { + auto tir_vars = DefinableTIRVarsInStructInfo(TupleStructInfo(params.Map(GetStructInfo))); + std::unordered_set lookup(tir_vars.begin(), tir_vars.end()); + return [lookup = std::move(lookup)](const tir::Var& var) -> Optional { + if (lookup.count(var)) { + return var; + } else { + return NullOpt; + } + }; + }(); + ret_struct_info = EraseToWellDefined(body_sinfo.value(), f_shape_var_map); } FuncStructInfo func_sinfo(param_sinfo, ret_struct_info.value(), is_pure); diff --git a/src/relax/ir/expr_functor.cc b/src/relax/ir/expr_functor.cc index 63c74db7e33e..3ee403a25cda 100644 --- a/src/relax/ir/expr_functor.cc +++ b/src/relax/ir/expr_functor.cc @@ -606,8 +606,8 @@ Expr ExprMutator::VisitExpr_(const FunctionNode* op) { Expr ExprMutator::VisitExpr_(const IfNode* op) { Expr guard = this->VisitExpr(op->cond); - Expr true_b = this->VisitWithNewScope(op->true_branch); - Expr false_b = this->VisitWithNewScope(op->false_branch); + Expr true_b = this->VisitWithInnerScope(op->true_branch); + Expr false_b = this->VisitWithInnerScope(op->false_branch); if (op->cond.same_as(guard) && op->true_branch.same_as(true_b) && op->false_branch.same_as(false_b) && VisitAndCheckStructInfoFieldUnchanged(op->struct_info_)) { @@ -696,20 +696,24 @@ void ExprMutator::VisitBinding_(const MatchCastNode* binding) { Var new_var = this->VisitVarDef(binding->var); - if (new_var.same_as(binding->var) && new_value.same_as(binding->value) && - new_struct_info.same_as(binding->struct_info)) { - // re-emit old binding if nothing changes - builder_->EmitNormalized(GetRef(binding)); - return; - } + MatchCast new_binding = [&]() -> MatchCast { + if (new_var.same_as(binding->var) && new_value.same_as(binding->value) && + new_struct_info.same_as(binding->struct_info)) { + // re-emit old binding if nothing changes + return GetRef(binding); + } else { + new_value = builder_->NormalizeArgument(new_value); + new_var = WithStructInfo(new_var, new_struct_info); - new_value = builder_->NormalizeArgument(new_value); - new_var = WithStructInfo(new_var, new_struct_info); + var_remap_[binding->var->vid] = new_var; + var_remap_[new_var->vid] = new_var; - var_remap_[binding->var->vid] = new_var; - var_remap_[new_var->vid] = new_var; + return MatchCast(new_var, new_value, new_struct_info, binding->span); + } + }(); - builder_->EmitNormalized(MatchCast(new_var, new_value, new_struct_info, binding->span)); + builder_->EmitNormalized(new_binding); + builder_->AddDefinitionToScope(new_binding->var); } BindingBlock ExprMutator::VisitBindingBlock_(const BindingBlockNode* block) { @@ -800,7 +804,31 @@ Expr ExprMutator::VisitWithNewScope(const Expr& expr, Optional> param } builder_->BeginScope(params); + // Outer scope only includes TIR variables that can be inferred from + // the function parameters. With context(builder_->GetAnalyzer(), constraint); + builder_->BeginInnerScope(); + // Inner scope also includes any TIR variables that are defined by + // MatchCast nodes, and are internal to the scope. + Expr ret = this->VisitExpr(expr); + + builder_->EndScope(); + + // Normalization (and the resulting StructInfo inference) of the + // expr occurs outside of the body's parameters, but inside the + // function signature's scope. This keeps variables that are + // inferable based on the function signature, to allow callers to + // propagate StructInfo across the function. + ret = builder_->Normalize(ret); + builder_->EndScope(); + return ret; +} + +Expr ExprMutator::VisitWithInnerScope(const Expr& expr) { + ICHECK(expr->IsInstance()) + << "Normal form requires all new scope is stored as SeqExpr"; + + builder_->BeginInnerScope(); Expr ret = this->VisitExpr(expr); builder_->EndScope(); return ret; diff --git a/src/relax/transform/canonicalize_bindings.cc b/src/relax/transform/canonicalize_bindings.cc index 12eb81ac675d..d1a9f97337de 100644 --- a/src/relax/transform/canonicalize_bindings.cc +++ b/src/relax/transform/canonicalize_bindings.cc @@ -29,12 +29,119 @@ #include #include #include +#include namespace tvm { namespace relax { namespace { +class SymbolicVarCanonicalizer : public ExprMutator { + public: + Expr VisitExpr_(const FunctionNode* func) override { + auto cached = known_values_; + auto output = ExprMutator::VisitExpr_(func); + known_values_ = cached; + return output; + } + + void VisitBinding_(const MatchCastNode* binding) override { + auto tir_var_map = + InferSymbolicVarMap({{binding->var, binding->value}}, builder_->GetAnalyzer()); + for (const auto& [tir_var, prim_expr] : tir_var_map) { + if (auto it = known_values_.find(tir_var); it != known_values_.end()) { + CHECK(!builder_->GetAnalyzer()->CanProve(it->second.expr != prim_expr)) + << "ValueError: " + << "MatchCast statements must be consistent. " + << "However, the definition of Relax variable " << it->second.source->var + << " implies that TIR variable " << tir_var << " is " << it->second.expr + << ", while the later definition of Relax variable " << binding->var + << " instead implies that TIR variable " << tir_var << " is " << prim_expr; + } else { + known_values_[tir_var] = KnownValue{prim_expr, GetRef(binding)}; + } + } + ExprMutator::VisitBinding_(binding); + } + + Expr VisitExpr_(const IfNode* op) override { + Expr guard = this->VisitExpr(op->cond); + + auto cached = known_values_; + Expr true_b = this->VisitWithInnerScope(op->true_branch); + known_values_ = cached; + Expr false_b = this->VisitWithInnerScope(op->false_branch); + known_values_ = cached; + + if (op->cond.same_as(guard) && op->true_branch.same_as(true_b) && + op->false_branch.same_as(false_b)) { + return GetRef(op); + } + + // The two branches may have had different TIR variables inlined. + // For example, one branch has a dynamic implementation and + // produces `R.Tensor([M,N])`, while the other branch checks if + // `N==16` and produces `R.Tensor([M,16])`. After the branch, the + // output is `R.Tensor([M,N])`. However, the `GetStructLCA` would + // correctly return `R.Tensor(ndim=2)`, removing all shape + // information. + // + // Since we know the StructInfo prior to replacing TIR variables, + // this pass can provide a better StructInfo than the generic + // handling in ExprMutator, by restoring the symbolic variables + // within each branch. + auto new_sinfo = VisitExprDepStructInfoField(Downcast(op->struct_info_)); + + StructuralEqual struct_equal; + if (!struct_equal(new_sinfo, GetStructInfo(true_b))) { + auto output_var = Var("then_branch_with_dyn", new_sinfo); + + true_b = SeqExpr({BindingBlock({ + MatchCast(output_var, true_b, new_sinfo), + })}, + output_var); + } + + if (!struct_equal(new_sinfo, GetStructInfo(false_b))) { + auto output_var = Var("else_branch_with_dyn", new_sinfo); + + false_b = SeqExpr({BindingBlock({ + MatchCast(output_var, false_b, new_sinfo), + })}, + output_var); + } + + return If(guard, true_b, false_b, op->span); + } + + PrimExpr VisitPrimExpr(const PrimExpr& expr) override { + if (known_values_.empty()) { + return expr; + } + PrimExpr output = tir::Substitute(expr, [this](const tir::Var& var) -> Optional { + if (auto it = known_values_.find(var); it != known_values_.end()) { + return it->second.expr; + } else { + return NullOpt; + } + }); + if (output.same_as(expr)) { + return expr; + } + + output = builder_->GetAnalyzer()->Simplify(output); + return output; + } + + private: + struct KnownValue { + PrimExpr expr; + MatchCast source; + }; + + std::unordered_map known_values_; +}; + struct CanonicalizationPlan { Map replace_usage; Map replace_binding; @@ -377,16 +484,39 @@ class BindingCanonicalizer : public ExprMutator { }; } // namespace -Expr CanonicalizeBindings(const Expr& expr) { return BindingCanonicalizer::Apply(expr); } +Expr CanonicalizeTIRVariables(Expr expr) { return SymbolicVarCanonicalizer()(std::move(expr)); } + +Expr CanonicalizeRelaxBindings(Expr expr) { return BindingCanonicalizer::Apply(std::move(expr)); } + +Expr CanonicalizeBindings(Expr expr) { + expr = CanonicalizeTIRVariables(std::move(expr)); + expr = CanonicalizeRelaxBindings(std::move(expr)); + return expr; +} namespace transform { +Pass CanonicalizeTIRVariables() { + auto pass_func = [=](Function f, IRModule m, PassContext pc) { + return Downcast(CanonicalizeTIRVariables(f)); + }; + return CreateFunctionPass(pass_func, 1, "CanonicalizeTIRVariables", {}); +} + +Pass CanonicalizeRelaxBindings() { + auto pass_func = [=](Function f, IRModule m, PassContext pc) { + return Downcast(CanonicalizeBindings(f)); + }; + return CreateFunctionPass(pass_func, 1, "CanonicalizeRelaxBindings", {}); +} + Pass CanonicalizeBindings() { - runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(CanonicalizeBindings(f)); - }; - return CreateFunctionPass(pass_func, 1, "CanonicalizeBindings", {}); + return tvm::transform::Sequential( + { + CanonicalizeTIRVariables(), + CanonicalizeRelaxBindings(), + }, + "CanonicalizeBindings"); } TVM_REGISTER_GLOBAL("relax.transform.CanonicalizeBindings").set_body_typed(CanonicalizeBindings); diff --git a/src/relax/transform/utils.h b/src/relax/transform/utils.h index 5755e118541f..932dca30a110 100644 --- a/src/relax/transform/utils.h +++ b/src/relax/transform/utils.h @@ -420,7 +420,7 @@ Expr EliminateCommonSubexpr(const Expr& expr, bool call_only = false); * * \ret The canonicalized expression */ -Expr CanonicalizeBindings(const Expr& expr); +Expr CanonicalizeBindings(Expr expr); /* \brief Remove use of trivial bindings * diff --git a/src/relax/utils.cc b/src/relax/utils.cc index f0239e424f30..77416dc92b1d 100644 --- a/src/relax/utils.cc +++ b/src/relax/utils.cc @@ -122,11 +122,7 @@ tvm::Map InferSymbolicVarMap( if (!var_sinfo) return; auto expr_sinfo = expr.as(); - CHECK(expr_sinfo) << "Cannot bind expression with struct type " << expr - << " to variable with struct type " << var; - CHECK_EQ(var_sinfo->dtype, expr_sinfo->dtype) - << "Cannot bind expression with struct type " << expr << " to variable with struct type " - << var << ", due to conflicting PrimExpr DataType"; + if (!expr_sinfo) return; if (!var_sinfo->value.defined() || !expr_sinfo->value.defined()) return; @@ -139,15 +135,12 @@ tvm::Map InferSymbolicVarMap( if (!var_shape->values.defined()) return; auto expr_shape = expr.as(); - CHECK(expr_shape) << "Cannot bind expression with struct type " << expr - << " to variable with struct type " << var; + if (!expr_shape) return; if (!expr_shape->values.defined()) return; auto var_shape_arr = var_shape->values.value(); auto expr_shape_arr = expr_shape->values.value(); - CHECK_EQ(var_shape_arr.size(), expr_shape_arr.size()) - << "Cannot bind shape " << expr_shape_arr << " of dimension " << expr_shape_arr.size() - << " to variable with shape " << var_shape_arr << " of dimension " << var_shape_arr.size(); + if (var_shape_arr.size() != expr_shape_arr.size()) return; for (size_t i = 0; i < var_shape_arr.size(); i++) { bind_from_prim_expr(var_shape_arr[i], expr_shape_arr[i]); } @@ -159,8 +152,7 @@ tvm::Map InferSymbolicVarMap( if (!var_tensor->shape.defined()) return; auto expr_tensor = expr.as(); - CHECK(expr_tensor) << "Cannot bind expression with struct type " << expr - << " to variable with struct type " << var; + if (!expr_tensor) return; if (!expr_tensor->shape.defined()) return; bind_from_shape(GetStructInfo(var_tensor->shape.value()), diff --git a/src/script/ir_builder/relax/frame.cc b/src/script/ir_builder/relax/frame.cc index 792331dda4c0..3153c0770e38 100644 --- a/src/script/ir_builder/relax/frame.cc +++ b/src/script/ir_builder/relax/frame.cc @@ -46,6 +46,11 @@ void SeqExprFrameNode::EnterWithScope() { BindingBlock()->EnterWithScope(); } +void FunctionFrameNode::EnterWithScope() { + this->block_builder->BeginScope(params); + SeqExprFrameNode::EnterWithScope(); +} + void FunctionFrameNode::ExitWithScope() { using ir::IRModuleFrame; using tvm::relax::Expr; @@ -54,7 +59,7 @@ void FunctionFrameNode::ExitWithScope() { // Step 1: Create the function. CHECK(output.defined()) << "ValueError: A Relax function must have a return value. Please use " "`return` to return an Expr"; - this->block_builder->BeginScope(params); + Expr body = this->block_builder->Normalize(tvm::relax::SeqExpr(binding_blocks, output.value())); // if the function is not private, add a global symbol to its attributes if (!is_private.value_or(Bool(false))->value && name.defined() && diff --git a/src/script/ir_builder/relax/ir.cc b/src/script/ir_builder/relax/ir.cc index 2e94ae420a97..453c7fdb5522 100644 --- a/src/script/ir_builder/relax/ir.cc +++ b/src/script/ir_builder/relax/ir.cc @@ -70,15 +70,7 @@ tvm::relax::Var Arg(const String& name, const tvm::relax::StructInfo& struct_inf FunctionFrame frame = FindFunctionFrame("R.Arg"); tvm::relax::Var var(name, struct_info); frame->params.push_back(var); - - // This constraint would normally be provided as part of - // `BlockBuilder::BeginScope`. However, because the frame and its - // scope are initialized before the arguments are known, the scope - // doesn't have access to these constraints. - auto* analyzer = frame->block_builder->GetAnalyzer(); - for (const auto& tir_var : DefinableTIRVarsInStructInfo(struct_info)) { - analyzer->MarkGlobalNonNegValue(tir_var); - } + frame->block_builder->AddDefinitionToScope(var); return var; } diff --git a/tests/python/relax/test_dataflow_rewriter.py b/tests/python/relax/test_dataflow_rewriter.py new file mode 100644 index 000000000000..828aa92bda28 --- /dev/null +++ b/tests/python/relax/test_dataflow_rewriter.py @@ -0,0 +1,1512 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +import tvm.testing +from tvm.script import ir as I, relax as R, tir as T + +import pytest + + +def test_rewrite_defined_by_ir_module(): + @R.rewriter + class Rewriter: + @R.function + def pattern(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): + C = R.add(A, B) + return C + + @R.function + def replacement(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): + C = R.call_pure_packed( + "my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32") + ) + return C + + @R.function + def before(x: R.Tensor([32], "float32")): + R.func_attr({"global_symbol": "main"}) + split = R.split(x, 2) + lhs = split[0] + rhs = split[1] + out = lhs + rhs + return out + + @R.function + def expected(x: R.Tensor([32], "float32")): + R.func_attr({"global_symbol": "main"}) + split = R.split(x, 2) + lhs = split[0] + rhs = split[1] + out = R.call_pure_packed( + "my_optimized_add_impl", lhs, rhs, sinfo_args=R.Tensor([16], "float32") + ) + return out + + after = Rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +def test_missing_pattern_raises_error(): + """The rewriter must define a pattern to be matched""" + + with pytest.raises(KeyError, match="pattern"): + + @R.rewriter + class Rewriter: + @R.function + def replacement(): + return R.tuple() + + +def test_incorrect_function_type_of_pattern_raises_error(): + """The rewriter's pattern must be a Relax function""" + + with pytest.raises(TypeError, match="pattern"): + + @R.rewriter + class Rewriter: + @T.prim_func + def pattern(): + pass + + @R.function + def replacement(): + return R.tuple() + + +def test_missing_replacement_raises_error(): + """The rewriter must define a replacement""" + + with pytest.raises(KeyError, match="replacement"): + + @R.rewriter + class Rewriter: + @R.function + def pattern(): + return R.tuple() + + +def test_incorrect_function_type_of_replacement_raises_error(): + """The rewriter's replacement must be a Relax function""" + + with pytest.raises(TypeError, match="replacement"): + + @R.rewriter + class Rewriter: + @R.function + def pattern(): + return R.tuple() + + @T.prim_func + def replacement(): + pass + + +def test_mismatch_of_static_shapes_raises_error(): + """The pattern and replacement must accept the same shapes""" + + with pytest.raises(ValueError, match="must have the same signature"): + + @R.rewriter + class Rewriter: + @R.function + def pattern(A: R.Tensor([32])): + return A + + @R.function + def replacement(A: R.Tensor([16])): + return A + + +def test_rewriter_may_be_applied_to_ir_module(): + """A rewriter may mutate an IRModule + + The `PatternMatchingRewriter.__call__` implementation may accept + either a single Relax function, or an entire IRModule. If it is + passed an IRModule, then all functions in the `IRModule` are + updated. + + """ + + @R.rewriter + class Rewriter: + @R.function + def pattern(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): + C = R.add(A, B) + return C + + @R.function + def replacement(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): + C = R.call_pure_packed( + "my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32") + ) + return C + + @I.ir_module + class Before: + @R.function + def func_a(x: R.Tensor([32], "float32")): + split = R.split(x, 2) + lhs = split[0] + rhs = split[1] + out = lhs + rhs + return out + + @R.function + def func_b(x: R.Tensor([16], "float32")): + out = x + x + return out + + @I.ir_module + class Expected: + @R.function + def func_a(x: R.Tensor([32], "float32")): + split = R.split(x, 2) + lhs = split[0] + rhs = split[1] + out = R.call_pure_packed( + "my_optimized_add_impl", lhs, rhs, sinfo_args=R.Tensor([16], "float32") + ) + return out + + @R.function + def func_b(x: R.Tensor([16], "float32")): + out = R.call_pure_packed( + "my_optimized_add_impl", x, x, sinfo_args=R.Tensor([16], "float32") + ) + return out + + After = Rewriter(Before) + tvm.ir.assert_structural_equal(Expected, After) + + +def test_rewriter_may_be_used_as_ir_transform(): + """A rewriter may be used as a tvm.ir.transform.Pass""" + + @R.rewriter + class Rewriter: + @R.function + def pattern(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): + C = R.add(A, B) + return C + + @R.function + def replacement(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): + C = R.call_pure_packed( + "my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32") + ) + return C + + @I.ir_module + class Before: + @R.function + def main(x: R.Tensor([16], "float32")): + y = x + x + return y + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor([16], "float32")): + out = R.call_pure_packed( + "my_optimized_add_impl", x, x, sinfo_args=R.Tensor([16], "float32") + ) + return out + + After = tvm.ir.transform.Sequential([Rewriter])(Before) + tvm.ir.assert_structural_equal(Expected, After) + + +def test_same_pattern_applied_multiple_times(): + """The pattern-match may apply multiple times""" + + @R.rewriter + class Rewriter: + @R.function + def pattern(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): + C = R.add(A, B) + return C + + @R.function + def replacement(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): + C = R.call_pure_packed( + "my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32") + ) + return C + + @R.function(private=True) + def before(x: R.Tensor([16], "float32")): + y = x + x + z = y + y + return z + + @R.function(private=True) + def expected(x: R.Tensor([16], "float32")): + y = R.call_pure_packed("my_optimized_add_impl", x, x, sinfo_args=R.Tensor([16], "float32")) + z = R.call_pure_packed("my_optimized_add_impl", y, y, sinfo_args=R.Tensor([16], "float32")) + return z + + after = Rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +def test_composition_of_rewrite_rules(): + """Rewrite rules may be composed together""" + + @R.rewriter + class RewriteAdd: + @R.function + def pattern(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): + C = A + B + return C + + @R.function + def replacement(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): + C = R.call_pure_packed( + "my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32") + ) + return C + + @R.rewriter + class RewriteMultiply: + @R.function + def pattern(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): + C = A * B + return C + + @R.function + def replacement(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): + C = R.call_pure_packed( + "my_optimized_mul_impl", A, B, sinfo_args=R.Tensor([16], "float32") + ) + return C + + @R.function(private=True) + def before( + A: R.Tensor([16], "float32"), + B: R.Tensor([16], "float32"), + C: R.Tensor([16], "float32"), + ): + D = A + B + E = C * D + return E + + @R.function(private=True) + def expected( + A: R.Tensor([16], "float32"), + B: R.Tensor([16], "float32"), + C: R.Tensor([16], "float32"), + ): + D = R.call_pure_packed("my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32")) + E = R.call_pure_packed("my_optimized_mul_impl", C, D, sinfo_args=R.Tensor([16], "float32")) + return E + + rewriter = RewriteAdd | RewriteMultiply + + after = rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +def test_recursive_rewrite_rules(): + """Rewrite rules are applied until convergence + + In this test, both the `RewriteAdd` and `RewriteMultiply` patterns + must be applied in order to produce the expected output. However, + the `RewriteMultiply` pattern relies on the expression produced by + the `RewriteAdd` pass. + + """ + + @R.rewriter + class RewriteAdd: + @R.function + def pattern(A: R.Tensor([16], "float32")): + return A + A + + @R.function + def replacement(A: R.Tensor([16], "float32")): + return A * R.const(2.0, "float32") + + @R.rewriter + class RewriteMultiply: + @R.function + def pattern(A: R.Tensor([16], "float32"), B: R.Tensor([], "float32")): + C = A * B + return C + + @R.function + def replacement(A: R.Tensor([16], "float32"), B: R.Tensor([], "float32")): + C = R.call_pure_packed( + "my_optimized_mul_impl", A, B, sinfo_args=R.Tensor([16], "float32") + ) + return C + + @R.function(private=True) + def before(A: R.Tensor([16], "float32")): + B = A + A + return B + + @R.function(private=True) + def expected(A: R.Tensor([16], "float32")): + B = R.call_pure_packed( + "my_optimized_mul_impl", + A, + R.const(2.0, "float32"), + sinfo_args=R.Tensor([16], "float32"), + ) + return B + + rewriter = RewriteAdd | RewriteMultiply + + after = rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +def test_rewrite_of_arbitrary_dtype(): + """A pattern-match may apply to a tensor with unknown dtype + + In this test case, a pattern identifies `R.strided_slice` usage + which returns the last slice of an array, and replaces it with a + view into the input array. + + """ + + @R.rewriter + class Rewriter: + @R.function + def pattern(A: R.Tensor(["M", "N"])) -> R.Tensor(["N"]): + M = T.int64() + N = T.int64() + last_slice_2d: R.Tensor([1, N]) = R.strided_slice(A, axes=[0], begin=[M - 1], end=[M]) + last_slice_1d: R.Tensor([N]) = R.squeeze(last_slice_2d, axis=0) + return last_slice_1d + + @R.function + def replacement(A: R.Tensor(["M", "N"])) -> R.Tensor(["N"]): + M = T.int64() + N = T.int64() + + # TODO(Lunderberg): Improve this syntax. A Relax + # PrimValue (e.g. `A.dtype.bits`) should be usable in any + # Relax context that accepts a `PrimExpr`. Currently, + # this requires `R.match_cast` to produce a TIR symbolic + # variable from the Relax PrimValue. + bits_per_element = T.uint8() + _ = R.match_cast( + A.dtype.bits, + R.Prim(value=bits_per_element), + ) + lanes_per_element = T.uint16() + _ = R.match_cast( + A.dtype.lanes, + R.Prim(value=lanes_per_element), + ) + + last_slice = R.memory.view( + A, + [N], + relative_byte_offset=(M - 1) + * N + * T.ceildiv( + bits_per_element.astype("int64") * lanes_per_element.astype("int64"), 8 + ), + ) + return last_slice + + @I.ir_module + class Before: + @R.function + def main( + A: R.Tensor([32, 16], "float16"), + B: R.Tensor(["P", "Q"], "int4x8"), + C: R.Tensor([16, 32]), + ): + P = T.int64() + Q = T.int64() + + A_slice_2d = R.strided_slice(A, axes=[0], begin=[31], end=[32]) + A_slice_1d = R.squeeze(A_slice_2d, axis=0) + + B_slice_2d = R.strided_slice(B, axes=[0], begin=[P - 1], end=[P]) + B_slice_1d = R.squeeze(B_slice_2d, axis=0) + + C_slice_2d = R.strided_slice(C, axes=[0], begin=[15], end=[16]) + C_slice_1d = R.squeeze(C_slice_2d, axis=0) + + return (A_slice_1d, B_slice_1d, C_slice_1d) + + @I.ir_module + class Expected: + @R.function + def main( + A: R.Tensor([32, 16], "float16"), + B: R.Tensor(["P", "Q"], "int4x8"), + C: R.Tensor([16, 32]), + ): + P = T.int64() + Q = T.int64() + + # The pattern matches any 2-d tensor, with any data type. + # When the match's shape and dtype are both known, + # normalization and canonicalization produces a statically + # known value for `relative_byte_offset`. + # + # Relative offset is `(31 rows) * + # (16 elements/row) * + # (2 bytes/element)` + A_slice_1d = R.memory.view(A, shape=[16], relative_byte_offset=992) + + # The pattern can also match a 2-d tensor with dynamic + # shape. The `relative_byte_offset` uses the known + # datatype (4 bytes for each int4x8), but with dynamic + # shape variables substituted in where required. + # + # Relative offset is `((P-1) rows) * + # (Q elements/row) * + # (4 bytes/element)` + B_slice_1d = R.memory.view(B, shape=[Q], relative_byte_offset=(P - 1) * Q * 4) + + # The pattern can also match a 2-d tensor with static + # shape, but unknown data type. The + # `relative_byte_offset` is determined based on the known + # number of elements, and the dynamic size of each + # element. + # + # Relative offset is `(15 rows) * + # (32 elements/row) * + # (ceildiv(bits*lanes,8) bytes/element)` + C_bits_per_element = T.uint8() + C_bits_prim_value = C.dtype.bits + _ = R.match_cast( + C_bits_prim_value, + R.Prim(value=C_bits_per_element), + ) + C_lanes_per_element = T.uint16() + C_lanes_prim_value = C.dtype.lanes + _ = R.match_cast( + C_lanes_prim_value, + R.Prim(value=C_lanes_per_element), + ) + + C_slice_1d = R.memory.view( + C, + shape=[32], + relative_byte_offset=( + (C_bits_per_element.astype("int64") * C_lanes_per_element.astype("int64") + 7) + // 8 + ) + * 480, + ) + + return (A_slice_1d, B_slice_1d, C_slice_1d) + + after = Rewriter(Before) + tvm.ir.assert_structural_equal(Expected, after) + + +def test_rewrite_may_introduce_private_relax_subroutines(): + """The replacement may contain subroutines""" + + @R.rewriter + class Rewriter: + @R.function + def pattern(A: R.Tensor([16], "float32")): + return A + A + + @R.function + def replacement(A: R.Tensor([16], "float32")): + return Rewriter.subroutine(A) + + @R.function(private=True) + def subroutine(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"): + return A * R.const(2.0, "float32") + + @I.ir_module + class Before: + @R.function + def main(A: R.Tensor([16], "float32")): + B = A + A + C = B + B + return C + + @I.ir_module + class Expected: + @R.function + def main(A: R.Tensor([16], "float32")): + B = Expected.subroutine(A) + C = Expected.subroutine(B) + return C + + @R.function(private=True) + def subroutine(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"): + return A * R.const(2.0, "float32") + + After = Rewriter(Before) + tvm.ir.assert_structural_equal(Expected, After) + + +def test_rewrite_only_introduces_private_subroutines_when_required(): + """Only subroutines that are used will be added to the module + + Like `test_rewrite_may_introduce_private_relax_subroutines`, but + the rewritten function only requires some of the subroutines + provided by the rewriter. + + """ + + @R.rewriter + class RewriteAdd: + @R.function + def pattern(A: R.Tensor([16], "float32")): + return A + A + + @R.function + def replacement(A: R.Tensor([16], "float32")): + return RewriteAdd.subroutine_add(A) + + @R.function(private=True) + def subroutine_add(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"): + return A * R.const(2.0, "float32") + + @R.rewriter + class RewriteMul: + @R.function + def pattern(A: R.Tensor([16], "float32")): + return A * A + + @R.function + def replacement(A: R.Tensor([16], "float32")): + return R.call_tir(RewriteMul.subroutine_mul, [A], out_sinfo=R.Tensor([16], "float32")) + + @T.prim_func(private=True) + def subroutine_mul(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")): + for i in range(16): + B[i] = A[i] * A[i] + + @I.ir_module + class Before: + @R.function + def main(A: R.Tensor([16], "float32")): + B = A + A + C = B + B + return C + + @I.ir_module + class Expected: + @R.function + def main(A: R.Tensor([16], "float32")): + B = Expected.subroutine_add(A) + C = Expected.subroutine_add(B) + return C + + @R.function(private=True) + def subroutine_add(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"): + return A * R.const(2.0, "float32") + + rewriter = RewriteAdd | RewriteMul + + After = rewriter(Before) + tvm.ir.assert_structural_equal(Expected, After) + + +def test_rewriter_may_not_introduce_public_subroutines(): + """The rewriter may only introduce private functions""" + + with pytest.raises(ValueError, match="is publicly exposed"): + + @R.rewriter + class Rewriter: + @R.function + def pattern(A: R.Tensor([16], "float32")): + return A + A + + @R.function + def replacement(A: R.Tensor([16], "float32")): + return Rewriter.subroutine(A) + + @R.function + def subroutine(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"): + return A * R.const(2.0, "float32") + + +def test_rewrite_branches_may_reuse_subroutine_name(): + """Each rewriter is independent, and may reuse subroutine names""" + + @R.rewriter + class RewriteAdd: + @R.function + def pattern(A: R.Tensor([16], "float32")): + return A + A + + @R.function + def replacement(A: R.Tensor([16], "float32")): + return RewriteAdd.subroutine(A) + + @R.function(private=True) + def subroutine(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"): + return A * R.const(2.0, "float32") + + @R.rewriter + class RewriteMul: + @R.function + def pattern(A: R.Tensor([16], "float32")): + return A * A + + @R.function + def replacement(A: R.Tensor([16], "float32")): + return R.call_tir(RewriteMul.subroutine, [A], out_sinfo=R.Tensor([16], "float32")) + + @T.prim_func(private=True) + def subroutine(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")): + for i in range(16): + B[i] = A[i] * A[i] + + @I.ir_module + class Before: + @R.function + def main(A: R.Tensor([16], "float32")): + B = A + A + C = B * B + return C + + @I.ir_module + class Expected: + @R.function + def main(A: R.Tensor([16], "float32")): + B = Expected.subroutine(A) + C = R.call_tir(Expected.subroutine_1, [B], out_sinfo=R.Tensor([16], "float32")) + return C + + @R.function(private=True) + def subroutine(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"): + return A * R.const(2.0, "float32") + + @T.prim_func(private=True) + def subroutine_1(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")): + for i in range(16): + B[i] = A[i] * A[i] + + rewriter = RewriteAdd | RewriteMul + + After = rewriter(Before) + tvm.ir.assert_structural_equal(Expected, After) + + +def test_rewrite_of_explicit_relax_tuple(): + """The rewriter function may return a tuple + + When it occurs explicitly within the Relax function, the tuple + pattern matches against the Relax tuple, and the Relax tuple is + replaced. + + """ + + @R.rewriter + class Rewriter: + @R.function + def pattern( + lhs_A: R.Tensor([16, 16], "float32"), + lhs_B: R.Tensor([16, 16], "float32"), + rhs: R.Tensor([16], "float32"), + ): + proj_A = R.matmul(lhs_A, rhs) + proj_B = R.matmul(lhs_B, rhs) + proj_tuple = (proj_A, proj_B) + return proj_tuple + + @R.function + def replacement( + lhs_A: R.Tensor([16, 16], "float32"), + lhs_B: R.Tensor([16, 16], "float32"), + rhs: R.Tensor([16], "float32"), + ): + lhs = R.concat([lhs_A, lhs_B]) + proj_concat = R.matmul(lhs, rhs) + proj_tuple = R.split(proj_concat, 2) + return proj_tuple + + @R.function(private=True) + def before( + state: R.Tensor([16], "float32"), + A: R.Tensor([16, 16], "float32"), + B: R.Tensor([16, 16], "float32"), + ): + proj_A = R.matmul(A, state) + proj_B = R.matmul(B, state) + proj_tuple = (proj_A, proj_B) + out = proj_tuple[0] + proj_tuple[1] + return out + + @R.function(private=True) + def expected( + state: R.Tensor([16], "float32"), + A: R.Tensor([16, 16], "float32"), + B: R.Tensor([16, 16], "float32"), + ): + concat_AB = R.concat([A, B]) + proj_concat = R.matmul(concat_AB, state) + proj_tuple = R.split(proj_concat, 2) + out = proj_tuple[0] + proj_tuple[1] + return out + + after = Rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +def test_rewrite_of_output_relax_tuple(): + """The rewriter may update a tuple being returned + + Unlike most relax expressions, tuples may appear as nested + expressions. Pattern-matching should be aware of this option. + + Like `test_rewrite_of_explicit_relax_tuple`, but the tuple appears + as the return value in the function being modified. + + """ + + @R.rewriter + class Rewriter: + @R.function + def pattern( + lhs_A: R.Tensor([16, 16], "float32"), + lhs_B: R.Tensor([16, 16], "float32"), + rhs: R.Tensor([16], "float32"), + ): + proj_A = R.matmul(lhs_A, rhs) + proj_B = R.matmul(lhs_B, rhs) + proj_tuple = (proj_A, proj_B) + return proj_tuple + + @R.function + def replacement( + lhs_A: R.Tensor([16, 16], "float32"), + lhs_B: R.Tensor([16, 16], "float32"), + rhs: R.Tensor([16], "float32"), + ): + lhs = R.concat([lhs_A, lhs_B]) + proj_concat = R.matmul(lhs, rhs) + proj_tuple = R.split(proj_concat, 2) + return proj_tuple + + @R.function(private=True) + def before( + state: R.Tensor([16], "float32"), + A: R.Tensor([16, 16], "float32"), + B: R.Tensor([16, 16], "float32"), + ): + proj_A = R.matmul(A, state) + proj_B = R.matmul(B, state) + return (proj_A, proj_B) + + @R.function(private=True) + def expected( + state: R.Tensor([16], "float32"), + A: R.Tensor([16, 16], "float32"), + B: R.Tensor([16, 16], "float32"), + ): + concat_AB = R.concat([A, B]) + proj_concat = R.matmul(concat_AB, state) + proj_tuple = R.split(proj_concat, 2) + return proj_tuple + + after = Rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +def test_rewrite_of_implicit_tuple(): + """The rewriter function may return a tuple + + The tuple being replaced does not need to explicitly exist within + the updated Relax function. So long as each element of the tuple + pattern matches a Relax expression, the pattern match can apply. + + This rule ensures that pattern-matching is never broken when + `CanonicalizeBindings` is applied. + + This test is identical to `test_rewrite_of_explicit_relax_tuple`, + except that the function does not contain the round trip of + packing `proj_A` and `proj_B` into a tuple, then immediately + unpacking them from the tuple. + + """ + + @R.rewriter + class Rewriter: + @R.function + def pattern( + lhs_A: R.Tensor([16, 16], "float32"), + lhs_B: R.Tensor([16, 16], "float32"), + rhs: R.Tensor([16], "float32"), + ): + proj_A = R.matmul(lhs_A, rhs) + proj_B = R.matmul(lhs_B, rhs) + proj_tuple = (proj_A, proj_B) + return proj_tuple + + @R.function + def replacement( + lhs_A: R.Tensor([16, 16], "float32"), + lhs_B: R.Tensor([16, 16], "float32"), + rhs: R.Tensor([16], "float32"), + ): + lhs = R.concat([lhs_A, lhs_B]) + proj_concat = R.matmul(lhs, rhs) + proj_tuple = R.split(proj_concat, 2) + return proj_tuple + + @R.function(private=True) + def before( + state: R.Tensor([16], "float32"), + A: R.Tensor([16, 16], "float32"), + B: R.Tensor([16, 16], "float32"), + ): + proj_A = R.matmul(A, state) + proj_B = R.matmul(B, state) + out = proj_A + proj_B + return out + + @R.function(private=True) + def expected( + state: R.Tensor([16], "float32"), + A: R.Tensor([16, 16], "float32"), + B: R.Tensor([16, 16], "float32"), + ): + concat_AB = R.concat([A, B]) + proj_concat = R.matmul(concat_AB, state) + proj_tuple = R.split(proj_concat, 2) + out = proj_tuple[0] + proj_tuple[1] + return out + + after = Rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +def test_rewrite_of_implicit_tuple_with_shared_wildcard(): + """Tuple elements may depend on the same input + + Here, both elements of the tuple depend on `y`. + + """ + + @R.rewriter + class Rewriter: + @R.function + def pattern( + x: R.Tensor([16], "float32"), + y: R.Tensor([16], "float32"), + z: R.Tensor([16], "float32"), + ): + lhs = x + y + rhs = y + z + return (lhs, rhs) + + @R.function + def replacement( + x: R.Tensor([16], "float32"), + y: R.Tensor([16], "float32"), + z: R.Tensor([16], "float32"), + ): + return R.call_pure_packed( + "optimized_impl", + x, + y, + z, + sinfo_args=R.Tuple( + [ + R.Tensor([16], "float32"), + R.Tensor([16], "float32"), + ] + ), + ) + + @R.function(private=True) + def before( + A: R.Tensor([16], "float32"), + B: R.Tensor([16], "float32"), + C: R.Tensor([16], "float32"), + ): + lhs = A + B + rhs = B + C + out = R.multiply(lhs, rhs) + return out + + @R.function(private=True) + def expected( + A: R.Tensor([16], "float32"), + B: R.Tensor([16], "float32"), + C: R.Tensor([16], "float32"), + ): + lhs_rhs = R.call_pure_packed( + "optimized_impl", + A, + B, + C, + sinfo_args=R.Tuple( + [ + R.Tensor([16], "float32"), + R.Tensor([16], "float32"), + ] + ), + ) + out = R.multiply(lhs_rhs[0], lhs_rhs[1]) + return out + + after = Rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +def test_no_rewrite_of_implicit_tuple_when_shared_wildcard_is_mismatched(): + """Tuple elements must match simultaneously + + Each element of the tuple matches individually, but the two + elements both depend on `B`. Because the first tuple element + would require `y = B`, while the second tuple element would + require `y = C`, the match fails. + + """ + + @R.rewriter + class Rewriter: + @R.function + def pattern( + x: R.Tensor([16], "float32"), + y: R.Tensor([16], "float32"), + z: R.Tensor([16], "float32"), + ): + lhs = x + y + rhs = y + z + return (lhs, rhs) + + @R.function + def replacement( + A: R.Tensor([16], "float32"), + B: R.Tensor([16], "float32"), + C: R.Tensor([16], "float32"), + ): + return R.call_pure_packed( + "optimized_impl", + A, + B, + C, + sinfo_args=R.Tuple( + [ + R.Tensor([16], "float32"), + R.Tensor([16], "float32"), + ] + ), + ) + + @R.function(private=True) + def before( + A: R.Tensor([16], "float32"), + B: R.Tensor([16], "float32"), + C: R.Tensor([16], "float32"), + D: R.Tensor([16], "float32"), + ): + lhs = A + B + rhs = C + D + out = R.multiply(lhs, rhs) + return out + + expected = before + + after = Rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +def test_implicit_tuple_may_not_introduce_extra_compute(): + """Matching of implicit tuple may not cause extra compute + + Here, the `(proj_A, proj_B)` tuple could be an implcit tuple + match, but that would repeat the computation of `proj_A`. It + would be computed once on its own, to be used for `proj_A_on_B`, + and once for computing `(proj_A, proj_B)`. + + """ + + @R.rewriter + class Rewriter: + @R.function + def pattern( + lhs_A: R.Tensor([16, 16], "float32"), + lhs_B: R.Tensor([16, 16], "float32"), + rhs: R.Tensor([16, 16], "float32"), + ): + proj_A = R.matmul(lhs_A, rhs) + proj_B = R.matmul(lhs_B, rhs) + proj_tuple = (proj_A, proj_B) + return proj_tuple + + @R.function + def replacement( + lhs_A: R.Tensor([16, 16], "float32"), + lhs_B: R.Tensor([16, 16], "float32"), + rhs: R.Tensor([16, 16], "float32"), + ): + lhs = R.concat([lhs_A, lhs_B]) + proj_concat = R.matmul(lhs, rhs) + proj_tuple = R.split(proj_concat, 2) + return proj_tuple + + @R.function(private=True) + def before( + state: R.Tensor([16, 16], "float32"), + A: R.Tensor([16, 16], "float32"), + B: R.Tensor([16, 16], "float32"), + ): + # This function has no location at which a tuple + # `(proj_A,proj_B)` could be constructed, then unpacked. + + proj_A = R.matmul(A, state) + + # A tuple `(proj_A, proj_B)` could not be constructed at this + # location, because `proj_B` has not yet been computed. + + proj_A_on_B = R.matmul(proj_A, B) + proj_B = R.matmul(proj_A_on_B, state) + + # A tuple `(proj_A, proj_B)` could be constructed here, but a + # use-site of `proj_A` has already occurred. Implicit + # matching of a tuple is only allowed if it would replace + # every use-site of a variable. + + out = proj_A + proj_B + return out + + expected = before + + after = Rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +def test_rewrite_of_implicit_tuple_with_three_elements(): + """Implicit tuples may contain three elements""" + + @R.rewriter + class Rewriter: + @R.function + def pattern(qkv: R.Tensor([12288], "float32")): + qkv_tuple = R.split(qkv, 3, axis=0) + q = qkv_tuple[0] + k = qkv_tuple[1] + v = qkv_tuple[2] + q_embed = R.call_pure_packed( + "rotary_embedding", [q], sinfo_args=R.Tensor([4096], "float32") + ) + k_embed = R.call_pure_packed( + "rotary_embedding", [k], sinfo_args=R.Tensor([4096], "float32") + ) + + return (q_embed, k_embed, v) + + @R.function + def replacement(qkv: R.Tensor([12288], "float32")): + return R.call_pure_packed( + "split_rotary_embedding", + [qkv], + sinfo_args=[ + R.Tensor([4096], "float32"), + R.Tensor([4096], "float32"), + R.Tensor([4096], "float32"), + ], + ) + + @R.function(private=True) + def before( + state: R.Tensor([4096], "float32"), + proj_qkv: R.Tensor([12288, 4096], "float32"), + kv_cache: R.Object, + ): + qkv = R.matmul(proj_qkv, state) + qkv_tuple = R.split(qkv, 3, axis=0) + q = qkv_tuple[0] + k = qkv_tuple[1] + v = qkv_tuple[2] + q_embed = R.call_pure_packed( + "rotary_embedding", [q], sinfo_args=R.Tensor([4096], "float32") + ) + k_embed = R.call_pure_packed( + "rotary_embedding", [k], sinfo_args=R.Tensor([4096], "float32") + ) + + attention = R.call_pure_packed( + "compute_self_attention", + [q_embed, k_embed, v, kv_cache], + sinfo_args=R.Tensor([4096]), + ) + + return attention + + @R.function(private=True) + def expected( + state: R.Tensor([4096], "float32"), + proj_qkv: R.Tensor([12288, 4096], "float32"), + kv_cache: R.Object, + ): + qkv = R.matmul(proj_qkv, state) + embedded_qkv_tuple = R.call_pure_packed( + "split_rotary_embedding", + [qkv], + sinfo_args=[ + R.Tensor([4096], "float32"), + R.Tensor([4096], "float32"), + R.Tensor([4096], "float32"), + ], + ) + + v = embedded_qkv_tuple[2] + q_embed = embedded_qkv_tuple[0] + k_embed = embedded_qkv_tuple[1] + + attention = R.call_pure_packed( + "compute_self_attention", + [q_embed, k_embed, v, kv_cache], + sinfo_args=R.Tensor([4096]), + ) + + return attention + + after = Rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +def test_pattern_matching_may_not_reorder_across_impure_functions(): + """Matched pattern must be ordered with respect to impure functions + + To ensure that debug printouts, memory management, performance + measurements, etc are not impacted by a pattern match, a pattern + must be entirely before, or entirely after an impure function. A + pattern match in which some parts of the matched expression are + performed before an impure function, while others are performed + afterwards, is not allowed. + + In this test, the matmul and the add may not be fused, because the + impure print statement occurs between them. + + """ + + @R.rewriter + class Rewriter: + @R.function + def pattern( + state: R.Tensor([16], "float32"), + weights: R.Tensor([16, 16], "float32"), + bias: R.Tensor([16], "float32"), + ): + state = R.matmul(weights, state) + state = R.add(bias, state) + return state + + @R.function + def replacement( + state: R.Tensor([16], "float32"), + weights: R.Tensor([16, 16], "float32"), + bias: R.Tensor([16], "float32"), + ): + return R.call_pure_packed( + "my_optimized_fma_impl", + state, + weights, + bias, + sinfo_args=R.Tensor([16], "float32"), + ) + + @R.function(private=True, pure=False) + def before( + state: R.Tensor([16], "float32"), + weights: R.Tensor([16, 16], "float32"), + bias: R.Tensor([16], "float32"), + ): + R.print(format="Start of function") + state = R.matmul(weights, state) + R.print(format="After matmul, before add") + state = R.add(bias, state) + R.print(format="End of function") + return state + + expected = before + + after = Rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +def test_pattern_matching_may_occur_between_impure_functions(): + """Matched pattern may be adjacent to impure functions + + To ensure that debug printouts, memory management, performance + measurements, etc are not impacted by a pattern match, a pattern + must be entirely before, or entirely after an impure function. A + pattern match in which some parts of the matched expression are + performed before an impure function, while others are performed + afterwards, is not allowed. + + In this test, the matmul and the add may be fused, because the + pattern occurs without an impure print statement in-between. + + """ + + @R.rewriter + class Rewriter: + @R.function + def pattern( + state: R.Tensor([16], "float32"), + weights: R.Tensor([16, 16], "float32"), + bias: R.Tensor([16], "float32"), + ): + state = R.matmul(weights, state) + state = R.add(bias, state) + return state + + @R.function + def replacement( + state: R.Tensor([16], "float32"), + weights: R.Tensor([16, 16], "float32"), + bias: R.Tensor([16], "float32"), + ): + return R.call_pure_packed( + "my_optimized_fma_impl", + state, + weights, + bias, + sinfo_args=R.Tensor([16], "float32"), + ) + + @R.function(private=True, pure=False) + def before( + state: R.Tensor([16], "float32"), + weights: R.Tensor([16, 16], "float32"), + bias: R.Tensor([16], "float32"), + ): + R.print(format="Start of function") + state = R.matmul(weights, state) + state = R.add(bias, state) + R.print(format="End of function") + return state + + @R.function(private=True, pure=False) + def expected( + state: R.Tensor([16], "float32"), + weights: R.Tensor([16, 16], "float32"), + bias: R.Tensor([16], "float32"), + ): + R.print(format="Start of function") + state = R.call_pure_packed( + "my_optimized_fma_impl", + state, + weights, + bias, + sinfo_args=R.Tensor([16], "float32"), + ) + R.print(format="End of function") + return state + + after = Rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +def test_rewrite_may_apply_within_conditional(): + """Rewrites may apply within to inner dataflow regions + + While dataflow regions may not contain conditionals, they may + occur within the body of conditionals. + + """ + + @R.rewriter + class Rewriter: + @R.function + def pattern(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): + return A + B + + @R.function + def replacement(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): + return R.call_pure_packed( + "my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32") + ) + + @R.function(private=True) + def before(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32"), cond: R.Prim("bool")): + if cond: + out = A + B + else: + C = A + B + out = C + B + return out + + @R.function(private=True) + def expected(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32"), cond: R.Prim("bool")): + if cond: + out = R.call_pure_packed( + "my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32") + ) + else: + C = R.call_pure_packed( + "my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32") + ) + out = R.call_pure_packed( + "my_optimized_add_impl", C, B, sinfo_args=R.Tensor([16], "float32") + ) + return out + + after = Rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +def test_match_dynamic_shape(): + """Pattern match/rewrites may be dynamic + + The tuple being replaced does not need to explicitly exist within + the updated Relax function. So long as each element of the tuple + pattern matches a Relax expression, the pattern match can apply. + + This rule ensures that pattern-matching is never broken when + `CanonicalizeBindings` is applied. + + This test is identical to `test_rewrite_of_explicit_relax_tuple`, + except that the function does not contain the round trip of + packing `proj_A` and `proj_B` into a tuple, then immediately + unpacking them from the tuple. + + """ + + @R.rewriter + class Rewriter: + @R.function + def pattern( + lhs_A: R.Tensor(["N1", "M"], "float32"), + lhs_B: R.Tensor(["N2", "M"], "float32"), + rhs: R.Tensor(["M"], "float32"), + ): + proj_A = R.matmul(lhs_A, rhs) + proj_B = R.matmul(lhs_B, rhs) + return (proj_A, proj_B) + + @R.function + def replacement( + lhs_A: R.Tensor(["N1", "M"], "float32"), + lhs_B: R.Tensor(["N2", "M"], "float32"), + rhs: R.Tensor(["M"], "float32"), + ): + N1 = T.int64() + N2 = T.int64() + + lhs = R.concat([lhs_A, lhs_B]) + proj_concat = R.matmul(lhs, rhs) + proj_A: R.Tensor([N1], "float32") = R.strided_slice( + proj_concat, axes=[0], begin=[0], end=[N1] + ) + proj_B: R.Tensor([N2], "float32") = R.strided_slice( + proj_concat, axes=[0], begin=[N1], end=[N2 + N1] + ) + return (proj_A, proj_B) + + @R.function(private=True) + def before( + state: R.Tensor([16], "float32"), + A: R.Tensor([16, 16], "float32"), + B: R.Tensor([16, 16], "float32"), + ): + proj_A = R.matmul(A, state) + proj_B = R.matmul(B, state) + out = proj_A + proj_B + return out + + @R.function(private=True) + def expected( + state: R.Tensor([16], "float32"), + A: R.Tensor([16, 16], "float32"), + B: R.Tensor([16, 16], "float32"), + ): + concat_AB = R.concat([A, B]) + proj_concat = R.matmul(concat_AB, state) + proj_A = R.strided_slice(proj_concat, axes=[0], begin=[0], end=[16]) + proj_B = R.strided_slice(proj_concat, axes=[0], begin=[16], end=[32]) + out = proj_A + proj_B + return out + + after = Rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +def test_match_dynamic_pattern_against_dynamic_shape(): + """A dynamic pattern may match a static shape""" + + @R.rewriter + class Rewriter: + @R.function + def pattern( + A: R.Tensor(["M", "N"], "float32"), + B: R.Tensor(["N", "N"], "float32"), + ): + return R.matmul(A, B) + + @R.function + def replacement( + A: R.Tensor(["M", "N"], "float32"), + B: R.Tensor(["N", "N"], "float32"), + ): + M = T.int64() + N = T.int64() + return R.call_pure_packed( + "my_optimized_square_matmul", + A, + B, + sinfo_args=R.Tensor([M, N], "float32"), + ) + + @R.function(private=True) + def before( + A: R.Tensor(["N", "N*2"], "float32"), + B: R.Tensor(["N*2", "N*2"], "float32"), + C: R.Tensor(["N", "N"], "float32"), + ): + N = T.int64() + D: R.Tensor([N, N * 2], "float32") = R.matmul(A, B) + E: R.Tensor([N * 2, N], "float32") = R.permute_dims(D) + F: R.Tensor([N * 2, N], "float32") = R.matmul(E, C) + return F + + @R.function(private=True) + def expected( + A: R.Tensor(["N", "N*2"], "float32"), + B: R.Tensor(["N*2", "N*2"], "float32"), + C: R.Tensor(["N", "N"], "float32"), + ): + N = T.int64() + + D: R.Tensor([N, N * 2], "float32") = R.call_pure_packed( + "my_optimized_square_matmul", + A, + B, + sinfo_args=R.Tensor([N, N * 2], "float32"), + ) + E: R.Tensor([N * 2, N], "float32") = R.permute_dims(D) + F: R.Tensor([N * 2, N], "float32") = R.call_pure_packed( + "my_optimized_square_matmul", + E, + C, + sinfo_args=R.Tensor([N * 2, N], "float32"), + ) + return F + + after = Rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_transform_canonicalize_bindings.py b/tests/python/relax/test_transform_canonicalize_bindings.py index d513c0cf6c6d..ea3b1c249b8b 100644 --- a/tests/python/relax/test_transform_canonicalize_bindings.py +++ b/tests/python/relax/test_transform_canonicalize_bindings.py @@ -198,9 +198,13 @@ def test_change_shape(): @I.ir_module class TestChangeShape: @R.function - def main(x: R.Tensor(("m", "n"))): + def main(x: R.Tensor(ndim=2)): y = x - # not trivial: introduces new shape vars + # The MatchCast is non-trivial, as it introduces new shape + # vars. Because the input tensor has an unknown shape + # rather than a symbolic shape, these new shape vars + # cannot be expressed in terms of previous variables. + # Therefore, the match cast must be retained. o, p = T.int64(), T.int64() z = R.match_cast(x, R.Tensor((o, p))) w = z @@ -210,7 +214,7 @@ def main(x: R.Tensor(("m", "n"))): @I.ir_module class Expected: @R.function - def main(x: R.Tensor(("m", "n"))): + def main(x: R.Tensor(ndim=2)): o, p = T.int64(), T.int64() z = R.match_cast(x, R.Tensor((o, p))) # the struct_info field on q will need to be updated @@ -220,6 +224,35 @@ def main(x: R.Tensor(("m", "n"))): verify(TestChangeShape, Expected) +def test_replace_symbolic_variable_and_remove_match_cast(): + @I.ir_module + class TestChangeShape: + @R.function + def main(x: R.Tensor(("m", "n"))): + y = x + # The MatchCast is non-trivial, as it introduces new shape + # vars. However, the new shape vars are redundant, and + # are replaced by canonicalization. After replacing the + # new shape vars, the MatchCast is trivial and may be + # removed. + o, p = T.int64(), T.int64() + z = R.match_cast(x, R.Tensor((o, p))) + w = z + q = R.add(w, y) + return R.add(q, w) + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("m", "n"))): + m = T.int64() + n = T.int64() + q: R.Tensor([m, n]) = R.add(x, x) + return R.add(q, x) + + verify(TestChangeShape, Expected) + + def test_unwrap_tuple(): @I.ir_module class Before: @@ -289,6 +322,222 @@ def main() -> R.Tensor((), "int32"): verify(Input, Expected) +def test_fold_variables_from_match_cast(): + """Symbolic variables in R.match_cast may be inferred + + If the argument to `R.match_cast` has known shape parameters, they + may be used to infer symbolic shape parameters. + + """ + + @I.ir_module + class Before: + @R.function + def main( + state: R.Tensor([16], dtype="float32"), + A: R.Tensor([16, 16], dtype="float32"), + B: R.Tensor([16, 16], dtype="float32"), + ): + N1 = T.int64() + M = T.int64() + N2 = T.int64() + + # The symbolic variables `N1`, `N2` and `M` are defined by + # these `R.match_cast` statements. Since the inputs have + # a known shape, the values of these symbolic variables + # may be inferred. + lhs_A = R.match_cast(A, R.Tensor([N1, M], dtype="float32")) + lhs_B = R.match_cast(B, R.Tensor([N2, M], dtype="float32")) + rhs = R.match_cast(state, R.Tensor([M], dtype="float32")) + + # The symbolic shapes propagate downstream. + lhs: R.Tensor([N1 + N2, M], "float32") = R.concat((lhs_A, lhs_B), axis=0) + proj_concat: R.Tensor([N1 + N2], "float32") = R.matmul(lhs, rhs, out_dtype="void") + proj_A = R.strided_slice( + proj_concat, + (R.prim_value(0),), + (R.prim_value(0),), + (R.prim_value(N1),), + assume_inbound=False, + ) + proj_B = R.strided_slice( + proj_concat, + [R.prim_value(0)], + [R.prim_value(N1)], + [R.prim_value(N1 + N2)], + assume_inbound=False, + ) + return (proj_A, proj_B) + + @I.ir_module + class Expected: + @R.function + def main( + state: R.Tensor([16], dtype="float32"), + A: R.Tensor([16, 16], dtype="float32"), + B: R.Tensor([16, 16], dtype="float32"), + ): + # The function no longer depends on symbolic variables. + # Shape inference is now propagated using the + # statically-known shapes. + + lhs: R.Tensor([32, 16], dtype="float32") = R.concat((A, B), axis=0) + proj_concat: R.Tensor([32], dtype="float32") = R.matmul(lhs, state, out_dtype="void") + proj_A: R.Tensor([16], dtype="float32") = R.strided_slice( + proj_concat, + [R.prim_value(0)], + [R.prim_value(0)], + [R.prim_value(16)], + assume_inbound=False, + ) + proj_B: R.Tensor([16], dtype="float32") = R.strided_slice( + proj_concat, + [R.prim_value(0)], + [R.prim_value(16)], + [R.prim_value(32)], + assume_inbound=False, + ) + return (proj_A, proj_B) + + verify(Before, Expected) + + +def test_inconsistent_match_cast_raises_error(): + """Symbolic variables from R.match_cast must be consistent + + All match cast statements must provide consistent definitions for + symbolic variables. In this test, the value of `M` would be + inferred as 16 from either `state` or `A`, but would be inferred + as 32 from `B`. + + """ + + @I.ir_module + class Before: + @R.function + def main( + state: R.Tensor([16], dtype="float32"), + A: R.Tensor([16, 16], dtype="float32"), + B: R.Tensor([32, 32], dtype="float32"), + ): + N1 = T.int64() + M = T.int64() + N2 = T.int64() + + # These R.match_cast statements define inconsistent values + # for the symbolic shape parameters. + lhs_A = R.match_cast(A, R.Tensor([N1, M], dtype="float32")) + lhs_B = R.match_cast(B, R.Tensor([N2, M], dtype="float32")) + rhs = R.match_cast(state, R.Tensor([M], dtype="float32")) + + lhs: R.Tensor([N1 + N2, M], "float32") = R.concat((lhs_A, lhs_B), axis=0) + proj_concat: R.Tensor([N1 + N2], "float32") = R.matmul(lhs, rhs, out_dtype="void") + proj_A = R.strided_slice( + proj_concat, + (R.prim_value(0),), + (R.prim_value(0),), + (R.prim_value(N1),), + assume_inbound=False, + ) + proj_B = R.strided_slice( + proj_concat, + [R.prim_value(0)], + [R.prim_value(N1)], + [R.prim_value(N1 + N2)], + assume_inbound=False, + ) + return (proj_A, proj_B) + + with pytest.raises(ValueError, match="MatchCast statements must be consistent"): + CanonicalizeBindings()(Before) + + +def test_match_cast_may_have_distinct_values_in_branches(): + """Conditional branches may have different values of symbolic variables + + Here, the value of `N` can be inferred as 16 within the `if` + branch and as 32 within the `else` branch. + + """ + + @I.ir_module + class Before: + @R.function + def main( + state: R.Tensor(["N"], dtype="float32"), + A: R.Tensor(["M", 16], dtype="float32"), + B: R.Tensor(["M", 32], dtype="float32"), + scale: R.Prim("float32"), + ): + N = T.int64() + M = T.int64() + + if N == 16: + weights: R.Tensor([M, 16], "float32") = A * scale + weights: R.Tensor([M, N], "float32") = R.match_cast( + weights, R.Tensor([M, N], "float32") + ) + weights: R.Tensor([M, N], "float32") = weights * scale + else: + weights: R.Tensor([M, 32], "float32") = B * scale + weights: R.Tensor([M, N], "float32") = R.match_cast( + weights, R.Tensor([M, N], "float32") + ) + weights: R.Tensor([M, N], "float32") = weights * scale + + weights: R.Tensor([M, N], "float32") = weights * scale + + out: R.Tensor([M], "float32") = R.matmul(weights, state) + + return out + + @I.ir_module + class Expected: + @R.function + def main( + state: R.Tensor(["N"], dtype="float32"), + A: R.Tensor(["M", 16], dtype="float32"), + B: R.Tensor(["M", 32], dtype="float32"), + scale: R.Prim("float32"), + ): + N = T.int64() + M = T.int64() + + if N == 16: + # Prior to the R.match_cast, the + weights: R.Tensor([M, 16], "float32") = A * scale + # The scaled weights within the branch may perform + # shape inference knowing that N==16. + weights: R.Tensor([M, 16], "float32") = weights * scale + # The match cast on exiting the if branch restores the + weights = R.match_cast(weights, R.Tensor([M, N], "float32")) + else: + # Prior to the R.match_cast, the + weights: R.Tensor([M, 32], "float32") = B * scale + # Within the else-branch, the R.match_cast implies + # that N==32. While this conflicts with the earlier + # definition, the two occur in separate branches, so + # this is legal. + # The scaled weights within the branch may perform + # shape inference knowing that N==32. + weights: R.Tensor([M, 32], "float32") = weights * scale + weights = R.match_cast(weights, R.Tensor([M, N], "float32")) + + # Outside of the conditional, we no longer have a known + # value for N, so this shape inference must be done using + # a dynamic shape for `N`. + weights: R.Tensor([M, N], "float32") = weights * scale + + # After the conditional branch, we no longer have a known + # value of N, so this shape inference must use the dynamic + # shape. + out: R.Tensor([M], "float32") = R.matmul(weights, state) + + return out + + verify(Before, Expected) + + def test_multiple_outputs(): @I.ir_module class Input: diff --git a/tests/python/relax/test_transform_legalize_ops_manipulate.py b/tests/python/relax/test_transform_legalize_ops_manipulate.py index dd0208f5db07..ba5d4d7d1219 100644 --- a/tests/python/relax/test_transform_legalize_ops_manipulate.py +++ b/tests/python/relax/test_transform_legalize_ops_manipulate.py @@ -720,7 +720,7 @@ def reshape( T_reshape[v_ax0] = rxplaceholder[v_ax0 % T.int64(3)] @R.function - def main(x: R.Tensor((3,), dtype="int64")) -> R.Tensor((3,), dtype="int64"): + def main(x: R.Tensor((3,), dtype="int64")) -> R.Tensor(ndim=1, dtype="int64"): x_1 = T.int64() gv: R.Shape([3]) = R.call_pure_packed("vm.builtin.tensor_to_shape", x, sinfo_args=(R.Shape([3]),)) y: R.Shape([x_1]) = R.match_cast(gv, R.Shape([x_1])) diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index 64014d1c49be..4f41b662caf2 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -2317,5 +2317,51 @@ def expected(A: R.Tensor(["extent"])) -> R.Tensor(["extent-1"]): tvm.ir.assert_structural_equal(inferred_sinfo, expected) +def test_conditional_may_use_symbolic_variables_from_function_scope(): + """Symbolic variables from function scope may be used in branch + + This is a regression test. In earlier implementations, the + branches of `relax::If` were normalized with + `EraseToWellDefinedInScope`, using a fresh variable scope. While + this had the intended behavior of preventing variables defined in + a single branch from being usable outside of the conditional, it + also caused the conditional's branches to treat function-scope + symbolic variables as if they were undefined. + + """ + + @R.function(private=True) + def explicit_sinfo( + A: R.Tensor(["N"], "float32"), + B: R.Tensor(["N"], "float32"), + cond: R.Prim("bool"), + ) -> R.Tensor(["N"], "float32"): + + N = T.int64() + + if cond: + out: R.Tensor([N], "float32") = A + B + else: + out: R.Tensor([N], "float32") = A * B + + return out + + @R.function(private=True) + def inferred_sinfo( + A: R.Tensor(["N"], "float32"), + B: R.Tensor(["N"], "float32"), + cond: R.Prim("bool"), + ): + N = T.int64() + if cond: + out = A + B + else: + out = A * B + + return out + + tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo) + + if __name__ == "__main__": tvm.testing.main()