From f495802417b1c65c8b42d52279ebf612d9512098 Mon Sep 17 00:00:00 2001 From: Yixin Dong Date: Mon, 30 Jan 2023 05:53:15 +0800 Subject: [PATCH] [FIX][Utils] Add CopyWithNewParams and Check repeated parameters (#389) This PR adds: 1. A check for any variable used as parameters in different functions of the same IRModule 2. A util to copy function while copying parameters in the new function, so the new function satisfies 1. --- include/tvm/relax/utils.h | 9 +++++ python/tvm/relax/__init__.py | 1 + python/tvm/relax/utils.py | 21 +++++++++- src/relax/analysis/well_formed.cc | 29 ++++++++++---- .../transform/merge_composite_functions.cc | 2 +- src/relax/utils.cc | 23 +++++++++++ .../python/relax/test_analysis_well_formed.py | 15 ++++++++ tests/python/relax/test_transform_fuse_tir.py | 2 + tests/python/relax/test_utils.py | 38 +++++++++++++++++++ 9 files changed, 130 insertions(+), 10 deletions(-) create mode 100644 tests/python/relax/test_utils.py diff --git a/include/tvm/relax/utils.h b/include/tvm/relax/utils.h index 7140b3efe1db..1457a16427cc 100644 --- a/include/tvm/relax/utils.h +++ b/include/tvm/relax/utils.h @@ -140,6 +140,15 @@ TVM_DLL bool IsBoolScalarType(const Type& ty, bool permit_unknown_rank = true, */ TVM_DLL bool IsLeafOrTuple(const Expr& expr); +/*! + * \brief Copy the given function. The parameters of the original function would be copied to + * satisfy the restriction in the well-formed check: any two functions cannot share the same + * parameter variable. + * \param func The relax function to copy. + * \return The copied function. + */ +TVM_DLL Function CopyWithNewParams(Function func); + } // namespace relax } // namespace tvm diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py index dfdf7dc4cedd..183903fbc6d7 100644 --- a/python/tvm/relax/__init__.py +++ b/python/tvm/relax/__init__.py @@ -26,6 +26,7 @@ from . import transform from . import expr_functor from . import struct_info +from . import utils # Expr diff --git a/python/tvm/relax/utils.py b/python/tvm/relax/utils.py index b70b672a1dd7..e1d5bf50c1ce 100644 --- a/python/tvm/relax/utils.py +++ b/python/tvm/relax/utils.py @@ -22,7 +22,8 @@ from .. import tir from ..runtime import String, convert_to_object from ..tir import PrimExpr -from .expr import Expr, PrimValue, ShapeExpr, StringImm +from . import _ffi_api +from .expr import Expr, Function, PrimValue, ShapeExpr, StringImm from .expr import Tuple as rx_Tuple @@ -254,3 +255,21 @@ def auto(func: FType) -> FType: args_converter = _ArgsConverter() # pylint: disable=invalid-name + + +def copy_with_new_params(func: Function) -> Function: + """Copy the given function. The parameters of the original function would be copied to + satisfy the restriction in the well-formed check: any two functions cannot share the same + parameter variable. + + Parameters + ---------- + func : Function + The relax function to copy. + + Returns + ------- + ret : Function + The copied function. + """ + return _ffi_api.CopyWithNewParams(func) # type: ignore diff --git a/src/relax/analysis/well_formed.cc b/src/relax/analysis/well_formed.cc index b5c5c616c9bb..bfa365cf2b66 100644 --- a/src/relax/analysis/well_formed.cc +++ b/src/relax/analysis/well_formed.cc @@ -30,15 +30,16 @@ * 3. When a Function has a corresponding GlobalVar and a `global_symbol` * attribute, the name of the GlobalVar must equal the value of the * `global_symbol` attribute value. - * 4. Vars are defined before use. - * 5. Vars are defined exactly once. - * 6. Symbolic Vars are defined before use. - * 7. DataflowVars cannot be defined inside BindingBlock. - * 8. Vars defined in IfNode, except the return Var, are invisible + * 4. Any variable cannot used as different function parameters in the same IRModule + * 5. Vars are defined before use. + * 6. Vars are defined exactly once. + * 7. Symbolic Vars are defined before use. + * 8. DataflowVars cannot be defined inside BindingBlock. + * 9. Vars defined in IfNode, except the return Var, are invisible * out of the If body.(May change for new AST designs) - * 9. SeqExpr only serves as function body, or in the true and + * 10. SeqExpr only serves as function body, or in the true and * false branches in IfNode. - * 10. The IR is in ANF: + * 11. The IR is in ANF: * (a) Expressions cannot contain nested complex expressions. * Here are the expressions that may be nested inside other expressions: * Var, DataflowVar, GlobalVar, Constant, ShapeExpr, @@ -53,7 +54,7 @@ * * The cond field of If nodes * * The op or args fields of Call nodes * * Inside the fields of Tuple nodes - * 11. Expr always has checked_type_ (with the exception of Op). + * 12. Expr always has checked_type_ (with the exception of Op). */ #include #include @@ -64,6 +65,8 @@ #include +#include "../../printer/text_printer.h" + namespace tvm { namespace relax { @@ -215,6 +218,15 @@ class WellFormedChecker : public relax::ExprVisitor, // check all expr are well defined. for (Var param : op->params) { this->VisitVarDef(param); + + if (param_var_func_map_.count(param) == 1) { + Malformed(Diagnostic::Error(param->span) + << "Relax variable " << param->name_hint() + << " is repeatedly used as parameters in function:\n" + << AsRelaxScript(param_var_func_map_[param], false) << "\nand function:\n" + << AsRelaxScript(GetRef(op), false)); + } + param_var_func_map_.insert({param, GetRef(op)}); } if (auto seq = op->body.as()) { @@ -440,6 +452,7 @@ class WellFormedChecker : public relax::ExprVisitor, std::unordered_set var_set_; std::unordered_set dataflow_var_set_; std::unordered_set symbolic_var_set_; + std::unordered_map param_var_func_map_; }; bool WellFormed(IRModule m, bool check_struct_info) { diff --git a/src/relax/transform/merge_composite_functions.cc b/src/relax/transform/merge_composite_functions.cc index 2c5c7cd05186..18d648e33b3f 100644 --- a/src/relax/transform/merge_composite_functions.cc +++ b/src/relax/transform/merge_composite_functions.cc @@ -269,7 +269,7 @@ class CompositeInliner : public ExprMutator { Expr VisitExpr_(const CallNode* call) { if (call->op->IsInstance()) { auto gvar = Downcast(call->op); - auto func = Downcast(mod_->Lookup(gvar)); + auto func = CopyWithNewParams(Downcast(mod_->Lookup(gvar))); auto composite_name_opt = func->GetAttr(attr::kComposite); ICHECK(composite_name_opt); std::string composite_name = composite_name_opt.value(); diff --git a/src/relax/utils.cc b/src/relax/utils.cc index 24414f250cbc..a77e4342e29c 100644 --- a/src/relax/utils.cc +++ b/src/relax/utils.cc @@ -82,5 +82,28 @@ bool IsLeafOrTuple(const Expr& expr) { expr.as() || expr.as(); } +class FunctionCopier : public ExprMutator { + public: + static Function Transform(Function func) { + FunctionCopier copier; + // the parameters would be copied and substituted to satisfy the restriction in the well-formed + // check: any two functions cannot share the same parameter variable. + Array new_params; + for (Var param : func->params) { + Var new_param = Var(param->vid, GetStructInfo(param), param->span); + copier.var_remap_[param->vid] = new_param; + new_params.push_back(new_param); + } + + Expr body = copier.VisitWithNewScope(func->body, new_params); + + return Function(new_params, body, func->ret_struct_info, func->attrs); + } +}; + +Function CopyWithNewParams(Function func) { return FunctionCopier::Transform(func); } + +TVM_REGISTER_GLOBAL("relax.CopyWithNewParams").set_body_typed(CopyWithNewParams); + } // namespace relax } // namespace tvm diff --git a/tests/python/relax/test_analysis_well_formed.py b/tests/python/relax/test_analysis_well_formed.py index 9099b5d33343..cc0de84d53af 100644 --- a/tests/python/relax/test_analysis_well_formed.py +++ b/tests/python/relax/test_analysis_well_formed.py @@ -99,6 +99,21 @@ def test_dataflow_var(): assert not rx.analysis.well_formed(mod, check_struct_info=False) +def test_param_var(): + v0 = rx.Var("v0", R.Tensor([m, n], "float32")) + v1 = rx.Var("v1", R.Tensor([m, n], "float32")) + v2 = rx.Var("v2", R.Tensor([m, n], "float32")) + bb = rx.BlockBuilder() + with bb.function("func1", [v0, v1]): + gv0 = bb.emit(rx.op.add(v0, v1)) + bb.emit_func_output(gv0) + with bb.function("func2", [v0, v2]): + gv0 = bb.emit(rx.op.add(v2, v1)) + bb.emit_func_output(gv0) + mod = bb.get() + assert not rx.analysis.well_formed(mod, check_struct_info=False) + + def test_global_var(): # Error: GlobalVar GlobalVar0 is not defined gv0 = rx.Var("gv0", R.Tensor([m, n], "float32")) diff --git a/tests/python/relax/test_transform_fuse_tir.py b/tests/python/relax/test_transform_fuse_tir.py index 868ba62bd252..1f26028bc325 100644 --- a/tests/python/relax/test_transform_fuse_tir.py +++ b/tests/python/relax/test_transform_fuse_tir.py @@ -44,6 +44,7 @@ def before(): fused_add_exp_squeeze = bb.get().get_global_var("fused_add_exp_squeeze") x = relax.Var("x", R.Tensor([10, 20], "float32")) + p0 = relax.Var("p0", R.Tensor([], "float32")) with bb.function("main", [x, p0]): with bb.dataflow(): gv = bb.emit_output(relax.Call(fused_add_exp_squeeze, [x, p0])) @@ -438,6 +439,7 @@ def before(): fused_add_exp = bb.get().get_global_var("fused_add_exp") x = relax.Var("x", R.Tensor([10, 20], "float32")) + p0 = relax.Var("p0", R.Tensor([], "float32")) with bb.function("main", [x, p0]): with bb.dataflow(): gv = bb.emit_output(relax.Call(fused_add_exp, [x, p0])) diff --git a/tests/python/relax/test_utils.py b/tests/python/relax/test_utils.py new file mode 100644 index 000000000000..1cf2b56fa91a --- /dev/null +++ b/tests/python/relax/test_utils.py @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +from tvm import relax +from tvm.ir.base import assert_structural_equal +from tvm.script.parser import relax as R + + +def test_copy_with_new_params(): + @R.function + def before(x: R.Tensor((3,), "float32"), y: R.Tensor((3,), "float32")): + gv = R.add(x, y) + return gv + + after = relax.utils.copy_with_new_params(before) + assert_structural_equal(after, before) + + assert len(after.params) == len(before.params) + for before_var, after_var in zip(before.params, after.params): + assert before_var != after_var + + +if __name__ == "__main__": + pytest.main([__file__])