Skip to content

Commit

Permalink
[FIX] Add check for repeated function parameters (#21)
Browse files Browse the repository at this point in the history
* Add check for repeated function parameters

* Modify output format and unit test
  • Loading branch information
Ubospica authored and tqchen committed Dec 30, 2022
1 parent 343a1e7 commit 588fb7b
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 1 deletion.
12 changes: 12 additions & 0 deletions src/relax/analysis/well_formed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@

#include <unordered_set>

#include "../../printer/text_printer.h"

namespace tvm {
namespace relax {

Expand Down Expand Up @@ -208,6 +210,15 @@ class WellFormedChecker : public relax::ExprVisitor,
// check all expr are well defined.
for (Var param : op->params) {
this->VisitVarDef(param);

if (param_var_func_map_.count(param) == 1) {
Malformed(Diagnostic::Error(param->span)
<< "Relax variable " << param->name_hint()
<< " is repeatedly used as parameters in function:\n"
<< AsRelaxScript(param_var_func_map_[param], false) << "\nand function:\n"
<< AsRelaxScript(GetRef<Function>(op), false));
}
param_var_func_map_.insert({param, GetRef<Function>(op)});
}

if (auto seq = op->body.as<SeqExprNode>()) {
Expand Down Expand Up @@ -426,6 +437,7 @@ class WellFormedChecker : public relax::ExprVisitor,
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> var_set_;
std::unordered_set<DataflowVar, ObjectPtrHash, ObjectPtrEqual> dataflow_var_set_;
std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual> symbolic_var_set_;
std::unordered_map<Var, Function, ObjectPtrHash, ObjectPtrEqual> param_var_func_map_;
};

bool WellFormed(IRModule m, bool check_struct_info) {
Expand Down
3 changes: 3 additions & 0 deletions src/relax/backend/vm/codegen_vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ class CodeGenVM : public ExprFunctor<Instruction::Arg(const Expr&)> {
ICHECK(gsymbol.defined()) << "there should be no local functions in Relax VM codegen phase. "
"Did you forget to apply LambdaLift pass?";

// var_register_map_ is local in function scope
var_register_map_.clear();

Array<String> param_names;
for (Var param : func_node->params) {
param_names.push_back(param->name_hint());
Expand Down
15 changes: 15 additions & 0 deletions tests/python/relax/test_analysis_well_formed.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,21 @@ def test_dataflow_var():
assert not rx.analysis.well_formed(mod, check_struct_info=False)


def test_param_var():
v0 = rx.Var("v0", R.Tensor([m, n], "float32"))
v1 = rx.Var("v1", R.Tensor([m, n], "float32"))
v2 = rx.Var("v2", R.Tensor([m, n], "float32"))
bb = rx.BlockBuilder()
with bb.function("func1", [v0, v1]):
gv0 = bb.emit(rx.op.add(v0, v1))
bb.emit_func_output(gv0)
with bb.function("func2", [v0, v2]):
gv0 = bb.emit(rx.op.add(v2, v1))
bb.emit_func_output(gv0)
mod = bb.get()
assert not rx.analysis.well_formed(mod, check_struct_info=False)


def test_global_var():
# Error: GlobalVar GlobalVar0 is not defined
gv0 = rx.Var("gv0", R.Tensor([m, n], "float32"))
Expand Down
5 changes: 4 additions & 1 deletion tests/python/relax/test_transform_fuse_tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def before():
fused_add_exp_squeeze = bb.get().get_global_var("fused_add_exp_squeeze")

x = relax.Var("x", R.Tensor([10, 20], "float32"))
p0 = relax.Var("p0", R.Tensor([], "float32"))
with bb.function("main", [x, p0]):
with bb.dataflow():
gv = bb.emit_output(relax.Call(fused_add_exp_squeeze, [x, p0]))
Expand Down Expand Up @@ -435,6 +436,7 @@ def before():
fused_add_exp = bb.get().get_global_var("fused_add_exp")

x = relax.Var("x", R.Tensor([10, 20], "float32"))
p0 = relax.Var("p0", R.Tensor([], "float32"))
with bb.function("main", [x, p0]):
with bb.dataflow():
gv = bb.emit_output(relax.Call(fused_add_exp, [x, p0]))
Expand Down Expand Up @@ -558,4 +560,5 @@ def fused_argmax_add(x, offset):


if __name__ == "__main__":
tvm.testing.main()
test_simple()
# tvm.testing.main()

0 comments on commit 588fb7b

Please sign in to comment.