Skip to content

Commit

Permalink
[REFACTOR][IR] alpha_equal to structural_equal (apache#5161)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics authored and Trevor Morris committed Apr 16, 2020
1 parent 53f3d9d commit 0fe3bbb
Show file tree
Hide file tree
Showing 52 changed files with 208 additions and 1,016 deletions.
4 changes: 3 additions & 1 deletion include/tvm/ir/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,9 @@ class IncompleteTypeNode : public TypeNode {
}

bool SEqualReduce(const IncompleteTypeNode* other, SEqualReducer equal) const {
return equal(kind, other->kind);
return
equal(kind, other->kind) &&
equal.FreeVarEqualImpl(this, other);
}

void SHashReduce(SHashReducer hash_reduce) const {
Expand Down
55 changes: 0 additions & 55 deletions include/tvm/relay/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,61 +64,6 @@ TVM_DLL Kind KindCheck(const Type& t, const IRModule& mod);
*/
TVM_DLL bool ConstantCheck(const Expr& e);

/*!
* \brief Compare two expressions for structural equivalence.
*
* This comparison operator respects scoping and compares
* expressions without regard to variable choice.
*
* For example: `let x = 1 in x` is equal to `let y = 1 in y`.
*
* See https://en.wikipedia.org/wiki/Lambda_calculus#Alpha_equivalence
* for more details.
*
* \param e1 The left hand expression.
* \param e2 The right hand expression.
*
* \return true if equal, otherwise false
*/
TVM_DLL bool AlphaEqual(const Expr& e1, const Expr& e2);

/*!
* \brief Compare two types for structural equivalence.
*
* This comparison operator respects scoping and compares
* expressions without regard to variable choice.
*
* For example: `forall s, Tensor[f32, s]` is equal to
* `forall w, Tensor[f32, w]`.
*
* See https://en.wikipedia.org/wiki/Lambda_calculus#Alpha_equivalence
* for more details.
*
* \param t1 The left hand type.
* \param t2 The right hand type.
*
* \return true if equal, otherwise false
*/
TVM_DLL bool AlphaEqual(const Type& t1, const Type& t2);

/*!
* \brief Compare two patterns for structural equivalence.
*
* This comparison operator respects scoping and compares
* patterns without regard to variable choice.
*
* For example: `A(x, _, y)` is equal to `A(z, _, a)`.
*
* See https://en.wikipedia.org/wiki/Lambda_calculus#Alpha_equivalence
* for more details.
*
* \param t1 The left hand pattern.
* \param t2 The right hand pattern.
*
* \return true if equal, otherwise false
*/
TVM_DLL bool AlphaEqual(const Pattern& t1, const Pattern& t2);

/*!
* \brief Check that each Var is only bound once.
*
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/ir/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
"""Unified type system in the project."""
from enum import IntEnum
import tvm
import tvm._ffi

from .base import Node
Expand All @@ -26,7 +27,7 @@ class Type(Node):
"""The base class of all types."""
def __eq__(self, other):
"""Compare two types for structural equivalence."""
return bool(_ffi_api.type_alpha_equal(self, other))
return bool(tvm.ir.structural_equal(self, other))

def __ne__(self, other):
return not self.__eq__(other)
Expand Down
1 change: 0 additions & 1 deletion python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@

from . import transform
from . import analysis
from .analysis import alpha_equal
from .build_module import build, create_executor, optimize
from .transform import build_config
from . import debug
Expand Down
72 changes: 0 additions & 72 deletions python/tvm/relay/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,78 +220,6 @@ def all_type_vars(expr, mod=None):
return _ffi_api.all_type_vars(expr, use_mod)


def alpha_equal(lhs, rhs):
"""Compare two Relay expr for structural equivalence (alpha equivalence).
Parameters
----------
lhs : tvm.relay.Expr
One of the input Expression.
rhs : tvm.relay.Expr
One of the input Expression.
Returns
-------
result : bool
True iff lhs is alpha equal to rhs.
"""
return bool(_ffi_api._alpha_equal(lhs, rhs))


def assert_alpha_equal(lhs, rhs):
"""Assert that two Relay expr is structurally equivalent. (alpha equivalence).
Parameters
----------
lhs : tvm.relay.Expr
One of the input Expression.
rhs : tvm.relay.Expr
One of the input Expression.
"""
_ffi_api._assert_alpha_equal(lhs, rhs)


def graph_equal(lhs, rhs):
"""Compare two Relay expr for data-flow equivalence.
The difference between this and alpha-equality is that
variables are not expected to match between lhs and rhs;
they are treated as sources and are mapped between each other.
Parameters
----------
lhs : tvm.relay.Expr
One of the input Expression.
rhs : tvm.relay.Expr
One of the input Expression.
Returns
-------
result : bool
True iff lhs is data-flow equivalent to rhs.
"""
return bool(_ffi_api._graph_equal(lhs, rhs))


def assert_graph_equal(lhs, rhs):
"""Compare two Relay expr for data-flow equivalence.
The difference between this and alpha-equality is that
variables are not expected to match between lhs and rhs;
they are treated as sources and are mapped between each other.
Parameters
----------
lhs : tvm.relay.Expr
One of the input Expression.
rhs : tvm.relay.Expr
One of the input Expression.
"""
_ffi_api._assert_graph_equal(lhs, rhs)


def collect_device_info(expr):
"""Collect the device allocation map for the given expression. The device
ids are propagated from the `device_copy` operators.
Expand Down
19 changes: 9 additions & 10 deletions src/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
*/
#include <tvm/runtime/registry.h>
#include <tvm/ir/module.h>
#include <tvm/node/structural_equal.h>
// NOTE: reverse dependency on relay.
// These dependencies do not happen at the interface-level,
// and are only used in minimum cases where they are clearly marked.
Expand Down Expand Up @@ -194,12 +195,11 @@ relay::Function RunTypeCheck(const IRModule& mod,
<< AsText(func, false)
<< std::endl;
}
func =
relay::Function(concat(func->params, fv),
func->body,
func->ret_type,
concat(func->type_params, ftv),
func->attrs);
func = relay::Function(concat(func->params, fv),
func->body,
func->ret_type,
concat(func->type_params, ftv),
func->attrs);
// Type check the item before we add it to the module.
relay::Function checked_func = InferType(func, mod, var);
return checked_func;
Expand All @@ -222,7 +222,7 @@ void IRModuleNode::Add(const GlobalVar& var,
CHECK(update)
<< "Already have definition for " << var->name_hint;
auto old_type = functions[var]->checked_type();
CHECK(relay::AlphaEqual(type, old_type))
CHECK(tvm::StructuralEqual()(type, old_type))
<< "Module#update changes type, not possible in this mode.";
}
var->checked_type_ = type;
Expand Down Expand Up @@ -353,9 +353,8 @@ IRModule IRModule::FromExpr(
if (auto* func_node = expr.as<BaseFuncNode>()) {
func = GetRef<BaseFunc>(func_node);
} else {
func = relay::Function(
relay::FreeVars(expr), expr, Type(),
relay::FreeTypeVars(expr, mod), {});
func = relay::Function(relay::FreeVars(expr), expr, Type(),
relay::FreeTypeVars(expr, mod), {});
}
auto main_gv = GlobalVar("main");
mod->Add(main_gv, func);
Expand Down
Loading

0 comments on commit 0fe3bbb

Please sign in to comment.