Skip to content

Commit

Permalink
[FIX][Utils] Add CopyWithNewParams and Check repeated parameters (apa…
Browse files Browse the repository at this point in the history
…che#389)

This PR adds:
1. A check for any variable used as parameters in different functions of the same IRModule
2. A util to copy function while copying parameters in the new function, so the new function satisfies 1.
  • Loading branch information
Ubospica authored and junrushao committed Feb 5, 2023
1 parent 03b22c3 commit f495802
Show file tree
Hide file tree
Showing 9 changed files with 130 additions and 10 deletions.
9 changes: 9 additions & 0 deletions include/tvm/relax/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,15 @@ TVM_DLL bool IsBoolScalarType(const Type& ty, bool permit_unknown_rank = true,
*/
TVM_DLL bool IsLeafOrTuple(const Expr& expr);

/*!
* \brief Copy the given function. The parameters of the original function would be copied to
* satisfy the restriction in the well-formed check: any two functions cannot share the same
* parameter variable.
* \param func The relax function to copy.
* \return The copied function.
*/
TVM_DLL Function CopyWithNewParams(Function func);

} // namespace relax
} // namespace tvm

Expand Down
1 change: 1 addition & 0 deletions python/tvm/relax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from . import transform
from . import expr_functor
from . import struct_info
from . import utils

# Expr

Expand Down
21 changes: 20 additions & 1 deletion python/tvm/relax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
from .. import tir
from ..runtime import String, convert_to_object
from ..tir import PrimExpr
from .expr import Expr, PrimValue, ShapeExpr, StringImm
from . import _ffi_api
from .expr import Expr, Function, PrimValue, ShapeExpr, StringImm
from .expr import Tuple as rx_Tuple


Expand Down Expand Up @@ -254,3 +255,21 @@ def auto(func: FType) -> FType:


args_converter = _ArgsConverter() # pylint: disable=invalid-name


def copy_with_new_params(func: Function) -> Function:
"""Copy the given function. The parameters of the original function would be copied to
satisfy the restriction in the well-formed check: any two functions cannot share the same
parameter variable.
Parameters
----------
func : Function
The relax function to copy.
Returns
-------
ret : Function
The copied function.
"""
return _ffi_api.CopyWithNewParams(func) # type: ignore
29 changes: 21 additions & 8 deletions src/relax/analysis/well_formed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,16 @@
* 3. When a Function has a corresponding GlobalVar and a `global_symbol`
* attribute, the name of the GlobalVar must equal the value of the
* `global_symbol` attribute value.
* 4. Vars are defined before use.
* 5. Vars are defined exactly once.
* 6. Symbolic Vars are defined before use.
* 7. DataflowVars cannot be defined inside BindingBlock.
* 8. Vars defined in IfNode, except the return Var, are invisible
* 4. Any variable cannot used as different function parameters in the same IRModule
* 5. Vars are defined before use.
* 6. Vars are defined exactly once.
* 7. Symbolic Vars are defined before use.
* 8. DataflowVars cannot be defined inside BindingBlock.
* 9. Vars defined in IfNode, except the return Var, are invisible
* out of the If body.(May change for new AST designs)
* 9. SeqExpr only serves as function body, or in the true and
* 10. SeqExpr only serves as function body, or in the true and
* false branches in IfNode.
* 10. The IR is in ANF:
* 11. The IR is in ANF:
* (a) Expressions cannot contain nested complex expressions.
* Here are the expressions that may be nested inside other expressions:
* Var, DataflowVar, GlobalVar, Constant, ShapeExpr,
Expand All @@ -53,7 +54,7 @@
* * The cond field of If nodes
* * The op or args fields of Call nodes
* * Inside the fields of Tuple nodes
* 11. Expr always has checked_type_ (with the exception of Op).
* 12. Expr always has checked_type_ (with the exception of Op).
*/
#include <tvm/relax/analysis.h>
#include <tvm/relax/expr.h>
Expand All @@ -64,6 +65,8 @@

#include <unordered_set>

#include "../../printer/text_printer.h"

