From 588fb7b037d0837e10b0a2710bf32714190c3dba Mon Sep 17 00:00:00 2001 From: Yixin Dong Date: Thu, 8 Dec 2022 14:51:30 +0800 Subject: [PATCH] [FIX] Add check for repeated function parameters (#21) * Add check for repeated function parameters * Modify output format and unit test --- src/relax/analysis/well_formed.cc | 12 ++++++++++++ src/relax/backend/vm/codegen_vm.cc | 3 +++ tests/python/relax/test_analysis_well_formed.py | 15 +++++++++++++++ tests/python/relax/test_transform_fuse_tir.py | 5 ++++- 4 files changed, 34 insertions(+), 1 deletion(-) diff --git a/src/relax/analysis/well_formed.cc b/src/relax/analysis/well_formed.cc index 5859ef4bed..e9d30dcbcb 100644 --- a/src/relax/analysis/well_formed.cc +++ b/src/relax/analysis/well_formed.cc @@ -64,6 +64,8 @@ #include +#include "../../printer/text_printer.h" + namespace tvm { namespace relax { @@ -208,6 +210,15 @@ class WellFormedChecker : public relax::ExprVisitor, // check all expr are well defined. for (Var param : op->params) { this->VisitVarDef(param); + + if (param_var_func_map_.count(param) == 1) { + Malformed(Diagnostic::Error(param->span) + << "Relax variable " << param->name_hint() + << " is repeatedly used as parameters in function:\n" + << AsRelaxScript(param_var_func_map_[param], false) << "\nand function:\n" + << AsRelaxScript(GetRef(op), false)); + } + param_var_func_map_.insert({param, GetRef(op)}); } if (auto seq = op->body.as()) { @@ -426,6 +437,7 @@ class WellFormedChecker : public relax::ExprVisitor, std::unordered_set var_set_; std::unordered_set dataflow_var_set_; std::unordered_set symbolic_var_set_; + std::unordered_map param_var_func_map_; }; bool WellFormed(IRModule m, bool check_struct_info) { diff --git a/src/relax/backend/vm/codegen_vm.cc b/src/relax/backend/vm/codegen_vm.cc index d8d09ddd88..6a5282bb41 100644 --- a/src/relax/backend/vm/codegen_vm.cc +++ b/src/relax/backend/vm/codegen_vm.cc @@ -72,6 +72,9 @@ class CodeGenVM : public ExprFunctor { ICHECK(gsymbol.defined()) << "there should be no local functions in Relax VM codegen phase. " "Did you forget to apply LambdaLift pass?"; + // var_register_map_ is local in function scope + var_register_map_.clear(); + Array param_names; for (Var param : func_node->params) { param_names.push_back(param->name_hint()); diff --git a/tests/python/relax/test_analysis_well_formed.py b/tests/python/relax/test_analysis_well_formed.py index 900b28dbea..358ec9e075 100644 --- a/tests/python/relax/test_analysis_well_formed.py +++ b/tests/python/relax/test_analysis_well_formed.py @@ -98,6 +98,21 @@ def test_dataflow_var(): assert not rx.analysis.well_formed(mod, check_struct_info=False) +def test_param_var(): + v0 = rx.Var("v0", R.Tensor([m, n], "float32")) + v1 = rx.Var("v1", R.Tensor([m, n], "float32")) + v2 = rx.Var("v2", R.Tensor([m, n], "float32")) + bb = rx.BlockBuilder() + with bb.function("func1", [v0, v1]): + gv0 = bb.emit(rx.op.add(v0, v1)) + bb.emit_func_output(gv0) + with bb.function("func2", [v0, v2]): + gv0 = bb.emit(rx.op.add(v2, v1)) + bb.emit_func_output(gv0) + mod = bb.get() + assert not rx.analysis.well_formed(mod, check_struct_info=False) + + def test_global_var(): # Error: GlobalVar GlobalVar0 is not defined gv0 = rx.Var("gv0", R.Tensor([m, n], "float32")) diff --git a/tests/python/relax/test_transform_fuse_tir.py b/tests/python/relax/test_transform_fuse_tir.py index 4f4a4ce3b8..c809560ed2 100644 --- a/tests/python/relax/test_transform_fuse_tir.py +++ b/tests/python/relax/test_transform_fuse_tir.py @@ -41,6 +41,7 @@ def before(): fused_add_exp_squeeze = bb.get().get_global_var("fused_add_exp_squeeze") x = relax.Var("x", R.Tensor([10, 20], "float32")) + p0 = relax.Var("p0", R.Tensor([], "float32")) with bb.function("main", [x, p0]): with bb.dataflow(): gv = bb.emit_output(relax.Call(fused_add_exp_squeeze, [x, p0])) @@ -435,6 +436,7 @@ def before(): fused_add_exp = bb.get().get_global_var("fused_add_exp") x = relax.Var("x", R.Tensor([10, 20], "float32")) + p0 = relax.Var("p0", R.Tensor([], "float32")) with bb.function("main", [x, p0]): with bb.dataflow(): gv = bb.emit_output(relax.Call(fused_add_exp, [x, p0])) @@ -558,4 +560,5 @@ def fused_argmax_add(x, offset): if __name__ == "__main__": - tvm.testing.main() + test_simple() + # tvm.testing.main()