namespace tvm {
namespace relax {

Expand Down Expand Up @@ -215,6 +218,15 @@ class WellFormedChecker : public relax::ExprVisitor,
// check all expr are well defined.
for (Var param : op->params) {
this->VisitVarDef(param);

if (param_var_func_map_.count(param) == 1) {
Malformed(Diagnostic::Error(param->span)
<< "Relax variable " << param->name_hint()
<< " is repeatedly used as parameters in function:\n"
<< AsRelaxScript(param_var_func_map_[param], false) << "\nand function:\n"
<< AsRelaxScript(GetRef<Function>(op), false));
}
param_var_func_map_.insert({param, GetRef<Function>(op)});
}

if (auto seq = op->body.as<SeqExprNode>()) {
Expand Down Expand Up @@ -440,6 +452,7 @@ class WellFormedChecker : public relax::ExprVisitor,
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> var_set_;
std::unordered_set<DataflowVar, ObjectPtrHash, ObjectPtrEqual> dataflow_var_set_;
std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual> symbolic_var_set_;
std::unordered_map<Var, Function, ObjectPtrHash, ObjectPtrEqual> param_var_func_map_;
};

bool WellFormed(IRModule m, bool check_struct_info) {
Expand Down
2 changes: 1 addition & 1 deletion src/relax/transform/merge_composite_functions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ class CompositeInliner : public ExprMutator {
Expr VisitExpr_(const CallNode* call) {
if (call->op->IsInstance<GlobalVarNode>()) {
auto gvar = Downcast<GlobalVar>(call->op);
auto func = Downcast<Function>(mod_->Lookup(gvar));
auto func = CopyWithNewParams(Downcast<Function>(mod_->Lookup(gvar)));
auto composite_name_opt = func->GetAttr<String>(attr::kComposite);
ICHECK(composite_name_opt);
std::string composite_name = composite_name_opt.value();
Expand Down
23 changes: 23 additions & 0 deletions src/relax/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,5 +82,28 @@ bool IsLeafOrTuple(const Expr& expr) {
expr.as<OpNode>() || expr.as<TupleNode>();
}

class FunctionCopier : public ExprMutator {
public:
static Function Transform(Function func) {
FunctionCopier copier;
// the parameters would be copied and substituted to satisfy the restriction in the well-formed
// check: any two functions cannot share the same parameter variable.
Array<Var> new_params;
for (Var param : func->params) {
Var new_param = Var(param->vid, GetStructInfo(param), param->span);
copier.var_remap_[param->vid] = new_param;
new_params.push_back(new_param);
}

Expr body = copier.VisitWithNewScope(func->body, new_params);

return Function(new_params, body, func->ret_struct_info, func->attrs);
}
};

Function CopyWithNewParams(Function func) { return FunctionCopier::Transform(func); }

TVM_REGISTER_GLOBAL("relax.CopyWithNewParams").set_body_typed(CopyWithNewParams);

} // namespace relax
} // namespace tvm
15 changes: 15 additions & 0 deletions tests/python/relax/test_analysis_well_formed.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,21 @@ def test_dataflow_var():
assert not rx.analysis.well_formed(mod, check_struct_info=False)


def test_param_var():
v0 = rx.Var("v0", R.Tensor([m, n], "float32"))
v1 = rx.Var("v1", R.Tensor([m, n], "float32"))
v2 = rx.Var("v2", R.Tensor([m, n], "float32"))
bb = rx.BlockBuilder()
with bb.function("func1", [v0, v1]):
gv0 = bb.emit(rx.op.add(v0, v1))
bb.emit_func_output(gv0)
with bb.function("func2", [v0, v2]):
gv0 = bb.emit(rx.op.add(v2, v1))
bb.emit_func_output(gv0)
mod = bb.get()
assert not rx.analysis.well_formed(mod, check_struct_info=False)


def test_global_var():
# Error: GlobalVar GlobalVar0 is not defined
gv0 = rx.Var("gv0", R.Tensor([m, n], "float32"))
Expand Down
2 changes: 2 additions & 0 deletions tests/python/relax/test_transform_fuse_tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def before():
fused_add_exp_squeeze = bb.get().get_global_var("fused_add_exp_squeeze")

x = relax.Var("x", R.Tensor([10, 20], "float32"))
p0 = relax.Var("p0", R.Tensor([], "float32"))
with bb.function("main", [x, p0]):
with bb.dataflow():
gv = bb.emit_output(relax.Call(fused_add_exp_squeeze, [x, p0]))
Expand Down Expand Up @@ -438,6 +439,7 @@ def before():
fused_add_exp = bb.get().get_global_var("fused_add_exp")

x = relax.Var("x", R.Tensor([10, 20], "float32"))
p0 = relax.Var("p0", R.Tensor([], "float32"))
with bb.function("main", [x, p0]):
with bb.dataflow():
gv = bb.emit_output(relax.Call(fused_add_exp, [x, p0]))
Expand Down
38 changes: 38 additions & 0 deletions tests/python/relax/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
from tvm import relax
from tvm.ir.base import assert_structural_equal
from tvm.script.parser import relax as R


def test_copy_with_new_params():
@R.function
def before(x: R.Tensor((3,), "float32"), y: R.Tensor((3,), "float32")):
gv = R.add(x, y)
return gv

after = relax.utils.copy_with_new_params(before)
assert_structural_equal(after, before)

assert len(after.params) == len(before.params)
for before_var, after_var in zip(before.params, after.params):
assert before_var != after_var


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit f495802

Please sign in to comment.